from typing import Literal, Optional

import pytorch_lightning as pl
import torch.nn as nn
from torch.nn.functional import pad
from pytorch_lightning.utilities.types import *
import random

from mind_the_pad.model_analysis.visualize_intermediate_results import get_relu_outputs_intermediate
from mind_the_pad.train_mnist.model_full_padding import build_model as full_build_model
from mind_the_pad.train_mnist.model_same_padding import build_model as same_build_model
from mind_the_pad.train_mnist.model_valid_padding import build_model as valid_build_model


def build_model_by_padding_size_mode(padding_size: Literal['valid', 'same', 'full'],
                                     padding_mode: Literal['zeros', 'reflect', 'replicate', 'circular', ''],
                                     output_classes: int):
    assert padding_size in ['valid', 'same', 'full'], padding_size
    assert padding_mode in ['zeros', 'reflect', 'replicate', 'circular', ''], padding_mode
    if padding_size == 'valid':
        return valid_build_model('', output_classes)
    elif padding_size == 'same':
        return same_build_model(padding_mode, output_classes)
    else:
        return full_build_model(padding_mode, output_classes)


class MnistPadding(pl.LightningModule):

    def __init__(self, padding_type: str, padding_mode: str, lr=1e-4, random_pad_input: int = 0,
                 output_classes=26, batch_norm: bool = False, bn_affine: bool = True):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.random_pad_input = random_pad_input
        self.padding_type = padding_type
        self.batch_norm = batch_norm
        self.bn_affine = bn_affine

        self.padding_mode = padding_mode
        self.model = build_model_by_padding_size_mode(padding_type, padding_mode, output_classes)
        if batch_norm:
            layers = []
            it_modules = self.model.modules()
            next(it_modules)  #first module is full network
            for module in it_modules:
                layers.append(module)
                if isinstance(module, nn.Conv2d):
                    layers.append(nn.BatchNorm2d(module.out_channels, affine=bn_affine))
            print(layers)
            self.model = nn.Sequential(*layers)
        self.ce = nn.CrossEntropyLoss()
        self.accuracy = pl.metrics.Accuracy()
        self._pad_mode = 'constant' if self.padding_mode == 'zeros' else self.padding_mode  #for pad function
        # self.running_average_loss = RunningAverage()

    def forward(self, x) -> torch.Tensor:
        if self.random_pad_input > 0:
            padding = (self.random_pad_input, self.random_pad_input, 0, 0) if random.random() > 0.5 else (0,0,self.random_pad_input, self.random_pad_input)
            x = pad(x, padding, mode=self._pad_mode)
        return self.model(x)

    def on_epoch_start(self) -> None:
        self.accuracy.reset()

    def training_step(self, batch, batch_idx) -> dict:
        X, y = batch
        y -= 1
        y_pred = self(X)
        loss_value = self.ce(y_pred, y)
        self.accuracy.update(y_pred.argmax(dim=-1), batch[1])

        # self.running_average_loss.update(loss_value.detach().item())
        return dict(loss=loss_value, y_pred=y_pred, X=X)

    def on_train_batch_end(self, outputs: dict, batch: tuple, batch_idx: int, dataloader_idx: int) -> None:
        self.log('train/loss', outputs['loss'].item(), prog_bar=True, on_step=True)

    def on_validation_batch_end(
        self, outputs, batch: dict, batch_idx: int, dataloader_idx: int
    ) -> None:
        self.log('val/loss', outputs['loss'].item(), on_step=True)

    def training_epoch_end(self, outputs) -> None:
        mean_loss = torch.FloatTensor([o['loss'] for o in outputs]).mean().item()
        self.log('train/epoch_accuracy', self.accuracy.compute(), on_epoch=True, on_step=False)
        self.log('train/epoch_loss', mean_loss, on_epoch=True, on_step=False)
        first_image = outputs[0]['X'][0].unsqueeze(0)
        relu_outputs = get_relu_outputs_intermediate(self, first_image).values()
        for i, relu_output in enumerate(relu_outputs):
            if len(relu_output.shape) == 4: #BCHW
                relu_output = relu_output.squeeze(0).mean(dim=0, keepdim=True)
                self.logger.experiment.add_image(f'relu_outputs/{i}', relu_output, self.current_epoch)

    def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
        return self.training_step(batch, batch_idx)

    def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
        mean_val_loss = torch.FloatTensor([o['loss'] for o in outputs]).mean().item()
        self.log('val/epoch_loss', mean_val_loss)
        self.log('val/epoch_accuracy', self.accuracy.compute())

    def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]:
        return self.training_step(batch, batch_idx)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.lr)

