import random
from dataclasses import dataclass

from logit_lens_compositionality.prompting import InContextQuery
from logit_lens_compositionality.tasks import CompositionalTask, Task, TaskT
from scripts.experiment_utils import Experiment, SlurmJob, Step, Sweep, step


@step(cacheable=True, version="001")
def generate_dataset(task_name: TaskT, seed: int = 0) -> list[CompositionalTask]:
    dataset = Task(task_name).build_dataset()
    random.seed(seed)
    random.shuffle(dataset)
    return dataset


@step(cacheable=True, version="001")
def generate_in_context_queries(
    dataset: list[CompositionalTask],
    icl_examples: int = 10,
    seed: int = 0,
) -> list[InContextQuery]:
    random.seed(seed)
    in_context_queries = []
    for query in dataset:
        context = []
        while len(context) < icl_examples:
            example = random.choice(dataset)
            if example not in context and example != query and not CompositionalTask.overlap(example, query):
                context.append(example)
        in_context_queries.append(InContextQuery(context=context, query=query))
    return in_context_queries


@dataclass
class GenerateDataExperiment(Experiment):
    task_name: TaskT
    icl_examples: int = 10
    seed: int = 0

    @property
    def step_dict(self) -> dict[str, Step]:
        steps = {}
        steps["dataset"] = generate_dataset(task_name=self.task_name, seed=self.seed)
        steps["in_context_queries"] = generate_in_context_queries(
            dataset=steps["dataset"],
            icl_examples=self.icl_examples,
            seed=self.seed,
        )
        return steps

    @property
    def slurm_job(self) -> SlurmJob | None:
        return SlurmJob(
            partition="batch",
            time_min=60,
            num_nodes=1,
            mem_per_node=32,
            cpus_per_node=2,
            gpus_per_node=0,
        )

    def results(self) -> dict:
        dataset = self.step_result("dataset")
        random.seed(0)
        random_sample = random.sample(dataset, 1)[0]
        return {
            "x": random_sample.x,
            "Fx": random_sample.Fx,
            "Gx": random_sample.Gx,
            "GFx": random_sample.GFx,
            "FGx": random_sample.FGx,
            "dataset_size": len(dataset),
        }


@dataclass
class GenerateDataSweep(Sweep[GenerateDataExperiment]):
    tasks: list[TaskT]
    icl_examples: int = 10
    seed: int = 0

    @property
    def experiments(self) -> list[GenerateDataExperiment]:
        return [
            GenerateDataExperiment(
                task_name=task_name,
                icl_examples=self.icl_examples,
                seed=self.seed,
            )
            for task_name in self.tasks
        ]


if __name__ == "__main__":
    GenerateDataExperiment.cli()
