Skip to article frontmatterSkip to article content

Klasifikace

Zatím jsme se zabývali jen regresními úlohami. Učení s učitelem ale zahrnuje dvě hlavní skupiny úloh - regresní úlohy a klasifikační úlohy.

Zatímco u regresních úloh je výstupem modelu spojitá hodnota (float), v klasifikačních úlohách představuje výstup modelu indikátor třídy (label).

Držme se našeho rybího trhu a ukažme si to na příkladu. Úloha predikovat váhu ryby byla regresní úloha, predikovali jsme spojitou hodnotu. Pokud budeme chtít predikovat druh ryby (Perch - okoun, Roach - plotice, Pike - štika, ...), jedná se o predikci kategorické hodnoty, tedy o klasifikaci.

Klasifikační úlohy mají trochu jiné vlastnosti a logiku, než úlohy regresní, proto existují modely přímo určené na takové úlohy. Říká se jim klasifikátory.

Zkusíme se ale nejdřív podívat na úlohu klasifikace z pohledu, který už známe, tedy z pohledu krajiny.

data
# načeteme si data 
import pandas as pd 
import numpy as np 
np.random.seed(2020)  # nastavení náhodného klasifikátoru

data = pd.read_csv("static/fish_data.csv", index_col=0)
data
Loading...

Úkol 1:

Nejčastějším druhem ryby je Perch (okoun). Naším cílem je vytvořit klasifikátor, který pro zadané míry (váha, různé délky a šířky) vrátí informaci, zda se jedná o okouna nebo jiný druh. (Máme tedy pro jednoduchost jen dvě třídy, Perch a ostatní.)

  • Uměla bys tuto úlohu napasovat na krajinu? Co by mohly být souřadnice a co nadmořská výška?

  • Pokud ses úspěšně poprala s předchozím dotazem, můžeš na klasifikaci použít některý z regresních modelů (ano, asi to nebude ideální, když jde o klasifikaci, ale zkusme nejdříve to, co již umíme). Co ale bude hodnota odezvy a jak ji budeme interpretovat?

Klasifikační modely

Přinášíme opět nějakou základní nabídku klasifikačních modelů:

Úkol 2:

Vyberete si jeden model a zkuste natrénovat na ryby.

Nejprve připravíme data obdobně jako v minulé hodině. Jako sloupeček odezvy použijeme True pro okouny a False pro ostatní ryby. Sloupeček Species pak už nebudeme potřebovat, stejně tak můžeme vypustit sloupeček ID.

# připravme data
y = data["Species"] == "Perch"
y = y.astype(int)
X = data.drop(columns=["ID", "Species"])

Dalším krokem je rozdělení na trénovací a validační data. Nezapomeňme na stratifikaci.

# rozdělme na trénovací a validační množinu
from sklearn.model_selection import train_test_split 
X_train_raw, X_test_raw, y_train, y_test =  train_test_split(X, y, stratify=y)

Data přeškálujeme.

from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train_raw)
X_test = scaler.transform(X_test_raw)

Jako model zvolíme rozhodovací strom. Neboj se zkusit jiný klasifikátor dle své volby.

# vezměme klasifikátor 
# můžeš změnit 
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier()
# natrénujte
model.fit(X_train, y_train);

Máme natrénovaný model, jdeme se podívat, jak funguje na validačních datech.

# ohodnoťme validační množinu 
pred = model.predict(X_test) 
print("Skutečná třída:  Predikce:")
for true, predicted in zip(y_test, pred):
    print(f"{true:<15}  {predicted:<10} {'OK' if true == predicted else 'X'}")

print(f"Počet chyb: {sum(y_test != pred)}")
Skutečná třída:  Predikce:
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
1                1          OK
0                0          OK
1                1          OK
0                0          OK
0                0          OK
1                0          X
0                0          OK
1                0          X
0                0          OK
1                0          X
0                0          OK
0                0          OK
0                0          OK
1                1          OK
0                0          OK
0                1          X
0                1          X
0                0          OK
0                0          OK
1                1          OK
1                0          X
1                1          OK
1                1          OK
1                1          OK
0                0          OK
Počet chyb: 6

