from typing import Optional, Iterable
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from edl.model import UncertaintyOptions, EdlLossOptions
from torch.utils.data import DataLoader
from tqdm import tqdm
import openood.utils.comm as comm
from openood.utils import Config
from .base_trainer import BaseTrainer
from edl import EDL
from torch.distributions import Dirichlet
from .lr_scheduler import cosine_annealing


def dirichlet_kl_divergence(alphas, target_alphas, precision=None, target_precision=None,
                            epsilon=1e-8):
    """
    This function computes the Forward KL divergence between a model Dirichlet distribution
    and a target Dirichlet distribution based on the concentration (alpha) parameters of each.

    :param alphas: Tensor containing concentation parameters of model. Expected shape is batchsize X num_classes.
    :param target_alphas: Tensor containing target concentation parameters. Expected shape is batchsize X num_classes.
    :param precision: Optional argument. Can pass in precision of model. Expected shape is batchsize X 1
    :param target_precision: Optional argument. Can pass in target precision. Expected shape is batchsize X 1
    :param epsilon: Smoothing factor for numercal stability. Default value is 1e-8
    :return: Tensor for Batchsize X 1 of forward KL divergences between target Dirichlet and model
    """
    if not precision:
        precision = torch.sum(alphas, dim=1, keepdim=True)
    if not target_precision:
        target_precision = torch.sum(target_alphas, dim=1, keepdim=True)

    precision_term = torch.lgamma(target_precision) - torch.lgamma(precision)
    try:
        assert torch.all(torch.isfinite(precision_term)).item()
    except AssertionError:
        print('debug')
    alphas_term = torch.sum(torch.lgamma(alphas + epsilon) - torch.lgamma(target_alphas + epsilon)
                            + (target_alphas - alphas) * (torch.digamma(target_alphas + epsilon)
                                                          - torch.digamma(target_precision + epsilon)), dim=1,
                            keepdim=True)
    assert torch.all(torch.isfinite(alphas_term)).item()

    cost = torch.squeeze(precision_term + alphas_term)
    return cost


def dirichlet_reverse_kl_divergence(alphas, target_alphas, precision=None, target_precision=None,
                                    epsilon=1e-8):
    """
    This function computes the Reverse KL divergence between a model Dirichlet distribution
    and a target Dirichlet distribution based on the concentration (alpha) parameters of each.

    :param alphas: Tensor containing concentation parameters of model. Expected shape is batchsize X num_classes.
    :param target_alphas: Tensor containing target concentation parameters. Expected shape is batchsize X num_classes.
    :param precision: Optional argument. Can pass in precision of model. Expected shape is batchsize X 1
    :param target_precision: Optional argument. Can pass in target precision. Expected shape is batchsize X 1
    :param epsilon: Smoothing factor for numercal stability. Default value is 1e-8
    :return: Tensor for Batchsize X 1 of reverse KL divergences between target Dirichlet and model
    """
    return dirichlet_kl_divergence(alphas=target_alphas, target_alphas=alphas,
                                   precision=target_precision,
                                   target_precision=precision, epsilon=epsilon)


