import torch
import json
import os

def load_intent_concept_relations(config):
    with open('concepts_relevants/text_concepts_relevants.json', 'r') as f:
        text_relations = json.load(f)

    with open('concepts_relevants/audio_concepts_relevants.json', 'r') as f:
        audio_relations = json.load(f)

    with open('concepts_relevants/video_concepts_relevants.json', 'r') as f:
        video_relations = json.load(f)

    categories = [
        "Complain", "Inform", "Praise", "Apologise", "Thank", "Advise", "Criticize",
        "Arrange", "Introduce", "Care", "Comfort", "Leave", "Prevent", "Taunt",
        "Greet", "Agree", "Flaunt", "Oppose", "Joke", "Ask for help"
    ]
    
    text_matrix = []
    audio_matrix = []
    video_matrix = []
    
    for category in categories:
        if category in text_relations and len(text_relations[category]) > 0:
            text_matrix.append(text_relations[category])
        else:
            text_matrix.append([0.0] * 100)
            
        if category in audio_relations and len(audio_relations[category]) > 0:
            audio_matrix.append(audio_relations[category])
        else:
            audio_matrix.append([0.0] * 100)
            
        if category in video_relations and len(video_relations[category]) > 0:
            video_matrix.append(video_relations[category])
        else:
            video_matrix.append([0.0] * 100)
            print(f"Warning: {category} has no video relations, using zero vector.")

    return (
        torch.tensor(text_matrix, dtype=torch.float32),
        torch.tensor(audio_matrix, dtype=torch.float32),
        torch.tensor(video_matrix, dtype=torch.float32)
    )


def load_concept_features(concepts_path='concepts_features'):
    concept_A_path = os.path.join(concepts_path, 'concept_A.pt')
    concept_T_path = os.path.join(concepts_path, 'concept_T.pt')
    concept_V_path = os.path.join(concepts_path, 'concept_V.pt')

    for path, name in [(concept_A_path, 'concept_A'), (concept_T_path, 'concept_T'), (concept_V_path, 'concept_V')]:
        if not os.path.exists(path):
            raise FileNotFoundError(f"{name}.pt not exist: {path}")

    concept_A = torch.load(concept_A_path, map_location='cpu')
    concept_T = torch.load(concept_T_path, map_location='cpu')
    concept_V = torch.load(concept_V_path, map_location='cpu')
    
    return concept_A, concept_T, concept_V

