import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from collections import defaultdict
from torch.nn import CrossEntropyLoss

import openood.utils.comm as comm
from openood.utils import Config, get_config_default
from openood.datasets.utils import clone_dataloader

from .diversity import subset_batch
from .diversity import ADPDiversityLoss, NCLDiversityLoss
from .diversity import InputGradientDiversityLoss, GradCAMDiversityLoss
from .diversity import DiceDiversityLoss, OEDiversityLoss
from .lr_scheduler import cosine_annealing
from .logitnorm_trainer import LogitNormLoss


def non_diag(a):
    """Get non-diagonal elements of matrices.

    Args:
        a: Matrices tensor with shape (..., N, N).

    Returns:
        Non-diagonal elements with shape (..., N, N - 1).
    """
    n = a.shape[-1]
    prefix = list(a.shape)[:-2]
    return a.reshape(*(prefix + [n * n]))[..., :-1].reshape(*(prefix + [n - 1, n + 1]))[..., 1:].reshape(*(prefix + [n, n - 1]))


class AMPOptimizerWrapper:
    def __init__(self, optimizer, disable=False):
        self.optimizer = optimizer
        self.params = sum([group["params"] for group in optimizer.param_groups], [])
        self.disable = disable
        if not disable:
            self.scaler = torch.cuda.amp.GradScaler()

    @property
    def lr(self):
        return self.optimizer.param_groups[0]["lr"]

    def zero_grad(self, set_to_none=False):
        self.optimizer.zero_grad(set_to_none=set_to_none)

    def backward(self, loss):
        if self.disable:
            loss.backward()
        else:
            self.scaler.scale(loss).backward()

    def clip_grad_norm(self, max_norm):
        if not self.disable:
            self.scaler.unscale_(self.optimizer)
        return torch.nn.utils.clip_grad_norm_(self.params, max_norm)

    def step(self):
        if self.disable:
            self.optimizer.step()
        else:
            self.scaler.step(self.optimizer)
            self.scaler.update()


