from __future__ import annotations

import numpy as np
import gymnasium as gym
from ppo.policy_sb3 import train_policy, test_policy
from refinement.utils import CacheStates, train_model
from refinement.goal import Goal, ModifiedGoal, NewGoal
from refinement.avoid import Avoid
import matplotlib.pyplot as plt
from scipy.spatial import ConvexHull

class Node():

    def __init__(self, goal:np.ndarray, splittable:bool = True, final:bool=False, name:str = ""):
        self.goal = goal
        self.splittable = splittable
        self.children = {}
        self.final = final
        self.name = name
        self.idx = 0 

    def sample_state(self):
        return self.goal.sample_state()
    
    def __iter__ (self):
        return self

    def add_child(self, child:Node, avoid=None):
        self.children[id(child)] = {
            "child": child, 
            "reach_probability": 0, 
            "policy": None,
            "avoid": avoid
        }

    def __next__(self):
        
        keys = list(self.children.keys())
        if self.idx == len(keys):
            raise StopIteration
        else:
            self.idx+=1
            return self.children[keys[self.idx-1]]

    def remove_child(self, child:Node):
        self.children.pop(id(child))
    
    def print_graph(self):
        pass

def split_goal(goal:Goal, cached_states:CacheStates):

    hull = train_model(cached_states)

    goal_r = ModifiedGoal(
        x = goal.x, 
        y = goal.y,
        height = goal.height,
        width = goal.width,
        hull = hull,
        reachable = True
    )
    
    return goal_r

def add_goal(goal, trajectories):
    
    goal_trajectories = list(filter(lambda x: x[1], trajectories))
    points = []
    for traj, _ in goal_trajectories:
        points.append(traj[len(traj)//2])
    points = np.array(points)[:, goal.index_range]
    hull = ConvexHull(points)
    new_goal = NewGoal(lower_bound=goal.lower_bound, upper_bound=goal.upper_bound, index_range=goal.index_range, hull=hull)
    return new_goal
    
    

def add_avoid_region(avoid, trajectories: list, k: int = 1):
    
    map_region_points = {}
    violations = 0
    for region in avoid.list_of_regions:
        map_region_points[region] = []
    
    for trajectory in trajectories:
        for i in range(len(trajectory[0])-1):
            region = avoid.check_region(trajectory[0][i], trajectory[0][i+1])
            if region is not None:
                # print("avoided")
                
                map_region_points[region].extend(trajectory[0][max(i-k+1, 0) :i+1])
                violations+=1
    
    for region, points in map_region_points.items():
        if len(points) >= 3:
            region.extend_region(points)
       
    return violations/len(trajectories)
    


def depth_first_traversal(head: Node, train_env: gym.Env, test_env: gym.Env, minimum_reach: float = 0.9, n_episodes: int = 3000, n_episodes_test: int = 3000, path: str = ""):

    edges = []
    file = open(path + "/result.txt", "w")
    explore(head, train_env, test_env, minimum_reach, edges, n_episodes, n_episodes_test, file, policies = [])

def explore(parent: Node, train_env: gym.Env, test_env: gym.Env, minimum_reach: float = 0.9, edges: list = [], n_episodes: int = 3000, n_episodes_test: int = 3000, file = None, policies = []):

    if parent.final:
        return False
    
    for child in parent:
        if parent.name+"_"+child['child'].name not in edges:
            
            # parent.goal.plot()
            # child['child'].goal.plot()
            # plt.show()
            
            print(f"Evaluating edge ({parent.name}, {child['child'].name})")
            policy = train_policy(train_env, parent, child['child'], child['avoid'], n_episodes, minimum_reach, policies)
            reach, cached_states, trajectories = test_policy(policy, test_env, parent, child['child'], child['avoid'], n_episodes_test)

            print(f"Edge ({parent.name}, {child['child'].name}) reach probability: {reach}")
            
            print(f"Edge ({parent.name}, {child['child'].name}) reach probability: {reach}", file = file)
            if reach < minimum_reach and child['child'].splittable:

                print(f"Edge ({parent.name}, {child['child'].name}) not realised: {reach}")
                
                # print("Violations: ", add_avoid_region(child['avoid'], trajectories, k=1))
                
                # goal_r = split_goal(goal = child['child'].goal, cached_states = cached_states)
                new_goal = add_goal(child['child'].goal, trajectories)
                import pickle
                pickle.dump(new_goal.hull, open(f"{parent.name}_{child['child'].name}_hull.pkl", "wb"))
                
                goal_node = Node(
                    goal = new_goal, 
                    splittable=False,
                    final = False,
                    name = child['child'].name + "_new"
                )

                # goal_r_node = Node(
                #     goal = goal_r, 
                #     splittable=False,
                #     final = child['child'].final,
                #     name = child['child'].name + "_r"
                # )
                
                
                # goal_node_avoid = Node(
                #     goal = child['child'].goal, 
                #     splittable=False,
                #     final = child['child'].final,
                #     name = child['child'].name + "_avoid"
                # )


                # parent.add_child(goal_node_avoid, avoid = child['avoid'])
                parent.add_child(goal_node)
                goal_node.add_child(child['child'], avoid = child['avoid'])
                # parent.add_child(goal_r_node, avoid = child['avoid'])
                # grandparent.add_child(goal_r_node)
            
            new_policies = policies.copy()
            new_policies.append((policy, child['child']))
            parent.children[id(child['child'])]['reach_probability'] = reach
            parent.children[id(child['child'])]['policy'] = policy
            edges.append(parent.name+"_"+child['child'].name)

            del cached_states
            status = explore(child['child'], train_env, test_env , minimum_reach, edges, n_episodes, n_episodes_test, file, new_policies)
            
            if status:
                return False
        
        if child['child'].final and reach>=minimum_reach:
            return True
            