Úkol 3:

  • Asi je jasné, že regresní metriky se nám na klasifikační úlohy moc nehodí. Co bys použila jako metriku pro klasifikační úlohu?

Úkol 4:

  • Jedna z možností je porovnávat procento úspěšně klasifikovaných vzorů. V našem případě, to bude:
print(f"Úspěšnost: {100*sum(y_test == pred)/len(y_test):.2f} %")
Úspěšnost: 80.65 %

Úspěšnost není úplně špatná, poznat druh ryby podle rozměrů není jendoduchá úloha.

Představ si ale, že budeme mít v datovou množinu se 100 rybami, 95 z nich bude okounů (typu Perch). Bude ti klasifikátor, který bude mít toto procento úspěšnosti (stejné jako vyšlo nám), připadat dobrý nebo ne? Proč?

Úkol 5:

Nejprve projdeme klasifikační metriky. Pokud studuješ sama, nastuduj si kapitolu o klasifikačních metrikách a pak se vrať k tomuto cvičení.

Vyber si metriku pro naši úlohu a zkus najít, co nejlepší klasifikátor. Pak si načti testovací množinu a podívej se, jaké tvůj klasifikátor dává výsledky.

from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier 
from sklearn.svm import SVC

# zkus naučit různé modely a vyber nejlepší
models = {}

# KNeigbors 
for N in 1, 3, 5, 7:
    models[("nearest neighbors", N)] = KNeighborsClassifier(n_neighbors=N, weights="distance")
    
# tree
for d in range(3, 20):
    models[("tree", d)] = DecisionTreeClassifier(max_depth=d, class_weight='balanced')
    
# random forest
for N in range(1, 100):
    models[("random forest", N)] = RandomForestClassifier(n_estimators=N, class_weight='balanced')
    
# SVC
for C in range(-2, 10):
    models[("SVC", 10**C)] = SVC(C=10**C, class_weight='balanced')

Vytvořili jsme si slušnou zásobu modelů, uložili jsme je do slovníku. Každý model máme pro různé hodnoty příslušného hyper-parametru.

