#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Generate evaluation result statistics tables

import os
import json
import re
import pandas as pd
import numpy as np
from collections import Counter
from typing import List, Dict, Union, Tuple
import logging
from tqdm import tqdm

from dataset import BaseDataset, SNLIDataset, MTBenchDataset, SummEvalDataset

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("EvaluationSummary")

class EvaluationSummary:
    """Evaluation results summary class"""

    def __init__(self, 
                base_path: str = os.path.dirname(os.path.abspath(__file__)), 
                raw_results_path: str = None,
                metrics_output_path: str = None):
        """
        Initialize evaluation results summary class
        
        Args:
            base_path: Base path
            raw_results_path: Raw evaluation results path
            metrics_output_path: Metrics output path
        """
        self.base_path = base_path
        
        # Set paths
        if raw_results_path is None:
            self.raw_results_path = os.path.join(base_path, "evaluation_results", "raw")
        else:
            self.raw_results_path = raw_results_path
            
        if metrics_output_path is None:
            self.metrics_output_path = os.path.join(base_path, "evaluation_results", "metrics")
        else:
            self.metrics_output_path = metrics_output_path
        
        # Ensure output directory exists
        os.makedirs(self.metrics_output_path, exist_ok=True)
        
        # Supported datasets
        self.dataset_mapping = {
            "snli": lambda: SNLIDataset(file_path=os.path.join(base_path, "dataset", "test", "snli_test.jsonl"), name="snli"),
            "multinli": lambda: SNLIDataset(file_path=os.path.join(base_path, "dataset", "test", "multinli_test.jsonl"), name="multinli"),
            "mtbench": lambda: MTBenchDataset(file_path=os.path.join(base_path, "dataset", "test", "mt_bench_test.jsonl"), name="mtbench"),
            "summeval": lambda: SummEvalDataset(file_path=os.path.join(base_path, "dataset", "test", "summeval_test.jsonl"), name="summeval")
        }

    def parse_qwen_lora_filename(self, filename: str) -> Dict[str, str]:
        """
        Parse QwenLora model filename to extract epsilon, alpha, and epoch parameters
        
        Args:
            filename: Filename (e.g. qwen2.5-7b_eps_0.1_alpha_0.2_epoch_1.json)
            
        Returns:
            Dictionary containing parameters {'epsilon': '0.1', 'alpha': '0.2', 'epoch': '1'}
        """
        params = {}
        
        # Extract model_name (e.g. qwen2.5-7b)
        # Find index of _eps then extract substring before it
        model_name_match = filename.find('_eps')
        if model_name_match != -1:
            model_name = filename[:model_name_match]
            params['model_name'] = model_name
        else:
            params['model_name'] = 'unknown'

        # Extract epsilon (e.g. eps_0_1)
        eps_match = re.search(r'eps_(\d+\.\d+)', filename)
        if eps_match:
            eps_str = eps_match.group(1).replace('_', '.')
            params['epsilon'] = eps_str
        else:
            params['epsilon'] = 'unknown'
        
        # Extract alpha (e.g. alpha_0_2)
        alpha_match = re.search(r'alpha_(\d+\.\d+)', filename)
        if alpha_match:
            alpha_str = alpha_match.group(1).replace('_', '.')
            params['alpha'] = alpha_str
        else:
            params['alpha'] = 'unknown'
        
        # Extract epoch (e.g. epoch_1)
        epoch_match = re.search(r'epoch_(\d+)', filename)
        if epoch_match:
            params['epoch'] = epoch_match.group(1)
        else:
            params['epoch'] = 'unknown'  # Default value
            
        return params

    def calculate_accuracy(self, model_predictions: Dict[str, Dict[str, float]], 
                          dataset: BaseDataset) -> float:
        """
        Calculate accuracy
        
        Args:
            model_predictions: Model prediction results {sample ID: {label: probability}}
            dataset: Dataset object
            
        Returns:
            Accuracy
        """
        correct = 0
        total = 0
        
        for item in dataset:
            id_key = str(item[dataset.id_key])
            if id_key in model_predictions:
                gold_label = dataset.get_gold_label(item)
                if gold_label == "-" or gold_label == "":  # Skip samples without clear labels
                    continue
                    
                pred_dist = model_predictions[id_key]
                # Find label with highest probability
                max_prob_label = max(pred_dist.items(), key=lambda x: x[1])[0]
                if max_prob_label.lower() == gold_label.lower():
                    correct += 1
                total += 1
                
        return correct / total if total > 0 else 0.0

    def calculate_kl_divergence(self, pred_dist: Dict[str, float], 
                                true_dist: Dict[str, float]) -> float:
        """
        Calculate KL divergence
        
        Args:
            pred_dist: Predicted distribution {label: probability}
            true_dist: True distribution {label: probability}
            
        Returns:
            KL divergence value
        """
        all_labels = set(true_dist.keys()).union(set(pred_dist.keys()))
        epsilon = 1e-10  # Prevent division by zero
        
        # Ensure all labels exist in both distributions
        p_dist = {label: true_dist.get(label, epsilon) for label in all_labels}
        q_dist = {label: pred_dist.get(label, epsilon) for label in all_labels}
        
        # Normalize distributions
        p_sum = sum(p_dist.values())
        if p_sum > 0:
            p_dist = {k: v/p_sum for k, v in p_dist.items()}
            
        q_sum = sum(q_dist.values())
        if q_sum > 0:
            q_dist = {k: v/q_sum for k, v in q_dist.items()}
        
        # Calculate KL divergence
        kl = 0.0
        for label in all_labels:
            p = p_dist[label]
            q = q_dist[label]
            if p > epsilon:  # Only calculate when p>0
                kl += p * np.log(p / q)
                
        return kl

    def calculate_js_divergence(self, pred_dist: Dict[str, float], 
                               true_dist: Dict[str, float]) -> float:
        """
        Calculate JS divergence
        
        Args:
            pred_dist: Predicted distribution {label: probability}
            true_dist: True distribution {label: probability}
            
        Returns:
            JS divergence value
        """
        all_labels = set(true_dist.keys()).union(set(pred_dist.keys()))
        epsilon = 1e-10  # Prevent division by zero
        
        # Ensure all labels exist in both distributions
        p_dist = {label: true_dist.get(label, epsilon) for label in all_labels}
        q_dist = {label: pred_dist.get(label, epsilon) for label in all_labels}
        
        # Normalize distributions
        p_sum = sum(p_dist.values())
        if p_sum > 0:
            p_dist = {k: v/p_sum for k, v in p_dist.items()}
            
        q_sum = sum(q_dist.values())
        if q_sum > 0:
            q_dist = {k: v/q_sum for k, v in q_dist.items()}
        
        # Calculate intermediate distribution M = (P + Q) / 2
        m_dist = {label: (p_dist[label] + q_dist[label]) / 2 for label in all_labels}
        
        # Calculate KL(P || M) and KL(Q || M)
        kl_p_m = 0.0
        kl_q_m = 0.0
        
        for label in all_labels:
            p = p_dist[label]
            q = q_dist[label]
            m = m_dist[label]
            
            if p > epsilon:
                kl_p_m += p * np.log(p / m)
            if q > epsilon:
                kl_q_m += q * np.log(q / m)
                
        # JS divergence = 0.5 * (KL(P || M) + KL(Q || M))
        return 0.5 * (kl_p_m + kl_q_m)

    def calculate_metrics(self, model_predictions: Dict[str, Dict[str, float]], 
                         dataset: BaseDataset) -> Dict[str, float]:
        """
        Calculate all metrics
        
        Args:
            model_predictions: Model prediction results {sample ID: {label: probability}}
            dataset: Dataset object
            
        Returns:
            Dictionary with three metrics
        """
        # Calculate accuracy
        accuracy = self.calculate_accuracy(model_predictions, dataset)
        
        # Calculate KL divergence and JS divergence
        kl_values = []
        js_values = []
        
        for item in dataset:
            id_key = str(item[dataset.id_key])
            if id_key in model_predictions:
                true_dist = dataset.get_label(item)  # Get human annotation distribution
                pred_dist = model_predictions[id_key]  # Get model prediction distribution
                
                # Convert to dictionary form (if not)
                if isinstance(true_dist, list):
                    # Convert list to distribution
                    counter = Counter(true_dist)
                    total = len(true_dist)
                    true_dist = {k: v/total for k, v in counter.items()}
                
                # Calculate divergences
                kl = self.calculate_kl_divergence(pred_dist, true_dist)
                js = self.calculate_js_divergence(pred_dist, true_dist)
                
                kl_values.append(kl)
                js_values.append(js)
        
        # Calculate averages
        avg_kl = np.mean(kl_values) if kl_values else float('nan')
        avg_js = np.mean(js_values) if js_values else float('nan')
        
        return {
            'accuracy': accuracy,
            'kl_divergence': avg_kl,
            'js_divergence': avg_js
        }

    def process_dataset(self, dataset_name: str) -> None:
        """
        Process single dataset, generate evaluation results table
        
        Args:
            dataset_name: Dataset name
        """
        logger.info(f"Processing dataset: {dataset_name}")
        
        # Load dataset
        if dataset_name not in self.dataset_mapping:
            logger.error(f"Unknown dataset: {dataset_name}")
            return
            
        dataset = self.dataset_mapping[dataset_name]()
        
        # Create output directory
        dataset_metrics_dir = os.path.join(self.metrics_output_path, dataset_name)
        os.makedirs(dataset_metrics_dir, exist_ok=True)
        
        # Find all model prediction results for this dataset
        dataset_raw_dir = os.path.join(self.raw_results_path, dataset_name)
        if not os.path.exists(dataset_raw_dir):
            logger.warning(f"No raw results found for dataset: {dataset_name}")
            return
            
        # Collect results for all models
        results = []  # For storing main experiment results
        
        # Iterate through all model type directories
        for model_type in os.listdir(dataset_raw_dir):
            model_type_dir = os.path.join(dataset_raw_dir, model_type)
            if not os.path.isdir(model_type_dir):
                continue
                
            # Iterate through all prediction files under this model type
            for result_file in os.listdir(model_type_dir):
                if not result_file.endswith('.json'):
                    continue
                    
                file_path = os.path.join(model_type_dir, result_file)
                try:
                    with open(file_path, 'r', encoding='utf-8') as f:
                        predictions = json.load(f)
                except Exception as e:
                    logger.error(f"Failed to load {file_path}: {e}")
                    continue
                
                # Extract model information
                model_name = result_file.replace('.json', '')
                
                # Check if it's QwenLora model
                if model_type == 'qwenlora':
                    # Parse filename to get parameters
                    params = self.parse_qwen_lora_filename(result_file)
                    model_info = {
                        'model_type': 'qwenlora',
                        'model_name': params['model_name'],
                        'epsilon': params['epsilon'],
                        'alpha': params['alpha'],
                        'epoch': params['epoch']
                    }
                else:
                    model_info = {
                        'model_type': model_type,
                        'model_name': model_name
                    }
                
                # Calculate main experiment metrics
                logger.info(f"Calculating metrics for {model_type}/{model_name}")
                metrics = self.calculate_metrics(predictions, dataset)
                
                # Merge model info and metrics
                result_entry = {**model_info, **metrics}
                results.append(result_entry)
        
        # Convert results to DataFrame and save
        if results:
            # Reorganize DataFrame columns based on model type
            if all('model_type' in r for r in results):
                # Separate different types of models
                openai_results = [r for r in results if r.get('model_type') == 'openai']
                qwen_results = [r for r in results if r.get('model_type') == 'qwen']
                lora_results = [r for r in results if r.get('model_type') == 'qwenlora']
                
                # Create final DataFrame
                final_results = []
                
                # Process OpenAI and Qwen results
                for r in openai_results + qwen_results:
                    entry = {
                        'model_name': r.get('model_name', ''),
                        'accuracy': r.get('accuracy', float('nan')),
                        'kl_divergence': r.get('kl_divergence', float('nan')),
                        'js_divergence': r.get('js_divergence', float('nan'))
                    }
                    final_results.append(entry)
                
                # First group by model name, then sort by parameters
                # 1. Group by model name
                model_groups = {}
                for r in lora_results:
                    model_name = r.get('model_name', 'unknown')
                    if model_name not in model_groups:
                        model_groups[model_name] = []
                    model_groups[model_name].append(r)
                
                # 2. Sort by parameters within each model group
                lora_results_sorted = []
                for model_name in sorted(model_groups.keys()):  # Sort by model name alphabetically
                    group = model_groups[model_name]
                    # Within each model group, sort by epsilon, alpha, epoch
                    sorted_group = sorted(group, key=lambda r: (
                        float(r.get('epsilon', '0')) if r.get('epsilon', '0') != 'unknown' else 0,
                        float(r.get('alpha', '0')) if r.get('alpha', '0') != 'unknown' else 0,
                        int(r.get('epoch', '0')) if r.get('epoch', '0') != 'unknown' else 0
                    ))
                    lora_results_sorted.extend(sorted_group)
                
                # Process sorted LoRA results
                for r in lora_results_sorted:
                    entry = {
                        'model_name': r.get('model_name', ''),
                        'epsilon': r.get('epsilon', ''),
                        'alpha': r.get('alpha', ''),
                        'epoch': r.get('epoch', ''),
                        'accuracy': r.get('accuracy', float('nan')),
                        'kl_divergence': r.get('kl_divergence', float('nan')),
                        'js_divergence': r.get('js_divergence', float('nan'))
                    }
                    final_results.append(entry)
                
                df = pd.DataFrame(final_results)
            else:
                df = pd.DataFrame(results)
            
            # Save as CSV
            csv_path = os.path.join(dataset_metrics_dir, "evaluation_results.csv")
            df.to_csv(csv_path, index=False)
            logger.info(f"Saved evaluation results to {csv_path}")
        else:
            logger.warning(f"No results found for dataset: {dataset_name}")

    def process_all_datasets(self) -> None:
        """
        Process all datasets
        """
        # Find all dataset folders under raw directory
        if os.path.exists(self.raw_results_path):
            for item in os.listdir(self.raw_results_path):
                if os.path.isdir(os.path.join(self.raw_results_path, item)) and item in self.dataset_mapping:
                    self.process_dataset(item)
        else:
            logger.error(f"Raw results path does not exist: {self.raw_results_path}")


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Generate evaluation result statistics tables")
    parser.add_argument('--datasets', type=str, nargs='+', 
                        choices=['snli', 'multinli', 'mtbench', 'summeval', 'all'],
                        default=['snli'], help='Datasets to process')
    
    args = parser.parse_args()
    
    # Create summary object
    summary = EvaluationSummary()
    
    # Process specified datasets
    if 'all' in args.datasets:
        summary.process_all_datasets()
    else:
        for dataset in args.datasets:
            summary.process_dataset(dataset)
    
    print("Evaluation result summary completed!")