54 lines
1.9 KiB
Python
54 lines
1.9 KiB
Python
import pytorch_lightning as pl
|
|
import torch
|
|
|
|
|
|
class NetProvider(pl.LightningModule):
|
|
def __init__(self, net: torch.nn.Module, lr: float, criteria: torch.nn.Module, metrics: list[callable]) -> None:
|
|
"""Базовая обертка для сетей
|
|
|
|
Args:
|
|
net (torch.nn.Module): сеть - модуль pyTorch
|
|
lr (float): темп обучения
|
|
criteria (torch.nn.Module): функция ошибки
|
|
metrics (list[callable]): список метрик - функций, принимающих (preds: np.ndarray, true_labels: np.ndarray)
|
|
"""
|
|
super().__init__()
|
|
self.lr = lr
|
|
self.criteria = criteria
|
|
self.net = net
|
|
self.metrics = metrics
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
def configure_optimizers(self):
|
|
return torch.optim.AdamW(self.parameters(), lr=self.lr)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
data, labels = batch
|
|
preds = self.forward(data)
|
|
loss = self.criteria(preds, labels)
|
|
self.log("train/loss", loss, on_step=True, on_epoch=True)
|
|
for metric in self.metrics:
|
|
self.log(
|
|
f"train/{metric.__name__}",
|
|
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
|
|
on_step=True,
|
|
on_epoch=True,
|
|
)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
data, labels = batch
|
|
preds = self.forward(data)
|
|
loss = self.criteria(preds, labels)
|
|
self.log("val/loss", loss, on_step=True, on_epoch=True)
|
|
for metric in self.metrics:
|
|
self.log(
|
|
f"val/{metric.__name__}",
|
|
metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()),
|
|
on_step=True,
|
|
on_epoch=True,
|
|
)
|
|
return loss
|