models
{('nearest neighbors', 1): KNeighborsClassifier(n_neighbors=1, weights='distance'), ('nearest neighbors', 3): KNeighborsClassifier(n_neighbors=3, weights='distance'), ('nearest neighbors', 5): KNeighborsClassifier(weights='distance'), ('nearest neighbors', 7): KNeighborsClassifier(n_neighbors=7, weights='distance'), ('tree', 3): DecisionTreeClassifier(class_weight='balanced', max_depth=3), ('tree', 4): DecisionTreeClassifier(class_weight='balanced', max_depth=4), ('tree', 5): DecisionTreeClassifier(class_weight='balanced', max_depth=5), ('tree', 6): DecisionTreeClassifier(class_weight='balanced', max_depth=6), ('tree', 7): DecisionTreeClassifier(class_weight='balanced', max_depth=7), ('tree', 8): DecisionTreeClassifier(class_weight='balanced', max_depth=8), ('tree', 9): DecisionTreeClassifier(class_weight='balanced', max_depth=9), ('tree', 10): DecisionTreeClassifier(class_weight='balanced', max_depth=10), ('tree', 11): DecisionTreeClassifier(class_weight='balanced', max_depth=11), ('tree', 12): DecisionTreeClassifier(class_weight='balanced', max_depth=12), ('tree', 13): DecisionTreeClassifier(class_weight='balanced', max_depth=13), ('tree', 14): DecisionTreeClassifier(class_weight='balanced', max_depth=14), ('tree', 15): DecisionTreeClassifier(class_weight='balanced', max_depth=15), ('tree', 16): DecisionTreeClassifier(class_weight='balanced', max_depth=16), ('tree', 17): DecisionTreeClassifier(class_weight='balanced', max_depth=17), ('tree', 18): DecisionTreeClassifier(class_weight='balanced', max_depth=18), ('tree', 19): DecisionTreeClassifier(class_weight='balanced', max_depth=19), ('random forest', 1): RandomForestClassifier(class_weight='balanced', n_estimators=1), ('random forest', 2): RandomForestClassifier(class_weight='balanced', n_estimators=2), ('random forest', 3): RandomForestClassifier(class_weight='balanced', n_estimators=3), ('random forest', 4): RandomForestClassifier(class_weight='balanced', n_estimators=4), ('random forest', 5): RandomForestClassifier(class_weight='balanced', n_estimators=5), ('random forest', 6): RandomForestClassifier(class_weight='balanced', n_estimators=6), ('random forest', 7): RandomForestClassifier(class_weight='balanced', n_estimators=7), ('random forest', 8): RandomForestClassifier(class_weight='balanced', n_estimators=8), ('random forest', 9): RandomForestClassifier(class_weight='balanced', n_estimators=9), ('random forest', 10): RandomForestClassifier(class_weight='balanced', n_estimators=10), ('random forest', 11): RandomForestClassifier(class_weight='balanced', n_estimators=11), ('random forest', 12): RandomForestClassifier(class_weight='balanced', n_estimators=12), ('random forest', 13): RandomForestClassifier(class_weight='balanced', n_estimators=13), ('random forest', 14): RandomForestClassifier(class_weight='balanced', n_estimators=14), ('random forest', 15): RandomForestClassifier(class_weight='balanced', n_estimators=15), ('random forest', 16): RandomForestClassifier(class_weight='balanced', n_estimators=16), ('random forest', 17): RandomForestClassifier(class_weight='balanced', n_estimators=17), ('random forest', 18): RandomForestClassifier(class_weight='balanced', n_estimators=18), ('random forest', 19): RandomForestClassifier(class_weight='balanced', n_estimators=19), ('random forest', 20): RandomForestClassifier(class_weight='balanced', n_estimators=20), ('random forest', 21): RandomForestClassifier(class_weight='balanced', n_estimators=21), ('random forest', 22): RandomForestClassifier(class_weight='balanced', n_estimators=22), ('random forest', 23): RandomForestClassifier(class_weight='balanced', n_estimators=23), ('random forest', 24): RandomForestClassifier(class_weight='balanced', n_estimators=24), ('random forest', 25): RandomForestClassifier(class_weight='balanced', n_estimators=25), ('random forest', 26): RandomForestClassifier(class_weight='balanced', n_estimators=26), ('random forest', 27): RandomForestClassifier(class_weight='balanced', n_estimators=27), ('random forest', 28): RandomForestClassifier(class_weight='balanced', n_estimators=28), ('random forest', 29): RandomForestClassifier(class_weight='balanced', n_estimators=29), ('random forest', 30): RandomForestClassifier(class_weight='balanced', n_estimators=30), ('random forest', 31): RandomForestClassifier(class_weight='balanced', n_estimators=31), ('random forest', 32): RandomForestClassifier(class_weight='balanced', n_estimators=32), ('random forest', 33): RandomForestClassifier(class_weight='balanced', n_estimators=33), ('random forest', 34): RandomForestClassifier(class_weight='balanced', n_estimators=34), ('random forest', 35): RandomForestClassifier(class_weight='balanced', n_estimators=35), ('random forest', 36): RandomForestClassifier(class_weight='balanced', n_estimators=36), ('random forest', 37): RandomForestClassifier(class_weight='balanced', n_estimators=37), ('random forest', 38): RandomForestClassifier(class_weight='balanced', n_estimators=38), ('random forest', 39): RandomForestClassifier(class_weight='balanced', n_estimators=39), ('random forest', 40): RandomForestClassifier(class_weight='balanced', n_estimators=40), ('random forest', 41): RandomForestClassifier(class_weight='balanced', n_estimators=41), ('random forest', 42): RandomForestClassifier(class_weight='balanced', n_estimators=42), ('random forest', 43): RandomForestClassifier(class_weight='balanced', n_estimators=43), ('random forest', 44): RandomForestClassifier(class_weight='balanced', n_estimators=44), ('random forest', 45): RandomForestClassifier(class_weight='balanced', n_estimators=45), ('random forest', 46): RandomForestClassifier(class_weight='balanced', n_estimators=46), ('random forest', 47): RandomForestClassifier(class_weight='balanced', n_estimators=47), ('random forest', 48): RandomForestClassifier(class_weight='balanced', n_estimators=48), ('random forest', 49): RandomForestClassifier(class_weight='balanced', n_estimators=49), ('random forest', 50): RandomForestClassifier(class_weight='balanced', n_estimators=50), ('random forest', 51): RandomForestClassifier(class_weight='balanced', n_estimators=51), ('random forest', 52): RandomForestClassifier(class_weight='balanced', n_estimators=52), ('random forest', 53): RandomForestClassifier(class_weight='balanced', n_estimators=53), ('random forest', 54): RandomForestClassifier(class_weight='balanced', n_estimators=54), ('random forest', 55): RandomForestClassifier(class_weight='balanced', n_estimators=55), ('random forest', 56): RandomForestClassifier(class_weight='balanced', n_estimators=56), ('random forest', 57): RandomForestClassifier(class_weight='balanced', n_estimators=57), ('random forest', 58): RandomForestClassifier(class_weight='balanced', n_estimators=58), ('random forest', 59): RandomForestClassifier(class_weight='balanced', n_estimators=59), ('random forest', 60): RandomForestClassifier(class_weight='balanced', n_estimators=60), ('random forest', 61): RandomForestClassifier(class_weight='balanced', n_estimators=61), ('random forest', 62): RandomForestClassifier(class_weight='balanced', n_estimators=62), ('random forest', 63): RandomForestClassifier(class_weight='balanced', n_estimators=63), ('random forest', 64): RandomForestClassifier(class_weight='balanced', n_estimators=64), ('random forest', 65): RandomForestClassifier(class_weight='balanced', n_estimators=65), ('random forest', 66): RandomForestClassifier(class_weight='balanced', n_estimators=66), ('random forest', 67): RandomForestClassifier(class_weight='balanced', n_estimators=67), ('random forest', 68): RandomForestClassifier(class_weight='balanced', n_estimators=68), ('random forest', 69): RandomForestClassifier(class_weight='balanced', n_estimators=69), ('random forest', 70): RandomForestClassifier(class_weight='balanced', n_estimators=70), ('random forest', 71): RandomForestClassifier(class_weight='balanced', n_estimators=71), ('random forest', 72): RandomForestClassifier(class_weight='balanced', n_estimators=72), ('random forest', 73): RandomForestClassifier(class_weight='balanced', n_estimators=73), ('random forest', 74): RandomForestClassifier(class_weight='balanced', n_estimators=74), ('random forest', 75): RandomForestClassifier(class_weight='balanced', n_estimators=75), ('random forest', 76): RandomForestClassifier(class_weight='balanced', n_estimators=76), ('random forest', 77): RandomForestClassifier(class_weight='balanced', n_estimators=77), ('random forest', 78): RandomForestClassifier(class_weight='balanced', n_estimators=78), ('random forest', 79): RandomForestClassifier(class_weight='balanced', n_estimators=79), ('random forest', 80): RandomForestClassifier(class_weight='balanced', n_estimators=80), ('random forest', 81): RandomForestClassifier(class_weight='balanced', n_estimators=81), ('random forest', 82): RandomForestClassifier(class_weight='balanced', n_estimators=82), ('random forest', 83): RandomForestClassifier(class_weight='balanced', n_estimators=83), ('random forest', 84): RandomForestClassifier(class_weight='balanced', n_estimators=84), ('random forest', 85): RandomForestClassifier(class_weight='balanced', n_estimators=85), ('random forest', 86): RandomForestClassifier(class_weight='balanced', n_estimators=86), ('random forest', 87): RandomForestClassifier(class_weight='balanced', n_estimators=87), ('random forest', 88): RandomForestClassifier(class_weight='balanced', n_estimators=88), ('random forest', 89): RandomForestClassifier(class_weight='balanced', n_estimators=89), ('random forest', 90): RandomForestClassifier(class_weight='balanced', n_estimators=90), ('random forest', 91): RandomForestClassifier(class_weight='balanced', n_estimators=91), ('random forest', 92): RandomForestClassifier(class_weight='balanced', n_estimators=92), ('random forest', 93): RandomForestClassifier(class_weight='balanced', n_estimators=93), ('random forest', 94): RandomForestClassifier(class_weight='balanced', n_estimators=94), ('random forest', 95): RandomForestClassifier(class_weight='balanced', n_estimators=95), ('random forest', 96): RandomForestClassifier(class_weight='balanced', n_estimators=96), ('random forest', 97): RandomForestClassifier(class_weight='balanced', n_estimators=97), ('random forest', 98): RandomForestClassifier(class_weight='balanced', n_estimators=98), ('random forest', 99): RandomForestClassifier(class_weight='balanced', n_estimators=99), ('SVC', 0.01): SVC(C=0.01, class_weight='balanced'), ('SVC', 0.1): SVC(C=0.1, class_weight='balanced'), ('SVC', 1): SVC(C=1, class_weight='balanced'), ('SVC', 10): SVC(C=10, class_weight='balanced'), ('SVC', 100): SVC(C=100, class_weight='balanced'), ('SVC', 1000): SVC(C=1000, class_weight='balanced'), ('SVC', 10000): SVC(C=10000, class_weight='balanced'), ('SVC', 100000): SVC(C=100000, class_weight='balanced'), ('SVC', 1000000): SVC(C=1000000, class_weight='balanced'), ('SVC', 10000000): SVC(C=10000000, class_weight='balanced'), ('SVC', 100000000): SVC(C=100000000, class_weight='balanced'), ('SVC', 1000000000): SVC(C=1000000000, class_weight='balanced')}

