import json
import os
import numpy as np
from collections import defaultdict
import pandas as pd
from irrCAC.table import CAC

def create_contingency_table_2_raters(user1_scores, user2_scores, possible_values=None):
    """
    Create a contingency table ensuring all possible values are included.
    
    Args:
        user1_scores: List/array of ratings from user 1
        user2_scores: List/array of ratings from user 2
        possible_values: List of all possible rating values (e.g., [0, 1, 2])
    
    Returns:
        Complete pandas DataFrame contingency table
    """
    if possible_values is None:
        # Automatically determine all possible values from both raters
        all_values = set(user1_scores) | set(user2_scores)
        possible_values = sorted(list(all_values))
    
    # Convert to categorical with all possible values
    user1_cat = pd.Categorical(user1_scores, categories=possible_values)
    user2_cat = pd.Categorical(user2_scores, categories=possible_values)
    
    # Create contingency table
    contingency_table = pd.crosstab(
        user1_cat, 
        user2_cat, 
        margins=False,
        dropna=False  # Important: don't drop unused categories
    )
    
    return contingency_table


def gwet_ac1(ratings):
    table = create_contingency_table_2_raters(ratings[0], ratings[1], [0,1,2])
    cac_abstractors = CAC(table)
    gwet = cac_abstractors.gwet()
    return gwet['est']['coefficient_value']


def detect_data_format(data):
    """
    Detect whether this is env_data or conv_data format.
    
    Args:
        data: List of annotation items
    
    Returns:
        'env' or 'conv'
    """
    if not data:
        return None
    
    first_item = data[0]
    
    # Check for env_data format (has 'item_id' and 'state_scores' with 'score')
    if 'item_id' in first_item and 'state_scores' in first_item:
        # Check if state_scores contain 'score'
        state_scores = first_item['state_scores']
        if state_scores:
            first_state = next(iter(state_scores.values()))
            if 'score' in first_state:
                return 'env'
    
    # Check for conv_data format (has 'conversation_id' and 'state_ratings' with 'consistency_rating')
    if 'conversation_id' in first_item and 'state_ratings' in first_item:
        # Check if state_ratings contain 'consistency_rating'
        state_ratings = first_item['state_ratings']
        if state_ratings:
            first_state = next(iter(state_ratings.values()))
            if 'consistency_rating' in first_state:
                return 'conv'
    
    return None

def process_annotation_files(folder_path, data_type=None):
    """
    Process annotation files and calculate overall metrics.
    Automatically detects data format or uses specified type.
    
    Args:
        folder_path: Path to folder containing user1.json and user2.json
        data_type: 'env' or 'conv' to force a specific format, None for auto-detect
    
    Returns:
        Dictionary containing overall results
    """
    # Load the files
    user1_path = os.path.join(folder_path, 'user1.json')
    user2_path = os.path.join(folder_path, 'user2.json')
    
    try:
        with open(user1_path, 'r') as f:
            user1_data = json.load(f)
        
        with open(user2_path, 'r') as f:
            user2_data = json.load(f)
    except FileNotFoundError as e:
        raise FileNotFoundError(f"Could not find annotation files: {e}")
    
    print(f"Loaded {len(user1_data)} items from user1.json")
    print(f"Loaded {len(user2_data)} items from user2.json")
    
    if len(user1_data) != len(user2_data):
        raise ValueError("Files have different lengths!")
    
    # Auto-detect data format if not specified
    if data_type is None:
        data_type = detect_data_format(user1_data)
        if data_type is None:
            raise ValueError("Could not detect data format")
    
    print(f"Detected data format: {data_type}")
    
    # Process based on data type
    if data_type == 'env':
        return process_env_data(user1_data, user2_data)
    elif data_type == 'conv':
        return process_conv_data(user1_data, user2_data)
    else:
        raise ValueError(f"Unknown data type: {data_type}")

