import json
import logging
import pathlib
from typing import Optional, Dict, Any, Tuple, List

import pandas as pd

from common.constants import (
    MODELS_ORDER,
    OPENAI_LINEAGE, QWEN_LINEAGE, DEEPSEEK_LINEAGE, CLAUDE_LINEAGE,
    GEMINI_LINEAGE, GROK_LINEAGE, GLM_LINEAGE, MINIMAX_LINEAGE,
    KIMI_LINEAGE, SPARK_LINEAGE, DOUBAO_LINEAGE
)

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

DATASET_NAME_MAP = {
    "MATH-500": "MATH-500",
    "mmlu-redux": "MMLU-Redux",
    "simple-qa": "SimpleQA",
}
MODEL_NAME_MAP = {
    "deepseek-chat": "deepseek-V3",
    "deepseek-reasoner": "deepseek-R1",
    "x1": "spark-X1",
}

def pretty_dataset_name(name):
    return DATASET_NAME_MAP.get(name, name)

def pretty_model_name(name: str) -> str:
    if not isinstance(name, str):
        return name
    i = name.find("(")  
    if i == -1:
        base = name.strip()
        return MODEL_NAME_MAP.get(base, base)
    base = name[:i].strip()
    suffix = name[i:]  # "(IR)" "(COT)"
    return f"{MODEL_NAME_MAP.get(base, base)}{suffix}"
# ====================================


class AnalysisConfig:
    VALID_SAMPLES_NUM = {
        "MATH-500": 500,
        "mmlu-redux": 5399,
        "simple-qa": 4326,
    }

    TOKEN_USAGE_MAPPING = [
        (
            set(OPENAI_LINEAGE) | set(QWEN_LINEAGE) | set(DEEPSEEK_LINEAGE) |
            set(MINIMAX_LINEAGE) | set(KIMI_LINEAGE) | set(GLM_LINEAGE) |
            set(SPARK_LINEAGE) | set(DOUBAO_LINEAGE),
            {"input": "prompt_tokens", "output": "completion_tokens"}
        ),
        (set(CLAUDE_LINEAGE), {"input": "input_tokens", "output": "output_tokens"}),
        (set(GEMINI_LINEAGE), {"input": "prompt_tokens", "output": "output_tokens"}),
        (set(GROK_LINEAGE), {"input": "prompt_tokens", "output": "completion_tokens"}),
    ]


def check_results_integrity(root_dir: pathlib.Path) -> bool:
    processed_results_dir = root_dir / "processed"
    raw_results_dir = root_dir / "raw"

    if not processed_results_dir.exists():
        return False
    if not raw_results_dir.exists():
        print('raw_dir is not exists')
    unique_models = set() 
    manufacturers = set() 

    for dataset_dir in processed_results_dir.iterdir():
        if not dataset_dir.is_dir():
            continue

        dataset_name = dataset_dir.name
        valid_samples_num = AnalysisConfig.VALID_SAMPLES_NUM.get(dataset_name)
        if not valid_samples_num:
            continue

        total_models = 0
        raw_dataset_dir = raw_results_dir / dataset_name

        for model_dir in dataset_dir.iterdir():
            if not model_dir.is_dir():
                continue

            json_file = model_dir / 'processed_results.json'
            if not json_file.exists():
                continue

            total_models += 1
            with json_file.open('r', encoding='utf-8') as f:
                results = json.load(f)

            if raw_dataset_dir.exists():
                raw_model_dir = raw_dataset_dir / model_dir.name
                if not raw_model_dir.is_dir():
                    print('not raw_model_dir.is_dir()')
                else:
                    raw_file_count = len(list(raw_model_dir.iterdir()))
                        

            base_model_name = model_dir.name.split('_')[0].strip()
            unique_models.add(base_model_name)

            for lineage_name, lineage_set in {
                "OPENAI_LINEAGE": OPENAI_LINEAGE,
                "QWEN_LINEAGE": QWEN_LINEAGE,
                "DEEPSEEK_LINEAGE": DEEPSEEK_LINEAGE,
                "CLAUDE_LINEAGE": CLAUDE_LINEAGE,
                "GEMINI_LINEAGE": GEMINI_LINEAGE,
                "GROK_LINEAGE": GROK_LINEAGE,
                "GLM_LINEAGE": GLM_LINEAGE,
                "MINIMAX_LINEAGE": MINIMAX_LINEAGE,
                "KIMI_LINEAGE": KIMI_LINEAGE,
                "SPARK_LINEAGE": SPARK_LINEAGE,
                "DOUBAO_LINEAGE": DOUBAO_LINEAGE,
            }.items():
                if base_model_name in lineage_set:
                    manufacturers.add(lineage_name)

    return True