Obdobně jako v minulé hodině vytvoříme funci, která ohodnotí model a vrátí hodnoty vybrané metriky na trénovací a validační množině. Hodnoty vrací ve slovníku (což nám pak umožní snadnější vytvoření dataframu s výsledky).

from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix

def train_and_eval(X_train, X_test, y_train, y_test, model):
    model.fit(X_train, y_train)
    y_pred_test = model.predict(X_test)
    y_pred_train = model.predict(X_train)
    return {"train": f1_score(y_train, y_pred_train), # metriku můžeš vyměnit za nějakou svojí
            "test": f1_score(y_test, y_pred_test)
           }
results = []
for name, model in models.items():
    res = train_and_eval(X_train, X_test, y_train, y_test, model)
    res["model"] = name[0]
    res["param"] = name[1]
    results.append(res)
    
df_results = pd.DataFrame(results)
df_results
Loading...

Závislost úspěsnosti modelu (dle zvolené metriky) na hodnotě příslušného hyperparametru si zobrazíme v grafu.

import seaborn as sns
import matplotlib.pyplot as plt

def zobraz_model(model_name, ax, logx=False):
    sns.lineplot(x="param", y="train", data=df_results[df_results["model"]==model_name], label="train", ax=ax)
    sns.lineplot(x="param", y="test", data=df_results[df_results["model"]==model_name], label="test", ax=ax)
    ax.set_title(model_name.capitalize())
    if logx:
        ax.set(xscale="log")
    
