from itertools import product
from util.scenario import Scenario, RulesDict

class VideoGenerator:
    def generate_video_from_prompt(self, prompt, save_path, **kwargs):
        """
        This function should return nothing, just generate the video from prompt and
            save the video to save_path.
        Example:
        prompt = "A bird is flying."
        save_path = "./output/bird_fly.mp4"
        """
        raise NotImplementedError
    
    def generate_video_from_multi_prompts(self, prompts_file, save_paths, **kwargs):
        """
        This function should return nothing, just generate the video from prompts loaded
            from file "prompts_file" (JSON list of strings) and save the video to save_path
            in save_paths respectively.
        Example:
        prompts_file = "prompts.json"  # ["A bird is flying.", "A cat is jumping."]
        save_path = ["./output/bird_fly.mp4", "./output/bird_fly2.mp4"]
        """
        raise NotImplementedError
    

class PromptGenerator:
    def generate_prompt(self, scenario, variables, **kwargs) -> str:
        """
        This function should return the prompt for generating the video.
        scenario is a string, representing the basic prompt without intervention.
        variables is a dictionary, with keys represeting variables, and values in {True, False}.
        Example:
        scenario = "throwing things into the water"
        variables = {
            "high speed": True,
            "streamlined shape": True,
            "large size": False,
            "spin present": True,
            "dense object": True,
            "wind present": False,
            "background objects present": True,
            "sunlight present": False
        }
        """
        raise NotImplementedError
        return "propmt"
    

class VideoAsker:
    def ask_for_variables(self, video_path, variable_list, **kwargs) -> dict:
        """
        This function should return a dictionary representing whether each variable in the
            variable_list exists in the video, saved in video_path. Values are in {True, False}.
        Example:
        video_path = "./output/throw_rock.mp4"
        variables = ["high speed", "streamlined shape", "large size", "spin present",
                    "dense object", "wind present", "background objects present",
                    "sunlight present", "splash", "object sink", "disturbed surface",
                    "reflection affected", "water sparkling"]
        A possible return value is as follows.
        """
        raise NotImplementedError
        return {variable: True for variable in variable_list}

class CausalModel:
    """
    This class is a data structure containing root, non-root variables and rules for 
        non-root variables.
    RulesDict are represented as a dictionary with non-root variables as keys. For each
        non-root varibale, its rule is represented as a tuple of dictionaries. The
        terms in the tuple are with "or" relation, while the inner terms are with "and"
        relation. See the example for a better image.
    Example:
    roots = ["high speed", "large size", "wind present"]
    non_roots = ["splash", "disturbed surface"]
    rules = {
        "splash": ({"high speed": True, "large size": True},),
        "disturbed surface": ({"splash": True}, {"wind present": True})
    }
    scenario = "throwing things into the water"
    The rules are explained as follows. Only when both high speed and large size are
        true, splash is true. So the only term in rules["splash"] is a dictionary with
        "high speed" and "large size" as keys and True as values. Differently, when 
        splash is true or wind present is true, in both cases disturbed surface is true.
        So there are two terms in rules["disturbed surface"].
    """
    def __init__(self, roots: list[str] | None = None, non_roots: list[str] | None = None,
                  rules: RulesDict | None = None, scenario: str | None = None):
        if roots is None:
            if scenario is None:
                raise ValueError("roots or scenario should not be None.")
            roots, non_roots, rules = Scenario.get_basic_data(scenario=scenario)
        self.roots = roots
        if non_roots is None:
            non_roots = []
        self.non_roots = non_roots
        if rules is None:
            rules = {}
        self.rules = rules
        # topological sort self.non_roots
        self.non_roots = self.topo_sorted_non_roots()
        self.variables = self.roots + self.non_roots
        self.scenario = scenario
        self._check()
        
    def _check(self) -> None:
        # 1. check if all non_roots are in rules
        for non_root in self.non_roots:
            if non_root not in self.rules:
                raise ValueError(f"non_root {non_root} is not in rules.")
        # 2. check if all parents in rules are in non_roots
        for head in self.rules:
            if head in self.roots:
                raise ValueError(f"root {head} is in rules.")
        # 3. check if all parents in rules are in variables
        for head in self.rules:
            for rule_term in self.rules[head]:
                for parent in rule_term:
                    if parent not in self.variables:
                        raise ValueError(f"parent {parent} in rules[{head}] is not in variables.")
        # 4. check there is no cycle in rules
        #TODO: implement cycle check

    def _transform_rules(self) -> RulesDict:
        m = len(self.non_roots)
        res = [None] * m
        non_roots_dict = {non_root: index for index, non_root in enumerate(self.non_roots)}
        stack = []
        for ind in range(m):
            stack.append()
        raise NotImplementedError # TODO: transform pa(Y)->Y to X->Y

    def topo_sorted_non_roots(self) -> list[str]:
        """
        Return a sorted list of non-root variables with topological order in rules.
        """
        m = len(self.non_roots)
        res = []
        non_roots_dict = {non_root: index for index, non_root in enumerate(self.non_roots)}
        visited = [False] * m
        def dfs(ind):
            if visited[ind]:
                return
            non_root = self.non_roots[ind]
            rule = self.rules[non_root]
            for rule_term in rule:
                for parent in rule_term:
                    if parent in non_roots_dict:
                        dfs(non_roots_dict[parent])
            res.append(self.non_roots[ind])
            visited[ind] = True
        for i in range(m):
            dfs(i)
        return res
    
    def _get_full_table(self):
        # See ../test/test_sampler.py -> test_generate_all_sample_table for an example.
        m = len(self.non_roots)
        var_index = {variable: index for index, variable in enumerate(self.variables)}
        res = []
        for root_values in product([False, True], repeat=len(self.roots)):
            row = list(root_values) + [False] * m
            for non_root in self.non_roots:
                # Y=1 if one of rules holds
                for rule_term in self.rules[non_root]:
                    if all(row[var_index[parent]] == expected_value
                           for parent, expected_value in rule_term.items()):
                        row[var_index[non_root]] = True
            res.append(row)
        return res
    
    def get_non_root_value(self, root_values_dict: dict[str, bool]) -> dict[str, bool]:
        res_dict = root_values_dict.copy()
        for non_root in self.non_roots:
            for rule_term in self.rules[non_root]:
                if all(res_dict[parent] == expected_value
                        for parent, expected_value in rule_term.items()):
                    res_dict[non_root] = True
                    break
            else:
                res_dict[non_root] = False
        return {non_root: res_dict[non_root] for non_root in self.non_roots}