from pydantic import BaseModel
from typing import List, Dict, Any, Optional, Union, Literal, Callable
import pandas as pd
import json
import functools
from functools import cached_property

from typing import Literal

from radar.data import perturb

ArtifactType = Literal[
    "bad-values",
    "inconsistent-formatting",
    "inconsistent-commonsense-logic",
    "missingness",
    "outliers",
    "clean",
]

ArtifactScope = Literal[
    "single-column", "naive-multi-column", "connected-multi-column", "clean"
]


class TaskInstance(BaseModel):
    task_id: str
    query: str
    artifact_types: List[ArtifactType]
    artifact_scope: ArtifactScope
    query_cols: List[str]
    artifact_reasoning_cols: List[str]
    clean_data: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]
    perturbation_spec: perturb.TableDeltaSpec
    base_data_num_tokens: int
    base_data_token_bucket: int
    expected_type: Optional[Literal["string", "number", "date", "boolean"]] = None
    answer: Optional[Any] = None
    perturbation_note: Optional[str] = None
    _clean_data_df: Optional[Union[pd.DataFrame, List[pd.DataFrame]]] = None
    _perturbed_data_df: Optional[pd.DataFrame] = None

    @classmethod
    def from_json(cls, path: str) -> "TaskInstance":
        data = json.load(open(path))
        return cls(**data)

    def model_post_init(self, __context: Any) -> None:
        if isinstance(self.clean_data, list) and isinstance(self.clean_data[0], list):
            self._clean_data_df = [pd.DataFrame(df) for df in self.clean_data]
            clean_data_df = self._clean_data_df[0]
        else:
            self._clean_data_df = pd.DataFrame(self.clean_data)
            clean_data_df = self._clean_data_df
        self._perturbed_data_df = perturb.apply_transform_spec(
            clean_data_df, self.perturbation_spec
        )

    def get_answer_perturbed(self) -> Any:
        func = self.get_callable()
        if func is None:
            return None
        return func(self._perturbed_data_df)

    @cached_property
    def num_rows(self) -> int:
        return len(self._perturbed_data_df)

    @cached_property
    def num_cols(self) -> int:
        return len(self._perturbed_data_df.columns)

    @property
    def clean_data_df(self) -> Union[pd.DataFrame, List[pd.DataFrame]]:
        return self._clean_data_df

    @property
    def perturbed_data_df(self) -> pd.DataFrame:
        return self._perturbed_data_df

    @functools.cached_property
    def num_rows(self) -> int:
        return len(self._perturbed_data_df)

    @functools.cached_property
    def num_cols(self) -> int:
        return len(self._perturbed_data_df.columns)

    def get_prompt_info(self) -> Dict[str, Any]:
        if self.perturbed_data_df is None:
            raise ValueError("Perturbed data is not available.")
        df = self.perturbed_data_df
        return {
            "table": df.to_csv(index=False) if df is not None else "",
            "question": self.query,
        }

    def get_prompt_info_codegen_agent(self) -> Dict[str, Any]:
        return {
            "table": self.perturbed_data_df,
            "question": self.query,
        }
