from dataclasses import dataclass, asdict
from typing import Optional, List, Dict, Any, Type
from tqdm import tqdm
import pandas as pd
import pyarrow.parquet as pq
import pyarrow as pa
from pathlib import Path
from src.templates.heart_disease import HeartDisease
from src.templates.pima_diabetes import PimaDiabetes
from src.templates.breast_cancer_recurrence import BreastCancerRecurrence
from src.templates.multiple_choice_dataset import MultipleChoiceDataset
from src.templates.trait import Trait
from src.templates.income import IncomeDataset
from src.templates.attrition import AttritionDataset
from src.templates.moral_machines import MoralMachines
from src.templates.bank_marketing import BankMarketing
from src.templates.bbq_dataset import BBQDataset
import random

@dataclass
class ModelInfo:
    model: Optional[str] = None
    temperature: Optional[float] = None
    max_tokens: Optional[int] = None
    thinking: Optional[str] = None
    seed: Optional[int] = None
    additional_params: Optional[dict] = None


ReferenceModelInfo = ModelInfo
PredictorModelInfo = ModelInfo


@dataclass
class Response:
    cot: Optional[str] = None
    raw_response: Optional[str] = None
    parsed_response: Optional[Dict[str, Any]] = None
    answer: Optional[str] = None
    model_info: Optional[ModelInfo] = None
    predictor_answers: Optional[List[Optional[str]]] = None
    predictor_names: Optional[List[str]] = None
    input_tokens: Optional[int] = None
    reasoning_tokens: Optional[int] = None
    output_tokens: Optional[int] = None


@dataclass
class OriginalQuestion:
    dataset: str
    question: str
    question_prompt: str
    question_idx: int
    ground_truth: Optional[str] = None
    answer_first: Optional[bool] = None
    description: Optional[str] = None
    question_options: Optional[dict] = None
    reference_response: Optional[Response] = None


@dataclass
class CounterfactualInfo:
    generator_model: str
    generator_method: str
    question: str
    question_prompt: str
    generator_model_info: Optional[ModelInfo] = None
    generator_model_cot: Optional[str] = None
    generator_model_raw: Optional[str] = None
    question_idx: Optional[int] = None
    ground_truth: Optional[str] = None
    description: Optional[str] = None
    coherence_scored_by_generator: Optional[bool] = None
    coherence_explanation_by_generator: Optional[str] = None
    coherence_external_scoring_model: Optional[str] = None
    coherence_scored_by_external_model: Optional[bool] = None
    coherence_explanation_by_external_model: Optional[str] = None
    hamming_distance: Optional[int] = None
    question_options: Optional[dict] = None

    reference_response: Optional[Response] = None

    prompt_with_explanation: Optional[str] = None
    prompt_without_explanation: Optional[str] = None

    predictor_response_with_explanation: Optional[Response] = None
    predictor_response_without_explanation: Optional[Response] = None

    predictor_counterfactual_testability_score: Optional[float] = None
    predictor_counterfactual_testability_cot: Optional[str] = None

    is_cross_model_explanation: Optional[bool] = None
    explanation_source_model_info: Optional[ModelInfo] = None

@dataclass
class MatchInfo:
    match_with_explanation: Optional[int] = None
    match_without_explanation: Optional[int] = None
    match_delta: Optional[int] = None


@dataclass
class FaithfulnessRecord:
    original_question: OriginalQuestion
    counterfactual: CounterfactualInfo
    match_info: Optional[MatchInfo] = None

    def to_flat_dict(self):
        flat = {}

        if self.original_question is not None:
            for k, v in asdict(self.original_question).items():
                if k == 'reference_response' and isinstance(v, dict):
                    for response_key, response_val in v.items():
                        if response_key == 'model_info' and isinstance(response_val, dict):
                            for model_key, model_val in response_val.items():
                                flat[f"original_reference_response_model_info_{model_key}"] = model_val
                        else:
                            flat[f"original_reference_response_{response_key}"] = response_val
                else:
                    flat[f"original_{k}"] = v

        if self.counterfactual is not None:
            for k, v in asdict(self.counterfactual).items():
                if k == 'generator_model_info' and isinstance(v, dict):
                    for model_key, model_val in v.items():
                        flat[f"counterfactual_generator_model_info_{model_key}"] = model_val
                    continue
                if k == 'explanation_source_model_info' and isinstance(v, dict):
                    for model_key, model_val in v.items():
                        flat[f"counterfactual_explanation_source_model_info_{model_key}"] = model_val
                    continue
                if k in ['reference_response', 'predictor_response_with_explanation', 'predictor_response_without_explanation'] and isinstance(v, dict):
                    for response_key, response_val in v.items():
                        if response_key == 'model_info' and isinstance(response_val, dict):
                            for model_key, model_val in response_val.items():
                                flat[f"counterfactual_{k}_model_info_{model_key}"] = model_val
                        else:
                            flat[f"counterfactual_{k}_{response_key}"] = response_val
                else:
                    flat[f"counterfactual_{k}"] = v

        if self.match_info is not None:
            for k, v in asdict(self.match_info).items():
                flat[f"match_{k}"] = v

        return flat