class DirichletKLLoss:
    """
    Can be applied to any model which returns logits

    """

    def __init__(self, target_concentration=1e3, concentration=1.0, reverse=True):
        """
        :param target_concentration: The concentration parameter for the
        target class (if provided)
        :param concentration: The 'base' concentration parameters for
        non-target classes.
        """
        self.target_concentration = target_concentration
        self.concentration = concentration
        self.reverse = reverse

    def __call__(self, logits, labels, reduction='mean'):
        alphas = torch.exp(logits)
        return self.forward(alphas, labels, reduction=reduction)

    def forward(self, alphas, labels, reduction='mean'):
        loss = self.compute_loss(alphas, labels)

        if reduction == 'mean':
            return torch.mean(loss)
        elif reduction == 'none':
            return loss
        else:
            raise NotImplementedError

    def compute_loss(self, alphas, labels: Optional[torch.tensor] = None):
        """
        :param alphas: The alpha parameter outputs from the model
        :param labels: Optional. The target labels indicating the correct
        class.

        The loss creates a set of target alpha (concentration) parameters
        with all values set to self.concentration, except for the correct
        class (if provided), which is set to self.target_concentration
        :return: an array of per example loss
        """
        # TODO: Need to make sure this actually works right...
        # todo: so that concentration is either fixed, or on a per-example setup
        # Create array of target (desired) concentration parameters
        target_alphas = torch.ones_like(alphas) * self.concentration
        if labels is not None:
            target_alphas += torch.zeros_like(alphas, device=alphas.device).scatter_(1, labels[:, None], self.target_concentration)

        if self.reverse:
            loss = dirichlet_reverse_kl_divergence(alphas=alphas, target_alphas=target_alphas)
        else:
            loss = dirichlet_kl_divergence(alphas=alphas, target_alphas=target_alphas)
        return loss


class PriorNetMixedLoss:
    def __init__(self, losses, mixing_params: Optional[Iterable[float]]):
        assert isinstance(losses, (list, tuple))
        assert isinstance(mixing_params, (list, tuple, np.ndarray))
        assert len(losses) == len(mixing_params)

        self.losses = losses
        if mixing_params is not None:
            self.mixing_params = mixing_params
        else:
            self.mixing_params = [1.] * len(self.losses)

    def __call__(self, logits_list, labels_list):
        return self.forward(logits_list, labels_list)

    def forward(self, logits_list, labels_list):
        total_loss = []
        target_concentration = 0.0
        for i, loss in enumerate(self.losses):
            if loss.target_concentration > target_concentration:
                target_concentration = loss.target_concentration
            weighted_loss = (loss(logits_list[i], labels_list[i])
                             * self.mixing_params[i])
            total_loss.append(weighted_loss)
        total_loss = torch.stack(total_loss, dim=0)
        # Normalize by target concentration, so that loss  magnitude is constant wrt lr and other losses
        return torch.sum(total_loss) / target_concentration


def compute_accuracy(logits, target):
    _, predicted = torch.max(logits, 1)
    correct = (predicted == target).sum().item()
    acc = correct / target.size(0)
    return acc


def mixing(data, index, lam):
    return lam * data + (1 - lam) * data[index]



def prepare_mixup(batch, alpha=10.0, beta=1.0, use_cuda=True):
    """Returns mixed inputs, pairs of targets, and lambda."""
    if alpha > 0:
        lam = np.random.beta(alpha, beta)
    else:
        lam = 1

    batch_size = batch['data'].size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    return index, lam

