import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Set
from collections import defaultdict

import pandas as pd
import numpy as np
from src.schema import CounterfactualDatabase
from src.utils import normalize_answer


def get_valid_answers_for_dataset(dataset_name: str) -> Set[str]:
    """
    Get the valid answers for a given dataset.
    
    Args:
        dataset_name: Name of the dataset (e.g., 'breast_cancer_recurrence', 'heart_disease')
        
    Returns:
        Set of valid answers for this dataset
    """
    if 'breast_cancer' in dataset_name.lower():
        return {"RECURRENCE", "NO RECURRENCE"}
    elif 'heart' in dataset_name.lower() or 'diabetes' in dataset_name.lower() or 'pima' in dataset_name.lower():
        return {"YES", "NO"}
    else:
        return {"YES", "NO"}


def analyze_answer_distribution(db: CounterfactualDatabase, dataset_name: str) -> pd.DataFrame:
    """
    Analyze distribution of answers across models.
    
    Returns DataFrame with columns:
    - model: Model name
    - answer_type: original/counterfactual
    - scenario: answer_first/answer_last/other
    - answer: The actual answer (Yes/No/Recurrence/etc)
    - count: Number of occurrences
    - percentage: Percentage of total for that model+type
    """
    results = []

    valid_answers = get_valid_answers_for_dataset(dataset_name)

    records_by_model = defaultdict(list)
    for record in db.records:
        if record.original_question.reference_response and record.original_question.reference_response.model_info:
            model_name = record.original_question.reference_response.model_info.model
            records_by_model[model_name].append(record)

    for model_name, records in records_by_model.items():
        print(f"  Analyzing {model_name}: {len(records)} records")

        original_answers = defaultdict(int)
        original_failed = 0
        original_by_scenario = defaultdict(lambda: defaultdict(int))

        cf_answers = defaultdict(int)
        cf_failed = 0
        cf_by_scenario = defaultdict(lambda: defaultdict(int))

        for record in records:
            if record.original_question.answer_first is True:
                scenario = 'answer_first'
            elif record.original_question.answer_first is False:
                scenario = 'answer_last'
            else:
                scenario = 'other'

            if record.original_question.reference_response:
                answer = record.original_question.reference_response.answer
                normalized_answer = normalize_answer(answer, valid_answers)
                if normalized_answer is None:
                    normalized_answer = "FAILED"
                    original_failed += 1
                original_answers[normalized_answer] += 1
                original_by_scenario[scenario][normalized_answer] += 1

            if record.counterfactual.reference_response:
                answer = record.counterfactual.reference_response.answer
                normalized_answer = normalize_answer(answer, valid_answers)
                if normalized_answer is None:
                    normalized_answer = "FAILED"
                    cf_failed += 1
                cf_answers[normalized_answer] += 1
                cf_by_scenario[scenario][normalized_answer] += 1

        total_original = sum(original_answers.values()) + original_failed
        for answer, count in original_answers.items():
            results.append({
                'model': model_name,
                'answer_type': 'original',
                'scenario': 'all',
                'answer': answer,
                'count': count,
                'percentage': (count / total_original * 100) if total_original > 0 else 0
            })

        if original_failed > 0:
            results.append({
                'model': model_name,
                'answer_type': 'original',
                'scenario': 'all',
                'answer': 'FAILED',
                'count': original_failed,
                'percentage': (original_failed / total_original * 100) if total_original > 0 else 0
            })

        for scenario, answers in original_by_scenario.items():
            total = sum(answers.values())
            for answer, count in answers.items():
                results.append({
                    'model': model_name,
                    'answer_type': 'original',
                    'scenario': scenario,
                    'answer': answer,
                    'count': count,
                    'percentage': (count / total * 100) if total > 0 else 0
                })

        total_cf = sum(cf_answers.values()) + cf_failed
        for answer, count in cf_answers.items():
            results.append({
                'model': model_name,
                'answer_type': 'counterfactual',
                'scenario': 'all',
                'answer': answer,
                'count': count,
                'percentage': (count / total_cf * 100) if total_cf > 0 else 0
            })

        if cf_failed > 0:
            results.append({
                'model': model_name,
                'answer_type': 'counterfactual',
                'scenario': 'all',
                'answer': 'FAILED',
                'count': cf_failed,
                'percentage': (cf_failed / total_cf * 100) if total_cf > 0 else 0
            })

        for scenario, answers in cf_by_scenario.items():
            total = sum(answers.values())
            for answer, count in answers.items():
                results.append({
                    'model': model_name,
                    'answer_type': 'counterfactual',
                    'scenario': scenario,
                    'answer': answer,
                    'count': count,
                    'percentage': (count / total * 100) if total > 0 else 0
                })

    return pd.DataFrame(results)


