import torch
import torch.nn as nn
import torch.nn.functional as F
from ..SubNets.FeatureNets import BERTEncoder
from ..SubNets.transformers_encoder.transformer import TransformerEncoder
from data import benchmarks
import json
def load_intent_concept_relations(config):
    with open(config.weight_path['paths'][0], 'r') as f:
        text_relations = json.load(f)

    with open(config.weight_path['paths'][1], 'r') as f:
        audio_relations = json.load(f)

    with open(config.weight_path['paths'][2], 'r') as f:
        video_relations = json.load(f)

    categories = benchmarks[config.dataset]['intent_labels']
    
    text_matrix = []
    audio_matrix = []
    video_matrix = []
    
    for category in categories:
        if category in text_relations and len(text_relations[category]) > 0:
            text_matrix.append(text_relations[category])
        else:
            text_matrix.append([0.0] * 100)  
            
        if category in audio_relations and len(audio_relations[category]) > 0:
            audio_matrix.append(audio_relations[category])
        else:
            audio_matrix.append([0.0] * 100)
            
        if category in video_relations and len(video_relations[category]) > 0:
            video_matrix.append(video_relations[category])
        else:
            video_matrix.append([0.0] * 100)
            print(f"Warning: {category} has no video relations, using zero vector.")

    return (
        torch.tensor(text_matrix, dtype=torch.float32),
        torch.tensor(audio_matrix, dtype=torch.float32),
        torch.tensor(video_matrix, dtype=torch.float32)
    )

class ffn(nn.Module):
    def __init__(self, input_dim, output_dim, dropout=0.1):
        super(ffn, self).__init__()
        self.linear1 = nn.Linear(input_dim, output_dim)
        self.dropout = nn.Dropout(dropout)


    def forward(self, x):
        x = self.linear1(x)
        x = self.dropout(x)

        return x    
    

class ConceptWeightPredictor(nn.Module):
    def __init__(self, feature_dim=1024, num_concepts=100, hidden_dim=512):
        super(ConceptWeightPredictor, self).__init__()
        self.weight_net = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, 1),
        )
        self.layer_norm = nn.LayerNorm(num_concepts)
        self.activation = nn.Tanh()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)
    
    def forward(self, features):
        weights = self.weight_net(features).squeeze(-1)  # [B, num_concepts]
        weights = self.layer_norm(weights) 
        weights = self.activation(weights)

        return weights