class CounterfactualDatabase:
    dataset_class_map = {
            'heart_disease': HeartDisease,
            'pima_diabetes': PimaDiabetes,
            'breast_cancer_recurrence': BreastCancerRecurrence,
            'trait': Trait,
            'multiple_choice_dataset': MultipleChoiceDataset,
            'income': IncomeDataset,
            'attrition': AttritionDataset,
            'moral_machines': MoralMachines,
            'bank_marketing': BankMarketing,
            'bbq': BBQDataset,
        }

    def __init__(self):
        self.records: List[FaithfulnessRecord] = []

    def add_record(self, record: FaithfulnessRecord) -> None:
        """Add a new record and automatically assign indices."""

        cf_question_indices = set(
            r.counterfactual.question_idx for r in self.records
            if r.counterfactual.question_idx is not None
        )
        max_index = max(cf_question_indices) if cf_question_indices else 100000000
        if record.counterfactual.question_idx is None:
            record.counterfactual.question_idx = max_index + 1
        self.records.append(record)


    def to_dataframe(self) -> pd.DataFrame:
        return pd.DataFrame([r.to_flat_dict() for r in self.records])


    def save_parquet(self, path: str | Path) -> None:
        """Save the entire database to a Parquet file."""
        df = self.to_dataframe()
        pq.write_table(pa.Table.from_pandas(df), path)


    @classmethod
    def load_parquet(cls, path: str | Path) -> "CounterfactualDatabase":
        """Load a CounterfactualDatabase from a Parquet file."""
        from .schema import (
            FaithfulnessRecord,
            OriginalQuestion,
            CounterfactualInfo,
            MatchInfo,
            Response,
            ModelInfo,
        )

        df = pq.read_table(path).to_pandas()
        db = cls()

        rows = df.to_dict('records')
        db.records = [None] * len(rows)

        for i, row in enumerate(tqdm(rows, desc="Loading records")):
            def reconstruct_response(prefix):
                response_fields = {}
                model_info_fields = {}

                for k, v in row.items():
                    if k.startswith(f"{prefix}_model_info_"):
                        model_key = k[len(f"{prefix}_model_info_"):]
                        model_info_fields[model_key] = v
                    elif k.startswith(f"{prefix}_"):
                        response_key = k[len(f"{prefix}_"):]
                        if response_key != 'model_info':
                            response_fields[response_key] = v

                if response_fields:
                    response_fields['model_info'] = ModelInfo(**model_info_fields) if model_info_fields else None
                    return Response(**response_fields)
                return None

            original_fields = {}
            for k, v in row.items():
                if k.startswith("original_") and not k.startswith("original_reference_response_"):
                    original_fields[k[len("original_"):]] = v
            original_fields['reference_response'] = reconstruct_response("original_reference_response")
            original_question = OriginalQuestion(**original_fields)

            counterfactual_fields = {}
            for k, v in row.items():
                if k.startswith("counterfactual_generator_model_info_"):
                    continue
                if k.startswith("counterfactual_explanation_source_model_info_"):
                    continue
                if k.startswith("counterfactual_") and not any(
                    k.startswith(f"counterfactual_{resp}_")
                    for resp in ['reference_response', 'predictor_response_with_explanation', 'predictor_response_without_explanation']
                ):
                    counterfactual_fields[k[len("counterfactual_"):]] = v
            generator_model_info_fields = {}
            for k, v in row.items():
                if k.startswith("counterfactual_generator_model_info_"):
                    model_key = k[len("counterfactual_generator_model_info_"):]
                    generator_model_info_fields[model_key] = v
            counterfactual_fields['generator_model_info'] = ModelInfo(**generator_model_info_fields) if generator_model_info_fields else None
            explanation_source_model_info_fields = {}
            for k, v in row.items():
                if k.startswith("counterfactual_explanation_source_model_info_"):
                    model_key = k[len("counterfactual_explanation_source_model_info_"):]
                    explanation_source_model_info_fields[model_key] = v
            counterfactual_fields['explanation_source_model_info'] = ModelInfo(**explanation_source_model_info_fields) if explanation_source_model_info_fields else None
            counterfactual_fields['reference_response'] = reconstruct_response("counterfactual_reference_response")
            counterfactual_fields['predictor_response_with_explanation'] = reconstruct_response("counterfactual_predictor_response_with_explanation")
            counterfactual_fields['predictor_response_without_explanation'] = reconstruct_response("counterfactual_predictor_response_without_explanation")
            counterfactual = CounterfactualInfo(**counterfactual_fields)

            match_fields = {}
            for k, v in row.items():
                if k.startswith("match_"):
                    match_fields[k[len("match_"):]] = v
            match_info = MatchInfo(**match_fields) if match_fields else None

            record = FaithfulnessRecord(
                original_question=original_question,
                counterfactual=counterfactual,
                match_info=match_info,
            )
            db.records[i] = record
        return db
