from typing import List, Dict, Any, Optional, Tuple, Iterator
import random
import math
from algorithm.constraint import TupleInSetConstraint
from common.metric import Metrics, MeanMetric
from algorithm.env import Env
import numpy as np
from multiprocessing import Pool
from copy import deepcopy
import torch
import ipdb
class MCTSNode:
    def __init__(self, env: Env, parent=None, nn = None):
        self.env = env
        self.parent = parent
        self.children: List[MCTSNode] = []
        self.visits = 0
        self.rewards = 0
        self.action_info = None
        self.nn = nn
        self.policy = None
        self.value = None
        self.prior_policy: float = 0.0  
        self.success = 0
        
        if self.nn and not self.env.is_terminal():
            state = self.env.get_state_representation()
            self.policy, self.value = self.nn.policy_value_fn(state)
            unassigned = self.env.get_unassigned_vars()
            if unassigned:
                var_indices = [list(self.env.domains.keys()).index(v) for v in unassigned]
                total = sum(self.policy[i] for i in var_indices) if var_indices else 1.0
                self.policy = np.array([self.policy[i] / total if i in var_indices else 0.0 
                                        for i in range(len(self.policy))])
            else:
                self.policy = np.zeros(len(self.env.domains))
        self.untried_actions = self.env.get_actions(self.policy)

    def is_fully_expanded(self) -> bool:
        return len(self.untried_actions) == 0

    def best_child(self, c_param: float = 1.4) -> 'MCTSNode':
        weights = [
            (child.rewards / child.visits) 
            - 1.4 * math.sqrt(math.log(self.visits) / child.visits)
            for child in self.children
        ]
        selected_index = np.argmin(weights)
        return self.children[selected_index]
        
    def expand(self) -> 'MCTSNode':
        new_env = self.env.copy()
        while self.untried_actions:
            var, value = self.untried_actions.pop()
            if new_env.apply_action(var, value):
                break
        else:
            return None
        child = MCTSNode(new_env, parent=self, nn = self.nn)
        child.prior_policy = (self.policy[var] if self.policy is not None 
                             else 1.0 / len(self.untried_actions) if self.untried_actions else 0.0)
        self.children.append(child)
        return child

    def __str__(self) -> str:
        return f"Assignments={self.env.assignments}, Visits={self.visits}, Rewards={self.rewards:.2f}"

class Mirror:
    def __init__(self, env, max_iterations, nn):
        self.mcts = MCTS(env, max_iterations, nn)
        pass

    def run(self):
        return self.mcts.run(mirror=False)

class MCTS:
    def __init__(self, env, max_iterations: int = 10000, nn = None):
        self.max_iterations = max_iterations
        self.verbose = True
        self.node_count = 0
        self.max_depth = 0
        self.nn = nn
        self.env = env


    def run(self, mirror=True) -> Optional[Dict]:
        root = MCTSNode(self.env, nn = self.nn)
        self.node_count = 1
        self.max_depth = 0

        for i in range(self.max_iterations):
            node, path = self.tree_policy(root)
            if not node.env.is_terminal():
                if node.env.done:
                    node.env.done = False
                    continue
                node = node.expand()
                self.node_count += 1
                depth = len(node.env.assignments)
                self.max_depth = max(self.max_depth, depth)

            reward = self.simulate(node)

            if mirror:
                Mirror(node.env.copy(), self.max_iterations / 10, self.nn).run()

            self.backpropagate(node, reward)
            
            if node.env.is_terminal() and node.env.is_valid():
                return node.env.assignments

        data = self.get_training_data(self, path, reward)
        return data

    def get_training_data(self, path, reward):
        training_data = {'state': [], 'action': [], 'reward': [], 'next_state': [], 'done': []}
        
        for i, node in enumerate(path[:-1]):
            if node.env.is_terminal():
                continue
            
            state = node.env.get_state_representation()
            
            next_node = path[i + 1]
            if next_node.env.assignments:
                var = list(next_node.env.assignments.keys())[-1]
                action = list(node.env.domains.keys()).index(var)
            else:
                continue
            
            next_state = next_node.env.get_state_representation()
            done = next_node.env.is_terminal()
            transition_reward = self.compute_y(node)
            
            training_data['state'].append(state)
            training_data['action'].append(action)
            training_data['reward'].append(transition_reward)
            training_data['next_state'].append(next_state)
            training_data['done'].append(done)
        
        last_node = path[-1]
        if not last_node.env.is_terminal() and last_node.children:
            state = last_node.env.get_state_representation()
            best_child = last_node.best_child()
            if best_child.env.assignments:
                var = list(best_child.env.assignments.keys())[-1]
                action = list(last_node.env.domains.keys()).index(var)
                next_state = best_child.env.get_state_representation()
                done = best_child.env.is_terminal()
                transition_reward = self.compute_y(best_child)
                
                training_data['state'].append(state)
                training_data['action'].append(action)
                training_data['reward'].append(transition_reward)
                training_data['next_state'].append(next_state)
                training_data['done'].append(done)
        
        return training_data
        
    def _create_env(self, domains: Dict[Any, Any], constraints: List[Tuple[TupleInSetConstraint, List[Any]]],
                    vconstraints: Dict[Any, List[Tuple[TupleInSetConstraint, List[Any]]]]) -> Env:
        assignments = {}
        return Env(assignments, domains, constraints, vconstraints)

    def tree_policy(self, node: MCTSNode) -> Tuple[MCTSNode, List[MCTSNode]]:
        current = node
        path = [current]
        while not current.env.is_terminal():
            if not current.is_fully_expanded():
                break
            current = current.best_child()
            path.append(current)
        return current, path

    def simulate(self, node):
        env = node.env.copy()
        while len(env.assignments) < len(env.domains):
            var = min(env.get_unassigned_vars(),key=lambda v: (len(env.domains[v]), -sum(1 for c,_ in env.vconstraints[v])))
            values = env.domains[var]
            for value in values:
                if env.apply_action(var, value):
                    break
            else:  
                return 0
        
        return 1 if env.is_valid() else 0  

    def backpropagate(self, node: MCTSNode, reward: float):
        current = node
        while current is not None:
            current.visits += 1
            current.success += reward
            current = current.parent

    def get_best_solution(self, root: MCTSNode) -> Optional[Dict]:
        def traverse(node: MCTSNode) -> Optional[Dict]:
            if node.env.is_terminal() and node.env.is_valid():
                return node.env.assignments
            if not node.children:
                return None
            best_child = max(node.children, key=lambda c: c.rewards / c.visits if c.visits > 0 else 0)
            return traverse(best_child)
        return traverse(root)

    def compute_y(self, node):
        r = node.success / node.visits
        g = node.env.get_gap()
        return 0.5 * g + 0.5 * r

    def count_untried_actions(self, node: MCTSNode) -> int:
        count = len(node.untried_actions)
        for child in node.children:
            count += self.count_untried_actions(child)
        return count