import torch.nn as nn
import torch
import torch.nn.functional as F
import math
import utils

class LINEAR(nn.Module):
    def __init__(self, input_dim, nclass, bias=True):
        super(LINEAR, self).__init__()
        self.fc = nn.Linear(input_dim, nclass, bias)

    def forward(self, x):
        o = self.fc(x)
        return o


class LINEAR_TO_COS_SIM(nn.Module):
    def __init__(self, weights):
        super(LINEAR, self).__init__()
        self.weights = weights
        self.cos = nn.functional.cosine_similarity(dim=1)

    def forward(self, x):
        out = []
        for sample in x:
            temp = []
            for weight in self.weights:
                temp.append(self.cos(weight, sample))
            out.append(torch.stack(temp))
        o = torch.stack(out)
        return o

class SemanticFusion(nn.Module):
    def __init__(self, d_e: int, d_a: int):
        super().__init__()
        self.WQ = nn.Linear(d_e, d_a, bias=False)
        self.WK = nn.Linear(d_e, d_a, bias=False)
        self.WV = nn.Linear(d_e, d_a, bias=False)
        self.scale = math.sqrt(d_a)

    def forward(self, Fy: torch.Tensor) -> torch.Tensor:
        Q = self.WQ(Fy)
        K = self.WK(Fy)
        V = self.WV(Fy) 

        attn = (Q @ K.transpose(-2, -1)) / self.scale
        attn = F.softmax(attn, dim=-1)
        H = attn @ V
        return H.mean(dim=1)


class VMFMixtureRouter(nn.Module):
    def __init__(self, d_sphere, K):
        super().__init__()
        self.d = d_sphere
        self.K = K
        self.mu = nn.Parameter(torch.randn(K, d_sphere))    
        self.log_kappa = nn.Parameter(torch.zeros(K))         
        self.pi_logits = nn.Parameter(torch.zeros(K))       

    def _logC(self, kappa):
        v = self.d / 2.0 - 1.0
        if hasattr(torch.special, "ive"): 
            ive = torch.special.ive(v, kappa)  
            logIv = torch.log(ive.clamp_min(1e-30)) + kappa.abs()
        elif hasattr(torch.special, "iv"):
            Iv = torch.special.iv(v, kappa)
            logIv = torch.log(Iv.clamp_min(1e-30))
        else:
            return torch.zeros_like(kappa)

        logC = (v) * torch.log(kappa.clamp_min(1e-8)) - (self.d / 2.0) * math.log(2.0 * math.pi) - logIv
        return logC

    def forward(self, Fprime):
        B, d = Fprime.shape
        assert d == self.d, f"Expected d={self.d}, got {d}"

        mu = utils.l2_normalize(self.mu, dim=-1)                   
        kappa = F.softplus(self.log_kappa) + 1e-6          
        pi = F.softmax(self.pi_logits, dim=-1)                

        dot = Fprime @ mu.t()                          
        logC = self._logC(kappa).unsqueeze(0)             
        log_pi = torch.log(pi.clamp_min(1e-12)).unsqueeze(0)    
        log_comp = log_pi + logC + dot * kappa.unsqueeze(0)    

        r = F.softmax(log_comp, dim=-1)      
        log_p = torch.logsumexp(log_comp, dim=-1)    
        return r, log_p

class CSDConfig:
    d_e: int     
    d_a: int     
    K: int      
    kmeans_iters: int = 30
    kmeans_seed: int = 0

class CSD(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.fusion = SemanticFusion(cfg.d_e, cfg.d_a)
        self.router = VMFMixtureRouter(cfg.d_a, cfg.K)
        self._inited = False

    @torch.no_grad()
    def initialize_with_spherical_kmeans(self, Fy_all):
        self.eval()
        Fhat = self.fusion(Fy_all)                      
        Fprime = utils.l2_normalize(Fhat, dim=-1)           
        mu0 = utils.spherical_kmeans_init(
            Fprime, K=self.cfg.K, iters=self.cfg.kmeans_iters, seed=self.cfg.kmeans_seed
        )                                              
        self.router.mu.copy_(mu0)
        self._inited = True

    def forward(self, Fy):
        Fhat = self.fusion(Fy)                          
        Fprime = utils.l2_normalize(Fhat, dim=-1)          
        r, log_p = self.router(Fprime)                 
        return Fhat, Fprime, r, log_p

    def nll(self, log_p):
        return (-log_p).mean()


class BottleneckAdapter(nn.Module):
    def __init__(self, d, bottleneck, dropout):
        super().__init__()
        self.down = nn.Linear(d, bottleneck)
        self.up = nn.Linear(bottleneck, d)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.up(self.dropout(F.relu(self.down(x))))


class GlobalTextEncoder(nn.Module):
    def __init__(self, d_in, d_latent):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(d_in, d_latent), nn.ReLU())

    def forward(self, x):
        return self.net(x)


