import torch
import torch.nn as nn
from torch import Tensor
import pytorch_lightning as pl
import torchvision

from convexrobust.utils import torch_utils as TU

from collections import defaultdict
from abc import abstractmethod
from enum import Enum
from typing import Dict


class Norm(Enum):
    L1 = 1
    L2 = 2
    LInf = float('inf')


class Certificate:
    radius: Dict[Norm, float]

    def __init__(self, radius: Dict[Norm, float]):
        self.radius = defaultdict(int, radius)

    @classmethod
    def zero(cls):
        return cls({})

    @classmethod
    def from_l1(cls, radius, dim):
        return cls({
            Norm.L1: radius,
            Norm.L2: radius * TU.norm_ball_conversion_factor(2, 1, dim),
            Norm.LInf: radius * TU.norm_ball_conversion_factor(float('inf'), 1, dim),
        })

    @classmethod
    def from_l2(cls, radius, dim):
        return cls({
            Norm.L1: radius * TU.norm_ball_conversion_factor(1, 2, dim),
            Norm.L2: radius,
            Norm.LInf: radius * TU.norm_ball_conversion_factor(float('inf'), 2, dim),
        })

    @classmethod
    def from_linf(cls, radius, dim):
        return cls({
            Norm.L1: radius * TU.norm_ball_conversion_factor(1, float('inf'), dim),
            Norm.L2: radius * TU.norm_ball_conversion_factor(2, float('inf'), dim),
            Norm.LInf: radius,
        })


