init commit: загрузка и предобработка данных
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class TableDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, x: np.ndarray, y: np.ndarray) -> None:
|
||||
"""Простой датасет из объектов numpy
|
||||
|
||||
Args:
|
||||
x (np.ndarray): обучающие объекты
|
||||
y (np.ndarray): метки
|
||||
"""
|
||||
super().__init__()
|
||||
self.map_labels = None
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
def __len__(self):
|
||||
return len(self.x)
|
||||
|
||||
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return torch.FloatTensor(self.x[idx]), torch.FloatTensor(self.y[idx])
|
||||
@@ -0,0 +1,53 @@
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class NetProvider(pl.LightningModule):
|
||||
def __init__(self, net: torch.nn.Module, lr: float, criteria: torch.nn.Module, metrics: list[callable]) -> None:
|
||||
"""Базовая обертка для сетей
|
||||
|
||||
Args:
|
||||
net (torch.nn.Module): сеть - модуль pyTorch
|
||||
lr (float): темп обучения
|
||||
criteria (torch.nn.Module): функция ошибки
|
||||
metrics (list[callable]): список метрик - функций, принимающих (preds: np.ndarray, true_labels: np.ndarray)
|
||||
"""
|
||||
super().__init__()
|
||||
self.lr = lr
|
||||
self.criteria = criteria
|
||||
self.net = net
|
||||
self.metrics = metrics
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.AdamW(self.parameters(), lr=self.lr)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
data, labels = batch
|
||||
preds = self.forward(data)
|
||||
loss = self.criteria(preds, labels)
|
||||
self.log("train/loss", loss, on_step=True, on_epoch=True)
|
||||
for metric in self.metrics:
|
||||
self.log(
|
||||
f"train/{metric.__name__}",
|
||||
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
data, labels = batch
|
||||
preds = self.forward(data)
|
||||
loss = self.criteria(preds, labels)
|
||||
self.log("val/loss", loss, on_step=True, on_epoch=True)
|
||||
for metric in self.metrics:
|
||||
self.log(
|
||||
f"val/{metric.__name__}",
|
||||
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
|
||||
on_step=True,
|
||||
on_epoch=True,
|
||||
)
|
||||
return loss
|
||||
Reference in New Issue
Block a user