from util.interface import CausalModel, PromptGenerator
from pathlib import Path
import json
from util.scenario import Scenario
from random import sample as sample_without_replacement

N_prompts_for_same_value = 10

class PromptGeneratorV1(PromptGenerator):
    def __init__(self, model: CausalModel):
        super().__init__()
        self.model = model
        self.config_path = Path(__file__).resolve().parent.parent / 'config'
        self.scenario_index = Scenario.get_index(self.model.scenario)
        json_prefix = self.model.scenario.split("%")[0]
        json_path = self.config_path / "samples" / f"{json_prefix}.json"
        with open(json_path, "r") as json_file:
            data = json.load(json_file)
        self.factors = data["roots"]
        self.samples = {}
        self.indexs = {}
        for composition in data["compositions"]:
            self.samples[tuple(composition["factor_value"])] = composition["samples"]
            self.indexs[tuple(composition["factor_value"])] = 0
        json_full_path = self.config_path / "samples_with_results" / f"{json_prefix}.json"
        with open(json_full_path, "r") as json_full_file:
            data_full = json.load(json_full_file)
        self.samples_full = {}
        self.indexs_full = {}
        for composition in data_full["compositions"]:
            self.samples_full[tuple(composition["factor_value"])] = composition["samples"]
            self.indexs_full[tuple(composition["factor_value"])] = 0

    def generate_prompt(self, scenario, variables, **kwargs):
        values = [variables[name] for name in self.factors]
        if self.model.non_roots[0] not in variables:
            index = self.indexs[tuple(values)]
            prompt = self.samples[tuple(values)][index]
            self.indexs[tuple(values)] = (index + 1) % N_prompts_for_same_value
        else:
            index = self.indexs_full[tuple(values)]
            prompt = self.samples_full[tuple(values)][index]
            self.indexs_full[tuple(values)] = (index + 1) % N_prompts_for_same_value
        return prompt