def process_env_data(user1_data, user2_data):
    """Process env_data format (item_id, state_scores with score)."""
    # Create dictionaries for easier lookup
    user1_dict = {item['item_id']: item for item in user1_data}
    user2_dict = {item['item_id']: item for item in user2_data}
    
    # Collect all annotation pairs
    all_user1_scores = []
    all_user2_scores = []
    all_user1_rescaled = []
    all_user2_rescaled = []
    annotation_details = []
    
    # Process each item
    for item_id in user1_dict.keys():
        if item_id not in user2_dict:
            print(f"Warning: item_id {item_id} not found in user2 data")
            continue
        
        user1_item = user1_dict[item_id]
        user2_item = user2_dict[item_id]
        
        # Get all state score keys from both users
        all_state_keys = set(user1_item['state_scores'].keys()) | set(user2_item['state_scores'].keys())
        
        for state_key in all_state_keys:
            user1_score = user1_item['state_scores'].get(state_key, {}).get('score')
            user2_score = user2_item['state_scores'].get(state_key, {}).get('score')
            
            # Only include if both users have this state score
            if user1_score is not None and user2_score is not None:
                # Original scores [0,2]
                all_user1_scores.append(user1_score)
                all_user2_scores.append(user2_score)
                
                # Rescaled scores [0,1]
                user1_rescaled = user1_score / 2.0
                user2_rescaled = user2_score / 2.0
                all_user1_rescaled.append(user1_rescaled)
                all_user2_rescaled.append(user2_rescaled)
                
                annotation_details.append({
                    'item_id': item_id,
                    'state_key': state_key,
                    'user1_original': user1_score,
                    'user2_original': user2_score,
                    'user1_rescaled': user1_rescaled,
                    'user2_rescaled': user2_rescaled
                })
    
    return calculate_metrics(all_user1_scores, all_user2_scores, all_user1_rescaled, all_user2_rescaled, annotation_details, 'env')

def process_conv_data(user1_data, user2_data):
    """Process conv_data format (conversation_id, state_ratings with consistency_rating)."""
    # Create dictionaries for easier lookup
    user1_dict = {item['conversation_id']: item for item in user1_data}
    user2_dict = {item['conversation_id']: item for item in user2_data}
    
    # Collect all annotation pairs
    all_user1_scores = []
    all_user2_scores = []
    all_user1_rescaled = []
    all_user2_rescaled = []
    annotation_details = []
    
    # Process each conversation
    for conv_id in user1_dict.keys():
        if conv_id not in user2_dict:
            print(f"Warning: conversation_id {conv_id} not found in user2 data")
            continue
        
        user1_item = user1_dict[conv_id]
        user2_item = user2_dict[conv_id]
        
        # Get all state rating keys from both users
        all_state_keys = set(user1_item['state_ratings'].keys()) | set(user2_item['state_ratings'].keys())
        
        for state_key in all_state_keys:
            user1_rating = user1_item['state_ratings'].get(state_key, {}).get('consistency_rating')
            user2_rating = user2_item['state_ratings'].get(state_key, {}).get('consistency_rating')
            
            # Only include if both users have this state rating
            if user1_rating is not None and user2_rating is not None:
                # Original ratings [0,2] (assuming same scale)
                all_user1_scores.append(user1_rating)
                all_user2_scores.append(user2_rating)
                
                # Rescaled ratings [0,1]
                user1_rescaled = user1_rating / 2.0
                user2_rescaled = user2_rating / 2.0
                all_user1_rescaled.append(user1_rescaled)
                all_user2_rescaled.append(user2_rescaled)
                
                annotation_details.append({
                    'conversation_id': conv_id,
                    'state_key': state_key,
                    'user1_original': user1_rating,
                    'user2_original': user2_rating,
                    'user1_rescaled': user1_rescaled,
                    'user2_rescaled': user2_rescaled
                })
    
    return calculate_metrics(all_user1_scores, all_user2_scores, all_user1_rescaled, all_user2_rescaled, annotation_details, 'conv')

