Files
eeg-vgtu/core/nn/data.py
T

24 lines
635 B
Python

import numpy as np
import torch
class TableDataset(torch.utils.data.Dataset):
def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
"""Простой датасет из объектов numpy
Args:
x (np.ndarray): обучающие объекты
y (np.ndarray): метки
"""
super().__init__()
self.map_labels = None
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.y[idx])