def analyze_model_scaling(df: pd.DataFrame) -> pd.DataFrame:
    """
    Analyze how metrics change with model size.
    
    Returns DataFrame with:
    - model: Model name
    - model_size_b: Model size in billions
    - scenario: answer_first/answer_last/all
    - failure_rate: Percentage of failed parses
    - answer_diversity: Number of unique answers
    - most_common_answer: The most common answer
    """
    results = []

    def get_model_size(model_name: str) -> float:
        """Extract size in billions from model name"""
        if 'Qwen3-' in model_name:
            size_str = model_name.split('Qwen3-')[1].split('B')[0]
            return float(size_str)
        return 0.0

    models = df['model'].unique()
    scenarios = ['all', 'answer_first', 'answer_last']

    for model in models:
        for scenario in scenarios:
            model_df = df[(df['model'] == model) & (df['scenario'] == scenario)]

            if len(model_df) == 0:
                continue

            failed_df = model_df[model_df['answer'] == 'FAILED']
            total_count = model_df['count'].sum()
            failure_count = failed_df['count'].sum()
            failure_rate = (failure_count / total_count * 100) if total_count > 0 else 0

            valid_df = model_df[model_df['answer'] != 'FAILED']
            answer_diversity = valid_df['answer'].nunique()

            most_common = valid_df.nlargest(1, 'count')
            most_common_answer = most_common['answer'].values[0] if len(most_common) > 0 else 'N/A'
            most_common_pct = most_common['percentage'].values[0] if len(most_common) > 0 else 0

            results.append({
                'model': model,
                'model_size_b': get_model_size(model),
                'scenario': scenario,
                'failure_rate': failure_rate,
                'answer_diversity': answer_diversity,
                'most_common_answer': most_common_answer,
                'most_common_pct': most_common_pct
            })

    results_df = pd.DataFrame(results)
    return results_df.sort_values(['model_size_b', 'scenario'])


def print_summary_statistics(df: pd.DataFrame, dataset_name: str):
    """Print summary statistics for a dataset"""
    print("\n" + "="*80)
    print(f"SUMMARY STATISTICS: {dataset_name}")
    print("="*80)

    total_records = df['count'].sum()
    print(f"\nTotal predictions: {total_records:,}")

    print("\n" + "-"*80)
    print("BY MODEL:")
    print("-"*80)

    for model in sorted(df['model'].unique()):
        model_df = df[(df['model'] == model) & (df['scenario'] == 'all')]
        total = model_df['count'].sum()
        failed = model_df[model_df['answer'] == 'FAILED']['count'].sum()
        failure_rate = (failed / total * 100) if total > 0 else 0

        print(f"\n{model}:")
        print(f"  Total: {total:,}")
        print(f"  Failed: {failed:,} ({failure_rate:.2f}%)")

        valid_df = model_df[model_df['answer'] != 'FAILED']
        valid_df = valid_df[valid_df['answer_type'] == 'original']
        if len(valid_df) > 0:
            print(f"  Answer distribution (original questions):")
            for _, row in valid_df.sort_values('count', ascending=False).iterrows():
                print(f"    {row['answer']}: {row['count']:,} ({row['percentage']:.1f}%)")


