import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from architectures.linear_sequential import linear_sequential
from architectures.convolution_linear_sequential import convolution_linear_sequential
from architectures.vgg_sequential import vgg16_bn
from architectures.resnet_sequential import resnet18

def mixing(data, index, lam):
    return lam * data + (1 - lam) * data[index]

def cosine_annealing(step, total_steps, lr_max, lr_min):
    return lr_min + (lr_max - lr_min) * 0.5 * \
                (1 + np.cos(step / total_steps * np.pi))

def prepare_mixup(input, alpha=10.0, beta=1.0, use_cuda=True):
    """Returns mixed inputs, pairs of targets, and lambda."""
    if alpha > 0:
        lam = np.random.beta(alpha, beta)
    else:
        lam = 1

    batch_size = input.size()[0]
    # Automatically adapt to input tensor device
    index = torch.randperm(batch_size, device=input.device)

    return index, lam


class ModifiedEvidentialNet(nn.Module):
    def __init__(self,
                 input_dims,  # Input dimension. list of ints
                 output_dim,  # Output dimension. int
                 hidden_dims=[64, 64, 64],  # Hidden dimensions. list of ints
                 kernel_dim=None,  # Kernel dimension if conv architecture. int
                 architecture='linear',  # Encoder architecture name. int
                 k_lipschitz=None,  # Lipschitz constant. float or None (if no lipschitz)
                 batch_size=64,  # Batch size. int
                 lr=1e-3,  # Learning rate. float
                 loss='IEDL',  # Loss name. string
                 clf_type='softplus',
                 fisher_c=1.0,
                 kl_c=-1.0,
                 lamb1=1.0,
                 lamb2=1.0,
                 mix=False,  # Enable mix augmentation
                 mix_inter=False,  # Enable inter-class mixing
                 mix_inter_alpha=1.0,  # Alpha parameter for inter-class mixing
                 mix_inter_beta=1.0,  # Beta parameter for inter-class mixing
                 mix_noise=False,  # Enable noise mixing
                 noise_mix_alpha=1.0,  # Alpha parameter for noise mixing
                 noise_mix_beta=1.0,  # Beta parameter for noise mixing
                 noise_mix_ratio=1.0,  # Ratio for noise mixing
                 optimizer_type='adam',  # Optimizer type: 'adam' or 'sgd'
                 use_cosine_annealing=False,  # Whether to use cosine annealing scheduler
                 num_epochs=None,  # Number of epochs for cosine annealing
                 train_loader_len=None,  # Length of train loader for cosine annealing
                 cosine_lr_min_ratio=5e-6,  # Minimum lr ratio for cosine annealing (lr_min / lr_max)
                 use_sample_wise_kl_weight=False,  # Whether to use sample-wise KL loss weighting
                 kl_start_epoch=100,  # Epoch threshold to start applying KL loss
                 seed=123):  # Random seed for init. int
        super().__init__()

        torch.cuda.manual_seed(seed)
        torch.set_default_tensor_type(torch.FloatTensor)

        # Architecture parameters
        self.input_dims, self.output_dim, self.hidden_dims, self.kernel_dim = input_dims, output_dim, hidden_dims, kernel_dim
        self.k_lipschitz = k_lipschitz
        self.num_classes = output_dim  # Add num_classes for mix functionality
        # Training parameters
        self.batch_size, self.lr = batch_size, lr
        self.loss = loss
        self.clf_type = clf_type

        # Optimizer parameters
        self.optimizer_type = optimizer_type
        self.use_cosine_annealing = use_cosine_annealing
        self.num_epochs = num_epochs
        self.train_loader_len = train_loader_len
        self.cosine_lr_min_ratio = cosine_lr_min_ratio

        # self.target_con = target_con
        # self.kl_c = kl_c
        self.target_con = 1.0
        self.kl_c = kl_c
        self.fisher_c = fisher_c
        self.lamb1 = lamb1
        self.lamb2 = lamb2
        self.prior = 0
        
        # Sample-wise KL loss weighting
        self.use_sample_wise_kl_weight = use_sample_wise_kl_weight
        self.kl_start_epoch = kl_start_epoch

        # Mix parameters
        self.mix = mix
        self.mix_inter = mix_inter
        self.mix_inter_alpha = mix_inter_alpha
        self.mix_inter_beta = mix_inter_beta
        self.mix_noise = mix_noise
        self.noise_mix_alpha = noise_mix_alpha
        self.noise_mix_beta = noise_mix_beta
        self.noise_mix_ratio = noise_mix_ratio

        self.loss_mse = torch.tensor(0.0)
        self.loss_ce = torch.tensor(0.0)
        self.loss_var = torch.tensor(0.0)
        self.loss_kl = torch.tensor(0.0)
        self.loss_fisher = torch.tensor(0.0)

        # Feature selection
        if architecture == 'linear':
            self.sequential = linear_sequential(input_dims=self.input_dims,
                                                hidden_dims=self.hidden_dims,
                                                output_dim=self.output_dim,
                                                k_lipschitz=self.k_lipschitz)
        elif architecture == 'conv':
            assert len(input_dims) == 3
            self.sequential = convolution_linear_sequential(input_dims=self.input_dims,
                                                            linear_hidden_dims=self.hidden_dims,
                                                            conv_hidden_dims=[64, 64, 64],
                                                            output_dim=self.output_dim,
                                                            kernel_dim=self.kernel_dim,
                                                            k_lipschitz=self.k_lipschitz)
        elif architecture == 'vgg':
            assert len(input_dims) == 3
            self.sequential = vgg16_bn(output_dim=self.output_dim, k_lipschitz=self.k_lipschitz)
        elif architecture == 'resnet':
            assert len(input_dims) == 3
            self.sequential = resnet18(output_dim=self.output_dim, k_lipschitz=self.k_lipschitz)
        else:
            raise NotImplementedError

        self.softmax = nn.Softmax(dim=-1)
        self.clf_type = clf_type

        # Optimizer
        if self.optimizer_type == 'sgd':
            self.optimizer = torch.optim.SGD(self.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        else:  # default to adam
            self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        
        # Scheduler
        if self.use_cosine_annealing and self.num_epochs is not None and self.train_loader_len is not None:
            self.scheduler = torch.optim.lr_scheduler.LambdaLR(
                self.optimizer,
                lr_lambda=lambda step: cosine_annealing(
                    step,
                    self.num_epochs * self.train_loader_len,
                    1,
                    self.cosine_lr_min_ratio,
                ),
            )
        elif architecture == 'conv':
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=15, gamma=0.1)
        else:
            self.scheduler = None

    def forward(self, input, labels_=None, return_output='alpha', compute_loss=False, epoch=10.):
        assert not (labels_ is None and compute_loss)

        # # Forward
        # logits = self.sequential(input)
        if labels_ is not None:
            # labels_1hot = torch.zeros_like(logits).scatter_(-1, labels_.unsqueeze(-1), 1)
            # Get labels_1hot directly based on num_classes
            # Ensure same device
            labels_ = labels_.to(input.device)  # First move labels_ to the correct device
            labels_1hot = torch.zeros(input.shape[0], self.num_classes, device=input.device).scatter_(-1, labels_.unsqueeze(-1), 1)
        else:
            labels_1hot = None

        if self.mix and self.training:
            mixing_data = None
            mixing_target = None
            
            if self.mix_inter:
                mix_inter_alpha = self.mix_inter_alpha
                mix_inter_beta = self.mix_inter_beta
                index, lam = prepare_mixup(input, mix_inter_alpha, mix_inter_beta)

                mixing_data = mixing(input, index, lam)
                mixing_target = mixing(labels_1hot, index, lam)

            if self.mix_noise:
                noise_mix_alpha = self.noise_mix_alpha
                noise_mix_beta = self.noise_mix_beta
                random_data = torch.randn_like(input)
                random_label = torch.ones_like(labels_1hot)/self.num_classes
                moise_index, lam = prepare_mixup(input, noise_mix_alpha, noise_mix_beta)
                random_data = lam*input+ (1-lam)*random_data[moise_index]
                random_label = lam*labels_1hot+ (1-lam)*random_label[moise_index]
                noise_mix_ratio = self.noise_mix_ratio
                N = int(input.shape[0] * noise_mix_ratio)
                
                if mixing_data is not None:
                    # Both mix_inter and mix_noise are enabled
                    mixing_data = torch.cat((mixing_data, random_data[:N]), dim=0)
                    mixing_target = torch.cat((mixing_target, random_label[:N]), dim=0)
                else:
                    # Only mix_noise is enabled
                    mixing_data = random_data[:N]
                    mixing_target = random_label[:N]
            
            # Combine original and mixed data
            if mixing_data is not None:
                input = torch.cat((input, mixing_data), dim=0)
                labels_1hot = torch.cat((labels_1hot, mixing_target), dim=0)
                
        # Get logits and features for sample-wise KL weighting
        if self.use_sample_wise_kl_weight and hasattr(self.sequential, 'forward'):
            # Try to get features if the architecture supports it
            try:
                logits, features = self.sequential(input, return_features=True)
            except TypeError:
                # Fallback if architecture doesn't support return_features
                logits = self.sequential(input)
                features = None
        else:
            logits = self.sequential(input)
            features = None
        
        # Apply activation function based on clf_type
        if self.clf_type == 'softplus':
            evidence = F.softplus(logits)
        elif self.clf_type == 'exp':
            evidence = torch.exp(torch.clamp(logits, -10, 10))
        elif self.clf_type == 'relu':
            evidence = F.relu(logits)
        else:
            raise ValueError(f"Unsupported clf_type: {self.clf_type}")
            
        alpha = evidence + self.lamb2
        
        # Calculate loss
        if compute_loss:
            if self.loss == 'MEDL':
                self.loss_mse = self.compute_mse(labels_1hot, evidence)
                self.grad_loss = self.loss_mse
            elif self.loss == 'CE':
                self.loss_ce = self.compute_ce(labels_1hot, evidence)
                self.grad_loss = self.loss_ce
            elif self.loss == 'UCE':
                self.loss_uce = self.compute_expected_ce(labels_1hot, alpha)
                self.grad_loss = self.loss_uce
            elif self.loss == 'EDL':
                self.loss_mse, self.loss_var = self.compute_vanilla_mse(labels_1hot, alpha)
                self.grad_loss = self.loss_mse + self.loss_var
            elif self.loss == 'IEDL':
                self.loss_mse, self.loss_var, self.loss_fisher = self.compute_fisher_mse(labels_1hot, alpha)
                self.grad_loss = self.loss_mse + self.loss_var + self.fisher_c * self.loss_fisher
            elif self.loss == 'MSE-softmax':
                prob = torch.softmax(logits, dim=1)
                self.grad_loss = F.mse_loss(prob, labels_1hot)
            elif self.loss == 'CE-softmax':
                labels = torch.argmax(labels_1hot, dim=1)
                self.grad_loss = F.cross_entropy(logits, labels)

            add_kl_loss = True
            # add_kl_loss = True
            if add_kl_loss:
                kl_alpha = evidence * (1 - labels_1hot) + self.lamb2
                
                # Compute sample-wise weights if enabled
                if self.use_sample_wise_kl_weight and features is not None and labels_ is not None:
                    sample_weights = self.compute_sample_wise_weights(logits, labels_, features)
                    self.loss_kl = self.compute_kl_loss_weighted(kl_alpha, self.lamb2, sample_weights)
                else:
                    self.loss_kl = self.compute_kl_loss(kl_alpha, self.lamb2)

                if self.kl_c == -1:
                    regr = np.minimum(1.0, epoch / 10.)
                    self.grad_loss += regr * self.loss_kl
                else:
                    # if epoch >= self.kl_start_epoch:
                    self.grad_loss += self.kl_c * self.loss_kl
                    # else:
                    #     print(f"Epoch {epoch} is less than kl_start_epoch {self.kl_start_epoch}, so not add KL loss")

        if return_output == 'hard':
            return self.predict(logits)
        elif return_output == 'soft':
            return self.softmax(logits)
        elif return_output == 'alpha':
            return alpha
        elif return_output == 'evidence':
            return evidence
        else:
            raise AssertionError

    def compute_ce(self, labels_1hot, evidence):
        num_classes = evidence.shape[-1]
        prob = (evidence + self.lamb2) / \
               (torch.sum(evidence, dim=-1, keepdim=True) + self.lamb2 * num_classes)
        ce_loss = - (labels_1hot * torch.log(prob + 1e-7)).sum(-1)
        return ce_loss.mean()
    
    # Create expected cross entropy loss
    # \sum_{j=1}^K y_{i j}\left(\psi\left(S_i\right)-\psi\left(\alpha_{i j}\right)\right)
    def compute_expected_ce(self, labels_1hot, alpha):
        def edl_loss(func, y, alpha):
            S = torch.sum(alpha, dim=1, keepdim=True)
            A = torch.sum(y * (func(S) - func(alpha)), dim=1, keepdim=True)
            return A
        loss = torch.mean(edl_loss(torch.digamma, labels_1hot, alpha))
        return loss

    def compute_mse(self, labels_1hot, evidence):
        num_classes = evidence.shape[-1]

        gap = labels_1hot - (evidence + self.lamb2) / \
              (evidence + self.lamb1 * (torch.sum(evidence, dim=-1, keepdim=True) - evidence) + self.lamb2 * num_classes)

        loss_mse = gap.pow(2).sum(-1)

        return loss_mse.mean()

    def compute_vanilla_mse(self, labels_1hot, alpha):
        S = torch.sum(alpha, dim=-1, keepdim=True)
        loss_mse = (labels_1hot - alpha / S).pow(2).sum(-1).mean()
        loss_var = (alpha * (S - alpha) / (S * S * (S + 1))).sum(-1).mean()

        return loss_mse, loss_var

    def compute_fisher_mse(self, labels_1hot, alpha):
        S = torch.sum(alpha, dim=-1, keepdim=True)

        gamma1_alpha = torch.polygamma(1, alpha)
        gamma1_S = torch.polygamma(1, S)

        gap = labels_1hot - alpha / S

        loss_mse = (gap.pow(2) * gamma1_alpha).sum(-1).mean()
        loss_var = (alpha * (S - alpha) * gamma1_alpha / (S * S * (S + 1))).sum(-1).mean()
        loss_det_fisher = - (torch.log(gamma1_alpha).sum(-1) + torch.log(1.0 - (gamma1_S / gamma1_alpha).sum(-1))).mean()

        return loss_mse, loss_var, loss_det_fisher

    def compute_kl_loss(self, alphas, target_concentration, epsilon=1e-8):
        target_alphas = torch.ones_like(alphas) * target_concentration

        alp0 = torch.sum(alphas, dim=-1, keepdim=True)
        target_alp0 = torch.sum(target_alphas, dim=-1, keepdim=True)

        alp0_term = torch.lgamma(alp0 + epsilon) - torch.lgamma(target_alp0 + epsilon)
        alp0_term = torch.where(torch.isfinite(alp0_term), alp0_term, torch.zeros_like(alp0_term))
        assert torch.all(torch.isfinite(alp0_term)).item()

        alphas_term = torch.sum(torch.lgamma(target_alphas + epsilon) - torch.lgamma(alphas + epsilon)
                                + (alphas - target_alphas) * (torch.digamma(alphas + epsilon) -
                                                              torch.digamma(alp0 + epsilon)), dim=-1, keepdim=True)
        alphas_term = torch.where(torch.isfinite(alphas_term), alphas_term, torch.zeros_like(alphas_term))
        assert torch.all(torch.isfinite(alphas_term)).item()

        loss = torch.squeeze(alp0_term + alphas_term).mean()

        return loss

    def compute_sample_wise_weights(self, logits, labels, features):
        """
        Compute sample-wise weights for KL loss based on margin.
        
        Args:
            logits: Model logits [batch_size, num_classes]
            labels: True labels [batch_size]
            features: Last-layer features [batch_size, feature_dim]
            
        Returns:
            weights: Sample-wise weights [batch_size]
        """
        batch_size = logits.shape[0]
        device = logits.device
        
        # Get the weight matrix from the last linear layer
        # For different architectures, we need to extract the weight differently
        if hasattr(self.sequential, 'linear'):  # ResNet
            W = self.sequential.linear.weight  # [num_classes, feature_dim]
        elif hasattr(self.sequential, 'classifier'):  # VGG
            # Get the last linear layer from classifier
            for layer in reversed(self.sequential.classifier):
                if isinstance(layer, torch.nn.Linear):
                    W = layer.weight  # [num_classes, feature_dim]
                    break
        else:
            # Fallback: try to find the last linear layer
            W = None
            for module in reversed(list(self.sequential.modules())):
                if isinstance(module, torch.nn.Linear):
                    W = module.weight
                    break
            
            if W is None:
                # If we can't find the weight matrix, return uniform weights
                raise ValueError("Can't find the weight matrix")
                # return torch.ones(batch_size, device=device)
        
        # Compute margins for each sample
        margins = []
        for i in range(batch_size):
            y = labels[i].item()  # True class index
            h = features[i]  # Feature vector for sample i
            
            # Compute w_y^T * h (logit for true class)
            w_y_h = torch.dot(W[y], h)
            
            # Find the class with maximum logit among wrong classes
            logits_i = logits[i]
            logits_wrong = logits_i.clone()
            logits_wrong[y] = float('-inf')  # Mask out the true class
            j = torch.argmax(logits_wrong).item()  # Index of max wrong class
            
            # Compute w_j^T * h (logit for max wrong class)
            w_j_h = torch.dot(W[j], h)
            
            # Compute margin: w_y^T * h - w_j^T * h
            margin = w_y_h - w_j_h
            margins.append(margin)
        
        margins = torch.stack(margins)
        
        # Compute weights using sigmoid(-margin)
        weights = torch.sigmoid(-margins).detach()
        # print(weights)
        return weights

    def compute_kl_loss_weighted(self, alphas, target_concentration, sample_weights, epsilon=1e-8):
        """
        Compute weighted KL loss with sample-wise weights.
        
        Args:
            alphas: Alpha parameters [batch_size, num_classes]
            target_concentration: Target concentration parameter
            sample_weights: Sample-wise weights [batch_size]
            epsilon: Small value for numerical stability
            
        Returns:
            Weighted KL loss
        """
        target_alphas = torch.ones_like(alphas) * target_concentration

        alp0 = torch.sum(alphas, dim=-1, keepdim=True)
        target_alp0 = torch.sum(target_alphas, dim=-1, keepdim=True)

        alp0_term = torch.lgamma(alp0 + epsilon) - torch.lgamma(target_alp0 + epsilon)
        alp0_term = torch.where(torch.isfinite(alp0_term), alp0_term, torch.zeros_like(alp0_term))
        assert torch.all(torch.isfinite(alp0_term)).item()

        alphas_term = torch.sum(torch.lgamma(target_alphas + epsilon) - torch.lgamma(alphas + epsilon)
                                + (alphas - target_alphas) * (torch.digamma(alphas + epsilon) -
                                                              torch.digamma(alp0 + epsilon)), dim=-1, keepdim=True)
        alphas_term = torch.where(torch.isfinite(alphas_term), alphas_term, torch.zeros_like(alphas_term))
        assert torch.all(torch.isfinite(alphas_term)).item()

        # Compute per-sample KL loss
        kl_per_sample = torch.squeeze(alp0_term + alphas_term)
        
        # Apply sample-wise weights and compute weighted mean
        # 
        weighted_kl = (kl_per_sample * sample_weights).sum() / sample_weights.sum()

        return weighted_kl

    def step(self):
        self.optimizer.zero_grad()
        self.grad_loss.backward()
        self.optimizer.step()
        if self.use_cosine_annealing:
            self.scheduler.step()

    def predict(self, p):
        output_pred = torch.max(p, dim=-1)[1]
        return output_pred