init commit: загрузка и предобработка данных
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
Записи/OpenBCI GUI
|
||||
Записи/OpenVibe Signals
|
||||
__pycache__
|
||||
.ipynb_checkpoints
|
||||
.pytest_cache
|
||||
runs
|
||||
@@ -0,0 +1,50 @@
|
||||
|
||||
g_offset = nil
|
||||
g_duration = nil
|
||||
|
||||
-- this function is called when the box is initialized
|
||||
function initialize(box)
|
||||
|
||||
dofile(box:get_config("${Path_Data}") .. "/plugins/stimulation/lua-stimulator-stim-codes.lua")
|
||||
|
||||
g_offset = box:get_setting(2)
|
||||
g_duration = box:get_setting(3)
|
||||
end
|
||||
|
||||
-- this function is called when the box is uninitialized
|
||||
function uninitialize(box)
|
||||
|
||||
end
|
||||
|
||||
function wait_until(box, time)
|
||||
while box:get_current_time() < time do
|
||||
box:sleep()
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function wait_for(box, duration)
|
||||
wait_until(box, box:get_current_time() + duration)
|
||||
end
|
||||
|
||||
|
||||
|
||||
function process(box)
|
||||
-- loops on every received stimulation for a given input
|
||||
while box:keep_processing() do
|
||||
for stimulation = 1, box:get_stimulation_count(1) do
|
||||
|
||||
-- gets the received stimulation
|
||||
identifier, date, duration = box:get_stimulation(1, 1)
|
||||
-- discards it
|
||||
box:remove_stimulation(1, 1)
|
||||
|
||||
-- delay the OVTK_GDF_Left and Right
|
||||
if identifier == OVTK_GDF_Left or identifier == OVTK_GDF_Right then
|
||||
box:send_stimulation(1, OVTK_GDF_Correct, date+g_offset, 0)
|
||||
box:send_stimulation(1, OVTK_GDF_Incorrect, date+g_offset+g_duration, 0)
|
||||
end
|
||||
end
|
||||
box:sleep()
|
||||
end
|
||||
end
|
||||
@@ -0,0 +1,471 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import pandas as pd\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"from matplotlib import pylab as plt\n",
|
||||
"\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from sklearn.metrics import balanced_accuracy_score\n",
|
||||
"\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
|
||||
"\n",
|
||||
"from core.io import filter_events, get_ranges, split_epochs_binary\n",
|
||||
"from core.signal_processing import filter_signal, __create_constants, CSPFilter\n",
|
||||
"from core.nn.provider import NetProvider\n",
|
||||
"from core.nn.data import TableDataset\n",
|
||||
"\n",
|
||||
"%matplotlib inline\n",
|
||||
"plt.style.use('bmh')\n",
|
||||
"plt.rcParams.update({'font.size': 14})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Необходимо проверить точность классификации с помощью полносвязной сети:\n",
|
||||
"\n",
|
||||
"1. Подобрать архитектуру (поиск по сетке) - возможно полносвязные слои заменить на свёртку\n",
|
||||
"2. Выбрать электроды\n",
|
||||
"3. Исследовать влияние фильтрации\n",
|
||||
"4. Исследовать влияние использования CSP-фильтра\n",
|
||||
"\n",
|
||||
"> По числовым значениям - **обсудить с Андреем Николаевичем**\n",
|
||||
"\n",
|
||||
"> **Окружение TacticENV**\n",
|
||||
"\n",
|
||||
"> **Сохранять все результаты - ребятам понадабятся графики/таблицы в статье, необзодимые визуализации обсудить с Андреем Николаевичем**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Примеры использования библеотеки `core`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Данные"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Загрузка"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"openvibe_config = json.load(open(\"./config/openVibe.json\", \"r\"))\n",
|
||||
"raw_data = pd.read_csv(\"./Записи/OpenVibe signals/motor-imagery-csp-1-acquisition-[2024.04.18-10.37.21].csv\")\n",
|
||||
"data = filter_events(raw_data, openvibe_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Фильтрация"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"constants = __create_constants(low=8.0, high=30.0, sample_rate=250.0, order=5) \n",
|
||||
"filtered_data = filter_signal(df=data, cols=[x for x in data.columns if \"Channel\" in x], constants=constants)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Получение отрезков и разметки"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"x, y = split_epochs_binary(\n",
|
||||
" df=filtered_data, \n",
|
||||
" epoch_duration=4 * 250, \n",
|
||||
" init_pause=int(0.5 * 250), \n",
|
||||
" final_pause=0, \n",
|
||||
" step=50000, # берем только 1 отрезок на действие (обсудить с А.Н.), подробнее см. доку \n",
|
||||
" map_labels={\"left\": 0, \"right\": 1},\n",
|
||||
" cols=[x for x in data.columns if \"Channel\" in x]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### CSP-фильтр"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(40, 1000)"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"csp_filter = CSPFilter(4)\n",
|
||||
"csp_filter.fit(x, y)\n",
|
||||
"csp_filter.transform(x).shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Нейронки\n",
|
||||
"\n",
|
||||
"> Используем `pytorch-lightning`, если будешь добавлять - придерживайся стиля 🤖\n",
|
||||
"\n",
|
||||
"Ниже пример полного пайплайна обучения (без фильтрации, без сколбзящего окна, без CSP-фильтра) для 3 случайных электродов - `Channel 1`, `Channel 2`, `Channel 3`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(40, 3000) (40, 2)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"raw_data = pd.read_csv(\"./Записи/OpenVibe signals/motor-imagery-csp-1-acquisition-[2024.04.18-10.37.21].csv\")\n",
|
||||
"data = filter_events(raw_data, openvibe_config)\n",
|
||||
"x, y = split_epochs_binary(\n",
|
||||
" df=data, \n",
|
||||
" epoch_duration=4 * 250, \n",
|
||||
" init_pause=int(0.5 * 250), \n",
|
||||
" final_pause=0, \n",
|
||||
" step=50000, # берем только 1 отрезок на действие (обсудить с А.Н.), подробнее см. доку \n",
|
||||
" map_labels={\"left\": 0, \"right\": 1},\n",
|
||||
" cols=[\"Channel 1\", \"Channel 2\", \"Channel 3\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Делаем сигналы плоскими, лейблы в OneHot\n",
|
||||
"x = x.reshape((-1, 3 * 1000))\n",
|
||||
"y = np.eye(2)[y]\n",
|
||||
"print(x.shape, y.shape)\n",
|
||||
"\n",
|
||||
"# бъем train/test\n",
|
||||
"X_train, X_test, y_train, y_test = train_test_split(x, y, train_size=0.8)\n",
|
||||
"\n",
|
||||
"# конвертим в объекты торча\n",
|
||||
"train_loader = torch.utils.data.DataLoader(TableDataset(X_train, y_train), batch_size=32)\n",
|
||||
"val_loader = torch.utils.data.DataLoader(TableDataset(X_test, y_test), batch_size=8)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"net = torch.nn.Sequential(\n",
|
||||
" torch.nn.Linear(3000, 1024),\n",
|
||||
" torch.nn.Sigmoid(),\n",
|
||||
" torch.nn.Linear(1024, 2),\n",
|
||||
" torch.nn.Softmax(dim=1),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"wrapped_net = NetProvider(\n",
|
||||
" net=net, \n",
|
||||
" lr=1e-4, \n",
|
||||
" criteria=torch.nn.BCELoss(), \n",
|
||||
" metrics=[\n",
|
||||
" lambda pred, labels: balanced_accuracy_score(np.argmax(labels, axis=1), np.argmax(pred, axis=1))\n",
|
||||
" ],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GPU available: True (cuda), used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n",
|
||||
"HPU available: False, using: 0 HPUs\n",
|
||||
"Missing logger folder: runs\\example\\lightning_logs\n",
|
||||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
|
||||
"\n",
|
||||
" | Name | Type | Params\n",
|
||||
"----------------------------------------\n",
|
||||
"0 | criteria | BCELoss | 0 \n",
|
||||
"1 | net | Sequential | 3.1 M \n",
|
||||
"----------------------------------------\n",
|
||||
"3.1 M Trainable params\n",
|
||||
"0 Non-trainable params\n",
|
||||
"3.1 M Total params\n",
|
||||
"12.300 Total estimated model params size (MB)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "dc2a3deccc5b4690a005b7309e3f9322",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Sanity Checking: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"c:\\Users\\user\\anaconda3\\envs\\TacticENV\\lib\\site-packages\\pytorch_lightning\\trainer\\connectors\\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
|
||||
"c:\\Users\\user\\anaconda3\\envs\\TacticENV\\lib\\site-packages\\pytorch_lightning\\trainer\\connectors\\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "97215ced575c46d99b8e3cd20ab809e9",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Training: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "2d57332f35bf49ab8ff89e3f4215180b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "64ab7eb3058c4e518478349348fcb950",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7c621ba6368341d68765f570aaa2b8db",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "4fc7906c46b44fc8814452719f69a6cb",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "dcdd8be5bf65421db1af4cf9303da835",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "cb3f5bc1991b4addb12cd99b9213a57b",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "734438becdbc443e87942898e64199a4",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "0baf18f624a341efba714be21cf42c5f",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "d39784d032f84f5cbc06b1694beb1acc",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b1998c91254e493893112c63a0e0b680",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Validation: | | 0/? [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"`Trainer.fit` stopped: `max_epochs=10` reached.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.set_float32_matmul_precision('medium')\n",
|
||||
"trainer = pl.Trainer(\n",
|
||||
" default_root_dir=\"./runs/example\",\n",
|
||||
" accelerator=\"auto\",\n",
|
||||
" max_epochs=10,\n",
|
||||
" callbacks=[\n",
|
||||
" ModelCheckpoint(save_weights_only=True, monitor=\"val/loss\", mode=\"min\"),\n",
|
||||
" ],\n",
|
||||
" log_every_n_steps=1,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"trainer.fit(wrapped_net, train_loader, val_loader)\n",
|
||||
"__best_path = list(Path(\"./runs/example/lightning_logs/version_0/checkpoints/\").glob(\"*\"))[0]\n",
|
||||
"wrapped_net.load_state_dict(torch.load(str(__best_path))['state_dict'])\n",
|
||||
"del __best_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"map_actions": {
|
||||
"769": "OVTK_GDF_Left",
|
||||
"768": "OVTK_GDF_Start_Of_Trial",
|
||||
"33281": "OVTK_StimulationId_Train",
|
||||
"786": "OVTK_GDF_Cross_On_Screen",
|
||||
"32769": "OVTK_StimulationId_ExperimentStart",
|
||||
"32775": "OVTK_StimulationId_BaselineStart",
|
||||
"33553": "OVTK_StimulationId_AddedSamplesBegin",
|
||||
"32776": "OVTK_StimulationId_BaselineStop",
|
||||
"1010": "OVTK_GDF_End_Of_Session",
|
||||
"770": "OVTK_GDF_Right",
|
||||
"781": "OVTK_GDF_Feedback_Continuous",
|
||||
"33282": "OVTK_StimulationId_Beep",
|
||||
"800": "OVTK_GDF_End_Of_Trial",
|
||||
"33552": "OVTK_StimulationId_RemovedSamples",
|
||||
"32770": "OVTK_StimulationId_ExperimentStop",
|
||||
"33554": "OVTK_StimulationId_AddedSamplesEnd"
|
||||
},
|
||||
"main_actions": {
|
||||
"left": "769",
|
||||
"right": "770",
|
||||
"stop": "800"
|
||||
}
|
||||
}
|
||||
+135
@@ -0,0 +1,135 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def filter_events(df: pd.DataFrame, config: dict[str, any]) -> pd.DataFrame:
|
||||
"""Получение табличной нотации для экспериментов из файлов OpenVibe (.csv)
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): сырой файл .ov сконвертированный в .csv и загруженный в pandas
|
||||
config (dict[str, any]): словарь с расшифровкой кодов событий OpenVibe
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: фрейм (таблица) в нашей собственной нотации
|
||||
"""
|
||||
|
||||
events = []
|
||||
__temp_event = {}
|
||||
|
||||
new_frame = df.copy()
|
||||
|
||||
__ov_events = df[~df["Event Id"].isna()][["Event Id"]]
|
||||
for idx in range(__ov_events.shape[0]):
|
||||
__row_events = __ov_events.iloc[idx, 0].split(":")
|
||||
|
||||
if config["main_actions"]["left"] in __row_events or config["main_actions"]["right"] in __row_events:
|
||||
__temp_event = {"start": __ov_events.index[idx]}
|
||||
if config["main_actions"]["left"] in __row_events:
|
||||
__temp_event["action"] = "left"
|
||||
if config["main_actions"]["right"] in __row_events:
|
||||
__temp_event["action"] = "right"
|
||||
|
||||
# В записи stop повторяется дважды (как и любые другие действия)
|
||||
if config["main_actions"]["stop"] in __row_events and len(__temp_event) > 0:
|
||||
__temp_event["stop"] = __ov_events.index[idx + 1]
|
||||
events.append(__temp_event.copy())
|
||||
__temp_event = {}
|
||||
|
||||
new_frame = new_frame.drop(columns=["Epoch", "Event Id", "Event Date", "Event Duration"])
|
||||
new_frame["action"] = None
|
||||
|
||||
for event in events:
|
||||
new_frame.loc[event["start"] : event["stop"], "action"] = event["action"]
|
||||
|
||||
new_frame["action"] = new_frame["action"].fillna("relax")
|
||||
return new_frame
|
||||
|
||||
|
||||
def get_ranges(labels: np.ndarray) -> list[list[int, int, int]]:
|
||||
"""Преобразование сырых лейблов временного ряда в отрезки для визуализации
|
||||
|
||||
Args:
|
||||
labels (np.ndarray): лейблы
|
||||
|
||||
Returns:
|
||||
list[list[int, int, int]]: список отрезков для визуализации, каждый отрезок определен как
|
||||
тройка "начало", "конец", "тип действия"
|
||||
"""
|
||||
|
||||
actions = []
|
||||
__temp_action = [0]
|
||||
__prev_action = labels[0]
|
||||
idx = 1
|
||||
|
||||
while True:
|
||||
if idx + 1 > len(labels):
|
||||
__temp_action.extend([idx, __prev_action])
|
||||
actions.append(__temp_action)
|
||||
break
|
||||
|
||||
if labels[idx] != __prev_action:
|
||||
__temp_action.extend([idx - 1, __prev_action])
|
||||
actions.append(__temp_action)
|
||||
__temp_action = [idx]
|
||||
__prev_action = labels[idx]
|
||||
|
||||
idx += 1
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def split_epochs_binary(
|
||||
df: pd.DataFrame,
|
||||
epoch_duration: int,
|
||||
init_pause: int,
|
||||
final_pause: int,
|
||||
step: int,
|
||||
map_labels: dict[str, int],
|
||||
cols: list[str],
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Формирование бинарного датасета (из отрезков, когда есть действие!)
|
||||
|
||||
Notes:
|
||||
Если нужно использовать только один отрезок на действие - установть step > [отсчеты в отрезке]
|
||||
|
||||
Например, если отрезок действия 4 с (1000 отсчётов для частоты дискретизации 250 Гц), то поставив step=1001
|
||||
будет сформирован только 1 обучающий (валидационный) объект из данного действия.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): фрейм с данными
|
||||
epoch_duration (int): длительность эпохи (в отсчетах сигнала, t*f)
|
||||
init_pause (int): начальная пауза - сколько отсчетов выбросить с момента предъявления стимула
|
||||
(в отсчетах сигнала, t*f)
|
||||
final_pause (int): конечная пауза - сколько отсчетов выкинуть в конце (в отсчетах сигнала, t*f)
|
||||
step (int): шаг смещения внутри времени прдъявления стимула (в отсчетах сигнала, t*f)
|
||||
map_labels (dict[str, int]): словарь соотносящий действие и его числовое значение,
|
||||
например, {'left': 0, 'right': 1}
|
||||
cols (list[str]): имена колонок, которые будут добавлены в датасет,
|
||||
например: ['Channel 1', 'Channel 2', 'Channel 3']
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, np.ndarray]: кортеж с 2 массивами:
|
||||
|
||||
* обучающие данные, размерность [количество образцов X длительность эпохи X количество каналов]
|
||||
|
||||
* метки классов (в соответсвии с map_labels), размерность [количество образцов X 1]
|
||||
|
||||
Raise:
|
||||
ValueError: словарь map_labels содержит информацию не о 2х классах
|
||||
"""
|
||||
|
||||
if len(map_labels) != 2:
|
||||
raise ValueError(f"Словарь map_labels должен содержать 2 класса, получено: {len(map_labels)}!")
|
||||
|
||||
actions = get_ranges(df["action"])
|
||||
x = []
|
||||
y = []
|
||||
|
||||
for action in actions:
|
||||
if action[2] in map_labels.keys():
|
||||
__low_bound = action[0] + init_pause
|
||||
__upper_bound = action[1] - final_pause - epoch_duration + 1
|
||||
for inner_idx in range(__low_bound, __upper_bound, step):
|
||||
x.append(df.loc[inner_idx : inner_idx + epoch_duration - 1, cols].values)
|
||||
y.append(map_labels[action[2]])
|
||||
return np.array(x), np.array(y)
|
||||
@@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class TableDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
|
||||
"""Простой датасет из объектов numpy
|
||||
|
||||
Args:
|
||||
x (np.ndarray): обучающие объекты
|
||||
y (np.ndarray): метки
|
||||
"""
|
||||
super().__init__()
|
||||
self.map_labels = None
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.y[idx])
|
||||
@@ -0,0 +1,53 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class NetProvider(pl.LightningModule):
|
||||
def __init__(self, net: torch.nn.Module, lr: float, criteria: torch.nn.Module, metrics: list[callable]) -> None:
|
||||
"""Базовая обертка для сетей
|
||||
|
||||
Args:
|
||||
net (torch.nn.Module): сеть - модуль pyTorch
|
||||
lr (float): темп обучения
|
||||
criteria (torch.nn.Module): функция ошибки
|
||||
metrics (list[callable]): список метрик - функций, принимающих (preds: np.ndarray, true_labels: np.ndarray)
|
||||
"""
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.criteria = criteria
|
||||
self.net = net
|
||||
self.metrics = metrics
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.AdamW(self.parameters(), lr=self.lr)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
data, labels = batch
|
||||
preds = self.forward(data)
|
||||
loss = self.criteria(preds, labels)
|
||||
self.log("train/loss", loss, on_step=True, on_epoch=True)
|
||||
for metric in self.metrics:
|
||||
self.log(
|
||||
f"train/{metric.__name__}",
|
||||
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
data, labels = batch
|
||||
preds = self.forward(data)
|
||||
loss = self.criteria(preds, labels)
|
||||
self.log("val/loss", loss, on_step=True, on_epoch=True)
|
||||
for metric in self.metrics:
|
||||
self.log(
|
||||
f"val/{metric.__name__}",
|
||||
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
return loss
|
||||
@@ -0,0 +1,134 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy
|
||||
from scipy.signal import butter, lfilter
|
||||
|
||||
|
||||
def __create_constants(low: float, high: float, sample_rate: float, order: int) -> tuple[float, float]:
|
||||
"""Вычисление констан полосового фильтра (Баттерворта)
|
||||
|
||||
Args:
|
||||
low (float): нижняя чатсота среза
|
||||
high (float): верхняя частота среза
|
||||
sample_rate (float): частота дискретизации
|
||||
order (int): порядок фильтра
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: константы фильтра, необходимые для вычислений
|
||||
"""
|
||||
return butter(order, [low, high], fs=sample_rate, btype="band")
|
||||
|
||||
|
||||
def filter_signal(df: pd.DataFrame, cols: list[str], constants: tuple[float, float]) -> pd.DataFrame:
|
||||
"""Фильтрация датасета (сигналов с датчиков)
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): датафрейм с данными из OpenVibe
|
||||
cols (list[str]): список колонок, данные в которых нужно отфильтровать
|
||||
constants (tuple[float, float]): константы для вычисления фильтра (см. __create_constants)
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: входной фрейм с данными, с примененным фильтром
|
||||
"""
|
||||
new_df = df.copy()
|
||||
for col in cols:
|
||||
new_df[col] = lfilter(*constants, df[col].values)
|
||||
|
||||
return new_df
|
||||
|
||||
|
||||
class CSPFilter:
|
||||
"""
|
||||
Реализация CSP фильтра в стиле sklearn.
|
||||
Основано на реализации github.co/Hiroaki-K4/total_perspective_vortex/main/srcs/csp.py
|
||||
"""
|
||||
|
||||
def __init__(self, n_components: int) -> None:
|
||||
self.n_components = n_components
|
||||
self.classes = None
|
||||
self.filters = None
|
||||
self.patterns = None
|
||||
self.mean = None
|
||||
self.std = None
|
||||
|
||||
def __get_covariate(self, x: np.ndarray, y: np.ndarray, cls: any) -> np.ndarray:
|
||||
"""Получение ковариаций
|
||||
|
||||
Args:
|
||||
x (np.ndarray): отсчеты сигнала
|
||||
y (np.ndarray): метки класса
|
||||
cls (any): текущий класс
|
||||
|
||||
Returns:
|
||||
np.ndarray: матрица ковариаций
|
||||
"""
|
||||
|
||||
x_class = x[y == cls]
|
||||
_, _, n_channels = x_class.shape
|
||||
# x_class = np.transpose(x_class, [1, 0, 2])
|
||||
x_class = x_class.reshape(n_channels, -1)
|
||||
cov = np.dot(x_class, x_class.T)
|
||||
return cov
|
||||
|
||||
def fit(self, x: np.ndarray, y: np.ndarray) -> None:
|
||||
"""Обучение фильтра
|
||||
|
||||
Args:
|
||||
x (np.ndarray): отсчеты сигнала (числа)
|
||||
y (np.ndarray): метки класса (любые данные)
|
||||
"""
|
||||
self.classes = np.unique(y)
|
||||
n_classes = len(self.classes)
|
||||
if n_classes != 2:
|
||||
raise ValueError(f"Количество классов для CSP должно быть = 2 (получено {n_classes})")
|
||||
|
||||
cov_neg = self.__get_covariate(x, y, self.classes[0])
|
||||
cov_pos = self.__get_covariate(x, y, self.classes[1])
|
||||
|
||||
eig_vals, eig_vecs = scipy.linalg.eigh(cov_neg, cov_pos)
|
||||
for i in range(len(eig_vecs)):
|
||||
eig_vecs[i] = eig_vecs[i] / np.linalg.norm(eig_vecs[i])
|
||||
|
||||
sorted_vals = np.argsort(eig_vals)
|
||||
idxs = np.empty_like(sorted_vals)
|
||||
idxs[0::2] = sorted_vals[len(sorted_vals) // 2 :][::-1]
|
||||
idxs[1::2] = sorted_vals[: len(sorted_vals) // 2]
|
||||
|
||||
eig_vecs = eig_vecs[:, idxs]
|
||||
self.filters = eig_vecs.T
|
||||
self.patterns = np.linalg.inv(eig_vecs)
|
||||
pick_filters = self.filters[:, self.n_components]
|
||||
|
||||
x = np.asarray([np.dot(pick_filters, epoch.T) for epoch in x])
|
||||
x = (x**2).mean(axis=1)
|
||||
self.mean = x.mean(axis=0)
|
||||
self.std = x.std(axis=0)
|
||||
|
||||
def transform(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Применение фильтра
|
||||
|
||||
Args:
|
||||
x (np.ndarray): отсчеты сигнала (числа)
|
||||
|
||||
Returns:
|
||||
np.ndarray: вторичные параметры (фичи)
|
||||
"""
|
||||
pick_filters = self.filters[: self.n_components]
|
||||
x = np.asarray([np.dot(pick_filters, epoch.T) for epoch in x])
|
||||
x = (x**2).mean(axis=1)
|
||||
x -= self.mean
|
||||
x /= self.std
|
||||
return x
|
||||
|
||||
def fit_transform(self, x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
||||
"""Обучение с последующим применением ко входным данным
|
||||
|
||||
Args:
|
||||
x (np.ndarray): отсчеты сигнала (числа)
|
||||
y (np.ndarray): метки класса (любые данные)
|
||||
|
||||
Returns:
|
||||
np.ndarray: вторичные параметры (фичи)
|
||||
"""
|
||||
self.fit(x, y)
|
||||
return self.transform(x)
|
||||
+34
@@ -0,0 +1,34 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_ranges(labels: np.ndarray) -> list[list[int, int, int]]:
|
||||
"""Преобразование сырых лейблов временного ряда в отрезки для визуализации
|
||||
|
||||
Args:
|
||||
labels (np.ndarray): лейблы
|
||||
|
||||
Returns:
|
||||
list[list[int, int, int]]: список отрезков для визуализации, каждый отрезок определен как
|
||||
тройка "начало", "конец", "тип действия"
|
||||
"""
|
||||
|
||||
actions = []
|
||||
__temp_action = [0]
|
||||
__prev_action = labels[0]
|
||||
idx = 1
|
||||
|
||||
while True:
|
||||
if idx + 1 > len(labels):
|
||||
__temp_action.extend([idx, __prev_action])
|
||||
actions.append(__temp_action)
|
||||
break
|
||||
|
||||
if labels[idx] != __prev_action:
|
||||
__temp_action.extend([idx - 1, __prev_action])
|
||||
actions.append(__temp_action)
|
||||
__temp_action = [idx]
|
||||
__prev_action = labels[idx]
|
||||
|
||||
idx += 1
|
||||
|
||||
return actions
|
||||
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -0,0 +1,70 @@
|
||||
|
||||
=================================================
|
||||
Îïèñàíèå ôîðìàòà csv ôàéëà
|
||||
http://openvibe.inria.fr/csv-file-format-description/
|
||||
Êîäû ñîáûòèé:
|
||||
http://openvibe.inria.fr/stimulation-codes/
|
||||
|
||||
|
||||
|
||||
Ôîðìàò âûâîäà:
|
||||
|
||||
Time:250Hz,
|
||||
Epoch,
|
||||
Channel 1,
|
||||
Channel 2,
|
||||
Channel 3,
|
||||
Channel 4,
|
||||
Channel 5,
|
||||
Channel 6,
|
||||
Channel 7,
|
||||
Channel 8,
|
||||
Channel 9,
|
||||
Channel 10,
|
||||
Channel 11,
|
||||
Channel 12,
|
||||
Channel 13,
|
||||
Channel 14,
|
||||
Channel 15,
|
||||
Channel 16,
|
||||
Event Id,
|
||||
Event Date,
|
||||
|
||||
|
||||
|
||||
Ïðèìåðû ôðàãìåíòîâ çàïèñè îòíîñÿùèåñÿ ê ñîáûòèÿì ïîêàçà ñòðåëîê:
|
||||
|
||||
Ëåâî
|
||||
32.9839999469,257,175321.6875000000,161476.5781250000,157384.8906250000,165382.3125000000,182353.1093750000,163875.0781250000,160708.5000000000,162043.2968750000,150507.6562500000,177546.5781250000,156140.7187500000,132061.4531250000,152687.8437500000,153717.4062500000,163318.2500000000,154924.4218750000,
|
||||
770:770,
|
||||
32.9874541578:32.9876184834,
|
||||
0.0000000000:0.0000000000
|
||||
32.9879999468,257,175311.5468750000,161454.5156250000,157379.3906250000,165363.1093750000,182351.3593750000,163863.5000000000,160706.2656250000,162040.2187500000,150504.3281250000,177536.3281250000,156103.6406250000,132053.7187500000,152655.0156250000,153707.7500000000,163314.8906250000,154920.7187500000,,,
|
||||
32.9919999468,257,175305.4843750000,161444.4375000000,157375.6718750000,165352.2968750000,182349.0781250000,163856.1250000000,160703.7343750000,162037.6875000000,150501.0468750000,177530.1250000000,156083.9531250000,132048.3281250000,152644.7812500000,153701.3906250000,163310.8281250000,154917.0468750000,,,
|
||||
32.9959999467,257,175315.4062500000,161463.2812500000,157381.4062500000,165367.3437500000,182351.8906250000,163865.5000000000,160707.1562500000,162041.9375000000,150504.0156250000,177539.1562500000,156111.0312500000,132055.3593750000,152673.0468750000,153708.9375000000,163314.3750000000,154920.6718750000,
|
||||
770,
|
||||
32.9962524336,
|
||||
0.0000000000
|
||||
|
||||
Ïðàâî
|
||||
60.9639999014,476,175538.3437500000,162268.3906250000,157672.1093750000,165778.6562500000,182674.5156250000,165550.2656250000,161696.2656250000,162263.4062500000,151121.2812500000,178085.1250000000,156655.3437500000,132842.8750000000,153574.0000000000,154055.8593750000,163721.6718750000,154574.9375000000,
|
||||
769:769,
|
||||
60.9675073745:60.9675452565,
|
||||
0.0000000000:0.0000000000
|
||||
60.9679999014,476,175526.0312500000,162246.7031250000,157664.8125000000,165758.0937500000,182672.2187500000,165538.8906250000,161693.2500000000,162259.3750000000,151118.2187500000,178074.6406250000,156616.5000000000,132834.6093750000,153542.0781250000,154045.5781250000,163716.9375000000,154570.5625000000,
|
||||
769,
|
||||
60.9707477945,
|
||||
0.0000000000
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,70 @@
|
||||
|
||||
=================================================
|
||||
Îïèñàíèå ôîðìàòà csv ôàéëà
|
||||
http://openvibe.inria.fr/csv-file-format-description/
|
||||
Êîäû ñîáûòèé:
|
||||
http://openvibe.inria.fr/stimulation-codes/
|
||||
|
||||
|
||||
|
||||
Ôîðìàò âûâîäà:
|
||||
|
||||
Time:250Hz,
|
||||
Epoch,
|
||||
Channel 1,
|
||||
Channel 2,
|
||||
Channel 3,
|
||||
Channel 4,
|
||||
Channel 5,
|
||||
Channel 6,
|
||||
Channel 7,
|
||||
Channel 8,
|
||||
Channel 9,
|
||||
Channel 10,
|
||||
Channel 11,
|
||||
Channel 12,
|
||||
Channel 13,
|
||||
Channel 14,
|
||||
Channel 15,
|
||||
Channel 16,
|
||||
Event Id,
|
||||
Event Date,
|
||||
|
||||
|
||||
|
||||
Ïðèìåðû ôðàãìåíòîâ çàïèñè îòíîñÿùèåñÿ ê ñîáûòèÿì ïîêàçà ñòðåëîê:
|
||||
|
||||
Ëåâî
|
||||
32.9839999469,257,175321.6875000000,161476.5781250000,157384.8906250000,165382.3125000000,182353.1093750000,163875.0781250000,160708.5000000000,162043.2968750000,150507.6562500000,177546.5781250000,156140.7187500000,132061.4531250000,152687.8437500000,153717.4062500000,163318.2500000000,154924.4218750000,
|
||||
770:770,
|
||||
32.9874541578:32.9876184834,
|
||||
0.0000000000:0.0000000000
|
||||
32.9879999468,257,175311.5468750000,161454.5156250000,157379.3906250000,165363.1093750000,182351.3593750000,163863.5000000000,160706.2656250000,162040.2187500000,150504.3281250000,177536.3281250000,156103.6406250000,132053.7187500000,152655.0156250000,153707.7500000000,163314.8906250000,154920.7187500000,,,
|
||||
32.9919999468,257,175305.4843750000,161444.4375000000,157375.6718750000,165352.2968750000,182349.0781250000,163856.1250000000,160703.7343750000,162037.6875000000,150501.0468750000,177530.1250000000,156083.9531250000,132048.3281250000,152644.7812500000,153701.3906250000,163310.8281250000,154917.0468750000,,,
|
||||
32.9959999467,257,175315.4062500000,161463.2812500000,157381.4062500000,165367.3437500000,182351.8906250000,163865.5000000000,160707.1562500000,162041.9375000000,150504.0156250000,177539.1562500000,156111.0312500000,132055.3593750000,152673.0468750000,153708.9375000000,163314.3750000000,154920.6718750000,
|
||||
770,
|
||||
32.9962524336,
|
||||
0.0000000000
|
||||
|
||||
Ïðàâî
|
||||
60.9639999014,476,175538.3437500000,162268.3906250000,157672.1093750000,165778.6562500000,182674.5156250000,165550.2656250000,161696.2656250000,162263.4062500000,151121.2812500000,178085.1250000000,156655.3437500000,132842.8750000000,153574.0000000000,154055.8593750000,163721.6718750000,154574.9375000000,
|
||||
769:769,
|
||||
60.9675073745:60.9675452565,
|
||||
0.0000000000:0.0000000000
|
||||
60.9679999014,476,175526.0312500000,162246.7031250000,157664.8125000000,165758.0937500000,182672.2187500000,165538.8906250000,161693.2500000000,162259.3750000000,151118.2187500000,178074.6406250000,156616.5000000000,132834.6093750000,153542.0781250000,154045.5781250000,163716.9375000000,154570.5625000000,
|
||||
769,
|
||||
60.9707477945,
|
||||
0.0000000000
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,4 @@
|
||||
1. В данных существуюет небольшие выбросы по времени предъявления стимула (лево/право) см. pdf. В программе есть вариации по длительности предъявления или это ошибка чтения? >90% стимулов - 5 с. однако, 1-3 стимула (в среднем) длиннее - до 7 с. (см. pdf).
|
||||
2. Какие электроды берем для эксперемента с 3мя? Номера в нотации 'Channel #'
|
||||
3. Если эпохи по 4 с после предъявления сигнала (с учётом паузы в 0,5 с), то через какой шаг смещаемся внутри времени предъявления стимула? Через каждый отсчет - не будет онлайна (не успеем вычислить фильтры и сеть со скоростью 1/250с). Пока сделал через 0,5 с (принцип скользящего окна - когда праый край касается границы предъявления стимула прекращаем скольжение, см. pdf).
|
||||
4. Реализация CSP может отличаться.
|
||||
Reference in New Issue
Block a user