class InvarianceProjector(nn.Module):
    def __init__(self, d_latent, d_causal, d_noncausal):
        super().__init__()
        self.Pc = nn.Linear(d_latent, d_causal)
        self.Pn = nn.Linear(d_latent, d_noncausal)

    def forward(self, z):
        return self.Pc(z), self.Pn(z)
    
class WeightDecoder(nn.Module):
    def __init__(self, d_in, d_out, hidden):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, hidden),
            nn.ReLU(),
            nn.Linear(hidden, d_out),
        )

    def forward(self, z):
        return self.net(z)

class ACSConfig:
    d_e: int             
    d_latent: int         
    d_causal: int
    d_noncausal: int
    d_w: int              
    K: int              
    alpha: float = 0.8    
    adapter_bottleneck_ratio: float = 0.25
    adapter_dropout: float = 0.0
    decoder_hidden: int = 4096


class ACS(nn.Module):
    def __init__(self, cfg: ACSConfig):
        super().__init__()
        self.cfg = cfg
        self.ET = GlobalTextEncoder(cfg.d_e, cfg.d_latent)
        bottleneck = max(1, int(cfg.d_latent * cfg.adapter_bottleneck_ratio))
        self.adapters = nn.ModuleList([
            BottleneckAdapter(cfg.d_latent, bottleneck, dropout=cfg.adapter_dropout)
            for _ in range(cfg.K)
        ])

        self.inv_proj = InvarianceProjector(cfg.d_latent, cfg.d_causal, cfg.d_noncausal)
        self.DW = WeightDecoder(cfg.d_causal + cfg.d_noncausal, cfg.d_w, hidden=cfg.decoder_hidden)

    def _route_adapters(self, h, r):
        B, m, d = h.shape
        alpha = self.cfg.alpha

        A = torch.stack([ad(h) for ad in self.adapters], dim=2)
        base = alpha * h.unsqueeze(2) + (1.0 - alpha) * A  
        r_exp = r.unsqueeze(1).unsqueeze(-1)              
        z = (base * r_exp).sum(dim=2)                    
        return z

    def invariance_loss(self, z):
        B, m, _ = z.shape
        zc, zn = self.inv_proj(z)  
        perm = torch.randperm(m, device=z.device)
        zn_swapped = zn[:, perm, :]
        Z1 = torch.cat([zc, zn], dim=-1)          
        Z2 = torch.cat([zc, zn_swapped], dim=-1)  
        w1 = self.DW(Z1) 
        w2 = self.DW(Z2)  
        return utils.cosine_distance(w1, w2).mean()

    def forward(self, Fy, r):
        h = self.ET(Fy)                      
        z = self._route_adapters(h, r)       
        inv = self.invariance_loss(z)         
        zc, zn = self.inv_proj(z)
        Z = torch.cat([zc, zn], dim=-1)       
        w_sources = self.DW(Z)               
        w_mean = w_sources.mean(dim=1)      
        return {"w_pred_sources": w_sources, "w_pred_mean": w_mean, "inv_loss": inv}


