import clip
import torch
import torch.nn as nn
import torch.nn.functional as F

class LearnablePrompt(nn.Module):
    def __init__(self, model_name="ViT-L/14", 
                 num_learnable_tokens=5, device="cuda"):
        super().__init__()
        self.device = device
        self.pos_prompts = [ 
            "a real human face",
            "a bonafide face with expressive eyes",
            "a genuine face with natural mouth"
             
        ]
        self.neg_prompts = [
            "a fake human face",
            "a spoof face with dull eyes", 
            "a forged face with unnatural mouth" 
        ]
        self.num_learnable_tokens = num_learnable_tokens
        self.clip_model, _ = clip.load(model_name, device=device, jit=False)
        self.clip_model.float()
        for p in self.clip_model.parameters():
            p.requires_grad = False

        self.text_dim = self.clip_model.token_embedding.embedding_dim
        self.learnable_tokens = nn.Parameter(
            0.01 * torch.randn(1, num_learnable_tokens, self.text_dim, device=device)
        )

    def encode_prompts(self, prompt_list):
        text_features = []
        for base_prompt in prompt_list:
            tokenized = clip.tokenize([base_prompt]).to(self.device)
            with torch.no_grad():
                base_embedding = self.clip_model.token_embedding(tokenized).float()
            rest = base_embedding[:, self.num_learnable_tokens:, :].clone()
            text_embedding = torch.cat([self.learnable_tokens, rest], dim=1)
            pos_emb = self.clip_model.positional_embedding[:text_embedding.size(1), :].float()
            x = (text_embedding + pos_emb).permute(1, 0, 2)
            x = self.clip_model.transformer(x)
            x = x.permute(1, 0, 2)
            x = self.clip_model.ln_final(x)
            feat = x[:, -1, :]
            feat = feat / feat.norm(dim=-1, keepdim=True)
            text_features.append(feat)
        fused = torch.stack(text_features, dim=0).mean(dim=0)
        return fused / fused.norm(dim=-1, keepdim=True)

    def forward(self):
        pos_feature = self.encode_prompts(self.pos_prompts)
        neg_feature = self.encode_prompts(self.neg_prompts)
        return pos_feature, neg_feature

class MetaWeightGenerator(nn.Module):
    def __init__(self):
        super(MetaWeightGenerator, self).__init__()
        self.meta_net = nn.Sequential(
            nn.Linear(512*3, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 3),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, visual, audio, kv):
        combined = torch.cat((visual, audio, kv), dim=-1)
        return self.meta_net(combined)


class TAVFD(nn.Module):
    def __init__(self, feature_dim=1024, gl_dim=768, hidden_dim=1024, num_heads=8, tm_weights=None):
        super(TAVFD, self).__init__()
        self.feature_dim = feature_dim
        self.text_proj = nn.Linear(768, 512)
        self.gl_proj = nn.Linear(gl_dim, 512)
        self.audio_proj = nn.Linear(feature_dim, 512)
        self.visual_proj = nn.Linear(feature_dim, 512)
        self.tg_proj = nn.Linear(feature_dim, 512)
        self.output_fc = nn.Sequential(
            nn.Linear(1024 + 512, 512),
            nn.LayerNorm(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        self.text_feature = LearnablePrompt()
        self.meta_weight_generator = MetaWeightGenerator(visual_dim=512, audio_dim=512, hidden_dim=512)
        self.task_embedding = nn.Embedding(1, 3) 
        if tm_weights is not None:
            self.task_embedding.weight.data[0] = torch.tensor(tm_weights, dtype=torch.float32)
        else:
            self.task_embedding.weight.data[0] = torch.tensor([0.1, 0.1, -0.1], dtype=torch.float32)

    def forward(self, visual_features, audio_features, gl_features=None):
        tx_pos, tx_neg = self.text_feature()
        B, T, _ = visual_features.shape
        tx_features = tx_pos.unsqueeze(0).repeat(B, 1, 1)
        tx_features_neg = tx_neg.unsqueeze(0).repeat(B, 1, 1)

        audio_features = self.audio_proj(audio_features)
        visual_features = self.visual_proj(visual_features)
        tx_features = self.text_feature(tx_features.float()).expand(-1, T, -1)
        tx_features_neg = self.text_feature(tx_features_neg.float()).expand(-1, T, -1)
        gl_features = self.gl_proj(gl_features).expand(-1, T, -1)

        tg = torch.cat([tx_features, gl_features], dim=-1)
        tg = self.tg_proj(tg) 

        weights = self.meta_weight_generator(visual_features, audio_features, tg) 
        task_embed = self.task_embedding(torch.tensor(0, device=weights.device)).view(1,1,3)
        weights = F.softmax(weights + task_embed, dim=-1)
        fusion_input = torch.cat([
            weights[..., 0].unsqueeze(-1) * audio_features,
            weights[..., 1].unsqueeze(-1) * visual_features,
            weights[..., 2].unsqueeze(-1) * tg
        ], dim=-1) 
 
        output = self.output_fc(fusion_input)
        return output, tx_features, gl_features, tx_features_neg