import asyncio
import json
import os
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Dict

from src.reference_answer_generation.reference_answer_generator import ReferenceAnswerGenerator
from src.schema import CounterfactualDatabase
from src.utils import LLMConfig, cleanup_after_model, get_model_name


@dataclass
class ExperimentConfig:
    llm_configs: List[LLMConfig]
    input_parquet: str
    output_folder: str = "experiments/scaling_laws"

    def to_dict(self):
        return {
            'llm_configs': [asdict(config) for config in self.llm_configs],
            'input_parquet': self.input_parquet,
            'output_folder': self.output_folder,
        }


class MultiLLMExperimentRunner:
    def __init__(self, config: ExperimentConfig, max_batch_size: int = 100):
        self.config = config
        os.makedirs(config.output_folder, exist_ok=True)

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.run_folder = Path(config.output_folder) / f"run_{timestamp}"
        self.run_folder.mkdir(parents=True, exist_ok=True)
        self.max_batch_size = max_batch_size

        config_path = self.run_folder / "experiment_config.json"
        with open(config_path, 'w') as f:
            json.dump(config.to_dict(), f, indent=2)

        print(f"Experiment folder: {self.run_folder}")

    async def run(self):
        print("="*80)
        print("MULTI-LLM REFERENCE ANSWER EXPERIMENT")
        print("="*80)
        print(f"Input: {self.config.input_parquet}")
        print(f"Models: {len(self.config.llm_configs)}")
        print("="*80)

        base_parquet = Path(self.config.input_parquet)

        if not base_parquet.exists():
            raise FileNotFoundError(f"Input parquet not found: {base_parquet}")

        db = CounterfactualDatabase.load_parquet(base_parquet)
        print(f"Loaded {len(db.records)} records from {base_parquet}")

        unique_datasets = {r.original_question.dataset for r in db.records}
        if len(unique_datasets) == 1:
            dataset_name = next(iter(unique_datasets))
            dataset_class = db.dataset_class_map[dataset_name]
            print(f"Detected dataset: {dataset_name}")
            print(f"Using dataset class: {dataset_class.__name__}")
        else:
            dataset_name = "combined"
            print(f"Detected dataset: {dataset_name} ({len(unique_datasets)} dataset types)")
            for ds in sorted(unique_datasets):
                print(f"  - {ds}")
        print("="*80)

        output_parquet = self.run_folder / f"{dataset_name}_multi_model_responses.parquet"

        model_databases = {}

        for llm_idx, llm_config in enumerate(self.config.llm_configs, 1):
            model_name = get_model_name(llm_config)

            print(f"\n{'-'*80}")
            print(f"Model {llm_idx}/{len(self.config.llm_configs)}: {model_name}")
            print(f"Full name: {llm_config.model_name}")
            if hasattr(llm_config, 'enable_reasoning') and llm_config.enable_reasoning is not None:
                print(f"Reasoning mode: {llm_config.enable_reasoning}")
            print(f"{'-'*80}")

            enhanced_db = await self._generate_model_responses(
                llm_config=llm_config,
                db=db
            )

            if enhanced_db is not None:
                model_databases[model_name] = enhanced_db
                print(f"✓ Completed {model_name}")
                self._save_multi_model_parquet(model_databases, output_parquet)
            else:
                print(f"✗ Failed {model_name}")

        if model_databases:
            self._save_multi_model_parquet(model_databases, output_parquet)
            print(f"\n✓ Saved multi-model responses to: {output_parquet}")
        else:
            print(f"\n✗ No successful model runs")

        print("\n" + "="*80)
        print("✓ GENERATION COMPLETE")
        print("="*80)
        print(f"Results saved to: {self.run_folder}")
        print("\nTo analyze results, use MultiLLMAnalyzer class")

    async def _generate_model_responses(
        self,
        llm_config: LLMConfig,
        db: CounterfactualDatabase,
    ) -> Optional[CounterfactualDatabase]:
        """
        Generate responses for all records using one model.
        Simply uses ReferenceAnswerGenerator to process the database,
        which handles batching, thinking models, cleanup, etc.

        Args:
            llm_config: LLM configuration
            db: Database with records to process

        Returns:
            New database with reference_response filled in, or None on failure
        """
        generator = None
        try:
            import tempfile
            with tempfile.NamedTemporaryFile(mode='w', suffix='.parquet', delete=False) as tmp_in:
                input_path = tmp_in.name
            with tempfile.NamedTemporaryFile(mode='w', suffix='.parquet', delete=False) as tmp_out:
                output_path = tmp_out.name

            db.save_parquet(input_path)

            generator = ReferenceAnswerGenerator(config=llm_config)

            enhanced_db = await generator.process_parquet(input_path, output_path, max_batch_size=self.max_batch_size)

            import os
            os.unlink(input_path)
            os.unlink(output_path)

            return enhanced_db

        except Exception as e:
            print(f"❌ ERROR: {e}")
            import traceback
            traceback.print_exc()
            return None

        finally:
            cleanup_after_model(generator)

    def _save_multi_model_parquet(
        self,
        model_databases: Dict[str, CounterfactualDatabase],
        output_path: Path
    ):
        """
        Save multi-model responses to a parquet file.
        
        Combines all model databases into one parquet file.
        Each record has the counterfactual's reference_response set to that model's response.
        Models are distinguished by the ModelInfo in the Response object.
        
        Args:
            model_databases: Dict mapping model names to databases with responses
            output_path: Where to save the parquet file
        """
        print("\n" + "="*80)
        print("CREATING MULTI-MODEL PARQUET")
        print("="*80)

        combined_db = CounterfactualDatabase()

        for model_name, db in model_databases.items():
            print(f"  Adding {len(db.records)} records from: {model_name}")
            combined_db.records.extend(db.records)

        print(f"✓ Combined {len(combined_db.records)} total records ({len(model_databases)} models)")

        combined_db.save_parquet(output_path)
        print(f"✓ Saved to {output_path}")