def calculate_metrics(all_user1_scores, all_user2_scores, all_user1_rescaled, all_user2_rescaled, annotation_details, data_type):
    """Calculate overall metrics for the annotations."""
    # Convert to numpy arrays
    user1_scores = np.array(all_user1_scores)
    user2_scores = np.array(all_user2_scores)
    user1_rescaled = np.array(all_user1_rescaled)
    user2_rescaled = np.array(all_user2_rescaled)
    
    # Calculate overall metrics
    total_annotations = len(user1_scores)
    
    # Average scores (rescaled)
    avg_user1_rescaled = np.mean(user1_rescaled)
    avg_user2_rescaled = np.mean(user2_rescaled)
    overall_avg_rescaled = np.mean(np.concatenate([user1_rescaled, user2_rescaled]))
    
    # Calculate Gwet's AC1 using original scores
    ratings_matrix = np.array([user1_scores, user2_scores])
    ac1_coefficient = gwet_ac1(ratings_matrix)
    
    # Additional statistics
    score_differences = user1_rescaled - user2_rescaled
    mean_absolute_difference = np.mean(np.abs(score_differences))
    
    # Count exact agreements
    exact_agreements = np.sum(user1_scores == user2_scores)
    agreement_rate = exact_agreements / total_annotations
    
    results = {
        'data_type': data_type,
        'total_annotations': total_annotations,
        'rescaled_scores': {
            'avg_user1': avg_user1_rescaled,
            'avg_user2': avg_user2_rescaled,
            'overall_average': overall_avg_rescaled
        },
        'agreement_metrics': {
            'gwet_ac1': ac1_coefficient,
            'exact_agreement_rate': agreement_rate,
            'exact_agreements': exact_agreements,
            'mean_absolute_difference': mean_absolute_difference
        },
        'score_distribution': {
            'user1_original': {
                'mean': float(np.mean(user1_scores)),
                'std': float(np.std(user1_scores)),
                'distribution': {int(k): int(v) for k, v in zip(*np.unique(user1_scores, return_counts=True))}
            },
            'user2_original': {
                'mean': float(np.mean(user2_scores)),
                'std': float(np.std(user2_scores)),
                'distribution': {int(k): int(v) for k, v in zip(*np.unique(user2_scores, return_counts=True))}
            }
        },
        'annotation_details': annotation_details
    }
    
    return results

def print_results(results):
    """Print formatted results."""
    print("\n" + "="*80)
    print(f"OVERALL ANNOTATION ANALYSIS RESULTS ({results['data_type'].upper()}_DATA)")
    print("="*80)
    
    print(f"\nTotal annotations: {results['total_annotations']}")
    
    print(f"\nRescaled Scores [0,1]:")
    print("-" * 40)
    print(f"Average score (User 1): {results['rescaled_scores']['avg_user1']:.3f}")
    print(f"Average score (User 2): {results['rescaled_scores']['avg_user2']:.3f}")
    print(f"Overall average: {results['rescaled_scores']['overall_average']:.3f}")
    
    print(f"\nInter-Annotator Agreement:")
    print("-" * 40)
    ac1_val = results['agreement_metrics']['gwet_ac1']
    if ac1_val is not None:
        print(f"Gwet's AC1 coefficient: {ac1_val:.3f}")
        
        # Interpretation of AC1 values
        if ac1_val < 0:
            interpretation = "Poor (worse than chance)"
        elif ac1_val < 0.2:
            interpretation = "Slight"
        elif ac1_val < 0.4:
            interpretation = "Fair"
        elif ac1_val < 0.6:
            interpretation = "Moderate"
        elif ac1_val < 0.8:
            interpretation = "Substantial"
        else:
            interpretation = "Almost perfect"
        
        print(f"Agreement interpretation: {interpretation}")
    else:
        print("Gwet's AC1 coefficient: Could not calculate")
    
    print(f"Exact agreement rate: {results['agreement_metrics']['exact_agreement_rate']:.3f}")
    print(f"Exact agreements: {results['agreement_metrics']['exact_agreements']}/{results['total_annotations']}")
    print(f"Mean absolute difference (rescaled): {results['agreement_metrics']['mean_absolute_difference']:.3f}")
    
    print(f"\nScore Distributions (Original [0,2]):")
    print("-" * 40)
    print("User 1:", results['score_distribution']['user1_original']['distribution'])
    print("User 2:", results['score_distribution']['user2_original']['distribution'])

# Example usage
if __name__ == "__main__":
    with open("config.json") as f:
        config = json.loads(f.read())
    
    output_dir = config["annotation_config"]["output_dir"]
    

    # folder_path = os.path.join(output_dir, "env_data")

    for sub_folder in ["env_data", "conv_data"]:
        folder_path = os.path.join(output_dir, sub_folder)

        print(f"\n\n\n\n\nCOMPUTING META EVALUATION SCORES FOR {folder_path}")
    
        try:
            # The script will auto-detect whether it's env_data or conv_data format
            results = process_annotation_files(folder_path)
            print_results(results)
            
        except Exception as e:
            print(f"Error processing files: {e}")

