import pytorch_lightning as pl import torch class NetProvider(pl.LightningModule): def __init__(self, net: torch.nn.Module, lr: float, criteria: torch.nn.Module, metrics: list[callable]) -> None: """Базовая обертка для сетей Args: net (torch.nn.Module): сеть - модуль pyTorch lr (float): темп обучения criteria (torch.nn.Module): функция ошибки metrics (list[callable]): список метрик - функций, принимающих (preds: np.ndarray, true_labels: np.ndarray) """ super().__init__() self.lr = lr self.criteria = criteria self.net = net self.metrics = metrics def forward(self, x): return self.net(x) def configure_optimizers(self): return torch.optim.AdamW(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): data, labels = batch preds = self.forward(data) loss = self.criteria(preds, labels) self.log("train/loss", loss, on_step=True, on_epoch=True) for metric in self.metrics: self.log( f"train/{metric.__name__}", metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()), on_step=True, on_epoch=True, ) return loss def validation_step(self, batch, batch_idx): data, labels = batch preds = self.forward(data) loss = self.criteria(preds, labels) self.log("val/loss", loss, on_step=True, on_epoch=True) for metric in self.metrics: self.log( f"val/{metric.__name__}", metric(preds.cpu().detach().numpy(), labels.cpu().detach().numpy()), on_step=True, on_epoch=True, ) return loss