def _extract_token_usage(usage: Dict[str, Any], original_model_name: str) -> Tuple[int, int]:
    if not usage:
        return 0, 0

    config = next(
        (cfg for lineage, cfg in AnalysisConfig.TOKEN_USAGE_MAPPING if original_model_name in lineage),
        None
    )
    if config is None:
        return 0, 0

    input_tokens = usage.get(config["input"], 0) or 0
    output_tokens = usage.get(config["output"], 0) or 0

    if original_model_name in GEMINI_LINEAGE:
        output_tokens += usage.get("thoughts_tokens", 0) or 0
    if original_model_name in GROK_LINEAGE:
        output_tokens += usage.get("reasoning_tokens", 0) or 0

    return int(input_tokens), int(output_tokens)


def _get_model_name(result: Dict[str, Any]) -> str:
    model_name = result.get('model_name', 'unknown')
    if result.get('enable_intrinsic_reasoning', False):
        return f"{model_name}(IR)"
    if result.get('cot_reasoning', False):
        return f"{model_name}(CoT)"
    return model_name


def _parse_result_item(content: Dict[str, Any], question_hash: str, dataset_name_raw: str) -> Optional[Dict[str, Any]]:
    if not isinstance(content, dict) or 'evaluation' not in content or 'result' not in content:
        return None

    result_data = content['result']
    model_name_with_suffix = _get_model_name(result_data)            
    model_name_pretty = pretty_model_name(model_name_with_suffix)       
    dataset_pretty = pretty_dataset_name(dataset_name_raw)           

    is_correct = bool(content['evaluation'].get('correct', False))
    usage = result_data.get('usage', {})
    input_tokens, output_tokens = _extract_token_usage(usage, result_data.get('model_name', 'unknown'))

    if result_data.get('enable_intrinsic_reasoning', False):
        reasoning_type = "IR"
    elif result_data.get('cot_reasoning', False):
        reasoning_type = "w/ COT"
    else:
        reasoning_type = "w/o COT"

    return {
        "dataset": dataset_pretty,                            
        "question_hash": question_hash,
        "model_name": model_name_pretty,                  
        "is_correct": is_correct,
        "error": 0 if is_correct else 1,
        "reasoning_type": reasoning_type,
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "original_model_name": result_data.get('model_name', 'unknown'),  
    }


def generate_all_results_raw(root_dir: pathlib.Path, output_dir: pathlib.Path) -> pathlib.Path:
    processed_results_dir = root_dir / "processed"
    output_dir.mkdir(parents=True, exist_ok=True)

    records: List[Dict[str, Any]] = []

    if not processed_results_dir.exists():
        return output_dir / "all_results_raw.csv"

    for dataset_dir in sorted(processed_results_dir.iterdir()):
        if not dataset_dir.is_dir():
            continue

        dataset_name_raw = dataset_dir.name
        for model_dir in sorted(dataset_dir.iterdir()):
            if not model_dir.is_dir():
                continue
            json_file = model_dir / "processed_results.json"
            if not json_file.exists():
                continue

            try:
                with json_file.open("r", encoding="utf-8") as f:
                    results = json.load(f)
            except Exception as e:
                continue

            for item in results:
                # item: {question_hash: content}
                if not isinstance(item, dict):
                    continue
                for question_hash, content in item.items():
                    rec = _parse_result_item(content, question_hash, dataset_name_raw)
                    if rec:
                        records.append(rec)

    if not records:
        df = pd.DataFrame(columns=[
            "dataset", "question_hash", "model_name", "is_correct", "error",
            "reasoning_type", "input_tokens", "output_tokens", "original_model_name"
        ])
    else:
        def sort_key(record: Dict[str, Any]):
            name = record["model_name"]
            if name.endswith("(IR)"):
                suffix_priority = 0
            elif name.endswith("(COT)"):
                suffix_priority = 1
            else:
                suffix_priority = 2
            base_rank = MODELS_ORDER.get(record["original_model_name"], 9999)
            return (suffix_priority, base_rank, record["dataset"], record["question_hash"])

        records.sort(key=sort_key)
        df = pd.DataFrame.from_records(records)

    out_path = output_dir / "all_results_raw.csv"
    df.to_csv(out_path, index=False)
    return out_path


if __name__ == "__main__":
    import click

    @click.command()
    @click.option("--root-dir", required=True, type=click.Path(path_type=pathlib.Path, exists=True),
                  help="Root directory containing model results (expects processed/ and optionally raw/).")
    @click.option("--output-dir", required=True, type=click.Path(path_type=pathlib.Path),
                  help="Directory to save all_results_raw.csv")
    def main(root_dir: pathlib.Path, output_dir: pathlib.Path):
        check_results_integrity(root_dir)
        generate_all_results_raw(root_dir, output_dir)

    main()
