from typing import Optional, List, Tuple
import itertools
import random

class ObservationNode:
    def __init__(self, problem: str, layer_num: int, prev_observation_strs: Tuple[str, ...] = (), num_observations_to_generate: int = 10, max_observation_k: int = 2) -> None:
        self.problem = problem 
        self.layer_num = layer_num

        self.prev_observation_strs = prev_observation_strs
        self.derived_observations: Optional[List[str]] = None
        self.next_observation_nodes: Optional[List[ObservationNode]] = None

        self.num_observations_to_generate = num_observations_to_generate
        self.max_observation_k = max_observation_k

        self.log = {}

        assert self.layer_num >= 0
        if self.layer_num == 0:
            assert len(prev_observation_strs) == 0
    
    def has_expanded(self) -> bool:
        assert (self.derived_observations is None) == (self.next_observation_nodes is None)
        return self.derived_observations is not None

    def collect_logs(self):
        if not self.has_expanded():
            return None
        else:
            childrens_logs = [obs.collect_logs() for obs in self.next_observation_nodes]
            assert all(log is None for log in childrens_logs) or \
                    all(isinstance(log, list) for log in childrens_logs) or \
                    all(isinstance(log, dict) for log in childrens_logs)
            
            if all(log is None for log in childrens_logs):
                return self.log
            else:
                return childrens_logs

    def collect_highest_level_problem_obs(self) -> list:
        if not self.has_expanded():
            return [(self.problem, self.prev_observation_strs)]
        else:
            return [obs.collect_highest_level_problem_obs() for obs in self.next_observation_nodes]
        
    def collect_all_problem_obs(self, repeated_elements: Optional[set] = None) -> list:
        if repeated_elements is None:
            repeated_elements = set()

        if not self.has_expanded():
            children_problem_obs = []
        else:
            children_problem_obs = [obs.collect_all_problem_obs(repeated_elements) for obs in self.next_observation_nodes]

        if self.prev_observation_strs not in repeated_elements:
            repeated_elements.add(self.prev_observation_strs)
            children_problem_obs += [(self.problem, self.prev_observation_strs)]

        return children_problem_obs

    def split_into_observation_combos(self, observations: Tuple[str, ...]) -> List[Tuple[str, ...]]:
        max_k = min(self.max_observation_k, len(observations))

        observation_combos = []
        for k in range(max_k+1):
            observation_combos.extend(itertools.combinations(observations, k))
        
        random.shuffle(observation_combos)
        return observation_combos

    def attribute_new_observations(self, observation_lists: list, observation_logs: list):
        assert len(observation_lists) == len(observation_logs)
        if self.has_expanded():
            assert len(self.next_observation_nodes) == len(observation_lists)
            for i in range(len(self.next_observation_nodes)):
                self.next_observation_nodes[i].attribute_new_observations(observation_lists[i], observation_logs[i])
        else:
            assert len(observation_lists) == 1
            assert isinstance(observation_logs[0], dict)
            self.log = observation_logs[0]

            observation_strs = observation_lists[0]
            self.derived_observations = list(observation_strs)
            observation_combos = self.split_into_observation_combos(observation_strs)

            self.next_observation_nodes = []

            for combo in observation_combos:
                new_node = self.create_node_like(combo)
                self.next_observation_nodes.append(new_node)
           
    def create_node_like(self, observation_combo: Tuple[str, ...]) -> "ObservationNode":
        return ObservationNode(self.problem, self.layer_num + 1, observation_combo, num_observations_to_generate=self.num_observations_to_generate, max_observation_k=self.max_observation_k) 