class DEEDTrainer:
    METRIC_SMOOTHING = 0.8
    CLS_LOSSES = {
        "ce": CrossEntropyLoss,
        "logitnorm": LogitNormLoss
    }

    def __init__(self, net, train_loader, config):
        # Parse args.
        self.net = net
        self.base_net = self.net.module if isinstance(self.net, torch.nn.parallel.DistributedDataParallel) else self.net
        try:
            self.train_loader, self.train_unlabeled_loader = train_loader
        except ValueError:
            self.train_loader, self.train_unlabeled_loader = train_loader, None
        self.config = config

        # Setup devices.
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.amp_dtype = {
            "fp16": torch.float16,
            "bfp16": torch.bfloat16,
            "none": None
        }[get_config_default(config.trainer, "amp_dtype", "none")]
        self.use_amp = self.amp_dtype is not None
        if not self.use_amp:
            # Workaround for PyTorch checks.
            self.amp_dtype = torch.bfloat16 if self.device == "cpu" else torch.float16

        # Setup loaders and optimizers.
        if not self.config.trainer.same_batch:
            # Use different batches for different models.
            self.train_loaders = [self.train_loader] + [clone_dataloader(self.train_loader)
                                                        for _ in range(self.base_net.num_models - 1)]

        self.model_parameters = [p for name, p in self.base_net.named_parameters() if not name.startswith("discriminator.")]
        self.discriminator_parameters = [p for name, p in self.base_net.named_parameters() if name.startswith("discriminator.")]

        self.optimizer = torch.optim.SGD(
            self.model_parameters,
            config.optimizer.lr,
            momentum=config.optimizer.momentum,
            weight_decay=config.optimizer.weight_decay,
            nesterov=True
        )
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lr_lambda=lambda step: cosine_annealing(
                step,
                config.optimizer.num_epochs * len(self.train_loader),
                1,
                1e-6 / config.optimizer.lr,
            ),
        )
        self.optimizer = AMPOptimizerWrapper(self.optimizer, disable=not self.use_amp)

        if self.base_net.discriminator is not None:
            assert self.discriminator_parameters
            self.discriminator_optimizer = torch.optim.RMSprop(
                self.discriminator_parameters,
                config.optimizer.discriminator_lr,
                momentum=config.optimizer.momentum,
                weight_decay=config.optimizer.weight_decay
            )
            self.discriminator_optimizer = AMPOptimizerWrapper(self.discriminator_optimizer, disable=not self.use_amp)

        # Setup losses.
        self.cls_loss = self.CLS_LOSSES[get_config_default(self.config.loss, "cls_loss", "ce")]()

        need_input_gradients = False
        losses = {}
        if get_config_default(config.loss, "inpgrad", 0.0) > 0:
            losses["inpgrad"] = (config.loss.inpgrad, InputGradientDiversityLoss(mode=config.loss.inpgrad_mode))
            need_input_gradients = True
        if get_config_default(config.loss, "ncl", 0.0) > 0:
            losses["ncl"] = (config.loss.ncl, NCLDiversityLoss())
        if get_config_default(config.loss, "adp", 0.0) > 0:
            losses["adp"] = (config.loss.adp, ADPDiversityLoss(alpha=config.loss.adp_alpha, beta=config.loss.adp_beta))
        if get_config_default(config.loss, "gradcam", 0.0) > 0:
            losses["gradcam"] = (config.loss.gradcam, GradCAMDiversityLoss(
                mode=config.loss.gradcam_mode,
                similarity=config.loss.similarity,
                activation=config.loss.activation,
                padding=config.loss.padding))
        self.diversify_ood = get_config_default(config.loss, "diversify_ood", False)
        if self.diversify_ood:
            self.id_losses = {}
            self.ood_losses = losses
            self.need_id_input_gradients = False
            self.need_ood_input_gradients = need_input_gradients
        else:
            self.id_losses = losses
            self.ood_losses = {}
            self.need_id_input_gradients = need_input_gradients
            self.need_ood_input_gradients = False
        if get_config_default(config.loss, "dice", 0.0) > 0:
            if self.base_net.discriminator is None:
                raise ValueError("Need discriminator for the DICE loss.")
            self.id_losses["dice"] = (config.loss.dice, DiceDiversityLoss(self.base_net.discriminator))
        if get_config_default(config.loss, "oe", 0.0) > 0:
            self.ood_losses["oe"] = (config.loss.oe, OEDiversityLoss())
        if self.ood_losses and (self.train_unlabeled_loader is None):
            raise ValueError(f"Need OOD sample to apply {self.ood_losses.keys()} losses")

    def train_epoch(self, epoch_idx):
        metrics = defaultdict(float)
        self.net.train()
        if self.config.trainer.same_batch:
            train_dataiter = iter(self.train_loader)
        else:
            train_dataiters = [iter(train_loader) for train_loader in self.train_loaders]

        if self.train_unlabeled_loader is not None:
            unlabeled_dataiter = iter(self.train_unlabeled_loader)

        pbar = tqdm(range(1, len(self.train_loader) + 1),
                    position=0,
                    leave=True,
                    disable=not comm.is_main_process())
        for step_idx in pbar:
            if self.config.trainer.same_batch:
                batch = next(train_dataiter)
                data = [batch['data'].to(self.device, non_blocking=True)] * self.base_net.num_models
                targets = [batch['label'].to(self.device, non_blocking=True)] * self.base_net.num_models
            else:
                batches = [next(i) for i in train_dataiters]
                data = [b['data'].to(self.device, non_blocking=True) for b in batches]
                targets = [b['label'].to(self.device, non_blocking=True) for b in batches]

            batch_metrics = {}

            apply_diversity = \
                epoch_idx <= self.config.trainer.diversity_loss_first_epochs or \
                epoch_idx >= self.config.optimizer.num_epochs - self.config.trainer.diversity_loss_last_epochs or \
                step_idx % self.config.trainer.diversity_loss_period == 0
            apply_diversity = apply_diversity and (self.id_losses or self.ood_losses)
            batch_metrics["apply_diversity"] = float(bool(apply_diversity))

            ood_data = None
            id_data = None
            if apply_diversity:
                # Either use a data batch or an unlabeled OOD batch.
                if self.ood_losses:
                    try:
                        diversity_batch = next(unlabeled_dataiter)
                    except StopIteration:
                        unlabeled_dataiter = iter(self.train_unlabeled_loader)
                        diversity_batch = next(unlabeled_dataiter)
                    ood_data = diversity_batch['data'].to(self.device).requires_grad_(self.need_ood_input_gradients)
                    ood_target = None
                if self.id_losses:
                    if self.config.trainer.same_batch:
                        id_data = data[0]
                        id_target = targets[0]
                    else:
                        batch_idx = random.randrange(len(batches))
                        id_data = data[batch_idx]
                        id_target = targets[batch_idx]

            data = torch.stack(data).requires_grad_(self.need_id_input_gradients)  # (N, B, C, H, W).
            batch_size = data.shape[1]

            # Forward pass.
            with torch.autocast(device_type=self.device, enabled=self.use_amp, dtype=self.amp_dtype):
                # Simple forward and optionally forward OOD data.
                if ood_data is not None:
                    # Forward OOD images along with ID data for better OE training.
                    ood_expanded = ood_data[None].expand(self.base_net.num_models, -1, -1, -1, -1)  # NBCHW.
                    images = torch.cat([data, ood_expanded], 1)  # NBCHW.
                else:
                    images = data
                _, logits_classifier, inp, features, feature_maps = self.net(images, return_ensemble=True)
                # (N, B, C), (?), (N, B, F), (N, B, C, H, W).
                logits_cls = logits_classifier[:, :batch_size]  # Truncate OOD logits if any.
                logits_ood = logits_classifier[:, batch_size:]  # Might be empty if OOD is unavailable.

                # Forward diversity data.
                if self.config.trainer.same_batch:
                    # Reuse outputs if there is no need for diversity batch forwarding.
                    id_logits_classifier, id_inp, id_features, id_feature_maps = (
                        logits_classifier, inp, features, feature_maps
                    )
                elif (apply_diversity and self.id_losses):
                    id_images = id_data[None].expand(self.base_net.num_models, -1, -1, -1, -1).requires_grad_(self.need_id_input_gradients)
                    self.net.eval()
                    id_agg_logits, id_logits_classifier, id_inp, id_features, id_feature_maps = self.net(id_images,
                                                                                                         return_ensemble=True)
                    self.net.train()
                    if id_target is None:
                        id_target = id_agg_logits.argmax(1)  # (B).
                        assert id_target.ndim == 1

                # Compute losses.
                total_loss = 0
                discriminator_loss = 0
                weight_ce = get_config_default(self.config.loss, "ce", 1.0)
                if weight_ce > 0:
                    loss_ce = sum([self.cls_loss(logits_cls[i], targets[i])
                                   for i in range(self.base_net.num_models)]) / self.base_net.num_models
                    batch_metrics["loss_ce"] = float(loss_ce)
                    loss = loss_ce * weight_ce
                    total_loss = total_loss + loss
                    if get_config_default(self.config.loss, "adversarial", 0) > 0:
                        loss_adv = self._adversarial_loss(images, targets,
                                                          loss_ce, self.config.loss.adversarial,
                                                          batch_stop=batch_size)
                        loss = loss_adv * weight_ce
                        total_loss = total_loss + loss
                        batch_metrics["loss_adv"] = float(loss_adv)

                if apply_diversity:
                    assert (not self.id_losses) or (id_features is None) or (id_features[0].ndim == 2)  # N x (B, D).
                    assert (not self.id_losses) or (id_feature_maps is None) or (id_feature_maps[0].ndim == 4)  # N x (B, C, H, W).
                    self.net.eval()
                    div_loss = 0
                    for name, (weight, loss) in self.id_losses.items():
                        if (name == "dice") and (self.config.loss.dice_ramp_up_epochs > 0):
                            dice_scale = min((epoch_idx - 1) / self.config.loss.dice_ramp_up_epochs, 1)
                            weight = weight * dice_scale
                            batch_metrics["dice_scale"] = dice_scale
                        loss_diversity = loss(id_inp, id_target, id_logits_classifier,
                                              id_features, id_feature_maps, batch_stop=batch_size)
                        div_loss = div_loss + weight * loss_diversity
                        batch_metrics[f"loss_{name}"] = float(loss_diversity)
                    for name, (weight, loss) in self.ood_losses.items():
                        loss_diversity = loss(inp, None, logits_classifier,
                                              features, feature_maps, batch_start=batch_size)
                        div_loss = div_loss + weight * loss_diversity
                        batch_metrics[f"loss_{name}"] = float(loss_diversity)
                    batch_metrics["div_loss"] = float(div_loss)
                    total_loss = total_loss + div_loss
                    batch_metrics["loss"] = float(total_loss)
                    self.net.train()
                    if self.id_losses and (self.base_net.discriminator is not None):
                        discriminator_loss, disc_metrics = self.base_net.discriminator.loss(
                            [subset_batch(f, 0, stop=batch_size).detach() for f in id_features],
                            id_target)
                        batch_metrics.update(disc_metrics)
            pbar_message = f'Epoch: {epoch_idx:03d}'
            # Backward and step.
            if isinstance(total_loss, torch.Tensor):
                pbar_message = pbar_message + f', loss: {total_loss.item():.4f}'
                if (step_idx + 1) % self.config.trainer.gradient_accumulation == 0 or (step_idx + 1) == len(self.train_loader):
                    self.optimizer.zero_grad(set_to_none=True)
                total_loss = total_loss / self.config.trainer.gradient_accumulation
                self.optimizer.backward(total_loss)
                if get_config_default(self.config.optimizer, "grad_clip", 0.0) > 0:
                    self.optimizer.clip_grad_norm(self.config.optimizer.grad_clip)
                if (step_idx + 1) % self.config.trainer.gradient_accumulation == 0 or (step_idx + 1) == len(self.train_loader):
                    self.optimizer.step()
            self.scheduler.step()
            if isinstance(discriminator_loss, torch.Tensor):
                if (step_idx + 1) % self.config.trainer.gradient_accumulation == 0 or (step_idx + 1) == len(self.train_loader):
                    self.discriminator_optimizer.zero_grad(set_to_none=True)
                discriminator_loss = discriminator_loss / self.config.trainer.gradient_accumulation
                self.discriminator_optimizer.backward(discriminator_loss)
                if get_config_default(self.config.optimizer, "grad_clip", 0.0) > 0:
                    self.discriminator_optimizer.clip_grad_norm(self.config.optimizer.grad_clip)
                if (step_idx + 1) % self.config.trainer.gradient_accumulation == 0 or (step_idx + 1) == len(self.train_loader):
                    self.discriminator_optimizer.step()

            with torch.no_grad():
                for k, v in batch_metrics.items():
                    metrics[k] = metrics[k] * self.METRIC_SMOOTHING + v * (1 - self.METRIC_SMOOTHING)

            pbar.set_description(pbar_message)

        metrics = {k: self.save_metrics(v) for k, v in metrics.items()}
        metrics.update({
            'epoch_idx': epoch_idx,
            'lr': self.optimizer.lr
        })

        return self.net, metrics

    def save_metrics(self, loss_avg):
        all_loss = comm.gather(loss_avg)
        total_losses_reduced = np.mean([x for x in all_loss])

        return total_losses_reduced

    def _adversarial_loss(self, inputs, target, ce_loss, eps,
                          batch_start=None, batch_stop=None):
        image_grad = torch.autograd.grad(
            ce_loss, inputs, retain_graph=True, create_graph=True
        )
        inputs_adv = inputs + eps * torch.sign(image_grad[0])  # NBCHW.
        assert inputs_adv.ndim == 5
        inputs_adv = subset_batch(inputs_adv, 1, batch_start, batch_stop)
        _, logits_classifier, _, _, _ = self.net(inputs_adv, return_ensemble=True)
        loss_adv = sum([F.cross_entropy(logits_classifier[i], subset_batch(target[i], 0, batch_start, batch_stop))
                        for i in range(self.base_net.num_models)]) / self.base_net.num_models
        return loss_adv
