import pickle
import numpy as np

def inspect_pkl_structure(pkl_path, num_samples=5):
    print(f"\nLoading {pkl_path}...")
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)
    
    print("\n=== Top Level Structure ===")
    if isinstance(data, dict):
        print("Data type: dictionary")
        print("Keys:", list(data.keys()))
        
        print("\n=== Structure for Each Split ===")
        for split_name, split_data in data.items():
            print(f"\n{split_name} split:")
            if isinstance(split_data, dict):
                print("Keys:", list(split_data.keys()))
                print("\nValue types:")
                for key, value in split_data.items():
                    print(f"{key}: {type(value)}")
                    if hasattr(value, 'shape'):
                        print(f"Shape: {value.shape}")
                    elif isinstance(value, list):
                        print(f"Length: {len(value)}")

            print("\nLabel Distribution:")
            # Classification labels (all classes)
            cls_labels = split_data['classification_labels']
            unique_labels, counts = np.unique(cls_labels, return_counts=True)
            print("\nClassification labels (All classes):")
            for label, count in zip(unique_labels, counts):
                print(f"Label {label}: {count} samples ({count/len(cls_labels)*100:.2f}%)") # Label 0: Negative; Label 1: Neutral; Label 2: positive
            
            # Regression labels (all classes)
            reg_labels = split_data['regression_labels']
            print("\nRegression labels:")
            print(f"Min: {reg_labels.min():.3f}")
            print(f"Max: {reg_labels.max():.3f}")
            print(f"Mean: {reg_labels.mean():.3f}")
            print(f"Std: {reg_labels.std():.3f}")

            # binary classification (non-zero)
            non_zeros = np.array([i for i in range(len(reg_labels)) if reg_labels[i] != 0])
            non_zeros_labels = reg_labels[non_zeros]
            positive_nz = np.sum(non_zeros_labels > 0)
            negative_nz = np.sum(non_zeros_labels < 0)
            total_nz = len(non_zeros_labels)

            print("\nBinary sentiment distribution (excluding zeros):")
            pos_ratio_nz = positive_nz/total_nz*100
            neg_ratio_nz = negative_nz/total_nz*100
            print(f"Positive samples (>0): {positive_nz} ({pos_ratio_nz:.2f}%)")
            print(f"Negative samples (<0): {negative_nz} ({neg_ratio_nz:.2f}%)")
            print(f"Total non-zero samples: {total_nz}")
            print(f"Majority class ratio: {max(pos_ratio_nz, neg_ratio_nz):.2f}%")

            # 7 classes classification
            reg_labels_clip = np.clip(reg_labels, a_min=-3., a_max=3.) 
            reg_labels_round = np.round(reg_labels_clip)
            all_possible_labels = np.array([-3, -2, -1, 0, 1, 2, 3])
            counts_7 = np.zeros(7)
            for i, label in enumerate(all_possible_labels):
                counts_7[i] = np.sum(reg_labels_round == label)

            print("\n7-class distribution (after clipping to [-3,3] and rounding):")
            max_ratio_7 = 0
            total = len(reg_labels)
            for label, count in zip(all_possible_labels, counts_7):
                ratio = count/total*100
                print(f"Label {int(label)}: {int(count)} samples ({ratio:.2f}%)")
                max_ratio_7 = max(max_ratio_7, ratio)
            print(f"Majority class ratio: {max_ratio_7:.2f}%")

            # data instances
            print(f"\nFirst {num_samples} samples:")
            for i in range(min(num_samples, len(split_data['raw_text']))):
                print(f"\nSample {i+1}:")
                print(f"Raw text: {split_data['raw_text'][i]}")
                print(f"ID: {split_data['id'][i]}")
                print(f"Classification label: {split_data['classification_labels'][i]}")
                print(f"Regression label: {split_data['regression_labels'][i]:.3f}")
                print(f"Annotation: {split_data['annotations'][i]}")

        
if __name__ == "__main__":
    pkl_path = "MOSEI/aligned_50.pkl"
    inspect_pkl_structure(pkl_path)