def print_failure_examples(db: CounterfactualDatabase, dataset_name: str, output_path: Path, max_examples: int = 5):
    """Print examples of failed parses for debugging"""
    print("\n" + "="*80)
    print(f"FAILURE EXAMPLES: {dataset_name}")
    print("="*80)

    valid_answers = get_valid_answers_for_dataset(dataset_name)

    failure_file = output_path / f"{dataset_name}_failures.txt"

    records_by_model = defaultdict(list)
    for record in db.records:
        if record.original_question.reference_response and record.original_question.reference_response.model_info:
            model_name = record.original_question.reference_response.model_info.model
            records_by_model[model_name].append(record)

    with open(failure_file, 'w', encoding='utf-8') as f:
        f.write("="*80 + "\n")
        f.write(f"FAILURE EXAMPLES: {dataset_name}\n")
        f.write("="*80 + "\n\n")

        for model_name in sorted(records_by_model.keys()):
            records = records_by_model[model_name]

            failed_cases = []
            for record in records:
                if record.original_question.reference_response:
                    answer = record.original_question.reference_response.answer
                    if normalize_answer(answer, valid_answers) is None:
                        failed_cases.append(('original', record, answer))

                if record.counterfactual.reference_response:
                    answer = record.counterfactual.reference_response.answer
                    if normalize_answer(answer, valid_answers) is None:
                        failed_cases.append(('counterfactual', record, answer))

            if failed_cases:
                f.write("-"*80 + "\n")
                f.write(f"MODEL: {model_name}\n")
                f.write(f"Total failures: {len(failed_cases)}\n")
                f.write("-"*80 + "\n\n")

                for i, (question_type, record, raw_answer) in enumerate(failed_cases, 1):
                    f.write(f"Failure {i} ({question_type}):\n")
                    f.write(f"  Question: {record.original_question.question}\n")
                    f.write(f"  Ground Truth: {record.original_question.ground_truth}\n")
                    f.write(f"  Raw Output (answer field): '{raw_answer}'\n")
                    normalized = normalize_answer(raw_answer, valid_answers)
                    f.write(f"  Normalized: '{normalized if normalized else 'FAILED'}'\n")

                    if question_type == 'original':
                        resp_obj = record.original_question.reference_response
                    else:
                        resp_obj = record.counterfactual.reference_response

                    if resp_obj:
                        f.write(f"  Raw Response (full LLM output):\n")
                        f.write(f"    {resp_obj.raw_response}\n")

                        if resp_obj.parsed_response:
                            f.write(f"  Parsed Response (dict):\n")
                            for key, val in resp_obj.parsed_response.items():
                                val_str = str(val)[:200] if len(str(val)) > 200 else str(val)
                                f.write(f"    {key}: {val_str}\n")
                    f.write("\n")

                f.write("\n\n")

                print(f"\n{'-'*80}")
                print(f"MODEL: {model_name}")
                print(f"Total failures: {len(failed_cases)}")
                print(f"{'-'*80}")

                for i, (question_type, record, raw_answer) in enumerate(failed_cases[:max_examples], 1):
                    print(f"\nFailure Example {i} ({question_type}):")
                    print(f"  Question: {record.original_question.question[:200]}...")
                    print(f"  Ground Truth: {record.original_question.ground_truth}")
                    print(f"  Raw Output (answer field): '{raw_answer}'")
                    normalized = normalize_answer(raw_answer, valid_answers)
                    print(f"  Normalized: '{normalized if normalized else 'FAILED'}'")

                    if question_type == 'original':
                        resp_obj = record.original_question.reference_response
                    else:
                        resp_obj = record.counterfactual.reference_response

                    if resp_obj:
                        raw_resp = resp_obj.raw_response
                        if raw_resp and len(raw_resp) > 300:
                            print(f"  Raw Response: {raw_resp[:300]}...")
                        else:
                            print(f"  Raw Response: {raw_resp}")

                        if resp_obj.parsed_response:
                            print(f"  Parser extracted keys: {list(resp_obj.parsed_response.keys())}")
                            if 'answer' in resp_obj.parsed_response:
                                print(f"  Parser found answer: '{resp_obj.parsed_response['answer']}'")
                            else:
                                print(f"  Parser DID NOT find 'answer' key!")

                if len(failed_cases) > max_examples:
                    print(f"\n  ... and {len(failed_cases) - max_examples} more failures")
            else:
                f.write(f"{model_name}: No failures! ✓\n\n")
                print(f"\n{model_name}: No failures! ✓")

    print(f"\n✓ All failures saved to: {failure_file}")


