import os
import json
import pickle
import torch
from torch.utils.data import Dataset
import numpy as np


def load_data_list(data_root, fold, split):

    data_list = []
    
    if split == 'test':
        splits_to_load = ['test']

        
    for s in splits_to_load:
        if s == 'test':
            json_file = 'test_set.json'
        else:
            json_file = f'fold_{fold}_{s}.json'
            
        json_path = os.path.join(data_root, json_file)
        
        if not os.path.exists(json_path):
            if s == 'cal':
                print(f"Warning: Calibration file {json_file} not found. Skipping merge.")
                continue
            else:
                raise FileNotFoundError(f"Split file not found: {json_path}")
                
        with open(json_path, 'r') as f:
            data = json.load(f)
            if isinstance(data, dict) and 'indices' in data: # AVE format
                data_list.extend(data['indices'])
            elif isinstance(data, list):
                data_list.extend(data)
            else:
                raise ValueError(f"Unknown JSON format in {json_file}")
                
    return data_list

def compute_class_weights(labels_list, num_classes=None):

    if len(labels_list) == 0:
        return None
    
    labels_np = np.array(labels_list)
    unique_classes, class_counts = np.unique(labels_np, return_counts=True)
    
    if num_classes is None:
        num_classes = max(unique_classes) + 1
        
    total_samples = len(labels_list)
    
    weights = torch.ones(num_classes) 

    icf_weights = total_samples / (num_classes * class_counts + 1e-6)
    
    for cls_idx, w in zip(unique_classes, icf_weights):
        weights[int(cls_idx)] = float(w)
        
    return weights

class AVEDataset(Dataset):
    def __init__(self, data_root, fold=1, split='train', feature_pkl='ave_pooled_features.pkl'):
        self.data_root = data_root
        self.split = split
        
        pkl_path = os.path.join(data_root, feature_pkl)
        if not os.path.exists(pkl_path):
            raise FileNotFoundError(f"AVE PKL not found: {pkl_path}")
        with open(pkl_path, 'rb') as f:
            self.all_data = pickle.load(f)
            
        self.indices = load_data_list(data_root, fold, split)

        self.class_weights = None
        if split == 'train':

            labels = [self.all_data[i]['label'] for i in self.indices]
            self.class_weights = compute_class_weights(labels, num_classes=28)
            print(f"[AVE] Class Weights: {self.class_weights}")

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        global_idx = self.indices[idx]
        sample = self.all_data[global_idx]
        visual = torch.tensor(sample['visual'], dtype=torch.float32)
        audio = torch.tensor(sample['audio'], dtype=torch.float32)
        
        if self.split == 'train':
            visual += torch.randn_like(visual) * 1e-4
            audio += torch.randn_like(audio) * 1e-4

        label = torch.tensor(sample['label'], dtype=torch.long)
        return {'view1': visual, 'view2': audio, 'target': label, 'index': global_idx}

