24 lines
635 B
Python
24 lines
635 B
Python
import numpy as np
|
|
import torch
|
|
|
|
|
|
class TableDataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
|
|
"""Простой датасет из объектов numpy
|
|
|
|
Args:
|
|
x (np.ndarray): обучающие объекты
|
|
y (np.ndarray): метки
|
|
"""
|
|
super().__init__()
|
|
self.map_labels = None
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def __len__(self):
|
|
return len(self.x)
|
|
|
|
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.y[idx])
|