import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.manifold import TSNE
from itertools import cycle
from torch.autograd import Variable
from optim.pytorchtools import EarlyStopping


plt.style.use('ggplot')
matplotlib.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'legend.fontsize': 10,
    'figure.titlesize': 16
})

def kl_divergence(logits_p, logits_q):
    """
    Computes the KL divergence, which measures the difference between two distributions (represented by logits).
    """
    ptemp = torch.distributions.categorical.Categorical(logits=logits_p)
    qtemp = torch.distributions.categorical.Categorical(logits=logits_q)
    result = torch.distributions.kl.kl_divergence(ptemp, qtemp)
    return result.unsqueeze(1)

def exp_rampup(rampup_length):
    """
    Exponential ramp-up from https://arxiv.org/abs/1610.02242,
    used to smoothly increase the weight of the unsupervised loss in the early stages of training.
    """
    def warpper(epoch):
        if epoch < rampup_length:
            epoch_clip = np.clip(epoch, 0.0, rampup_length)
            phase = 1.0 - epoch_clip / rampup_length
            return float(np.exp(-5.0 * phase * phase))
        else:
            return 1.0
    return warpper

def mse_with_softmax(logit1, logit2):
    """
    Computes the mean squared error (i.e., L2-norm distance) between the softmax outputs of two logits.
    """
    assert logit1.size() == logit2.size(), "logit1 / logit2 size mismatch"
    p1 = F.softmax(logit1, dim=1)
    p2 = F.softmax(logit2, dim=1)
    dist_per_sample = torch.norm(p1 - p2, p=2, dim=1)
    return dist_per_sample.mean()

def perturbation(X, method, std=0.01, mean=0.):
    """
    A simple example of random perturbation. This can be extended to other methods if needed.
    """
    if method == 'noise':
        noise = torch.randn_like(X) * std + mean
        X = X + noise
    return X

def build_M_inv(T, alpha=1.0, device='cpu', dtype=torch.float32, add_diagonal_flag=True):
    """Return ``(D^TD + alpha I)^{-1}`` used in the Sobolev metric.

    The matrix captures second-order differences with Dirichlet boundaries. Its
    inverse scales adversarial gradients so that smoothness is measured in the
    Sobolev ``H^-2`` norm.
    """
    S = torch.zeros(T, T, dtype=dtype, device=device)
    for i in range(T):
        S[i, i] = 2.0
        if i - 1 >= 0:
            S[i, i - 1] = -1.0
        if i + 1 < T:
            S[i, i + 1] = -1.0

    if add_diagonal_flag:
        I = torch.eye(T, dtype=dtype, device=device)
        S_alpha = S + alpha * I
    else:
        S_alpha = S

    M_inv = torch.linalg.inv(S_alpha)
    return M_inv

def clamp_and_bound_perturbation(r_adv, eps_min, eps_max):
    """
    Cleans and clips the generated adversarial perturbation r_adv:
      1) NaN/Inf -> 0
      2) If L2 norm is too small (< eps_min) or too large (> eps_max), replace with random perturbation.
    """
    # Remove NaN/Inf first
    r_adv[torch.isnan(r_adv)] = 0
    r_adv[r_adv == float("inf")] = 0
    r_adv[r_adv == float("-inf")] = 0

    batch_size = r_adv.size(0)
    norm_r_adv = torch.norm(r_adv.view(batch_size, -1), p=2, dim=1, keepdim=True)

    # Mark too small / too large
    too_small = (norm_r_adv < eps_min).float()
    too_large = (norm_r_adv > eps_max).float()
    mask = torch.clamp(too_small + too_large, 0, 1).view(-1,1,1).expand_as(r_adv)

    # Replace with random perturbation
    rand_perturb = torch.rand_like(r_adv) - 0.5
    r_adv = (1 - mask)*r_adv + mask*rand_perturb

    return r_adv


