import torch
import torch.nn as nn
from typing import Union, List, Optional, Iterable
from torch.optim import Optimizer
from src.verify.trainer.callbacks import Callback
from src.verify.trainer.metrics import BaseMetric
from src.verify.trainer.trainers import BaseTrainer


class ForwardTrainer(BaseTrainer):
    def __init__(self, net, loss_fn,
                 optimizer: Optional[Optimizer] = None,
                 input_indices: Optional[Iterable[int]] = 0,
                 label_indices: Optional[Iterable[int]] = 1,
                 callbacks: List[Callback] = None,
                 batch_metrics: List[BaseMetric] = None,
                 epoch_metrics: List[BaseMetric] = None,
                 device: Union[str, torch.device, list] = None,
                 parallel_dim: Union[None, int] = None,
                 ):
        super(ForwardTrainer, self).__init__(net, callbacks, batch_metrics, epoch_metrics, device, parallel_dim)
        self.loss_fn = loss_fn
        self.input_indices = input_indices
        self.label_indices = label_indices
        self.optimizer = optimizer

        if isinstance(self.loss_fn, nn.Module):
            self.loss_fn.to(self.device)

    def train_on_batch(self, batch):
        input = self.select_batch(batch, self.input_indices)
        label = self.select_batch(batch, self.label_indices)

        output = self.net(*input)
        loss = self.loss_fn(output, *label)

        self.optimizer.zero_grad()
        loss.backward()

        self.callbacks.on_backward_end(None)

        self.optimizer.step()

        return {'loss': loss,
                'output': output,
                'target': label[0]}

    def eval_on_batch(self, batch):
        input = self.select_batch(batch, self.input_indices)
        label = self.select_batch(batch, self.label_indices)
        output = self.net(*input)
        loss = self.loss_fn(output, *label)
        return {'loss': loss,
                'output': output,
                'target': label[0]}

    def predict_on_batch(self, batch):
        input = self.select_batch(batch, self.input_indices)
        output = self.net(*input)
        return output