class ATT_AUTOENCODER_INV(nn.Module):
    def __init__(self, opt, input_dim, embed_dim, att_dim=312, output_dim=None, wordemb_dim=512):
        super(ATT_AUTOENCODER_INV, self).__init__()
        self.opt = opt
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.att_dim = att_dim
        self.output_dim = output_dim
        self.wordemb_dim = wordemb_dim
        self.attention_dim = self.opt.attention_dim

        self.embed_dim = [embed_dim, embed_dim]
        if output_dim is None:
            self.output_dim = input_dim
        self.encoder_merge = nn.Sequential(
            nn.Linear(self.att_dim+self.wordemb_dim, self.embed_dim[0]),
            nn.ReLU(inplace=True)
        )
        self.encoder_merge1 = nn.Sequential(
            nn.Linear(self.att_dim+self.attention_dim, self.embed_dim[0]),
            nn.ReLU(inplace=True)
        )
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, self.embed_dim[0]),
            nn.ReLU(inplace=True)
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.embed_dim[1], 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, self.output_dim)
        )
        self.attention_fc = nn.Linear(self.wordemb_dim, 1)
        self.beta = nn.Parameter(torch.tensor(0.2))
        if self.opt.conclude_inv:
            self.inv_attention = nn.Linear(self.wordemb_dim, self.attention_dim)

        self.q_proj = nn.Linear(self.wordemb_dim, self.attention_dim)
        self.k_proj = nn.Linear(self.wordemb_dim, self.attention_dim)
        self.v_proj = nn.Linear(self.wordemb_dim, self.attention_dim)
        self.out_proj = nn.Linear(self.attention_dim, self.attention_dim)

        if not self.opt.inv_merge:
            self.q_proj_inv = nn.Linear(self.wordemb_dim, self.attention_dim)
            self.k_proj_inv = nn.Linear(self.wordemb_dim, self.attention_dim)
            self.v_proj_inv = nn.Linear(self.wordemb_dim, self.attention_dim)
            self.out_proj_inv = nn.Linear(self.attention_dim, self.attention_dim)

    def scfa_pooling(self, gpt_emb):
        gpt_emb = gpt_emb.permute(0, 2, 1)
        attn_weights = self.attention_fc(gpt_emb).squeeze(-1)
        attn_weights = F.softmax(attn_weights, dim=1)
        weighted_gpt_emb = torch.sum(gpt_emb * attn_weights.unsqueeze(-1), dim=1)
        return weighted_gpt_emb

    def cross_attention(self, query, context):
        q = self.q_proj(query) 
        k = self.k_proj(context) 
        v = self.v_proj(context)  
        attn_scores = torch.bmm(q, k.transpose(1, 2))
        scale_factor = math.sqrt(self.attention_dim)
        attn_scores = attn_scores / scale_factor
        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.bmm(attn_weights, v)
        output = self.out_proj(output)
        output = output.mean(dim=1)
        return output
    
    def cross_attention_inv(self, query, context):
        q = self.q_proj_inv(query) 
        k = self.k_proj_inv(context) 
        v = self.v_proj_inv(context)  
        attn_scores = torch.bmm(q, k.transpose(1, 2))
        scale_factor = math.sqrt(self.attention_dim)
        attn_scores = attn_scores / scale_factor
        attn_weights = F.softmax(attn_scores, dim=-1)
        output = torch.bmm(attn_weights, v)
        output = self.out_proj_inv(output)
        output = output.mean(dim=1)
        return output

    def encode(self, x, x_inv=None, flag=False):

        self.attribute_f = x

        att_emb = self.attribute_f[:, :self.att_dim]
        desc_emb = self.attribute_f[:, self.att_dim:self.att_dim+self.wordemb_dim]
        gpt_emb = self.attribute_f[:, self.att_dim+self.wordemb_dim:].view(x.shape[0], -1, self.wordemb_dim)

        self.attribute_inv = x_inv

        if self.opt.inv_merge:
            att_emb_inv = self.attribute_inv[:, :self.att_dim]
            desc_emb_inv = self.attribute_inv[:, self.att_dim:self.att_dim+self.wordemb_dim]
        else:
            att_emb_inv = self.attribute_inv[:, :self.att_dim]
            desc_emb_inv = self.attribute_inv[:, self.att_dim:].view(x.shape[0], -1, self.wordemb_dim)
        

        if self.opt.factual_branch == 'attention':
            fused_context = self.cross_attention(desc_emb_inv, gpt_emb)
            x = torch.cat([att_emb, fused_context], dim=1)
            if not flag:
                desc_emb_inv = self.inv_attention(desc_emb_inv)
                x_inv = torch.cat([att_emb_inv, desc_emb_inv], dim=1)
                x = torch.cat([x, x_inv], dim=0)

            return self.encoder_merge1(x)

        elif self.opt.factual_branch == 'mean':
            gpt_emb_t = gpt_emb.permute(0, 2, 1)
            gpt_emb_mean = self.scfa_pooling(gpt_emb_t)

            desc_emb = self.beta*desc_emb + (1-self.beta)*gpt_emb_mean
            desc_emb = self.inv_attention(desc_emb)
            x = torch.cat([att_emb, desc_emb], dim=1)
            
            if not flag:
                x_inv = torch.cat([att_emb_inv, desc_emb_inv], dim=1)
                x = torch.cat([x, x_inv], dim=0)
            
            return self.encoder_merge1(x)
        
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x, x_inv=None, flag=True):
        z = self.encode(x, x_inv, flag)
        return self.decode(z)


