init commit: загрузка и предобработка данных

This commit is contained in:
xausssr
2024-05-11 12:29:10 +03:00
commit 0a55543ae2
22 changed files with 1986 additions and 0 deletions
View File
+23
View File
@@ -0,0 +1,23 @@
import numpy as np
import torch
class TableDataset(torch.utils.data.Dataset):
def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
"""Простой датасет из объектов numpy
Args:
x (np.ndarray): обучающие объекты
y (np.ndarray): метки
"""
super().__init__()
self.map_labels = None
self.x = x
self.y = y
def __len__(self):
return len(self.x)
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.y[idx])
+53
View File
@@ -0,0 +1,53 @@
import pytorch_lightning as pl
import torch
class NetProvider(pl.LightningModule):
def __init__(self, net: torch.nn.Module, lr: float, criteria: torch.nn.Module, metrics: list[callable]) -> None:
"""Базовая обертка для сетей
Args:
net (torch.nn.Module): сеть - модуль pyTorch
lr (float): темп обучения
criteria (torch.nn.Module): функция ошибки
metrics (list[callable]): список метрик - функций, принимающих (preds: np.ndarray, true_labels: np.ndarray)
"""
super().__init__()
self.lr = lr
self.criteria = criteria
self.net = net
self.metrics = metrics
def forward(self, x):
return self.net(x)
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.lr)
def training_step(self, batch, batch_idx):
data, labels = batch
preds = self.forward(data)
loss = self.criteria(preds, labels)
self.log("train/loss", loss, on_step=True, on_epoch=True)
for metric in self.metrics:
self.log(
f"train/{metric.__name__}",
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
on_step=True,
on_epoch=True,
)
return loss
def validation_step(self, batch, batch_idx):
data, labels = batch
preds = self.forward(data)
loss = self.criteria(preds, labels)
self.log("val/loss", loss, on_step=True, on_epoch=True)
for metric in self.metrics:
self.log(
f"val/{metric.__name__}",
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
on_step=True,
on_epoch=True,
)
return loss