135 lines
4.9 KiB
Python
135 lines
4.9 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import scipy
|
|
from scipy.signal import butter, lfilter
|
|
|
|
|
|
def __create_constants(low: float, high: float, sample_rate: float, order: int) -> tuple[float, float]:
|
|
"""Вычисление констан полосового фильтра (Баттерворта)
|
|
|
|
Args:
|
|
low (float): нижняя чатсота среза
|
|
high (float): верхняя частота среза
|
|
sample_rate (float): частота дискретизации
|
|
order (int): порядок фильтра
|
|
|
|
Returns:
|
|
tuple[float, float]: константы фильтра, необходимые для вычислений
|
|
"""
|
|
return butter(order, [low, high], fs=sample_rate, btype="band")
|
|
|
|
|
|
def filter_signal(df: pd.DataFrame, cols: list[str], constants: tuple[float, float]) -> pd.DataFrame:
|
|
"""Фильтрация датасета (сигналов с датчиков)
|
|
|
|
Args:
|
|
df (pd.DataFrame): датафрейм с данными из OpenVibe
|
|
cols (list[str]): список колонок, данные в которых нужно отфильтровать
|
|
constants (tuple[float, float]): константы для вычисления фильтра (см. __create_constants)
|
|
|
|
Returns:
|
|
pd.DataFrame: входной фрейм с данными, с примененным фильтром
|
|
"""
|
|
new_df = df.copy()
|
|
for col in cols:
|
|
new_df[col] = lfilter(*constants, df[col].values)
|
|
|
|
return new_df
|
|
|
|
|
|
class CSPFilter:
|
|
"""
|
|
Реализация CSP фильтра в стиле sklearn.
|
|
Основано на реализации github.co/Hiroaki-K4/total_perspective_vortex/main/srcs/csp.py
|
|
"""
|
|
|
|
def __init__(self, n_components: int) -> None:
|
|
self.n_components = n_components
|
|
self.classes = None
|
|
self.filters = None
|
|
self.patterns = None
|
|
self.mean = None
|
|
self.std = None
|
|
|
|
def __get_covariate(self, x: np.ndarray, y: np.ndarray, cls: any) -> np.ndarray:
|
|
"""Получение ковариаций
|
|
|
|
Args:
|
|
x (np.ndarray): отсчеты сигнала
|
|
y (np.ndarray): метки класса
|
|
cls (any): текущий класс
|
|
|
|
Returns:
|
|
np.ndarray: матрица ковариаций
|
|
"""
|
|
|
|
x_class = x[y == cls]
|
|
_, _, n_channels = x_class.shape
|
|
# x_class = np.transpose(x_class, [1, 0, 2])
|
|
x_class = x_class.reshape(n_channels, -1)
|
|
cov = np.dot(x_class, x_class.T)
|
|
return cov
|
|
|
|
def fit(self, x: np.ndarray, y: np.ndarray) -> None:
|
|
"""Обучение фильтра
|
|
|
|
Args:
|
|
x (np.ndarray): отсчеты сигнала (числа)
|
|
y (np.ndarray): метки класса (любые данные)
|
|
"""
|
|
self.classes = np.unique(y)
|
|
n_classes = len(self.classes)
|
|
if n_classes != 2:
|
|
raise ValueError(f"Количество классов для CSP должно быть = 2 (получено {n_classes})")
|
|
|
|
cov_neg = self.__get_covariate(x, y, self.classes[0])
|
|
cov_pos = self.__get_covariate(x, y, self.classes[1])
|
|
|
|
eig_vals, eig_vecs = scipy.linalg.eigh(cov_neg, cov_pos)
|
|
for i in range(len(eig_vecs)):
|
|
eig_vecs[i] = eig_vecs[i] / np.linalg.norm(eig_vecs[i])
|
|
|
|
sorted_vals = np.argsort(eig_vals)
|
|
idxs = np.empty_like(sorted_vals)
|
|
idxs[0::2] = sorted_vals[len(sorted_vals) // 2 :][::-1]
|
|
idxs[1::2] = sorted_vals[: len(sorted_vals) // 2]
|
|
|
|
eig_vecs = eig_vecs[:, idxs]
|
|
self.filters = eig_vecs.T
|
|
self.patterns = np.linalg.inv(eig_vecs)
|
|
pick_filters = self.filters[:, self.n_components]
|
|
|
|
x = np.asarray([np.dot(pick_filters, epoch.T) for epoch in x])
|
|
x = (x**2).mean(axis=1)
|
|
self.mean = x.mean(axis=0)
|
|
self.std = x.std(axis=0)
|
|
|
|
def transform(self, x: np.ndarray) -> np.ndarray:
|
|
"""Применение фильтра
|
|
|
|
Args:
|
|
x (np.ndarray): отсчеты сигнала (числа)
|
|
|
|
Returns:
|
|
np.ndarray: вторичные параметры (фичи)
|
|
"""
|
|
pick_filters = self.filters[: self.n_components]
|
|
x = np.asarray([np.dot(pick_filters, epoch.T) for epoch in x])
|
|
x = (x**2).mean(axis=1)
|
|
x -= self.mean
|
|
x /= self.std
|
|
return x
|
|
|
|
def fit_transform(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
"""Обучение с последующим применением ко входным данным
|
|
|
|
Args:
|
|
x (np.ndarray): отсчеты сигнала (числа)
|
|
y (np.ndarray): метки класса (любые данные)
|
|
|
|
Returns:
|
|
np.ndarray: вторичные параметры (фичи)
|
|
"""
|
|
self.fit(x, y)
|
|
return self.transform(x)
|