import torch
import torch.nn as nn
import numpy as np
from typing import Union, List, Optional, Iterable
from torch.utils.data import Dataset, DataLoader
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from trainer.callbacks import CallbackList, Callback
from trainer.metrics import BaseMetric, MetricsSummary
from trainer.trainers import BaseTrainer


class ForwardTrainer(BaseTrainer):
    def __init__(self, net, loss_fn,
                 optimizer: Optional[Optimizer] = None,
                 input_indices: Optional[Iterable[int]] = 0,
                 label_indices: Optional[Iterable[int]] = 1,
                 callbacks: List[Callback] = None,
                 batch_metrics: List[BaseMetric] = None,
                 epoch_metrics: List[BaseMetric] = None,
                 device: Union[str, torch.device, list] = None,
                 parallel_dim: Union[None, int] = None,
                 ):
        super(ForwardTrainer, self).__init__(net, callbacks, batch_metrics, epoch_metrics, device, parallel_dim)
        self.loss_fn = loss_fn
        self.input_indices = input_indices
        self.label_indices = label_indices
        self.optimizer = optimizer

        if isinstance(self.loss_fn, nn.Module):
            self.loss_fn.to(self.device)

    def train_on_batch(self, batch):
        input = self.select_batch(batch, self.input_indices)
        label = self.select_batch(batch, self.label_indices)

        output = self.net(*input)
        loss = self.loss_fn(output, *label)

        self.optimizer.zero_grad()
        loss.backward()

        self.callbacks.on_backward_end(None)

        self.optimizer.step()

        return {'loss': loss,
                'output': output,
                'target': label[0]}

    def eval_on_batch(self, batch):
        input = self.select_batch(batch, self.input_indices)
        label = self.select_batch(batch, self.label_indices)
        output = self.net(*input)
        loss = self.loss_fn(output, *label)
        return {'loss': loss,
                'output': output,
                'target': label[0]}

    def predict_on_batch(self, batch):
        input = self.select_batch(batch, self.input_indices)
        output = self.net(*input)
        return output


