init commit: загрузка и предобработка данных

This commit is contained in:
xausssr
2024-05-11 12:29:10 +03:00
commit 0a55543ae2
22 changed files with 1986 additions and 0 deletions
+6
View File
@@ -0,0 +1,6 @@
Записи/OpenBCI GUI
Записи/OpenVibe Signals
__pycache__
.ipynb_checkpoints
.pytest_cache
runs
+50
View File
@@ -0,0 +1,50 @@
g_offset = nil
g_duration = nil
-- this function is called when the box is initialized
function initialize(box)
dofile(box:get_config("${Path_Data}") .. "/plugins/stimulation/lua-stimulator-stim-codes.lua")
g_offset = box:get_setting(2)
g_duration = box:get_setting(3)
end
-- this function is called when the box is uninitialized
function uninitialize(box)
end
function wait_until(box, time)
while box:get_current_time() < time do
box:sleep()
end
end
function wait_for(box, duration)
wait_until(box, box:get_current_time() + duration)
end
function process(box)
-- loops on every received stimulation for a given input
while box:keep_processing() do
for stimulation = 1, box:get_stimulation_count(1) do
-- gets the received stimulation
identifier, date, duration = box:get_stimulation(1, 1)
-- discards it
box:remove_stimulation(1, 1)
-- delay the OVTK_GDF_Left and Right
if identifier == OVTK_GDF_Left or identifier == OVTK_GDF_Right then
box:send_stimulation(1, OVTK_GDF_Correct, date+g_offset, 0)
box:send_stimulation(1, OVTK_GDF_Incorrect, date+g_offset+g_duration, 0)
end
end
box:sleep()
end
end
+471
View File
@@ -0,0 +1,471 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch\n",
"\n",
"from pathlib import Path\n",
"\n",
"from matplotlib import pylab as plt\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import balanced_accuracy_score\n",
"\n",
"import pytorch_lightning as pl\n",
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
"\n",
"from core.io import filter_events, get_ranges, split_epochs_binary\n",
"from core.signal_processing import filter_signal, __create_constants, CSPFilter\n",
"from core.nn.provider import NetProvider\n",
"from core.nn.data import TableDataset\n",
"\n",
"%matplotlib inline\n",
"plt.style.use('bmh')\n",
"plt.rcParams.update({'font.size': 14})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Необходимо проверить точность классификации с помощью полносвязной сети:\n",
"\n",
"1. Подобрать архитектуру (поиск по сетке) - возможно полносвязные слои заменить на свёртку\n",
"2. Выбрать электроды\n",
"3. Исследовать влияние фильтрации\n",
"4. Исследовать влияние использования CSP-фильтра\n",
"\n",
"> По числовым значениям - **обсудить с Андреем Николаевичем**\n",
"\n",
"> **Окружение TacticENV**\n",
"\n",
"> **Сохранять все результаты - ребятам понадабятся графики/таблицы в статье, необзодимые визуализации обсудить с Андреем Николаевичем**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Примеры использования библеотеки `core`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Данные"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Загрузка"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"openvibe_config = json.load(open(\"./config/openVibe.json\", \"r\"))\n",
"raw_data = pd.read_csv(\"./Записи/OpenVibe signals/motor-imagery-csp-1-acquisition-[2024.04.18-10.37.21].csv\")\n",
"data = filter_events(raw_data, openvibe_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Фильтрация"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"constants = __create_constants(low=8.0, high=30.0, sample_rate=250.0, order=5) \n",
"filtered_data = filter_signal(df=data, cols=[x for x in data.columns if \"Channel\" in x], constants=constants)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Получение отрезков и разметки"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"x, y = split_epochs_binary(\n",
" df=filtered_data, \n",
" epoch_duration=4 * 250, \n",
" init_pause=int(0.5 * 250), \n",
" final_pause=0, \n",
" step=50000, # берем только 1 отрезок на действие (обсудить с А.Н.), подробнее см. доку \n",
" map_labels={\"left\": 0, \"right\": 1},\n",
" cols=[x for x in data.columns if \"Channel\" in x]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CSP-фильтр"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(40, 1000)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"csp_filter = CSPFilter(4)\n",
"csp_filter.fit(x, y)\n",
"csp_filter.transform(x).shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Нейронки\n",
"\n",
"> Используем `pytorch-lightning`, если будешь добавлять - придерживайся стиля 🤖\n",
"\n",
"Ниже пример полного пайплайна обучения (без фильтрации, без сколбзящего окна, без CSP-фильтра) для 3 случайных электродов - `Channel 1`, `Channel 2`, `Channel 3`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(40, 3000) (40, 2)\n"
]
}
],
"source": [
"raw_data = pd.read_csv(\"./Записи/OpenVibe signals/motor-imagery-csp-1-acquisition-[2024.04.18-10.37.21].csv\")\n",
"data = filter_events(raw_data, openvibe_config)\n",
"x, y = split_epochs_binary(\n",
" df=data, \n",
" epoch_duration=4 * 250, \n",
" init_pause=int(0.5 * 250), \n",
" final_pause=0, \n",
" step=50000, # берем только 1 отрезок на действие (обсудить с А.Н.), подробнее см. доку \n",
" map_labels={\"left\": 0, \"right\": 1},\n",
" cols=[\"Channel 1\", \"Channel 2\", \"Channel 3\"]\n",
")\n",
"\n",
"# Делаем сигналы плоскими, лейблы в OneHot\n",
"x = x.reshape((-1, 3 * 1000))\n",
"y = np.eye(2)[y]\n",
"print(x.shape, y.shape)\n",
"\n",
"# бъем train/test\n",
"X_train, X_test, y_train, y_test = train_test_split(x, y, train_size=0.8)\n",
"\n",
"# конвертим в объекты торча\n",
"train_loader = torch.utils.data.DataLoader(TableDataset(X_train, y_train), batch_size=32)\n",
"val_loader = torch.utils.data.DataLoader(TableDataset(X_test, y_test), batch_size=8)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"net = torch.nn.Sequential(\n",
" torch.nn.Linear(3000, 1024),\n",
" torch.nn.Sigmoid(),\n",
" torch.nn.Linear(1024, 2),\n",
" torch.nn.Softmax(dim=1),\n",
")\n",
"\n",
"wrapped_net = NetProvider(\n",
" net=net, \n",
" lr=1e-4, \n",
" criteria=torch.nn.BCELoss(), \n",
" metrics=[\n",
" lambda pred, labels: balanced_accuracy_score(np.argmax(labels, axis=1), np.argmax(pred, axis=1))\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"Missing logger folder: runs\\example\\lightning_logs\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"\n",
" | Name | Type | Params\n",
"----------------------------------------\n",
"0 | criteria | BCELoss | 0 \n",
"1 | net | Sequential | 3.1 M \n",
"----------------------------------------\n",
"3.1 M Trainable params\n",
"0 Non-trainable params\n",
"3.1 M Total params\n",
"12.300 Total estimated model params size (MB)\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dc2a3deccc5b4690a005b7309e3f9322",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\user\\anaconda3\\envs\\TacticENV\\lib\\site-packages\\pytorch_lightning\\trainer\\connectors\\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
"c:\\Users\\user\\anaconda3\\envs\\TacticENV\\lib\\site-packages\\pytorch_lightning\\trainer\\connectors\\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "97215ced575c46d99b8e3cd20ab809e9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2d57332f35bf49ab8ff89e3f4215180b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "64ab7eb3058c4e518478349348fcb950",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c621ba6368341d68765f570aaa2b8db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4fc7906c46b44fc8814452719f69a6cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dcdd8be5bf65421db1af4cf9303da835",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cb3f5bc1991b4addb12cd99b9213a57b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "734438becdbc443e87942898e64199a4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0baf18f624a341efba714be21cf42c5f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d39784d032f84f5cbc06b1694beb1acc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b1998c91254e493893112c63a0e0b680",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Validation: | | 0/? [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
]
}
],
"source": [
"torch.set_float32_matmul_precision('medium')\n",
"trainer = pl.Trainer(\n",
" default_root_dir=\"./runs/example\",\n",
" accelerator=\"auto\",\n",
" max_epochs=10,\n",
" callbacks=[\n",
" ModelCheckpoint(save_weights_only=True, monitor=\"val/loss\", mode=\"min\"),\n",
" ],\n",
" log_every_n_steps=1,\n",
")\n",
"\n",
"trainer.fit(wrapped_net, train_loader, val_loader)\n",
"__best_path = list(Path(\"./runs/example/lightning_logs/version_0/checkpoints/\").glob(\"*\"))[0]\n",
"wrapped_net.load_state_dict(torch.load(str(__best_path))['state_dict'])\n",
"del __best_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
+25
View File
@@ -0,0 +1,25 @@
{
"map_actions": {
"769": "OVTK_GDF_Left",
"768": "OVTK_GDF_Start_Of_Trial",
"33281": "OVTK_StimulationId_Train",
"786": "OVTK_GDF_Cross_On_Screen",
"32769": "OVTK_StimulationId_ExperimentStart",
"32775": "OVTK_StimulationId_BaselineStart",
"33553": "OVTK_StimulationId_AddedSamplesBegin",
"32776": "OVTK_StimulationId_BaselineStop",
"1010": "OVTK_GDF_End_Of_Session",
"770": "OVTK_GDF_Right",
"781": "OVTK_GDF_Feedback_Continuous",
"33282": "OVTK_StimulationId_Beep",
"800": "OVTK_GDF_End_Of_Trial",
"33552": "OVTK_StimulationId_RemovedSamples",
"32770": "OVTK_StimulationId_ExperimentStop",
"33554": "OVTK_StimulationId_AddedSamplesEnd"
},
"main_actions": {
"left": "769",
"right": "770",
"stop": "800"
}
}
View File
+135
View File
@@ -0,0 +1,135 @@
import numpy as np
import pandas as pd
def filter_events(df: pd.DataFrame, config: dict[str, any]) -> pd.DataFrame:
"""Получение табличной нотации для экспериментов из файлов OpenVibe (.csv)
Args:
df (pd.DataFrame): сырой файл .ov сконвертированный в .csv и загруженный в pandas
config (dict[str, any]): словарь с расшифровкой кодов событий OpenVibe
Returns:
pd.DataFrame: фрейм (таблица) в нашей собственной нотации
"""
events = []
__temp_event = {}
new_frame = df.copy()
__ov_events = df[~df["Event Id"].isna()][["Event Id"]]
for idx in range(__ov_events.shape[0]):
__row_events = __ov_events.iloc[idx, 0].split(":")
if config["main_actions"]["left"] in __row_events or config["main_actions"]["right"] in __row_events:
__temp_event = {"start": __ov_events.index[idx]}
if config["main_actions"]["left"] in __row_events:
__temp_event["action"] = "left"
if config["main_actions"]["right"] in __row_events:
__temp_event["action"] = "right"
# В записи stop повторяется дважды (как и любые другие действия)
if config["main_actions"]["stop"] in __row_events and len(__temp_event) > 0:
__temp_event["stop"] = __ov_events.index[idx + 1]
events.append(__temp_event.copy())
__temp_event = {}
new_frame = new_frame.drop(columns=["Epoch", "Event Id", "Event Date", "Event Duration"])
new_frame["action"] = None
for event in events:
new_frame.loc[event["start"] : event["stop"], "action"] = event["action"]
new_frame["action"] = new_frame["action"].fillna("relax")
return new_frame
def get_ranges(labels: np.ndarray) -> list[list[int, int, int]]:
"""Преобразование сырых лейблов временного ряда в отрезки для визуализации
Args:
labels (np.ndarray): лейблы
Returns:
list[list[int, int, int]]: список отрезков для визуализации, каждый отрезок определен как
тройка "начало", "конец", "тип действия"
"""
actions = []
__temp_action = [0]
__prev_action = labels[0]
idx = 1
while True:
if idx + 1 > len(labels):
__temp_action.extend([idx, __prev_action])
actions.append(__temp_action)
break
if labels[idx] != __prev_action:
__temp_action.extend([idx - 1, __prev_action])
actions.append(__temp_action)
__temp_action = [idx]
__prev_action = labels[idx]
idx += 1
return actions
def split_epochs_binary(
df: pd.DataFrame,
epoch_duration: int,
init_pause: int,
final_pause: int,
step: int,
map_labels: dict[str, int],
cols: list[str],
) -> tuple[np.ndarray, np.ndarray]:
"""Формирование бинарного датасета (из отрезков, когда есть действие!)
Notes:
Если нужно использовать только один отрезок на действие - установть step > [отсчеты в отрезке]
Например, если отрезок действия 4 с (1000 отсчётов для частоты дискретизации 250 Гц), то поставив step=1001
будет сформирован только 1 обучающий (валидационный) объект из данного действия.
Args:
df (pd.DataFrame): фрейм с данными
epoch_duration (int): длительность эпохи (в отсчетах сигнала, t*f)
init_pause (int): начальная пауза - сколько отсчетов выбросить с момента предъявления стимула
(в отсчетах сигнала, t*f)
final_pause (int): конечная пауза - сколько отсчетов выкинуть в конце (в отсчетах сигнала, t*f)
step (int): шаг смещения внутри времени прдъявления стимула (в отсчетах сигнала, t*f)
map_labels (dict[str, int]): словарь соотносящий действие и его числовое значение,
например, {'left': 0, 'right': 1}
cols (list[str]): имена колонок, которые будут добавлены в датасет,
например: ['Channel 1', 'Channel 2', 'Channel 3']
Returns:
tuple[np.ndarray, np.ndarray]: кортеж с 2 массивами:
* обучающие данные, размерность [количество образцов X длительность эпохи X количество каналов]
* метки классов (в соответсвии с map_labels), размерность [количество образцов X 1]
Raise:
ValueError: словарь map_labels содержит информацию не о 2х классах
"""
if len(map_labels) != 2:
raise ValueError(f"Словарь map_labels должен содержать 2 класса, получено: {len(map_labels)}!")
actions = get_ranges(df["action"])
x = []
y = []
for action in actions:
if action[2] in map_labels.keys():
__low_bound = action[0] + init_pause
__upper_bound = action[1] - final_pause - epoch_duration + 1
for inner_idx in range(__low_bound, __upper_bound, step):
x.append(df.loc[inner_idx : inner_idx + epoch_duration - 1, cols].values)
y.append(map_labels[action[2]])
return np.array(x), np.array(y)
View File
+23
View File
@@ -0,0 +1,23 @@
import numpy as np
import torch
class TableDataset(torch.utils.data.Dataset):
def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
"""Простой датасет из объектов numpy
Args:
x (np.ndarray): обучающие объекты
y (np.ndarray): метки
"""
super().__init__()
self.map_labels = None
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.y[idx])
+53
View File
@@ -0,0 +1,53 @@
import pytorch_lightning as pl
import torch
class NetProvider(pl.LightningModule):
def __init__(self, net: torch.nn.Module, lr: float, criteria: torch.nn.Module, metrics: list[callable]) -> None:
"""Базовая обертка для сетей
Args:
net (torch.nn.Module): сеть - модуль pyTorch
lr (float): темп обучения
criteria (torch.nn.Module): функция ошибки
metrics (list[callable]): список метрик - функций, принимающих (preds: np.ndarray, true_labels: np.ndarray)
"""
super().__init__()
self.lr = lr
self.criteria = criteria
self.net = net
self.metrics = metrics
def forward(self, x):
return self.net(x)
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.lr)
def training_step(self, batch, batch_idx):
data, labels = batch
preds = self.forward(data)
loss = self.criteria(preds, labels)
self.log("train/loss", loss, on_step=True, on_epoch=True)
for metric in self.metrics:
self.log(
f"train/{metric.__name__}",
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
on_step=True,
on_epoch=True,
)
return loss
def validation_step(self, batch, batch_idx):
data, labels = batch
preds = self.forward(data)
loss = self.criteria(preds, labels)
self.log("val/loss", loss, on_step=True, on_epoch=True)
for metric in self.metrics:
self.log(
f"val/{metric.__name__}",
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
on_step=True,
on_epoch=True,
)
return loss
+134
View File
@@ -0,0 +1,134 @@
import numpy as np
import pandas as pd
import scipy
from scipy.signal import butter, lfilter
def __create_constants(low: float, high: float, sample_rate: float, order: int) -> tuple[float, float]:
"""Вычисление констан полосового фильтра (Баттерворта)
Args:
low (float): нижняя чатсота среза
high (float): верхняя частота среза
sample_rate (float): частота дискретизации
order (int): порядок фильтра
Returns:
tuple[float, float]: константы фильтра, необходимые для вычислений
"""
return butter(order, [low, high], fs=sample_rate, btype="band")
def filter_signal(df: pd.DataFrame, cols: list[str], constants: tuple[float, float]) -> pd.DataFrame:
"""Фильтрация датасета (сигналов с датчиков)
Args:
df (pd.DataFrame): датафрейм с данными из OpenVibe
cols (list[str]): список колонок, данные в которых нужно отфильтровать
constants (tuple[float, float]): константы для вычисления фильтра (см. __create_constants)
Returns:
pd.DataFrame: входной фрейм с данными, с примененным фильтром
"""
new_df = df.copy()
for col in cols:
new_df[col] = lfilter(*constants, df[col].values)
return new_df
class CSPFilter:
"""
Реализация CSP фильтра в стиле sklearn.
Основано на реализации github.co/Hiroaki-K4/total_perspective_vortex/main/srcs/csp.py
"""
def __init__(self, n_components: int) -> None:
self.n_components = n_components
self.classes = None
self.filters = None
self.patterns = None
self.mean = None
self.std = None
def __get_covariate(self, x: np.ndarray, y: np.ndarray, cls: any) -> np.ndarray:
"""Получение ковариаций
Args:
x (np.ndarray): отсчеты сигнала
y (np.ndarray): метки класса
cls (any): текущий класс
Returns:
np.ndarray: матрица ковариаций
"""
x_class = x[y == cls]
_, _, n_channels = x_class.shape
# x_class = np.transpose(x_class, [1, 0, 2])
x_class = x_class.reshape(n_channels, -1)
cov = np.dot(x_class, x_class.T)
return cov
def fit(self, x: np.ndarray, y: np.ndarray) -> None:
"""Обучение фильтра
Args:
x (np.ndarray): отсчеты сигнала (числа)
y (np.ndarray): метки класса (любые данные)
"""
self.classes = np.unique(y)
n_classes = len(self.classes)
if n_classes != 2:
raise ValueError(f"Количество классов для CSP должно быть = 2 (получено {n_classes})")
cov_neg = self.__get_covariate(x, y, self.classes[0])
cov_pos = self.__get_covariate(x, y, self.classes[1])
eig_vals, eig_vecs = scipy.linalg.eigh(cov_neg, cov_pos)
for i in range(len(eig_vecs)):
eig_vecs[i] = eig_vecs[i] / np.linalg.norm(eig_vecs[i])
sorted_vals = np.argsort(eig_vals)
idxs = np.empty_like(sorted_vals)
idxs[0::2] = sorted_vals[len(sorted_vals) // 2 :][::-1]
idxs[1::2] = sorted_vals[: len(sorted_vals) // 2]
eig_vecs = eig_vecs[:, idxs]
self.filters = eig_vecs.T
self.patterns = np.linalg.inv(eig_vecs)
pick_filters = self.filters[:, self.n_components]
x = np.asarray([np.dot(pick_filters, epoch.T) for epoch in x])
x = (x**2).mean(axis=1)
self.mean = x.mean(axis=0)
self.std = x.std(axis=0)
def transform(self, x: np.ndarray) -> np.ndarray:
"""Применение фильтра
Args:
x (np.ndarray): отсчеты сигнала (числа)
Returns:
np.ndarray: вторичные параметры (фичи)
"""
pick_filters = self.filters[: self.n_components]
x = np.asarray([np.dot(pick_filters, epoch.T) for epoch in x])
x = (x**2).mean(axis=1)
x -= self.mean
x /= self.std
return x
def fit_transform(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Обучение с последующим применением ко входным данным
Args:
x (np.ndarray): отсчеты сигнала (числа)
y (np.ndarray): метки класса (любые данные)
Returns:
np.ndarray: вторичные параметры (фичи)
"""
self.fit(x, y)
return self.transform(x)
+34
View File
@@ -0,0 +1,34 @@
import numpy as np
def get_ranges(labels: np.ndarray) -> list[list[int, int, int]]:
"""Преобразование сырых лейблов временного ряда в отрезки для визуализации
Args:
labels (np.ndarray): лейблы
Returns:
list[list[int, int, int]]: список отрезков для визуализации, каждый отрезок определен как
тройка "начало", "конец", "тип действия"
"""
actions = []
__temp_action = [0]
__prev_action = labels[0]
idx = 1
while True:
if idx + 1 > len(labels):
__temp_action.extend([idx, __prev_action])
actions.append(__temp_action)
break
if labels[idx] != __prev_action:
__temp_action.extend([idx - 1, __prev_action])
actions.append(__temp_action)
__temp_action = [idx]
__prev_action = labels[idx]
idx += 1
return actions
+911
View File
File diff suppressed because one or more lines are too long
Binary file not shown.
+70
View File
@@ -0,0 +1,70 @@
=================================================
Îïèñàíèå ôîðìàòà csv ôàéëà
http://openvibe.inria.fr/csv-file-format-description/
Êîäû ñîáûòèé:
http://openvibe.inria.fr/stimulation-codes/
Ôîðìàò âûâîäà:
Time:250Hz,
Epoch,
Channel 1,
Channel 2,
Channel 3,
Channel 4,
Channel 5,
Channel 6,
Channel 7,
Channel 8,
Channel 9,
Channel 10,
Channel 11,
Channel 12,
Channel 13,
Channel 14,
Channel 15,
Channel 16,
Event Id,
Event Date,
Ïðèìåðû ôðàãìåíòîâ çàïèñè îòíîñÿùèåñÿ ê ñîáûòèÿì ïîêàçà ñòðåëîê:
Ëåâî
32.9839999469,257,175321.6875000000,161476.5781250000,157384.8906250000,165382.3125000000,182353.1093750000,163875.0781250000,160708.5000000000,162043.2968750000,150507.6562500000,177546.5781250000,156140.7187500000,132061.4531250000,152687.8437500000,153717.4062500000,163318.2500000000,154924.4218750000,
770:770,
32.9874541578:32.9876184834,
0.0000000000:0.0000000000
32.9879999468,257,175311.5468750000,161454.5156250000,157379.3906250000,165363.1093750000,182351.3593750000,163863.5000000000,160706.2656250000,162040.2187500000,150504.3281250000,177536.3281250000,156103.6406250000,132053.7187500000,152655.0156250000,153707.7500000000,163314.8906250000,154920.7187500000,,,
32.9919999468,257,175305.4843750000,161444.4375000000,157375.6718750000,165352.2968750000,182349.0781250000,163856.1250000000,160703.7343750000,162037.6875000000,150501.0468750000,177530.1250000000,156083.9531250000,132048.3281250000,152644.7812500000,153701.3906250000,163310.8281250000,154917.0468750000,,,
32.9959999467,257,175315.4062500000,161463.2812500000,157381.4062500000,165367.3437500000,182351.8906250000,163865.5000000000,160707.1562500000,162041.9375000000,150504.0156250000,177539.1562500000,156111.0312500000,132055.3593750000,152673.0468750000,153708.9375000000,163314.3750000000,154920.6718750000,
770,
32.9962524336,
0.0000000000
Ïðàâî
60.9639999014,476,175538.3437500000,162268.3906250000,157672.1093750000,165778.6562500000,182674.5156250000,165550.2656250000,161696.2656250000,162263.4062500000,151121.2812500000,178085.1250000000,156655.3437500000,132842.8750000000,153574.0000000000,154055.8593750000,163721.6718750000,154574.9375000000,
769:769,
60.9675073745:60.9675452565,
0.0000000000:0.0000000000
60.9679999014,476,175526.0312500000,162246.7031250000,157664.8125000000,165758.0937500000,182672.2187500000,165538.8906250000,161693.2500000000,162259.3750000000,151118.2187500000,178074.6406250000,156616.5000000000,132834.6093750000,153542.0781250000,154045.5781250000,163716.9375000000,154570.5625000000,
769,
60.9707477945,
0.0000000000
Binary file not shown.
Binary file not shown.
Binary file not shown.
+70
View File
@@ -0,0 +1,70 @@
=================================================
Îïèñàíèå ôîðìàòà csv ôàéëà
http://openvibe.inria.fr/csv-file-format-description/
Êîäû ñîáûòèé:
http://openvibe.inria.fr/stimulation-codes/
Ôîðìàò âûâîäà:
Time:250Hz,
Epoch,
Channel 1,
Channel 2,
Channel 3,
Channel 4,
Channel 5,
Channel 6,
Channel 7,
Channel 8,
Channel 9,
Channel 10,
Channel 11,
Channel 12,
Channel 13,
Channel 14,
Channel 15,
Channel 16,
Event Id,
Event Date,
Ïðèìåðû ôðàãìåíòîâ çàïèñè îòíîñÿùèåñÿ ê ñîáûòèÿì ïîêàçà ñòðåëîê:
Ëåâî
32.9839999469,257,175321.6875000000,161476.5781250000,157384.8906250000,165382.3125000000,182353.1093750000,163875.0781250000,160708.5000000000,162043.2968750000,150507.6562500000,177546.5781250000,156140.7187500000,132061.4531250000,152687.8437500000,153717.4062500000,163318.2500000000,154924.4218750000,
770:770,
32.9874541578:32.9876184834,
0.0000000000:0.0000000000
32.9879999468,257,175311.5468750000,161454.5156250000,157379.3906250000,165363.1093750000,182351.3593750000,163863.5000000000,160706.2656250000,162040.2187500000,150504.3281250000,177536.3281250000,156103.6406250000,132053.7187500000,152655.0156250000,153707.7500000000,163314.8906250000,154920.7187500000,,,
32.9919999468,257,175305.4843750000,161444.4375000000,157375.6718750000,165352.2968750000,182349.0781250000,163856.1250000000,160703.7343750000,162037.6875000000,150501.0468750000,177530.1250000000,156083.9531250000,132048.3281250000,152644.7812500000,153701.3906250000,163310.8281250000,154917.0468750000,,,
32.9959999467,257,175315.4062500000,161463.2812500000,157381.4062500000,165367.3437500000,182351.8906250000,163865.5000000000,160707.1562500000,162041.9375000000,150504.0156250000,177539.1562500000,156111.0312500000,132055.3593750000,152673.0468750000,153708.9375000000,163314.3750000000,154920.6718750000,
770,
32.9962524336,
0.0000000000
Ïðàâî
60.9639999014,476,175538.3437500000,162268.3906250000,157672.1093750000,165778.6562500000,182674.5156250000,165550.2656250000,161696.2656250000,162263.4062500000,151121.2812500000,178085.1250000000,156655.3437500000,132842.8750000000,153574.0000000000,154055.8593750000,163721.6718750000,154574.9375000000,
769:769,
60.9675073745:60.9675452565,
0.0000000000:0.0000000000
60.9679999014,476,175526.0312500000,162246.7031250000,157664.8125000000,165758.0937500000,182672.2187500000,165538.8906250000,161693.2500000000,162259.3750000000,151118.2187500000,178074.6406250000,156616.5000000000,132834.6093750000,153542.0781250000,154045.5781250000,163716.9375000000,154570.5625000000,
769,
60.9707477945,
0.0000000000
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,4 @@
1. В данных существуюет небольшие выбросы по времени предъявления стимула (лево/право) см. pdf. В программе есть вариации по длительности предъявления или это ошибка чтения? >90% стимулов - 5 с. однако, 1-3 стимула (в среднем) длиннее - до 7 с. (см. pdf).
2. Какие электроды берем для эксперемента с 3мя? Номера в нотации 'Channel #'
3. Если эпохи по 4 с после предъявления сигнала (с учётом паузы в 0,5 с), то через какой шаг смещаемся внутри времени предъявления стимула? Через каждый отсчет - не будет онлайна (не успеем вычислить фильтры и сеть со скоростью 1/250с). Пока сделал через 0,5 с (принцип скользящего окна - когда праый край касается границы предъявления стимула прекращаем скольжение, см. pdf).
4. Реализация CSP может отличаться.