import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Beta, MultivariateNormal
import numpy as np
import pandas as pd
import os
from utils import compute_metric
from models.evi_clm import Evi_CLM
import math
from metrics.cas import concept_alignment_score
import torch.nn.init as init
def initialize_weights(module):
    if isinstance(module, nn.Linear):
        init.kaiming_normal_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
        if module.bias is not None:
            init.constant_(module.bias, 0)

class EC_CEM(Evi_CLM):
    def __init__(
        self,
        n_concepts,
        n_tasks,
        emb_size=16,
        embedding_activation="leakyrelu",
        c_extractor_arch="resnet34",
        optimizer="adam",
        learning_rate=0.01,
        weight_decay=4e-05,
        momentum=0.9,
        train_with_c_gt=False,
        concept_weight=None,
        interven_prob=0.25,
        concept_loss_weight=1,
        eps=1e-7,
        n_points=128,      
        icdf_iter=20,      
        test_interven=False,
        num_mc_samples=15  
    ):
        super().__init__(
            n_concepts,
            emb_size,
            embedding_activation,
            c_extractor_arch,
            optimizer,
            learning_rate,
            weight_decay,
            momentum,
            train_with_c_gt,
            concept_weight,
        )
        self.save_hyperparameters()
        self.n_concepts = n_concepts
        self.n_tasks = n_tasks
        self.emb_size = emb_size
        self.interven_prob = interven_prob
        self.concept_loss_weight = concept_loss_weight
        self.eps = eps
        self.n_points = n_points
        self.icdf_iter = icdf_iter
        self.test_interven = test_interven
        self.num_mc_samples = num_mc_samples  

        self.hidden_size = 128
        self.covariance_w = nn.Sequential(
            nn.Linear(512, 1),
            # nn.BatchNorm1d(1),
            nn.Dropout(0.2),
        )
        self.covariance_a = nn.Sequential(
            nn.Linear(512, self.n_concepts),
            nn.BatchNorm1d(self.n_concepts),
            nn.Dropout(0.2),
        )
        self.covariance_w.apply(initialize_weights)
        self.covariance_a.apply(initialize_weights)
        
        

        self.c2y_model = nn.Sequential(
            nn.Linear(n_concepts * emb_size, out_features=n_tasks)
        )
        self.loss_task = nn.CrossEntropyLoss()

        points, weights = self._get_gauss_legendre_points_weights(n_points)
        self.register_buffer('gl_points', torch.tensor(points, dtype=torch.float32))
        self.register_buffer('gl_weights', torch.tensor(weights, dtype=torch.float32))
        print("this is EC-CEM model")

    def _get_gauss_legendre_points_weights(self, n):
        points, weights = np.polynomial.legendre.leggauss(n)
        # 映射到[0, 1]
        points = 0.5 * (points + 1)
        weights = 0.5 * weights
        return points, weights

    def beta_pdf(self, x, alpha, beta):
        x = torch.clamp(x, self.eps, 1 - self.eps)
        log_beta = torch.lgamma(alpha) + torch.lgamma(beta) - torch.lgamma(alpha + beta)
        log_pdf = (alpha - 1) * torch.log(x) + \
                  (beta - 1) * torch.log(1 - x) - log_beta
        pdf = torch.exp(log_pdf)
        return pdf

    def beta_cdf(self, x, alpha, beta):
            
        points = self.gl_points.to(x.device) if hasattr(self, 'gl_points') else torch.linspace(0, 1, steps=100).to(x.device)
        weights = self.gl_weights.to(x.device) if hasattr(self, 'gl_weights') else torch.ones_like(points) * (1/points.numel())
        
        x_exp = x.unsqueeze(-1)
        alpha_exp = alpha.unsqueeze(-1)
        beta_exp = beta.unsqueeze(-1)
        
        t = points  
        x_mapped = t * x_exp  
        weights_scaled = weights * x_exp  
        
        # 计算 PDF
        log_beta_val = torch.lgamma(alpha_exp) + torch.lgamma(beta_exp) - torch.lgamma(alpha_exp + beta_exp)
        log_pdf = (alpha_exp - 1) * torch.log(x_mapped + self.eps) + \
                (beta_exp - 1) * torch.log(1 - x_mapped + self.eps) - log_beta_val
        pdf = torch.exp(log_pdf)
        
        integral = torch.sum(pdf * weights_scaled, dim=-1)
        
        return integral

    def beta_icdf(self, u, alpha, beta):
        u = torch.clamp(u, self.eps, 1 - self.eps)
        alpha = torch.clamp(alpha, min=self.eps)
        beta = torch.clamp(beta, min=self.eps)

        x = alpha / (alpha + beta)

        tol = 1e-6
        for _ in range(self.icdf_iter):
            cdf = self.beta_cdf(x, alpha, beta)
            pdf = self.beta_pdf(x, alpha, beta)

            delta = (cdf - u) / torch.clamp(pdf, min=self.eps)
            delta = torch.clamp(delta, -0.1, 0.1)  

            x_new = x - delta
            x_new = torch.clamp(x_new, self.eps, 1 - self.eps)

            if torch.max(torch.abs(x_new - x)) < tol:
                x = x_new
                break

            x = x_new

        return x

    def forward(self, x, c, train=False, test=False):
        pre_c = self.pre_concept_model(x)  
        L, Sigma = self.generate_covariance(pre_c)  
        contexts = []
        alpha_list = []
        beta_list = []
        for context_gen in self.concept_context_generators:
            context = context_gen(pre_c)
            contexts.append(context)
            alpha_gen = F.relu(self.alpha_gen(context)) + 1
            beta_gen = F.relu(self.beta_gen(context)) + 1
            alpha_list.append(alpha_gen)
            beta_list.append(beta_gen)

        alpha = torch.stack(alpha_list, dim=1).squeeze(-1)  
        beta = torch.stack(beta_list, dim=1).squeeze(-1)   

        dist = torch.distributions.Beta(alpha, beta)
        c_probs_samples = []
        c_probs_correlated_samples = []
        z_samples = []

        for _ in range(self.num_mc_samples):
            # Beta分布采样
            c_probs = dist.rsample()  
            c_probs_clamped = c_probs.clamp(self.eps, 1 - self.eps)
            
            # Copula转换
            u = self.beta_cdf(c_probs_clamped, alpha, beta)  
            u = torch.clamp(u, self.eps, 1 - self.eps)
            z = torch.distributions.Normal(0, 1).icdf(u)  
            
            # 应用相关性结构
            z_correlated = torch.bmm(L, z.unsqueeze(-1)).squeeze(-1)  
            
            # 转换回概率空间
            u_correlated = torch.distributions.Normal(0, 1).cdf(z_correlated)
            u_correlated = torch.clamp(u_correlated, self.eps, 1 - self.eps)
            c_probs_correlated = self.beta_icdf(u_correlated, alpha, beta)
            
            c_probs_samples.append(c_probs)
            c_probs_correlated_samples.append(c_probs_correlated)
            z_samples.append(z_correlated)
        
        # 堆叠所有样本
        c_probs_samples = torch.stack(c_probs_samples, dim=0)  # [num_mc_samples, B, n_concepts]
        c_probs_correlated_samples = torch.stack(c_probs_correlated_samples, dim=0)
        z_samples = torch.stack(z_samples, dim=0)

        #  计算概念表示
        contexts = torch.stack(contexts, dim=1)  # [B, n_concepts, 2*emb_size]
        contexts_pos = contexts[:, :, :self.emb_size]   
        contexts_neg = contexts[:, :, self.emb_size:]   

        if train:
            c_probs_mix_unsq = c_probs_correlated_samples.unsqueeze(-1)  # [num_mc_samples, B, n_concepts, 1]
            c_pred_samples = (contexts_pos.unsqueeze(0) * c_probs_mix_unsq + 
                            contexts_neg.unsqueeze(0) * (1 - c_probs_mix_unsq))  # [num_mc_samples, B, n_concepts, emb_size]
            c_pred_flat = c_pred_samples.view(-1, self.n_concepts * self.emb_size)
            y_logits = self.c2y_model(c_pred_flat)  # [num_mc_samples * B, n_tasks]
            y = y_logits.reshape(self.num_mc_samples, -1, self.n_tasks)  # [num_mc_samples, B, n_tasks]

            return (alpha, beta), y, c_probs_samples, c_probs_correlated_samples, c_pred_flat, z_samples, Sigma
        elif(train == False and test == False):     
            c_prob_final =  c_probs_correlated_samples.mean(dim=0) # [B, n_concepts]
            c_probs_mix_unsq = c_prob_final.unsqueeze(-1)  # [B, n_concepts, 1]
            c_pred_samples = (contexts_pos * c_probs_mix_unsq +
                            contexts_neg * (1 - c_probs_mix_unsq))  # [1, B, n_concepts, emb_size]
            c_pred_flat = c_pred_samples.view(-1, self.n_concepts * self.emb_size)
            y = self.c2y_model(c_pred_flat)  # [B, n_tasks]
            return (alpha, beta), y, c_probs_samples, c_prob_final, c_pred_flat, z_correlated, Sigma

        else:   
            c_probs =  alpha / (alpha + beta) # [B, n_concepts]
            c_probs_clamped = c_probs.clamp(self.eps, 1 - self.eps)
            
            u = self.beta_cdf(c_probs_clamped, alpha, beta)  # [B, n_concepts]
            u = torch.clamp(u, self.eps, 1 - self.eps)
            z = torch.distributions.Normal(0, 1).icdf(u)  # [B, n_concepts]
            
            z_correlated = torch.bmm(L, z.unsqueeze(-1)).squeeze(-1)  # [B, n_concepts]
            
            u_correlated = torch.distributions.Normal(0, 1).cdf(z_correlated)
            u_correlated = torch.clamp(u_correlated, self.eps, 1 - self.eps)
            c_probs_correlated = self.beta_icdf(u_correlated, alpha, beta)
            c_prob_final = c_probs_correlated
            
            c_probs_mix_unsq = c_prob_final.unsqueeze(-1)  # [B, n_concepts, 1]
            c_pred_samples = (contexts_pos * c_probs_mix_unsq +
                            contexts_neg * (1 - c_probs_mix_unsq))  # [1, B, n_concepts, emb_size]
            c_pred_flat = c_pred_samples.view(-1, self.n_concepts * self.emb_size)
            y = self.c2y_model(c_pred_flat)  # [B, n_tasks]
            return (alpha, beta), y, c_probs_samples, c_prob_final, c_pred_flat, z_correlated, Sigma


    def generate_covariance(self, h):

        w = F.softplus(self.covariance_w(h)) + 1.0  
        a = torch.tanh(self.covariance_a(h))
        
        I = torch.eye(self.n_concepts, device=h.device)
        I = I.unsqueeze(0).expand(h.size(0), -1, -1)
        
        aa_T = torch.bmm(a.unsqueeze(2), a.unsqueeze(1))
        
        eps = 1e-4
        
        Sigma = w.unsqueeze(-1) * I + aa_T + eps * I
        
        std = torch.sqrt(torch.diagonal(Sigma, dim1=1, dim2=2))
        std = torch.clamp(std, min=eps)  
        std_outer = torch.bmm(std.unsqueeze(2), std.unsqueeze(1))
        
        Corr = Sigma / std_outer
        
        Corr = 0.99 * Corr + 0.01 * I  
        
        try:
            L = torch.linalg.cholesky(Corr)
        except:
            Corr = Corr + 0.1 * I
            L = torch.linalg.cholesky(Corr)
        
        return L, Corr
    def copula_mc(self, Sigma_q, epsilon=1e-4):

        # E[log c_R(u)] = 0.5 * (tr(R) - d - log(det(R)))
        # res = -0.5 * logdet_term
        sign, logdet = torch.linalg.slogdet(Sigma_q)
        res = -0.5 * logdet          # [B]
        return res.mean() / self.n_concepts  

    
    def _run_step(self, batch, train, test=False):
        x, y, c, soft_c, sample_id = batch
        
        gamma, y_logits, c_probs_samples, c_probs_correlated_samples, c_pred_samples, z_samples, Sigma = self.forward(x, c, train, test)

        alpha, beta = gamma
        if train:
            final_prob = c_probs_correlated_samples.reshape(-1, self.n_concepts)
            S, B, K = c_probs_correlated_samples.shape
            task_loss = self.loss_task(y_logits.reshape(-1, self.n_tasks),y.expand(S, -1).reshape(-1))
        else:
            final_prob = c_probs_correlated_samples
            task_loss = self.loss_task(y_logits, y)

        task_loss_scalar = task_loss.detach()
        concept_labels = c if self.train_with_c_gt else soft_c
        origin_concept_loss = self.loss_concept(gamma, concept_labels)
        origin_concept_loss_scalar = origin_concept_loss.detach()
        kl_beta = self.kl_loss(gamma, c)
        kl_beta_scalar = kl_beta.detach()

        kl_copula = self.copula_mc(Sigma)
        kl_copula_scalar = kl_copula.detach()


        # # CUB
        lambda_marginal = min(1.0, self.current_epoch / 20) # 10
        lambda_1 = 1
        lambda_copula = 1.0 if self.current_epoch > 50 else 0.0
        
        loss = self.total_loss_func(
            task_loss, 
            origin_concept_loss,
            kl_beta, 
            kl_copula, 
            lambda_marginal, 
            lambda_1,
            lambda_copula
        )
        
        if test:
            with torch.no_grad():
                c_pred_reshaped = c_pred_samples.view(c_pred_samples.size(0), self.n_concepts, self.emb_size)
                concept_auc, task_auc = concept_alignment_score(
                    c_vec=c_pred_reshaped.detach().cpu().numpy(),
                    c_test=c.cpu().numpy(),
                    y_test=y.cpu().numpy(),
                    step=50,
                    force_alignment=False,
                    progress_bar=False
                )
        else:
            concept_auc = 0
            task_auc = 0
        if train:
            (c_acc, c_auc, c_f1), (y_acc, y_f1) = compute_metric(
                c_probs_correlated_samples.mean(dim=0),
                 y_logits.reshape(-1, self.n_tasks), 
                c, 
                y.expand(S, -1).reshape(-1)
            )
        else:
            (c_acc, c_auc, c_f1), (y_acc, y_f1) = compute_metric(
                c_probs_correlated_samples,
                y_logits, 
                c, 
                y
            )
        if train:
            print(
                f"Task Loss: {task_loss.item():.4f}, "
                f"Origin Concept Loss: {origin_concept_loss_scalar.item():.4f}, "
                f"KL Beta: {kl_beta_scalar.item():.4f}, "
                f"KL Copula: {kl_copula_scalar.item():.4f}, "
                f"Concept_embed_AUC: {concept_auc:.4f}, "
                f"Task_embed_AUC: {task_auc:.4f}"
            )
        else:
            print(
                f"Task Loss: {task_loss.item():.4f}, "
                f"Origin Concept Loss: {origin_concept_loss_scalar.item():.4f}, "
                f"Concept_embed_AUC: {concept_auc:.4f}, "
                f"Task_embed_AUC: {task_auc:.4f}"
            )
            
            
        if test:
            c_probs = gamma[0] / (gamma[0] + gamma[1])
            c_probs_before = c_probs  # [B, n_concepts]
            predict_label = torch.argmax(y_logits, dim=1)

        result = {
            "c_acc": c_acc,
            "c_auc": c_auc,
            "c_f1": c_f1,
            "y_acc": y_acc,
            "y_f1": y_f1,
            "origin_concept_loss": origin_concept_loss_scalar,
            "kl_beta_loss": kl_beta_scalar,
            "kl_copula_loss": kl_copula_scalar,
            "task_loss": task_loss_scalar,
            "loss": loss.detach(),
            "Concept_embed_AUC": concept_auc,
            "Task_embed_AUC": task_auc
        }
        return loss, result
    def _after_interventions(self, prob, c_true):
        mask = torch.bernoulli(torch.ones_like(prob) * self.interven_prob).to(prob.device)
        interven_idxs = mask.float()
        return prob * (1 - interven_idxs) + interven_idxs * c_true

    

    def total_loss_func(self, task_loss, origin_concept_loss, kl_beta,  kl_copula, lambda_marginal, lambda_1, lambda_copula=0):

        return origin_concept_loss + lambda_marginal * (kl_beta) + lambda_1* task_loss + lambda_copula * kl_copula