analyse iris data with spark and sklearn
Here we can download the free dataset for Iris.
Now we rename the file to csv and load it into Spark for check.
scala> val schema = "sepal_length DOUBLE, sepal_width DOUBLE, petal_length DOUBLE,petal_width DOUBLE,species STRING"
scala> val df = spark.read.schema(schema).csv("iris.csv")
This is a small dataset as you see below.
scala> df.agg(countDistinct("species")).show
+--------------+
|count(species)|
+--------------+
| 3|
+--------------+
scala> df.groupBy("species").count.show
+---------------+-----+
| species|count|
+---------------+-----+
| Iris-virginica| 50|
| Iris-setosa| 50|
|Iris-versicolor| 50|
+---------------+-----+
The column "species" is label and the other four columns are features. There are only three species for machine learning classification purpose.
So we write a script for this job, with sklearn as the ML library.
import joblib
import sklearn
import pandas as pd
from sklearn import svm
#Read the training data from the file
iris_data = pd.read_csv('./iris.csv',sep=',',names=["sepal_length", "sepal_width", "petal_length","petal_width","species"])
#Assigning the classes and removing the target variable
iris_label = iris_data.pop('species')
#We're going to be using the SVC (support vector classifier) SVM (support vector machine)
classifier = svm.SVC(gamma='auto')
#Training the model
classifier.fit(iris_data, iris_label)
#Saving the data locally
model_filename = 'model.joblib'
joblib.dump(classifier, model_filename)
As you see, only few lines of code we can train a mode to run this classification job.