Epilepsy Detection Using EEG Data

IPython notebook: download

In this example we'll use the cesium library to compare various techniques for epilepsy detection using a classic EEG time series dataset from Andrzejak et al.. The raw data are separated into five classes: Z, O, N, F, and S; we will consider a three-class classification problem of distinguishing normal (Z, O), interictal (N, F), and ictal (S) signals.

The overall workflow consists of three steps: first, we "featurize" the time series by selecting some set of mathematical functions to apply to each; next, we build some classification models which use these features to distinguish between classes; finally, we validate our models by generating predictions for some unseen holdout set and comparing them to the true class labels.

First, we'll load the data and inspect a representative time series from each class:

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn; seaborn.set()

from cesium import datasets

eeg = datasets.fetch_andrzejak()

# Group together classes (Z, O), (N, F), (S) as normal, interictal, ictal
eeg["classes"] = eeg["classes"].astype('U16') #  allocate memory for longer class names
eeg["classes"][np.logical_or(eeg["classes"]=="Z", eeg["classes"]=="O")] = "Normal"
eeg["classes"][np.logical_or(eeg["classes"]=="N", eeg["classes"]=="F")] = "Interictal"
eeg["classes"][eeg["classes"]=="S"] = "Ictal"

fig, ax = plt.subplots(1, len(np.unique(eeg["classes"])), sharey=True)
for label, subplot in zip(np.unique(eeg["classes"]), ax):
    i = np.where(eeg["classes"] == label)[0][0]
    subplot.plot(eeg["times"][i], eeg["measurements"][i])
    subplot.set(xlabel="time (s)", ylabel="signal", title=label)

png

Featurization

Once the data is loaded, we can generate features for each time series using the cesium.featurize module. The featurize module includes many built-in choices of features which can be applied for any type of time series data; here we've chosen a few generic features that do not have any special biological significance.

If Celery is running, the time series will automatically be split among the available workers and featurized in parallel; setting use_celery=False will cause the time series to be featurized serially.

from cesium import featurize
features_to_use = ['amplitude',
                   'percent_beyond_1_std',
                   'maximum',
                   'max_slope',
                   'median',
                   'median_absolute_deviation',
                   'percent_close_to_median',
                   'minimum',
                   'skew',
                   'std',
                   'weighted_average']
fset_cesium = featurize.featurize_time_series(times=eeg["times"],
                                              values=eeg["measurements"],
                                              errors=None,
                                              features_to_use=features_to_use,
                                              targets=eeg["classes"], use_celery=True)
print(fset_cesium)
<xarray.Dataset>
Dimensions:                    (channel: 1, name: 500)
Coordinates:
  * channel                    (channel) int64 0
  * name                       (name) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 ...
    target                     (name) object 'Normal' 'Normal' 'Normal' ...
Data variables:
    minimum                    (name, channel) float64 -146.0 -254.0 -146.0 ...
    amplitude                  (name, channel) float64 143.5 211.5 165.0 ...
    median_absolute_deviation  (name, channel) float64 28.0 32.0 31.0 31.0 ...
    percent_beyond_1_std       (name, channel) float64 0.1626 0.1455 0.1523 ...
    maximum                    (name, channel) float64 141.0 169.0 184.0 ...
    median                     (name, channel) float64 -4.0 -51.0 13.0 -4.0 ...
    percent_close_to_median    (name, channel) float64 0.505 0.6405 0.516 ...
    max_slope                  (name, channel) float64 1.111e+04 2.065e+04 ...
    skew                       (name, channel) float64 0.0328 -0.09271 ...
    weighted_average           (name, channel) float64 -4.132 -52.44 12.71 ...
    std                        (name, channel) float64 40.41 48.81 47.14 ...

The output of featurize_time_series is an xarray.Dataset which contains all the feature information needed to train a machine learning model: feature values are stored as data variables, and the time series index/class label are stored as coordinates (a channel coordinate will also be used later for multi-channel data).

Custom feature functions

Custom feature functions not built into cesium may be passed in using the custom_functions keyword, either as a dictionary {feature_name: function}, or as a dask graph. Functions should take three arrays times, measurements, errors as inputs; details can be found in the cesium.featurize documentation. Here we'll compute five standard features for EEG analysis provided by Guo et al. (2012):

import numpy as np
import scipy.stats

def mean_signal(t, m, e):
    return np.mean(m)

def std_signal(t, m, e):
    return np.std(m)

def mean_square_signal(t, m, e):
    return np.mean(m ** 2)

def abs_diffs_signal(t, m, e):
    return np.sum(np.abs(np.diff(m)))

def skew_signal(t, m, e):
    return scipy.stats.skew(m)

Now we'll pass the desired feature functions as a dictionary via the custom_functions keyword argument.

guo_features = {
    'mean': mean_signal,
    'std': std_signal,
    'mean2': mean_square_signal,
    'abs_diffs': abs_diffs_signal,
    'skew': skew_signal
}

fset_guo = featurize.featurize_time_series(times=eeg["times"], values=eeg["measurements"],
                                           errors=None, targets=eeg["classes"], 
                                           features_to_use=list(guo_features.keys()),
                                           custom_functions=guo_features,
                                           use_celery=True)
print(fset_guo)
<xarray.Dataset>
Dimensions:    (channel: 1, name: 500)
Coordinates:
  * channel    (channel) int64 0
  * name       (name) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ...
    target     (name) object 'Normal' 'Normal' 'Normal' 'Normal' 'Normal' ...
