import torch
from torch_scatter import scatter


class SupervisedTrainer:
    def __init__(self, accum):
        self.best_loss = 1.e8
        self.patience = 0
        self.accum = accum

    def train(self, dataloader, model, optimizer, device):
        model.train()

        train_losses = 0.
        num_graphs = 0
        optimizer.zero_grad()
        for i, data in enumerate(dataloader):
            data = data.to(device)

            pred = model(data)
            loss = scatter((pred - data.radius) ** 2, data.batch, dim=0, reduce='mean').mean()

            train_losses += loss.detach() * data.num_graphs
            num_graphs += data.num_graphs

            loss = loss / self.accum
            loss.backward()
            if (i + 1) % self.accum == 0 or (i + 1) == len(dataloader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0, error_if_nonfinite=True)
                optimizer.step()
                optimizer.zero_grad()

        return train_losses / num_graphs

    @torch.no_grad()
    def eval(self, dataloader, model, device):
        model.eval()

        val_losses = 0.
        num_graphs = 0
        for i, data in enumerate(dataloader):
            data = data.to(device)
            pred = model(data)
            loss = scatter((pred - data.radius) ** 2, data.batch, dim=0, reduce='mean')

            val_losses += loss.sum()
            num_graphs += data.num_graphs

        return val_losses / num_graphs


class ProbsTrainer:
    def __init__(self, accum):
        self.best_loss = 1.e8
        self.patience = 0
        self.accum = accum

    def train(self, dataloader, model, optimizer, device):
        model.train()

        train_losses = 0.
        num_graphs = 0
        optimizer.zero_grad()
        for i, data in enumerate(dataloader):
            data = data.to(device)

            l1, l2, l3 = model(data)
            loss = (l1 + l2 + l3).mean()
            train_losses += loss.detach() * data.num_graphs
            num_graphs += data.num_graphs

            loss = loss / self.accum
            loss.backward()
            if (i + 1) % self.accum == 0 or (i + 1) == len(dataloader):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0, error_if_nonfinite=True)
                optimizer.step()
                optimizer.zero_grad()

        return train_losses / num_graphs

    @torch.no_grad()
    def eval(self, dataloader, model, device):
        model.eval()

        open_losses = 0.
        tran_losses = 0.
        num_graphs = 0
        for i, data in enumerate(dataloader):
            data = data.to(device)
            l1, l2, l3 = model(data)
            open_losses += (l1 + l2).sum()
            tran_losses += l3.sum()
            num_graphs += data.num_graphs

        open_losses = open_losses / num_graphs
        tran_losses = tran_losses / num_graphs
        return open_losses.item(), tran_losses.item(), open_losses.item() + tran_losses.item()