class ATT_AUTOENCODER(nn.Module):
    def __init__(self, opt, input_dim, embed_dim, att_dim=312, output_dim=None, wordemb_dim=512):
        super(ATT_AUTOENCODER, self).__init__()
        self.opt = opt
        self.input_dim = input_dim
        self.embed_dim = embed_dim
        self.att_dim = att_dim
        self.output_dim = output_dim
        self.wordemb_dim = wordemb_dim
        self.attention_dim = self.opt.attention_dim

        self.embed_dim = [embed_dim, embed_dim]
        if output_dim is None:
            self.output_dim = input_dim
        self.encoder_merge = nn.Sequential(
            nn.Linear(self.att_dim+self.wordemb_dim, self.embed_dim[0]),
            nn.ReLU(inplace=True)
        )
        self.encoder_merge1 = nn.Sequential(
            nn.Linear(self.att_dim+self.attention_dim, self.embed_dim[0]),
            nn.ReLU(inplace=True)
        )
        if self.opt.conclude_inv:
            self.encoder_merge = nn.Sequential(
                nn.Linear((self.att_dim+self.wordemb_dim)*2, self.embed_dim[0]),
                nn.ReLU(inplace=True)
            )
            self.encoder_merge1 = nn.Sequential(
                nn.Linear((self.att_dim+self.attention_dim)*2, self.embed_dim[0]),
                nn.ReLU(inplace=True)
            )
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, self.embed_dim[0]),
            nn.ReLU(inplace=True)
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.embed_dim[1], 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, self.output_dim)
        )
        self.attention_fc = nn.Linear(self.wordemb_dim, 1)
        self.beta = nn.Parameter(torch.tensor(0.2))

        if self.opt.conclude_inv:
            self.inv_attention = nn.Linear(self.wordemb_dim, self.attention_dim)
        self.query_layer = nn.Linear(self.wordemb_dim, self.attention_dim)
        self.query_layer_att = nn.Linear(self.att_dim, self.attention_dim)
        self.key_layer = nn.Linear(self.wordemb_dim, self.attention_dim)
        self.value_layer = nn.Linear(self.wordemb_dim, self.attention_dim)

    def scfa_pooling(self, gpt_emb):
        gpt_emb = gpt_emb.permute(0, 2, 1)
        attn_weights = self.attention_fc(gpt_emb).squeeze(-1)
        attn_weights = F.softmax(attn_weights, dim=1)
        weighted_gpt_emb = torch.sum(gpt_emb * attn_weights.unsqueeze(-1), dim=1)
        return weighted_gpt_emb

    def compute_cosed(self, q, k):
        cs = F.cosine_similarity(q, k, dim=-1)
        ed = torch.norm(q - k, p=2, dim=-1)
        cosed = cs * ed
        return cosed

    def encode(self, x):
        original_dim = self.att_dim + (self.opt.view_num + 1) * self.wordemb_dim

        self.attribute_f = x[:, :original_dim]
        if self.opt.conclude_inv:
            self.attribute_inv = x[:, original_dim:]
        

        att_emb = self.attribute_f[:, :self.att_dim]
        desc_emb = self.attribute_f[:, self.att_dim:self.att_dim+self.wordemb_dim]
        gpt_emb = self.attribute_f[:, self.att_dim+self.wordemb_dim:].view(x.shape[0], -1, self.wordemb_dim)


        if self.opt.conclude_inv:
            att_emb_inv = self.attribute_inv[:, :self.att_dim]
            desc_emb_inv = self.attribute_inv[:, self.att_dim:self.att_dim+self.wordemb_dim]
            gpt_emb_inv = self.attribute_inv[:, self.att_dim+self.wordemb_dim:].view(x.shape[0], -1, self.wordemb_dim)
        

        if self.opt.factual_branch == 'attention':
            desc_emb_expanded = desc_emb.unsqueeze(1).expand(-1, self.opt.view_num, -1)
            q = self.query_layer(desc_emb_expanded)
            k = self.key_layer(gpt_emb)
            v = self.value_layer(gpt_emb)
            attn_scores = self.compute_cosed(q, k)
            attn_weights = torch.softmax(attn_scores, dim=-1)
            context = torch.matmul(attn_weights, v)
            fused_context = context.mean(dim=1)
            x = torch.cat([att_emb, fused_context], dim=1)
            if self.opt.conclude_inv:
                desc_emb_inv = self.inv_attention(desc_emb_inv)
                x_inv = torch.cat([att_emb_inv, desc_emb_inv], dim=1)
                x = torch.cat([x, x_inv], dim=1)
            return self.encoder_merge1(x)

        elif self.opt.factual_branch == 'mean':
            gpt_emb = gpt_emb.permute(0, 2, 1)
            gpt_emb = self.scfa_pooling(gpt_emb)

            x = torch.cat([att_emb, self.beta*desc_emb + (1-self.beta)*gpt_emb], dim=1)
            
            if self.opt.conclude_inv:
                x_inv = torch.cat([att_emb_inv, desc_emb_inv], dim=1)
                x = torch.cat([x, x_inv], dim=1)
            
            return self.encoder_merge(x)
        
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x, x_inv=None, flag=False):
        z = self.encode(x)
        return self.decode(z)