Data variables:
    abs_diffs  (name, channel) float64 4.695e+04 6.112e+04 5.127e+04 ...
    mean       (name, channel) float64 -4.132 -52.44 12.71 -3.992 -18.0 ...
    mean2      (name, channel) float64 1.65e+03 5.133e+03 2.384e+03 ...
    skew       (name, channel) float64 0.0328 -0.09271 -0.0041 0.06368 ...
    std        (name, channel) float64 40.41 48.81 47.14 47.07 44.91 45.02 ...

Multi-channel time series

The EEG time series considered here consist of univariate signal measurements along a uniform time grid. But featurize_time_series also accepts multi-channel data; to demonstrate this, we will decompose each signal into five frequency bands using a discrete wavelet transform as suggested by Subasi (2005), and then featurize each band separately using the five functions from above.

import pywt

n_channels = 5
eeg["dwts"] = [pywt.wavedec(m, pywt.Wavelet('db1'), level=n_channels-1)
               for m in eeg["measurements"]]
fset_dwt = featurize.featurize_time_series(times=None, values=eeg["dwts"], errors=None,
                                           features_to_use=list(guo_features.keys()),
                                           targets=eeg["classes"],
                                           custom_functions=guo_features)
print(fset_dwt)
<xarray.Dataset>
Dimensions:    (channel: 5, name: 500)
Coordinates:
  * channel    (channel) int64 0 1 2 3 4
  * name       (name) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ...
    target     (name) object 'Normal' 'Normal' 'Normal' 'Normal' 'Normal' ...
Data variables:
    abs_diffs  (name, channel) float64 2.513e+04 1.806e+04 3.241e+04 ...
    skew       (name, channel) float64 -0.0433 0.06578 0.2999 0.1239 0.1179 ...
    mean2      (name, channel) float64 1.294e+04 5.362e+03 2.321e+03 664.4 ...
    mean       (name, channel) float64 -17.08 -6.067 -0.9793 0.1546 0.03555 ...
    std        (name, channel) float64 112.5 72.97 48.17 25.77 10.15 119.8 ...

The output featureset has the same form as before, except now the channel coordinate is used to index the features by the corresponding frequency band. The functions in cesium.build_model and cesium.predict all accept featuresets from single- or multi-channel data, so no additional steps are required to train models or make predictions for multichannel featuresets using the cesium library.

Model Building

Model building in cesium is handled by the build_model_from_featureset function in the cesium.build_model submodule. The featureset output by featurize_time_series contains both the feature and target information needed to train a model; build_model_from_featureset is simply a wrapper that calls the fit method of a given scikit-learn model with the appropriate inputs. In the case of multichannel features, it also handles reshaping the featureset into a (rectangular) form that is compatible with scikit-learn.

For this example, we'll test a random forest classifier for the built-in cesium features, and a 3-nearest neighbors classifier for the others, as suggested by Guo et al. (2012).

from cesium.build_model import build_model_from_featureset
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import train_test_split

train, test = train_test_split(np.arange(len(eeg["classes"])), random_state=0)

rfc_param_grid = {'n_estimators': [8, 16, 32, 64, 128, 256, 512, 1024]}
model_cesium = build_model_from_featureset(fset_cesium.isel(name=train),
                                          RandomForestClassifier(max_features='auto',
                                                                 random_state=0),
                                          params_to_optimize=rfc_param_grid)
knn_param_grid = {'n_neighbors': [1, 2, 3, 4]}
model_guo = build_model_from_featureset(fset_guo.isel(name=train),
                                        KNeighborsClassifier(),
                                        params_to_optimize=knn_param_grid)
model_dwt = build_model_from_featureset(fset_dwt.isel(name=train),
                                        KNeighborsClassifier(),
                                        params_to_optimize=knn_param_grid)

Prediction

Making predictions for new time series based on these models follows the same pattern: first the time series are featurized using featurize_timeseries, and then predictions are made based on these features using predict.model_predictions,

from sklearn.metrics import accuracy_score
from cesium.predict import model_predictions

preds_cesium = model_predictions(fset_cesium, model_cesium, return_probs=False)
preds_guo = model_predictions(fset_guo, model_guo, return_probs=False)
preds_dwt = model_predictions(fset_dwt, model_dwt, return_probs=False)

print("Built-in cesium features: training accuracy={:.2%}, test accuracy={:.2%}".format(
          accuracy_score(preds_cesium[train], eeg["classes"][train]),
          accuracy_score(preds_cesium[test], eeg["classes"][test])))
print("Guo et al. features: training accuracy={:.2%}, test accuracy={:.2%}".format(
          accuracy_score(preds_guo[train], eeg["classes"][train]),
          accuracy_score(preds_guo[test], eeg["classes"][test])))
print("Wavelet transform features: training accuracy={:.2%}, test accuracy={:.2%}".format(
          accuracy_score(preds_dwt[train], eeg["classes"][train]),
          accuracy_score(preds_dwt[test], eeg["classes"][test])))
Built-in cesium features: training accuracy=100.00%, test accuracy=83.20%
Guo et al. features: training accuracy=90.93%, test accuracy=84.80%
Wavelet transform features: training accuracy=100.00%, test accuracy=95.20%

The workflow presented here is intentionally simplistic and omits many important steps such as feature selection, model parameter selection, etc., which may all be incorporated just as they would for any other scikit-learn analysis. But with essentially three function calls (featurize_time_series, build_model_from_featureset, and model_predictions), we are able to build a model from a set of time series and make predictions on new, unlabeled data. In upcoming posts we'll introduce the web frontend for cesium and describe how the same analysis can be performed in a browser with no setup or coding required.


Written by Cesium Developers in misc on Fri 08 July 2016. Tags: example,

Comments

comments powered by Disqus