Zatím jsme se zabývali jen regresními úlohami. Učení s učitelem ale zahrnuje dvě hlavní skupiny úloh - regresní úlohy a klasifikační úlohy.
Zatímco u regresních úloh je výstupem modelu spojitá hodnota (float), v klasifikačních úlohách představuje výstup modelu indikátor třídy (label).
Držme se našeho rybího trhu a ukažme si to na příkladu. Úloha predikovat váhu ryby byla regresní úloha, predikovali jsme spojitou hodnotu. Pokud budeme chtít predikovat druh ryby (Perch - okoun, Roach - plotice, Pike - štika, ...), jedná se o predikci kategorické hodnoty, tedy o klasifikaci.
Klasifikační úlohy mají trochu jiné vlastnosti a logiku, než úlohy regresní, proto existují modely přímo určené na takové úlohy. Říká se jim klasifikátory.
Zkusíme se ale nejdřív podívat na úlohu klasifikace z pohledu, který už známe, tedy z pohledu krajiny.

# načeteme si data
import pandas as pd
import numpy as np
np.random.seed(2020) # nastavení náhodného klasifikátoru
data = pd.read_csv("static/fish_data.csv", index_col=0)
data
Úkol 1:¶
Nejčastějším druhem ryby je Perch (okoun). Naším cílem je vytvořit klasifikátor, který pro zadané míry (váha, různé délky a šířky) vrátí informaci, zda se jedná o okouna nebo jiný druh. (Máme tedy pro jednoduchost jen dvě třídy, Perch a ostatní.)
Uměla bys tuto úlohu napasovat na krajinu? Co by mohly být souřadnice a co nadmořská výška?
Pokud ses úspěšně poprala s předchozím dotazem, můžeš na klasifikaci použít některý z regresních modelů (ano, asi to nebude ideální, když jde o klasifikaci, ale zkusme nejdříve to, co již umíme). Co ale bude hodnota odezvy a jak ji budeme interpretovat?
Klasifikační modely¶
Přinášíme opět nějakou základní nabídku klasifikačních modelů:
- n_estimators, integer, optional (default=100)
- C, float, optional (default=1.0)
- kernel,string, optional (default=’rbf’)
Úkol 2:¶
Vyberete si jeden model a zkuste natrénovat na ryby.
Nejprve připravíme data obdobně jako v minulé hodině. Jako sloupeček odezvy použijeme True
pro okouny a False
pro ostatní ryby. Sloupeček Species
pak už nebudeme potřebovat, stejně tak můžeme vypustit sloupeček ID
.
# připravme data
y = data["Species"] == "Perch"
y = y.astype(int)
X = data.drop(columns=["ID", "Species"])
Dalším krokem je rozdělení na trénovací a validační data. Nezapomeňme na stratifikaci.
# rozdělme na trénovací a validační množinu
from sklearn.model_selection import train_test_split
X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, y, stratify=y)
Data přeškálujeme.
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train_raw)
X_test = scaler.transform(X_test_raw)
Jako model zvolíme rozhodovací strom. Neboj se zkusit jiný klasifikátor dle své volby.
# vezměme klasifikátor
# můžeš změnit
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier()
# natrénujte
model.fit(X_train, y_train);
Máme natrénovaný model, jdeme se podívat, jak funguje na validačních datech.
# ohodnoťme validační množinu
pred = model.predict(X_test)
print("Skutečná třída: Predikce:")
for true, predicted in zip(y_test, pred):
print(f"{true:<15} {predicted:<10} {'OK' if true == predicted else 'X'}")
print(f"Počet chyb: {sum(y_test != pred)}")
Skutečná třída: Predikce:
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
1 1 OK
0 0 OK
1 1 OK
0 0 OK
0 0 OK
1 0 X
0 0 OK
1 0 X
0 0 OK
1 0 X
0 0 OK
0 0 OK
0 0 OK
1 1 OK
0 0 OK
0 1 X
0 1 X
0 0 OK
0 0 OK
1 1 OK
1 0 X
1 1 OK
1 1 OK
1 1 OK
0 0 OK
Počet chyb: 6
Úkol 3:¶
- Asi je jasné, že regresní metriky se nám na klasifikační úlohy moc nehodí. Co bys použila jako metriku pro klasifikační úlohu?
Úkol 4:¶
- Jedna z možností je porovnávat procento úspěšně klasifikovaných vzorů. V našem případě, to bude:
print(f"Úspěšnost: {100*sum(y_test == pred)/len(y_test):.2f} %")
Úspěšnost: 80.65 %
Úspěšnost není úplně špatná, poznat druh ryby podle rozměrů není jendoduchá úloha.
Představ si ale, že budeme mít v datovou množinu se 100 rybami, 95 z nich bude okounů (typu Perch). Bude ti klasifikátor, který bude mít toto procento úspěšnosti (stejné jako vyšlo nám), připadat dobrý nebo ne? Proč?
Úkol 5:¶
Nejprve projdeme klasifikační metriky. Pokud studuješ sama, nastuduj si kapitolu o klasifikačních metrikách a pak se vrať k tomuto cvičení.
Vyber si metriku pro naši úlohu a zkus najít, co nejlepší klasifikátor. Pak si načti testovací množinu a podívej se, jaké tvůj klasifikátor dává výsledky.
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
# zkus naučit různé modely a vyber nejlepší
models = {}
# KNeigbors
for N in 1, 3, 5, 7:
models[("nearest neighbors", N)] = KNeighborsClassifier(n_neighbors=N, weights="distance")
# tree
for d in range(3, 20):
models[("tree", d)] = DecisionTreeClassifier(max_depth=d, class_weight='balanced')
# random forest
for N in range(1, 100):
models[("random forest", N)] = RandomForestClassifier(n_estimators=N, class_weight='balanced')
# SVC
for C in range(-2, 10):
models[("SVC", 10**C)] = SVC(C=10**C, class_weight='balanced')
Vytvořili jsme si slušnou zásobu modelů, uložili jsme je do slovníku. Každý model máme pro různé hodnoty příslušného hyper-parametru.
models
{('nearest neighbors',
1): KNeighborsClassifier(n_neighbors=1, weights='distance'),
('nearest neighbors',
3): KNeighborsClassifier(n_neighbors=3, weights='distance'),
('nearest neighbors', 5): KNeighborsClassifier(weights='distance'),
('nearest neighbors',
7): KNeighborsClassifier(n_neighbors=7, weights='distance'),
('tree', 3): DecisionTreeClassifier(class_weight='balanced', max_depth=3),
('tree', 4): DecisionTreeClassifier(class_weight='balanced', max_depth=4),
('tree', 5): DecisionTreeClassifier(class_weight='balanced', max_depth=5),
('tree', 6): DecisionTreeClassifier(class_weight='balanced', max_depth=6),
('tree', 7): DecisionTreeClassifier(class_weight='balanced', max_depth=7),
('tree', 8): DecisionTreeClassifier(class_weight='balanced', max_depth=8),
('tree', 9): DecisionTreeClassifier(class_weight='balanced', max_depth=9),
('tree', 10): DecisionTreeClassifier(class_weight='balanced', max_depth=10),
('tree', 11): DecisionTreeClassifier(class_weight='balanced', max_depth=11),
('tree', 12): DecisionTreeClassifier(class_weight='balanced', max_depth=12),
('tree', 13): DecisionTreeClassifier(class_weight='balanced', max_depth=13),
('tree', 14): DecisionTreeClassifier(class_weight='balanced', max_depth=14),
('tree', 15): DecisionTreeClassifier(class_weight='balanced', max_depth=15),
('tree', 16): DecisionTreeClassifier(class_weight='balanced', max_depth=16),
('tree', 17): DecisionTreeClassifier(class_weight='balanced', max_depth=17),
('tree', 18): DecisionTreeClassifier(class_weight='balanced', max_depth=18),
('tree', 19): DecisionTreeClassifier(class_weight='balanced', max_depth=19),
('random forest',
1): RandomForestClassifier(class_weight='balanced', n_estimators=1),
('random forest',
2): RandomForestClassifier(class_weight='balanced', n_estimators=2),
('random forest',
3): RandomForestClassifier(class_weight='balanced', n_estimators=3),
('random forest',
4): RandomForestClassifier(class_weight='balanced', n_estimators=4),
('random forest',
5): RandomForestClassifier(class_weight='balanced', n_estimators=5),
('random forest',
6): RandomForestClassifier(class_weight='balanced', n_estimators=6),
('random forest',
7): RandomForestClassifier(class_weight='balanced', n_estimators=7),
('random forest',
8): RandomForestClassifier(class_weight='balanced', n_estimators=8),
('random forest',
9): RandomForestClassifier(class_weight='balanced', n_estimators=9),
('random forest',
10): RandomForestClassifier(class_weight='balanced', n_estimators=10),
('random forest',
11): RandomForestClassifier(class_weight='balanced', n_estimators=11),
('random forest',
12): RandomForestClassifier(class_weight='balanced', n_estimators=12),
('random forest',
13): RandomForestClassifier(class_weight='balanced', n_estimators=13),
('random forest',
14): RandomForestClassifier(class_weight='balanced', n_estimators=14),
('random forest',
15): RandomForestClassifier(class_weight='balanced', n_estimators=15),
('random forest',
16): RandomForestClassifier(class_weight='balanced', n_estimators=16),
('random forest',
17): RandomForestClassifier(class_weight='balanced', n_estimators=17),
('random forest',
18): RandomForestClassifier(class_weight='balanced', n_estimators=18),
('random forest',
19): RandomForestClassifier(class_weight='balanced', n_estimators=19),
('random forest',
20): RandomForestClassifier(class_weight='balanced', n_estimators=20),
('random forest',
21): RandomForestClassifier(class_weight='balanced', n_estimators=21),
('random forest',
22): RandomForestClassifier(class_weight='balanced', n_estimators=22),
('random forest',
23): RandomForestClassifier(class_weight='balanced', n_estimators=23),
('random forest',
24): RandomForestClassifier(class_weight='balanced', n_estimators=24),
('random forest',
25): RandomForestClassifier(class_weight='balanced', n_estimators=25),
('random forest',
26): RandomForestClassifier(class_weight='balanced', n_estimators=26),
('random forest',
27): RandomForestClassifier(class_weight='balanced', n_estimators=27),
('random forest',
28): RandomForestClassifier(class_weight='balanced', n_estimators=28),
('random forest',
29): RandomForestClassifier(class_weight='balanced', n_estimators=29),
('random forest',
30): RandomForestClassifier(class_weight='balanced', n_estimators=30),
('random forest',
31): RandomForestClassifier(class_weight='balanced', n_estimators=31),
('random forest',
32): RandomForestClassifier(class_weight='balanced', n_estimators=32),
('random forest',
33): RandomForestClassifier(class_weight='balanced', n_estimators=33),
('random forest',
34): RandomForestClassifier(class_weight='balanced', n_estimators=34),
('random forest',
35): RandomForestClassifier(class_weight='balanced', n_estimators=35),
('random forest',
36): RandomForestClassifier(class_weight='balanced', n_estimators=36),
('random forest',
37): RandomForestClassifier(class_weight='balanced', n_estimators=37),
('random forest',
38): RandomForestClassifier(class_weight='balanced', n_estimators=38),
('random forest',
39): RandomForestClassifier(class_weight='balanced', n_estimators=39),
('random forest',
40): RandomForestClassifier(class_weight='balanced', n_estimators=40),
('random forest',
41): RandomForestClassifier(class_weight='balanced', n_estimators=41),
('random forest',
42): RandomForestClassifier(class_weight='balanced', n_estimators=42),
('random forest',
43): RandomForestClassifier(class_weight='balanced', n_estimators=43),
('random forest',
44): RandomForestClassifier(class_weight='balanced', n_estimators=44),
('random forest',
45): RandomForestClassifier(class_weight='balanced', n_estimators=45),
('random forest',
46): RandomForestClassifier(class_weight='balanced', n_estimators=46),
('random forest',
47): RandomForestClassifier(class_weight='balanced', n_estimators=47),
('random forest',
48): RandomForestClassifier(class_weight='balanced', n_estimators=48),
('random forest',
49): RandomForestClassifier(class_weight='balanced', n_estimators=49),
('random forest',
50): RandomForestClassifier(class_weight='balanced', n_estimators=50),
('random forest',
51): RandomForestClassifier(class_weight='balanced', n_estimators=51),
('random forest',
52): RandomForestClassifier(class_weight='balanced', n_estimators=52),
('random forest',
53): RandomForestClassifier(class_weight='balanced', n_estimators=53),
('random forest',
54): RandomForestClassifier(class_weight='balanced', n_estimators=54),
('random forest',
55): RandomForestClassifier(class_weight='balanced', n_estimators=55),
('random forest',
56): RandomForestClassifier(class_weight='balanced', n_estimators=56),
('random forest',
57): RandomForestClassifier(class_weight='balanced', n_estimators=57),
('random forest',
58): RandomForestClassifier(class_weight='balanced', n_estimators=58),
('random forest',
59): RandomForestClassifier(class_weight='balanced', n_estimators=59),
('random forest',
60): RandomForestClassifier(class_weight='balanced', n_estimators=60),
('random forest',
61): RandomForestClassifier(class_weight='balanced', n_estimators=61),
('random forest',
62): RandomForestClassifier(class_weight='balanced', n_estimators=62),
('random forest',
63): RandomForestClassifier(class_weight='balanced', n_estimators=63),
('random forest',
64): RandomForestClassifier(class_weight='balanced', n_estimators=64),
('random forest',
65): RandomForestClassifier(class_weight='balanced', n_estimators=65),
('random forest',
66): RandomForestClassifier(class_weight='balanced', n_estimators=66),
('random forest',
67): RandomForestClassifier(class_weight='balanced', n_estimators=67),
('random forest',
68): RandomForestClassifier(class_weight='balanced', n_estimators=68),
('random forest',
69): RandomForestClassifier(class_weight='balanced', n_estimators=69),
('random forest',
70): RandomForestClassifier(class_weight='balanced', n_estimators=70),
('random forest',
71): RandomForestClassifier(class_weight='balanced', n_estimators=71),
('random forest',
72): RandomForestClassifier(class_weight='balanced', n_estimators=72),
('random forest',
73): RandomForestClassifier(class_weight='balanced', n_estimators=73),
('random forest',
74): RandomForestClassifier(class_weight='balanced', n_estimators=74),
('random forest',
75): RandomForestClassifier(class_weight='balanced', n_estimators=75),
('random forest',
76): RandomForestClassifier(class_weight='balanced', n_estimators=76),
('random forest',
77): RandomForestClassifier(class_weight='balanced', n_estimators=77),
('random forest',
78): RandomForestClassifier(class_weight='balanced', n_estimators=78),
('random forest',
79): RandomForestClassifier(class_weight='balanced', n_estimators=79),
('random forest',
80): RandomForestClassifier(class_weight='balanced', n_estimators=80),
('random forest',
81): RandomForestClassifier(class_weight='balanced', n_estimators=81),
('random forest',
82): RandomForestClassifier(class_weight='balanced', n_estimators=82),
('random forest',
83): RandomForestClassifier(class_weight='balanced', n_estimators=83),
('random forest',
84): RandomForestClassifier(class_weight='balanced', n_estimators=84),
('random forest',
85): RandomForestClassifier(class_weight='balanced', n_estimators=85),
('random forest',
86): RandomForestClassifier(class_weight='balanced', n_estimators=86),
('random forest',
87): RandomForestClassifier(class_weight='balanced', n_estimators=87),
('random forest',
88): RandomForestClassifier(class_weight='balanced', n_estimators=88),
('random forest',
89): RandomForestClassifier(class_weight='balanced', n_estimators=89),
('random forest',
90): RandomForestClassifier(class_weight='balanced', n_estimators=90),
('random forest',
91): RandomForestClassifier(class_weight='balanced', n_estimators=91),
('random forest',
92): RandomForestClassifier(class_weight='balanced', n_estimators=92),
('random forest',
93): RandomForestClassifier(class_weight='balanced', n_estimators=93),
('random forest',
94): RandomForestClassifier(class_weight='balanced', n_estimators=94),
('random forest',
95): RandomForestClassifier(class_weight='balanced', n_estimators=95),
('random forest',
96): RandomForestClassifier(class_weight='balanced', n_estimators=96),
('random forest',
97): RandomForestClassifier(class_weight='balanced', n_estimators=97),
('random forest',
98): RandomForestClassifier(class_weight='balanced', n_estimators=98),
('random forest',
99): RandomForestClassifier(class_weight='balanced', n_estimators=99),
('SVC', 0.01): SVC(C=0.01, class_weight='balanced'),
('SVC', 0.1): SVC(C=0.1, class_weight='balanced'),
('SVC', 1): SVC(C=1, class_weight='balanced'),
('SVC', 10): SVC(C=10, class_weight='balanced'),
('SVC', 100): SVC(C=100, class_weight='balanced'),
('SVC', 1000): SVC(C=1000, class_weight='balanced'),
('SVC', 10000): SVC(C=10000, class_weight='balanced'),
('SVC', 100000): SVC(C=100000, class_weight='balanced'),
('SVC', 1000000): SVC(C=1000000, class_weight='balanced'),
('SVC', 10000000): SVC(C=10000000, class_weight='balanced'),
('SVC', 100000000): SVC(C=100000000, class_weight='balanced'),
('SVC', 1000000000): SVC(C=1000000000, class_weight='balanced')}
Obdobně jako v minulé hodině vytvoříme funci, která ohodnotí model a vrátí hodnoty vybrané metriky na trénovací a validační množině. Hodnoty vrací ve slovníku (což nám pak umožní snadnější vytvoření dataframu s výsledky).
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
def train_and_eval(X_train, X_test, y_train, y_test, model):
model.fit(X_train, y_train)
y_pred_test = model.predict(X_test)
y_pred_train = model.predict(X_train)
return {"train": f1_score(y_train, y_pred_train), # metriku můžeš vyměnit za nějakou svojí
"test": f1_score(y_test, y_pred_test)
}
results = []
for name, model in models.items():
res = train_and_eval(X_train, X_test, y_train, y_test, model)
res["model"] = name[0]
res["param"] = name[1]
results.append(res)
df_results = pd.DataFrame(results)
df_results
Závislost úspěsnosti modelu (dle zvolené metriky) na hodnotě příslušného hyperparametru si zobrazíme v grafu.
import seaborn as sns
import matplotlib.pyplot as plt
def zobraz_model(model_name, ax, logx=False):
sns.lineplot(x="param", y="train", data=df_results[df_results["model"]==model_name], label="train", ax=ax)
sns.lineplot(x="param", y="test", data=df_results[df_results["model"]==model_name], label="test", ax=ax)
ax.set_title(model_name.capitalize())
if logx:
ax.set(xscale="log")
fig, axs = plt.subplots(ncols=4, figsize=(16,4))
zobraz_model("nearest neighbors", axs[0])
zobraz_model("tree", axs[1])
zobraz_model("random forest", axs[2])
zobraz_model("SVC", axs[3], logx=True)

Úkol 6:¶
Vyber si model, který se na validační množině jeví jako nejlepší. Vyzkoušej jej na testovací data.
# načtení data
test_data = pd.read_csv("static/fish_data_test.csv", index_col=0)
y_real_test = test_data["Species"] == "Perch"
y_real_test = y_real_test.astype(int)
X_real_test = test_data.drop(columns=["ID", "Species"])
X_real_test = scaler.transform(X_real_test)
# predikce
model = models[("SVC", 10**4)]
test_pred = model.predict(X_real_test)
# zkus přidat zvolenou metriku
print(f"Skutečná třída: Predikce:")
for true, predicted in zip(y_real_test, test_pred):
print(f"{true:<15} {predicted:<10} {'OK' if true == predicted else 'X'}")
print(f"Počet chyb: {sum(y_real_test != test_pred)}")
print(f"Úspěšnost: {100*sum(y_real_test == test_pred)/len(y_real_test):.2f} %")
Skutečná třída: Predikce:
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 1 X
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
0 0 OK
1 0 X
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
1 1 OK
0 0 OK
0 0 OK
0 1 X
Počet chyb: 3
Úspěšnost: 91.67 %