class ConceptFusionModel(nn.Module):
    def __init__(self, args):
        super(ConceptFusionModel, self).__init__()
        print(args.text_backbone_path)
        self.bertmodel = BERTEncoder.from_pretrained(args.text_backbone_path).to(args.device)
        self.concept_T = nn.Parameter(torch.randn(args.t_concept_num, args.text_feat_dim)) 
        self.concept_V = nn.Parameter(torch.randn(args.v_concept_num, args.text_feat_dim)) 
        self.concept_A = nn.Parameter(torch.randn(args.a_concept_num, args.text_feat_dim)) 
        self.align_a = ffn(args.audio_feat_dim, args.text_feat_dim) 
        self.align_v = ffn(args.video_feat_dim, args.text_feat_dim) 
        
        self.enc_t = TransformerEncoder(
            embed_dim=args.text_feat_dim,
            num_heads=args.num_head,
            layers=args.t_num_layers,
            attn_dropout=args.attn_dropout,
            relu_dropout=args.relu_dropout,
            embed_dropout=args.embed_dropout,
            attn_mask=args.use_mask
        )

        self.enc_v = TransformerEncoder(
            embed_dim=args.video_feat_dim,
            num_heads=args.num_head,
            layers=args.v_num_layers,
            attn_dropout=args.attn_dropout,
            relu_dropout=args.relu_dropout,
            embed_dropout=args.embed_dropout,
            attn_mask=args.use_mask
        )
        self.enc_a = TransformerEncoder(
            embed_dim=args.audio_feat_dim,
            num_heads=args.num_head,
            layers=args.a_num_layers,
            attn_dropout=args.attn_dropout,
            relu_dropout=args.relu_dropout,
            embed_dropout=args.embed_dropout,
            attn_mask=args.use_mask
        )
        self.tsfm = TransformerEncoder(
            embed_dim=args.text_feat_dim,  
            num_heads=args.num_head,
            layers=args.concept_num_layers,
            attn_dropout=args.attn_dropout,
            relu_dropout=args.relu_dropout,
            embed_dropout=args.embed_dropout,
            attn_mask=False
        )
        self.tsfm.embed_positions = None
        
        self.final_dropout = nn.Dropout(args.final_dropout)
        self.classifier = nn.Linear(args.text_feat_dim, args.num_labels)
        self.t_concept_num = args.t_concept_num
        self.a_concept_num = args.a_concept_num
        self.v_concept_num = args.v_concept_num
        self.total_concept_num = self.t_concept_num + self.a_concept_num + self.v_concept_num
        self.score_classifier = nn.Linear(self.total_concept_num, args.num_labels)

        self.text_weight_predictor = ConceptWeightPredictor(args.text_feat_dim, self.t_concept_num)
        self.audio_weight_predictor = ConceptWeightPredictor(args.text_feat_dim, self.a_concept_num)
        self.video_weight_predictor = ConceptWeightPredictor(args.text_feat_dim, self.v_concept_num)

        text_relations, audio_relations, video_relations = load_intent_concept_relations(args)
        
        self.register_buffer('target_text_weights', text_relations)     
        self.register_buffer('target_audio_weights', audio_relations)   
        self.register_buffer('target_video_weights', video_relations)   
    
    def forward(self, t_feats, a_feats, v_feats, labels=None, return_weights=False):
        t_feats = self.bertmodel(t_feats).last_hidden_state  
        t_feats = self.enc_t(t_feats.transpose(0,1)).transpose(0,1).mean(dim=1)  
        a_feats = self.enc_a(a_feats.transpose(0,1)).transpose(0,1).mean(dim=1)  
        v_feats = self.enc_v(v_feats.transpose(0,1)).transpose(0,1).mean(dim=1)  
        a_feats = self.align_a(a_feats)  
        v_feats = self.align_v(v_feats)  


        total_concepts = torch.cat((self.concept_T, self.concept_A, self.concept_V), dim=0)
        total_concepts = torch.softmax(total_concepts, dim=0)
        concept_T = total_concepts[:self.t_concept_num, :]
        concept_A = total_concepts[self.t_concept_num:self.t_concept_num + self.a_concept_num, :]
        concept_V = total_concepts[self.t_concept_num + self.a_concept_num:, :]

        text_concepts = F.normalize(t_feats.unsqueeze(1), dim=2) * F.normalize(concept_T.unsqueeze(0), dim=2)
        audio_concepts = F.normalize(a_feats.unsqueeze(1), dim=2) * F.normalize(concept_A.unsqueeze(0), dim=2)
        video_concepts = F.normalize(v_feats.unsqueeze(1), dim=2) * F.normalize(concept_V.unsqueeze(0), dim=2)

        score_t = text_concepts.sum(dim=2) 
        score_a = audio_concepts.sum(dim=2)
        score_v = video_concepts.sum(dim=2)
        
        
        concept_weights_t = self.text_weight_predictor(text_concepts)    
        concept_weights_a = self.audio_weight_predictor(audio_concepts)  
        concept_weights_v = self.video_weight_predictor(video_concepts)  
              

        weighted_score_t = score_t * concept_weights_t
        weighted_score_a = score_a * concept_weights_a
        weighted_score_v = score_v * concept_weights_v
        
        
        total_score = torch.cat((weighted_score_t, weighted_score_a, weighted_score_v), dim=1)  
        logits = self.score_classifier(total_score)

        text_concepts = text_concepts * torch.sigmoid(concept_weights_t).unsqueeze(2)  
        audio_concepts = audio_concepts * torch.sigmoid(concept_weights_a).unsqueeze(2)
        video_concepts = video_concepts * torch.sigmoid(concept_weights_v).unsqueeze(2)
        total_concepts = torch.cat((text_concepts, audio_concepts, video_concepts), dim=1)

        total_concepts = total_concepts.transpose(0, 1) 
        total_concepts = self.tsfm(total_concepts)      
        total_concepts = total_concepts.transpose(0, 1) 
        total_concepts = total_concepts.mean(dim=1) 
        
        logits += self.classifier(self.final_dropout(total_concepts))
        
        if return_weights:
            return logits, concept_weights_t, concept_weights_a, concept_weights_v, score_t, score_a, score_v
        else:
            return logits, concept_weights_t, concept_weights_a, concept_weights_v
    def _get_main_params(self):
        main_params = []

        weight_predictor_param_ids = set()
        for param in self.text_weight_predictor.parameters():
            weight_predictor_param_ids.add(id(param))
        for param in self.audio_weight_predictor.parameters():
            weight_predictor_param_ids.add(id(param))
        for param in self.video_weight_predictor.parameters():
            weight_predictor_param_ids.add(id(param))

        concept_param_ids = set()
        for name, param in self.named_parameters():
            if 'concept' in name:
                concept_param_ids.add(id(param))

        for param in self.parameters():
            if id(param) not in weight_predictor_param_ids and id(param) not in concept_param_ids:
                main_params.append(param)
        
        return main_params
    
    def _get_concept_params(self):
        concept_params = []
        for name, param in self.named_parameters():
            if 'concept' in name:
                concept_params.append(param)
        return concept_params
    
    def _get_weight_predictor_params(self):
        return (list(self.text_weight_predictor.parameters()) + 
                list(self.audio_weight_predictor.parameters()) + 
                list(self.video_weight_predictor.parameters()))