fig, axs = plt.subplots(ncols=4, figsize=(16,4))
zobraz_model("nearest neighbors", axs[0])
zobraz_model("tree", axs[1])
zobraz_model("random forest", axs[2])
zobraz_model("SVC", axs[3], logx=True)
<Figure size 1600x400 with 4 Axes>

Úkol 6:

Vyber si model, který se na validační množině jeví jako nejlepší. Vyzkoušej jej na testovací data.

# načtení data
test_data = pd.read_csv("static/fish_data_test.csv", index_col=0)
y_real_test = test_data["Species"] == "Perch"
y_real_test = y_real_test.astype(int)
X_real_test = test_data.drop(columns=["ID", "Species"])
X_real_test = scaler.transform(X_real_test)
# predikce
model = models[("SVC", 10**4)]
test_pred = model.predict(X_real_test)
# zkus přidat zvolenou metriku
print(f"Skutečná třída:  Predikce:")
for true, predicted in zip(y_real_test, test_pred):
    print(f"{true:<15}  {predicted:<10} {'OK' if true == predicted else 'X'}")

print(f"Počet chyb: {sum(y_real_test != test_pred)}")
print(f"Úspěšnost: {100*sum(y_real_test == test_pred)/len(y_real_test):.2f} %")
Skutečná třída:  Predikce:
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                1          X
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
0                0          OK
1                0          X
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
1                1          OK
0                0          OK
0                0          OK
0                1          X
Počet chyb: 3
Úspěšnost: 91.67 %