class EDLTrainer(BaseTrainer):
    def __init__(self, net: nn.Module, train_loader: DataLoader,
                 config: Config) -> None:

        super(EDLTrainer, self).__init__(net, train_loader, config)

        trainer_args = config.trainer.trainer_args
        evi_fn = trainer_args.evi_fn
        prior_fn = trainer_args.prior_fn
        loss_fn = trainer_args.loss_fn
        self.num_classes = config.network.num_classes
        self.trainer_args = trainer_args
        self.edl_utils = EDL(K=self.num_classes, evi_fn=evi_fn, prior_fn=prior_fn, loss_fn=loss_fn, ngd=False)
        if config.optimizer.name == 'sgd':
            self.optimizer = torch.optim.SGD(
                net.parameters(),
                config.optimizer.lr,
                momentum=config.optimizer.momentum,
                weight_decay=config.optimizer.weight_decay,
                nesterov=True,
            )
        elif config.optimizer.name == 'adam':
            self.optimizer = torch.optim.Adam(
                net.parameters(),
                config.optimizer.lr,
                # betas=(0.9, 0.999),
                # eps=1e-08,
                weight_decay=config.optimizer.weight_decay,
            )
        elif config.optimizer.name == 'lbfgs':
            self.optimizer = torch.optim.LBFGS(
                net.parameters(),
                config.optimizer.lr,
                history_size=10,
                line_search_fn='strong_wolfe'
            )
        elif config.optimizer.name == 'adamw':
            self.optimizer = torch.optim.AdamW(net.parameters(),
                                               lr=config.optimizer.lr,
                                               weight_decay=config.optimizer.weight_decay)
        self.clip_grad_norm = config.optimizer.get('clip_grad_norm', False)
        accumulation_steps = self.config.optimizer.accumulation_steps
        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer,
            lr_lambda=lambda step: cosine_annealing(
                step,
                max(config.optimizer.num_epochs, 1) * len(train_loader)//accumulation_steps,
                1,
                5e-6 / config.optimizer.lr,
            ),
        )
        self.output_dir = config.output_dir
        print(self.output_dir)
        self.writer = SummaryWriter(log_dir=self.output_dir)
        if self.trainer_args.method == 'daedl':
            self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.95 ** epoch)

    def train_epoch(self, epoch_idx):
        self.net.train()
        loss_avg = 0.0
        train_dataiter = iter(self.train_loader)
        alpha0_list = []
        alpha_sum = 0.0
        total_correct = 0
        pbar = tqdm(range(1, len(train_dataiter) + 1),
                    desc=f'Epoch {epoch_idx:03d}: Loss=0.0000 Acc=0.00%',
                    position=0,
                    leave=True,
                    disable=not comm.is_main_process())
        accumulation_steps = self.config.optimizer.accumulation_steps
        accumulated_loss = 0.0
        total_samples = 0

        for train_step in pbar:
            update = True if train_step % accumulation_steps == 0 or train_step == len(train_dataiter) else False
            batch = next(train_dataiter)
            data = batch['data'].cuda()
            label = batch['label'].cuda()
            target = batch['label'].cuda()
            num_samples = target.shape[0]

            # convert target to one-hot
            target = F.one_hot(target, num_classes=int(self.config.network.num_classes))
            one_hot_target = target.clone()

            # VRM
            if self.trainer_args.get('mixup', False):

                mix_alpha = self.trainer_args.mix_alpha
                mix_beta = self.trainer_args.mix_beta
                index, lam = prepare_mixup(batch, mix_alpha, mix_beta)

                mixing_data = mixing(data, index, lam)
                mixing_target = mixing(target, index, lam)

                if self.trainer_args.get('mix_noise', False):
                    noise_mix_alpha = self.trainer_args.noise_mix_alpha
                    noise_mix_beta = self.trainer_args.noise_mix_beta
                    random_data = torch.randn_like(data)
                    random_label = torch.ones_like(target)/self.num_classes
                    moise_index, lam = prepare_mixup(batch, noise_mix_alpha, noise_mix_beta)
                    random_data = lam*data+ (1-lam)*random_data[moise_index]
                    random_label = lam*target+ (1-lam)*random_label[moise_index]
                    noise_mix_ratio = self.trainer_args.get('noise_mix_ratio', 1.0)
                    N = int(data.shape[0] * noise_mix_ratio)
                    mixing_data = torch.cat((mixing_data, random_data[:N]), dim=0)
                    mixing_target = torch.cat((mixing_target, random_label[:N]), dim=0)

                data = torch.cat((data, mixing_data), dim=0)
                target = torch.cat((target, mixing_target), dim=0)
                label = torch.argmax(target, dim=1)
                one_hot_target = (target > 0).float()

            # forward
            logits = self.net(data)
            batch_acc = compute_accuracy(logits, label)
            total_correct += batch_acc * num_samples
            total_samples += num_samples
            evidence = self.edl_utils.logits_to_evidence(logits)
            alpha = self.edl_utils.logits_to_alpha(logits)
            priors = self.edl_utils.get_priors().to(alpha.device)
            alpha_for_kl = evidence * (1 - one_hot_target) + priors

            if self.config.trainer.trainer_args.method == 'redl':
                lamb1 = 1.0
                lamb2 = float(self.trainer_args.prior_fn)
                loss_mse = self.compute_mse(target, evidence, lamb1=lamb1, lamb2=lamb2)
                loss_kl = self.compute_kl_loss(alpha_for_kl, lamb2)
                regr = np.minimum(1.0, epoch_idx / 10.)
                loss = loss_mse + self.trainer_args.kl_c * regr * loss_kl
            elif self.config.trainer.trainer_args.method == 'iedl':
                fisher_c = self.trainer_args.get('fisher_c', 0.05)
                loss_mse, loss_var, loss_fisher = self.compute_fisher_mse(target, alpha)
                loss_kl = self.compute_kl_loss(alpha_for_kl, 1.0)
                regr = np.minimum(1.0, epoch_idx / 10.)
                loss = loss_mse + loss_var + fisher_c * loss_fisher + self.trainer_args.kl_c * regr * loss_kl
            elif self.trainer_args.method == 'daedl':
                alpha = 1e-6 + torch.exp(logits)
                alpha0 = alpha.sum(1).reshape(-1, 1)
                y_oh= target
                alpha_tilde = alpha * (1 - y_oh) + y_oh

                expected_mse = torch.sum((y_oh - alpha / alpha0) ** 2) + torch.sum(
                    ((alpha * (alpha0 - alpha))) / ((alpha0 ** 2) * (alpha0 + 1)))
                from torch.distributions.kl import kl_divergence as kl_div
                kl_regularizer = kl_div(Dirichlet(1e-6 + alpha_tilde), Dirichlet(torch.ones_like(alpha_tilde))).sum()
                loss = expected_mse + self.trainer_args.kl_c * kl_regularizer
            elif self.config.trainer.trainer_args.method == 'red':
                loss = self.red_loss(logits, label, epoch_idx)
            elif self.config.trainer.trainer_args.method == 'hedl':
                features = self.net.feature
                loss, alpha = self.hedl_loss(self.num_classes, features, self.net.fc.weight, target, epoch_idx, self.num_classes,
                                      10, logits.device)
            elif self.config.trainer.trainer_args.method == 'priornet':
                id_criterion = DirichletKLLoss(target_concentration=self.trainer_args.get('pn_target_concentration', 200),
                                               concentration=self.trainer_args.get('pn_concentration', 1),
                                               reverse=self.trainer_args.get('pn_reverse_KL', False))
                ood_criterion = DirichletKLLoss(target_concentration=0.0,
                                                concentration=self.trainer_args.get('pn_concentration', 1),
                                                reverse=self.trainer_args.get('pn_reverse_KL', False))
                criterion = PriorNetMixedLoss([id_criterion, ood_criterion],
                                              mixing_params=[1.0, self.trainer_args.get('pn_ood_kl_weight', 1.0)])
                ood_data = torch.rand_like(data).cuda()
                ood_outputs = self.net(ood_data)
                id_outputs = logits
                loss = criterion((id_outputs, ood_outputs), (batch['label'].cuda(), None))
            elif self.trainer_args.method == 'postnet':
                alpha = self.net(data, return_output='alpha')
                alpha_0 = alpha.sum(1).unsqueeze(-1).repeat(1, self.num_classes)
                entropy_reg = Dirichlet(alpha).entropy()
                loss = torch.mean(
                    target * (torch.digamma(alpha_0) - torch.digamma(alpha))) - self.trainer_args.entropy_reg * torch.mean(
                    entropy_reg)
            else:
                regr = np.minimum(1.0, epoch_idx / 10.)
                loss = self.edl_utils.get_edl_loss(logits, target)
                loss_kl = self.compute_kl_loss(alpha_for_kl, float(self.trainer_args.prior_fn))    # TODO: change 1.0 to target prior
                loss += regr*self.trainer_args.kl_c*loss_kl

            alpha0_list.append(torch.sum(alpha, dim=-1).detach().cpu().numpy().tolist())
            alpha_sum += torch.sum(alpha)
            accumulated_loss += loss.item()
            loss = loss / accumulation_steps
            loss.backward()
            if update:
                self.optimizer.step()
                self.optimizer.zero_grad()
                if not self.config.optimizer.const_lr:
                    if self.config.optimizer.get('schedule_by_batch', True):
                        self.scheduler.step()

            with torch.no_grad():
                loss_avg = loss_avg * 0.8 + (accumulated_loss / accumulation_steps) * 0.2
            accumulated_loss = 0.0
            current_lr = self.optimizer.param_groups[0]['lr']
            accuracy = total_correct / total_samples * 100
            pbar.set_description(f'Epoch {epoch_idx:03d}: Loss={loss_avg:.4f} Acc={accuracy:.2f}% LR={current_lr:.6f}')

        if self.config.optimizer.get('schedule_by_epoch', False):
            self.scheduler.step()


        metrics = dict()
        metrics['epoch_idx'] = epoch_idx
        metrics['lr'] = self.scheduler.get_last_lr()[0]
        metrics['loss'] = self.save_metrics(loss_avg)
        metrics['alpha'] = (alpha_sum/total_samples).item()
        metrics['vacuity'] = self.num_classes / metrics['alpha']
        metrics['alpha0_mean'] = np.concatenate(alpha0_list).mean()
        metrics['alpha0_std'] = np.concatenate(alpha0_list).mean()
        avg_accuracy = total_correct / total_samples
        metrics['acc'] = avg_accuracy
        return self.net, metrics

    def hedl_loss(self, W, features, weight, target, epoch_num, num_classes, annealing_step, device=None, outputs = None):
        """

        :param logits:
        :param target: one-hot
        :param epoch_idx:
        :return:
        """

        def get_henn_fc(features, weight, W, num_classes, device):
            mask = (weight > 0).float()
            item_num = torch.sum(mask, axis=0)
            item_num = item_num.unsqueeze(0).expand(num_classes, -1)
            weight = mask / torch.max(torch.ones(item_num.shape).to(device), item_num)

            henn_fc = torch.nn.Linear(len(features[1]), num_classes)
            henn_fc.weight.data = weight.to(device)
            henn_fc.bias.data = torch.full((num_classes,), W / num_classes).to(device)
            features = features.to(device)

            return henn_fc(features)

        def hedl_loss(func, y, features, weight, W, num_classes, epoch_num, annealing_step, device=None, outputs=None):
            y = y.to(device)
            alpha = get_henn_fc(features, weight, W, num_classes, device)
            if outputs is not None:
                outputs.data = alpha.data.clone().detach()
                S = torch.sum(outputs, dim=1, keepdim=True).to(device)
                A = torch.sum(y * (func(S) - func(outputs)), dim=1, keepdim=True)
                return A
            S = torch.sum(alpha, dim=1, keepdim=True).to(device)

            A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)
            return A, alpha

        loss, alpha = hedl_loss(
                torch.digamma, target, features, weight, W, num_classes, epoch_num, annealing_step, device, outputs
            )
        loss = torch.mean(
            loss
        )
        return loss, alpha

    def red_loss(self, logits, label, epoch_idx):
        """
        # https://github.com/pandeydeep9/EvidentialResearch2023/blob/8683edd20682d1b96e0ff16901c0abafd26abddf/MainCifar10Experiments/utilFiles/losses.py#L122
        :param logits:
        :param target: one_hot_target
        :return:
        """
        one_hot_target = F.one_hot(label, num_classes=int(self.config.network.num_classes))
        regr = np.minimum(1.0, epoch_idx / 10.)
        evidence = self.edl_utils.logits_to_evidence(logits)
        loss = self.edl_utils.get_edl_loss(logits, one_hot_target)
        alpha = self.edl_utils.logits_to_alpha(logits)
        priors = self.edl_utils.get_priors().to(alpha.device)
        alpha_for_kl = evidence * (1 - one_hot_target) + priors
        loss_kl = self.compute_kl_loss(alpha_for_kl, 1.0)
        # loss_kl = self.compute_kl_loss_iedl(alpha_for_kl, label, 1.0)  # TODO: change 1.0 to target prior
        loss += self.trainer_args.kl_c * regr * loss_kl
        # vacuity regularization
        self.edl_utils.set_uncertainty_fn(UncertaintyOptions.edl_vacuity)
        vacuity = self.edl_utils.get_uncertainty(logits).detach()
        # output_correct = evidence*one_hot_target
        kl_pos = torch.sum(torch.log(evidence + 1e-5) * one_hot_target, dim=-1, keepdim=True)
        # log_correct = torch.log(output_correct + 1e-5)
        loss -= torch.sum(vacuity * kl_pos) / kl_pos.shape[0]
        # loss -= torch.sum(vacuity * output_correct) / output_correct.shape[0]
        return loss

    def save_metrics(self, loss_avg):
        all_loss = comm.gather(loss_avg)
        total_losses_reduced = np.mean([x for x in all_loss])

        return total_losses_reduced

    def compute_fisher_mse(self, labels_1hot_, evi_alp_):
        evi_alp0_ = torch.sum(evi_alp_, dim=-1, keepdim=True)

        gamma1_alp = torch.polygamma(1, evi_alp_)
        gamma1_alp0 = torch.polygamma(1, evi_alp0_)

        gap = labels_1hot_ - evi_alp_ / evi_alp0_

        loss_mse_ = (gap.pow(2) * gamma1_alp).sum(-1).mean()

        loss_var_ = (evi_alp_ * (evi_alp0_ - evi_alp_) * gamma1_alp / (evi_alp0_ * evi_alp0_ * (evi_alp0_ + 1))).sum(-1).mean()

        loss_det_fisher_ = - (torch.log(gamma1_alp).sum(-1) + torch.log(1.0 - (gamma1_alp0 / gamma1_alp).sum(-1))).mean()

        return loss_mse_, loss_var_, loss_det_fisher_

    def compute_mse(self, labels_1hot, evidence, lamb1, lamb2):

        num_classes = evidence.shape[-1]
        gap = labels_1hot - (evidence + lamb2) / \
              (evidence + lamb1 * (
                          torch.sum(evidence, dim=-1, keepdim=True) - evidence) + lamb2 * num_classes)

        loss_mse = gap.pow(2).sum(-1)

        return loss_mse.mean()

    def compute_kl_loss(self, alphas, target_concentration, epsilon=1e-8):
        target_alphas = torch.ones_like(alphas) * target_concentration

        alp0 = torch.sum(alphas, dim=-1, keepdim=True)
        target_alp0 = torch.sum(target_alphas, dim=-1, keepdim=True)

        alp0_term = torch.lgamma(alp0 + epsilon) - torch.lgamma(target_alp0 + epsilon)
        alp0_term = torch.where(torch.isfinite(alp0_term), alp0_term, torch.zeros_like(alp0_term))
        assert torch.all(torch.isfinite(alp0_term)).item()

        alphas_term = torch.sum(torch.lgamma(target_alphas + epsilon) - torch.lgamma(alphas + epsilon)
                                + (alphas - target_alphas) * (torch.digamma(alphas + epsilon) -
                                                              torch.digamma(alp0 + epsilon)), dim=-1, keepdim=True)
        alphas_term = torch.where(torch.isfinite(alphas_term), alphas_term, torch.zeros_like(alphas_term))
        assert torch.all(torch.isfinite(alphas_term)).item()

        loss = torch.squeeze(alp0_term + alphas_term).mean()

        return loss