class AUTOENCODER(nn.Module):
    def __init__(self, opt, input_dim, embed_dim, output_dim=None, num_layers=2):
        super(AUTOENCODER, self).__init__()
        self.opt = opt
        self.input_dim = input_dim
        self.output_dim = output_dim
        if output_dim is None:
            self.output_dim = input_dim
        self.embed_dim = [embed_dim, embed_dim]
        if num_layers == 2:
            self.encoder = nn.Sequential(
                nn.Linear(self.input_dim, self.embed_dim[0]),
                nn.ReLU(inplace=True)
            )

            self.decoder = nn.Sequential(
                nn.Linear(self.embed_dim[1], self.output_dim)
            )

            if self.opt.conclude_inv and not self.opt.concatenation:
                self.decoder_inv = nn.Sequential(
                    nn.Linear(self.embed_dim[1], self.output_dim)
                )
        if num_layers == 3:
            self.encoder = nn.Sequential(
                nn.Linear(self.input_dim, self.embed_dim[0]),
                nn.ReLU(inplace=True)
            )

            self.decoder = nn.Sequential(
                nn.Linear(self.embed_dim[1], 4096),
                nn.ReLU(inplace=True),
                nn.Linear(4096, self.output_dim)
            )
        if num_layers == 4:
            self.encoder = nn.Sequential(
                nn.Linear(self.input_dim, self.embed_dim[0]),
                nn.ReLU(inplace=True),
                nn.Linear(self.embed_dim[0], self.embed_dim[0]),
                nn.ReLU(inplace=True)
            )

            self.decoder = nn.Sequential(
                nn.Linear(self.embed_dim[1], 1000),
                nn.ReLU(inplace=True),
                nn.Linear(1000, self.output_dim)
            )

    def encode(self, x, x_inv=None):
        if self.opt.concatenation:
            x=torch.cat((x, x_inv), dim=0)
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)
    
    def decode_inv(self, x):
        return self.decoder_inv(x)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z)

class JOINT_AUTOENCODER(nn.Module):
    def __init__(self, opt, autoencoder1, autoencoder2):
        super(JOINT_AUTOENCODER, self).__init__()
        self.ae1 = autoencoder1
        self.ae2 = autoencoder2
        self.opt = opt

    def encode1(self, x, x_inv=None, flag=False):
        if self.opt.concatenation:
            return self.ae1.encode(x, x_inv, flag)
        else:
            return self.ae1.encode(x)

    def encode2(self, x, x_inv=None):
        if self.opt.concatenation:
            return self.ae2.encode(x, x_inv)
        else:
            return self.ae2.encode(x)

    def decode1(self, x):
        return self.ae1.decode(x)

    def decode2(self, x):
        return self.ae2.decode(x)
    
    def decode_inv(self, x):
        return self.ae2.decode_inv(x)

    def forward(self, x):

        if self.opt.concatenation:
            att_in, weight_in, att_in_inv, weight_in_inv = x
        else:
            att_in, weight_in = x

        if self.opt.concatenation:
            latent_att = self.encode1(att_in, att_in_inv)
            latent_weight = self.encode2(weight_in, weight_in_inv)
        else:
            latent_att = self.encode1(att_in)
            latent_weight = self.encode2(weight_in)

        att_from_att = self.decode1(latent_att)
        att_from_weight = self.decode1(latent_weight)

        weight_from_weight = self.decode2(latent_weight)
        weight_from_att = self.decode2(latent_att)

        if self.opt.conclude_inv and not self.opt.concatenation:
            weight_from_att_inv = self.decode_inv(latent_att)
            return att_from_att, att_from_weight, weight_from_weight, weight_from_att, weight_from_att_inv, latent_att, latent_weight
        else:
            return att_from_att, att_from_weight, weight_from_weight, weight_from_att, latent_att, latent_weight

    def predict(self, x, x_inv=None, flag=True):
        latent_att = self.encode1(x, x_inv, flag)
        return self.decode2(latent_att)