class Model_SemiMean(nn.Module):
    def __init__(self, model, ema_model, opt, ip=1, xi=1.0, eps_min=0.1, eps_max=5.0, factor=3):
        super().__init__()
        self.model = model
        self.ema_model = ema_model
        self.ce_loss = nn.CrossEntropyLoss()
        self.global_step = 0

        # Unsupervised weight
        self.usp_weight = opt.usp_weight
        self.rampup = exp_rampup(opt.weight_rampup)

        # Adversarial perturbation-related
        self.ip = ip
        self.xi = xi
        self.eps_min = eps_min
        self.eps_max = eps_max
        self.factor = factor

        # If use_flag=True => "sobolev"; otherwise "new1"
        self.use_flag = opt.use_flag
        self.current_method = "sobolev" if self.use_flag else "new1"

        # Sobolev / difference matrix inverse
        self.M_inv = None
        self.M_inv_size = None
        self.alpha = getattr(opt, "diag_alpha", 0.01)

        # wandb init
        if not wandb.run:
            wandb.init(project="ProjectName", name="RunName", config=vars(opt))

    def forward(self, x):
        """ Allows the class to be used directly as a model => self.model(x). """
        return self.model(x)

    def update_ema(self, model, ema_model, alpha, global_step):
        """
        Updates EMA parameters:
        alpha decays dynamically; as training progresses, updates become smoother.
        """
        alpha = min(1 - 1/(global_step+1), alpha)
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

    def virtual_adversarial_new1(self, x):
        """
        f-VAT implementation based on L2 normalized gradient iteration + clamp_and_bound_perturbation.
        """
        batch_size = x.size(0)
        length = x.size(1)
        if self.M_inv is None or (self.M_inv_size != length):
            self.M_inv = build_M_inv(length, alpha=self.alpha, device=x.device, dtype=x.dtype)
            self.M_inv_size = length

        # Initial random perturbation (controllable size)
        d = torch.rand_like(x) - 0.5
        d = F.normalize(d.view(batch_size, -1), p=2, dim=1).view_as(d)

        for _ in range(self.ip):
            d.requires_grad_()
            x_hat = x + self.xi * d
            y = self.model(x)
            y_hat = self.model(x_hat)
            lds_loss = mse_with_softmax(y, y_hat)

            self.model.zero_grad()
            lds_loss.backward()
            d = d.grad.detach()
            # L2 normalization
            d = F.normalize(d.view(batch_size, -1), p=2, dim=1).view_as(d)

        # Do not multiply directly by a fixed factor, first go through clamp_and_bound
        r_adv = d
        r_adv = clamp_and_bound_perturbation(r_adv, self.eps_min, self.eps_max)
        return r_adv

    def virtual_adversarial_new_metric(self, x):
        """
        In the Sobolev metric, compute:
          r_adv = g / <g, M_inv*g>,
        and use clamp_and_bound_perturbation.
        """
        batch_size = x.size(0)
        length = x.size(1)
        if self.M_inv is None or (self.M_inv_size != length):
            self.M_inv = build_M_inv(length, alpha=self.alpha, device=x.device, dtype=x.dtype)
            self.M_inv_size = length

        # Initial perturbation
        d = torch.rand_like(x) - 0.5
        for _ in range(self.ip):
            d.requires_grad_()
            y = self.model(x)
            y_hat = self.model(x + d)
            lds_loss = mse_with_softmax(y, y_hat)
            self.model.zero_grad()
            lds_loss.backward()
            d = d.grad.detach()

        # g = d
        g = d
        M_inv_g = torch.einsum("ij,bjd->bid", self.M_inv, g)
        denom = torch.einsum("bjd,bjd->b", g, M_inv_g).clamp(min=1e-12)
        denom = denom.view(-1,1,1)
        r_adv = g / denom

        r_adv = clamp_and_bound_perturbation(r_adv, self.eps_min, self.eps_max)
        return r_adv

    def get_adversarial_perturbations(self, x):
        """
        Chooses the corresponding adversarial perturbation algorithm (new1 / sobolev) based on self.use_flag.
        """
        if self.use_flag:
            return x + self.virtual_adversarial_new_metric(x)
        else:
            return x + self.virtual_adversarial_new1(x)

    def train(self, tot_epochs, train_loader, train_loader_label, val_loader, test_loader, opt):
        """Run the f-VAT training loop.

        Args:
            tot_epochs (int): Number of epochs to train.
            train_loader: Unlabeled data iterator.
            train_loader_label: Labeled data iterator.
            val_loader: Validation data iterator.
            test_loader: Test data iterator.
            opt: Parsed command-line options.

        Returns:
            Tuple[float, float, int]: best test accuracy, best validation accuracy and epoch index.
        """
        patience = opt.patience
        ckpt_name = f'backbone_best_{opt.model_name}_{self.use_flag}.tar'
        ckpt_path = os.path.join(opt.ckpt_dir, ckpt_name)

        early_stopping = EarlyStopping(
            patience,
            verbose=True,
            checkpoint_pth=ckpt_path
        )
        optimizer = torch.optim.Adam(self.model.parameters(), lr=opt.learning_rate)

        best_val_acc, best_test_acc = 0, 0
        best_epoch = 0

        for epoch in range(tot_epochs):
            self.model.train()
            self.ema_model.train()

            acc_label, acc_unlabel = [], []
            loss_label, loss_unlabel = [], []

            # Use cycle() to ensure that if unlabeled data is more, labeled data can still match up.
            for i, (data_labeled, data_unlabel) in enumerate(zip(cycle(train_loader_label), train_loader)):
                self.global_step += 1
                x, targets = data_labeled
                aug1, aug2, targetAug = data_unlabel

                x, targets = x.cuda(), targets.cuda()
                aug1, aug2, targetAug = aug1.cuda(), aug2.cuda(), targetAug.cuda()

                # ========== 1) Supervised part ==========
                out = self.model(x)
                sup_loss = self.ce_loss(out, targets)
                pred = out.argmax(dim=-1)
                correct = pred.eq(targets).sum()
                acc_label.append(100.0 * correct / len(targets))
                loss_label.append(sup_loss.item())

                # ========== 2) Unsupervised part (VAT) ==========
                if self.current_method == "sobolev":
                    r_adv = self.virtual_adversarial_new_metric(aug1)
                else:
                    r_adv = self.virtual_adversarial_new1(aug1)

                aug1_hat = aug1 + r_adv
                with torch.no_grad():
                    out_ema = self.ema_model(aug1_hat)
                out_aug = self.model(aug1)
                unsup_loss = mse_with_softmax(out_ema, out_aug)

                pred_u = out_aug.argmax(dim=-1)
                correct_u = pred_u.eq(targetAug).sum()
                acc_unlabel.append(100.0 * correct_u / len(targetAug))
                loss_unlabel.append(unsup_loss.item())

                # Total loss
                total_loss = sup_loss + unsup_loss
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                # ========== 3) EMA update ==========
                self.update_ema(self.model, self.ema_model, opt.ema_decay, self.global_step)

                # ========== Log to wandb ==========
                wandb.log({
                    'sup_loss': sup_loss.item(),
                    'unsup_loss': unsup_loss.item(),
                    'total_loss': total_loss.item()
                }, step=self.global_step)

            # End of epoch, compute average loss/accuracy
            avg_acc_label = float(torch.tensor(acc_label).mean().item()) if len(acc_label) > 0 else 0
            avg_acc_unlabel = float(torch.tensor(acc_unlabel).mean().item()) if len(acc_unlabel) > 0 else 0

            # ========== Validation set accuracy ==========
            self.model.eval()
            val_accs = []
            with torch.no_grad():
                for xv, yv in val_loader:
                    xv, yv = xv.cuda(), yv.cuda()
                    outv = self.model(xv)
                    pv = outv.argmax(dim=-1)
                    corrv = pv.eq(yv).sum()
                    val_accs.append(100.0 * corrv / len(yv))

            mean_val_acc = float(torch.tensor(val_accs).mean().item()) if len(val_accs) > 0 else 0

            # If the validation set improves, evaluate on the test set
            if mean_val_acc > best_val_acc:
                best_val_acc = mean_val_acc
                best_epoch = epoch

                test_accs = []
                for xt, yt in test_loader:
                    xt, yt = xt.cuda(), yt.cuda()
                    outt = self.model(xt)
                    pt = outt.argmax(dim=-1)
                    corrt = pt.eq(yt).sum()
                    test_accs.append(100.0 * corrt / len(yt))
                best_test_acc = float(torch.tensor(test_accs).mean().item()) if len(test_accs) > 0 else 0

            # Early stopping
            early_stopping(mean_val_acc, self.model)
            if early_stopping.early_stop:
                print("Early stopping triggered.")
                break

            # Print info for this epoch
            print(f"Epoch[{epoch+1}/{tot_epochs}] "
                  f"LabelAcc={avg_acc_label:.2f}, UnLabelAcc={avg_acc_unlabel:.2f}, "
                  f"ValAcc={mean_val_acc:.2f}, BestValAcc={best_val_acc:.2f}, "
                  f"BestTestAcc={best_test_acc:.2f}")

            # Log to wandb
            wandb.log({
                'epoch_label_acc': avg_acc_label,
                'epoch_unlabel_acc': avg_acc_unlabel,
                'val_acc': mean_val_acc,
                'best_test_acc': best_test_acc
            }, step=self.global_step)


        return best_test_acc, best_val_acc, best_epoch