def main():
    parser = argparse.ArgumentParser(
        description="Analyze scaling laws experiment results",
        formatter_class=argparse.RawDescriptionHelpFormatter
    )

    parser.add_argument(
        'experiment_folder',
        type=str,
        help='Path to experiment folder (e.g., experiments/scaling_laws/qwen3/run_20241111_123456)'
    )

    parser.add_argument(
        '--output',
        type=str,
        default=None,
        help='Output folder for analysis results (default: same as experiment folder)'
    )

    args = parser.parse_args()

    experiment_path = Path(args.experiment_folder)
    if not experiment_path.exists():
        print(f"Error: Experiment folder not found: {experiment_path}")
        return

    output_path = Path(args.output) if args.output else experiment_path
    output_path.mkdir(parents=True, exist_ok=True)

    print("="*80)
    print("SCALING LAWS ANALYSIS")
    print("="*80)
    print(f"Experiment folder: {experiment_path}")
    print(f"Output folder: {output_path}")
    print("="*80)

    parquet_files = list(experiment_path.glob("*_multi_model_responses.parquet"))

    if not parquet_files:
        print("\nError: No parquet files found in experiment folder")
        return

    print(f"\nFound {len(parquet_files)} dataset(s):")
    for pf in parquet_files:
        print(f"  - {pf.name}")

    all_distributions = []
    all_scaling = []

    for parquet_file in parquet_files:
        dataset_name = parquet_file.stem.replace('_multi_model_responses', '')

        print(f"\n{'='*80}")
        print(f"ANALYZING: {dataset_name}")
        print('='*80)

        db = CounterfactualDatabase.load_parquet(str(parquet_file))
        print(f"Loaded {len(db.records)} records")

        print("\nAnalyzing answer distributions...")
        dist_df = analyze_answer_distribution(db, dataset_name)
        dist_df['dataset'] = dataset_name
        all_distributions.append(dist_df)

        print_summary_statistics(dist_df, dataset_name)

        print_failure_examples(db, dataset_name, output_path, max_examples=5)

        output_csv = output_path / f"{dataset_name}_answer_distribution.csv"
        dist_df.to_csv(output_csv, index=False)
        print(f"\n✓ Saved detailed distribution to: {output_csv}")

        print("\nAnalyzing model scaling trends...")
        scaling_df = analyze_model_scaling(dist_df)
        scaling_df['dataset'] = dataset_name
        all_scaling.append(scaling_df)

        scaling_csv = output_path / f"{dataset_name}_scaling_analysis.csv"
        scaling_df.to_csv(scaling_csv, index=False)
        print(f"✓ Saved scaling analysis to: {scaling_csv}")

        print("\n" + "-"*80)
        print("MODEL SCALING SUMMARY:")
        print("-"*80)

        overall_df = scaling_df[scaling_df['scenario'] == 'all']
        if len(overall_df) > 0:
            print("\nOVERALL (all scenarios combined):")
            print(overall_df.to_string(index=False))

        for scenario in ['answer_first', 'answer_last']:
            scenario_df = scaling_df[scaling_df['scenario'] == scenario]
            if len(scenario_df) > 0:
                print(f"\n{scenario.upper().replace('_', ' ')}:")
                print(scenario_df.to_string(index=False))

    if all_distributions:
        combined_dist = pd.concat(all_distributions, ignore_index=True)
        combined_dist_csv = output_path / "combined_answer_distribution.csv"
        combined_dist.to_csv(combined_dist_csv, index=False)
        print(f"\n✓ Saved combined distribution to: {combined_dist_csv}")

    if all_scaling:
        combined_scaling = pd.concat(all_scaling, ignore_index=True)
        combined_scaling_csv = output_path / "combined_scaling_analysis.csv"
        combined_scaling.to_csv(combined_scaling_csv, index=False)
        print(f"✓ Saved combined scaling analysis to: {combined_scaling_csv}")

    print("\n" + "="*80)
    print("✓ ANALYSIS COMPLETE")
    print("="*80)
    print(f"\nAll results saved to: {output_path}")


if __name__ == "__main__":
    main()