class BaseCertifiable(pl.LightningModule):
    def __init__(self, datamodule=None, single_logit=False, custom_loss=None):
        super().__init__()

        # If single_logit, positive logit is class 1, negative is class 2

        self.datamodule = datamodule
        self.single_logit = single_logit
        if custom_loss is None:
            self.loss_func = nn.BCEWithLogitsLoss() if single_logit else nn.CrossEntropyLoss()
        else:
            self.loss_func = custom_loss

        self.class_balance = nn.Parameter(torch.tensor(0.0), requires_grad=False)

    @abstractmethod
    def forward(self, x: Tensor) -> Tensor:  # type:ignore
        pass

    def forward_balanced(self, x: Tensor) -> Tensor:
        if self.single_logit:
            return self.forward(x) + self.class_balance
        else:
            balance = torch.tensor([[0.5, -0.5]]).to(TU.device()) * self.class_balance
            return self.forward(x) + balance.repeat(x.shape[0], 1)

    def predict(self, x: Tensor) -> Tensor:
        if self.single_logit:
            return (self.forward_balanced(x) <= 0).long()
        else:
            return self.forward_balanced(x).argmax(dim=1)

    @abstractmethod
    def certify(self, x: Tensor) -> Certificate:
        pass

    def extra_loss(self, signal, target):
        return torch.tensor(0.0)

    def training_signal_modify(self, signal):
        return signal

    def training_step(self, batch, batch_idx):
        assert self.training

        if self.current_epoch == 0 and batch_idx == 0:
            if batch[0].shape[1] > 2:
                self.log_images(batch, 'train')
            self.log_attributes()

        return self.compute_losses(batch, 'loss')

    def validation_step(self, batch, batch_idx):
        assert not self.training

        if self.current_epoch == 0 and batch_idx == 0:
            if batch[0].shape[1] > 2:
                self.log_images(batch, 'val')

        with torch.no_grad():
            return self.compute_losses(batch, 'val_loss')

    def compute_losses(self, batch, loss_string):
        signal, target = batch[0], batch[1]

        # Compute the predictions on the original signal and the (potentially modified) signal
        # The modified prediction is actually used for training but the other
        # is still worth logging.
        signal_tilde = self.training_signal_modify(signal)

        preds = self.forward(signal)
        preds_tilde = self.forward(signal_tilde)

        vars = {}
        vars['target'] = target
        vars['preds'] = preds.detach()
        vars['preds_tilde'] = preds_tilde.detach()

        if self.single_logit:
            target = target.float()

        # Stability training
        # pred_x = torch.distributions.Categorical(logits=preds)
        # pred_x_tilde = torch.distributions.Categorical(logits=preds_tilde)
        # stability_loss = torch.distributions.kl_divergence(pred_x, pred_x_tilde)
        # if torch.isinf(stability_loss.mean()):
            # import pdb; pdb.set_trace()
        # vars['classification_loss'] = \
            # (-pred_x.log_prob(target) + 6.0 * stability_loss).mean()

        vars['classification_loss'] = self.loss_func(
            (-1 if self.single_logit else 1) * preds_tilde, target
        )
        # We have sign opposite convention of BCEWithLogitsLoss logits, so negating

        vars['classification_loss_unmodified'] = self.loss_func(
            (-1 if self.single_logit else 1) * preds, target
        ).detach()

        extra_loss = self.extra_loss(signal, target)
        vars['extra_loss'] = extra_loss.detach()
        vars['loss'] = vars['classification_loss'] + extra_loss
        vars['classification_loss'] = vars['classification_loss'].detach()
        vars[loss_string] = vars['loss']

        # For the model checkpointing
        self.log(loss_string, vars[loss_string])

        return vars

    def training_epoch_end(self, outputs):
        self.log_outputs(outputs, 'Train')

        self.logger.experiment.flush()

    def validation_epoch_end(self, outputs):
        for epoch_out in outputs:
            epoch_out['loss'] = epoch_out.pop('val_loss')

        self.log_outputs(outputs, 'Valid')

    def log_outputs(self, outputs, type_string):
        experiment = self.logger.experiment
        losses = self.calc_means(outputs, 'loss')
        experiment.add_scalars(f'Loss/{type_string}', losses, self.current_epoch)

        # Can't average accuracies b/c uneven number of samples from each class in batch
        preds = torch.cat([r['preds'] for r in outputs])
        preds_tilde = torch.cat([r['preds_tilde'] for r in outputs])
        target = torch.cat([r['target'] for r in outputs])

        if self.single_logit:
            preds = torch.stack([preds, -preds], dim=1)
            preds_tilde = torch.stack([preds_tilde, -preds_tilde], dim=1)

        accs = {'acc_composite': self.compute_accuracy(preds_tilde, target)}
        accs_um = {'acc_composite_unmodified': self.compute_accuracy(preds, target)}
        for i in [0, 1]:
            relevant_classes = target == i
            accs[f'acc_{i}'] = self.compute_accuracy(preds_tilde[relevant_classes],
                                                     target[relevant_classes])
            accs_um[f'acc_{i}_unmodified'] = self.compute_accuracy(preds[relevant_classes],
                                                                   target[relevant_classes])

        experiment.add_scalars(f'Accuracies/{type_string}', accs, self.current_epoch)
        experiment.add_scalars(f'Accuracies Unmodified/{type_string}', accs_um, self.current_epoch)

        experiment.add_scalar(f'Pred std/{type_string}', preds.std(), self.current_epoch)
        experiment.add_scalar(f'Pred modified std/{type_string}',
                               preds_tilde.std(), self.current_epoch)

    def calc_means(self, outputs, in_key):
        def calc_mean(outputs, key):
            return torch.stack([x[key] for x in outputs]).mean()

        return {key: calc_mean(outputs, key) for key in outputs[0].keys() if in_key in key}

    def compute_accuracy(self, pred, target):
        _, pred_class = torch.max(pred, 1)
        return (pred_class == target).float().mean()

    def log_images(self, batch, stage):
        experiment = self.logger.experiment

        signal, target = batch[0].clone(), batch[1].clone()
        signal_modify = self.training_signal_modify(signal).clamp(0, 1)

        channel_n, tag_size = signal.shape[1], 4

        for i in range(signal.shape[0]):
            tag = torch.ones(channel_n, tag_size, tag_size).to(TU.device()) * target[i]
            signal[i][:, 0:tag_size, 0:tag_size] = tag
            signal_modify[i][:, 0:tag_size, 0:tag_size] = tag

        grid = torchvision.utils.make_grid(signal)
        grid_modify = torchvision.utils.make_grid(signal_modify)

        experiment.add_image(f'{stage}/Raw', grid)
        experiment.add_image(f'{stage}/Modify', grid_modify)

    def log_attributes(self):
        experiment = self.logger.experiment

        attributes = self.__dict__
        attributes = {k:v for k,v in attributes.items() if not k.startswith('_')}
        experiment.add_text('attributes', str(attributes), 0)
        if 'noise' in attributes.keys():
            noise_attributes = self.noise.__dict__
            experiment.add_text('noise_attributes', str(noise_attributes), 0)
