import random
import copy
import math
from collections import defaultdict

from simple_rl.agents.AgentClass import Agent

class Node:
    def __init__(self, state, parent=None, actions=[]):
        self.state = state
        self.actions = actions  # List of all possible actions at this state
        self.untried_actions = actions.copy()
        self.children = dict()  # action -> child Node
        self.parents = set()
        if parent is not None:
            self.parents.add(parent)
        self.N = 0  # Number of times the node has been visited
        self.Nsa = dict()  # action -> number of times action a has been taken from this state
        self.Qsa = dict()  # action -> estimated value of taking action a from this state

class UMCTS(Agent):
    def __init__(
            self, 
            actions, 
            gamma=0.9,
            num_simulations=1, 
            exploration_constant=1.4,
            max_depth=8, 
            r_max=1.,
            name="UMCTS"):
        super().__init__(name=name, actions=actions)
        self.gamma = gamma  # Discount factor
        self.num_simulations = num_simulations  # Number of MCTS simulations per act
        self.exploration_constant = exploration_constant
        self.max_depth = max_depth  # Maximum depth for simulations
        self.root = None  # Root node of the MCTS tree
        self.node_dict = dict()  # Map from state to Node
        self.U = defaultdict(lambda: defaultdict(float))  # U[state][action] = Q value
        self.r_max = r_max

        self.I = []  # List of upperbound 
        self.trans = []  # List of transitions
        self.tmp_trans = None  # Temporary transition
        self.dis = 1000000

    def re_init(self):
        """
        Re-initialization for multiple instances.
        :return: None
        """
        self.__init__(actions=self.actions, gamma=self.gamma, num_simulations=self.num_simulations, exploration_constant=self.exploration_constant,
                      max_depth=self.max_depth, name=self.name)
        

    def reset(self):
        """
        Reset the attributes to initial state (called between instances).
        :return: None
        """
        self.root = None
        self.node_dict = dict()
        self.U = defaultdict(lambda: defaultdict(float))

    def end_of_episode(self):
        """
        Reset between episodes within the same MDP.
        :return: None
        """
        # If you want to reset the tree between episodes, uncomment the following line
        # self.reset()
        pass  # Keep the tree between episodes if desired

    def act(self, state, reward, env):
        """
        Perform MCTS starting from the current state and return an action.
        :param state: current state
        :param env: environment, used for simulations
        :return: action to take
        """
        # print(self.U)
        # If state not in node_dict, create a new node
        if state not in self.node_dict:
            actions = self.actions  # Assuming all actions are available
            self.root = Node(state, actions=actions)
            self.node_dict[state] = self.root
        else:
            self.root = self.node_dict[state]
        # Perform MCTS simulations
        for _ in range(self.num_simulations):
            self.simulate(self.root, env, depth=0)
        # Select the action with highest estimated value
        action = self.select_action(self.root)
        return action

    def simulate(self, node, env, depth):
        """
        Perform one MCTS simulation starting from the given node.
        :param node: Node to start simulation from
        :param env: Environment to use for simulation (need to create copies)
        :param depth: current depth
        :return: value estimate of this simulation
        """
        if node.state.is_terminal():
            return 0

        if depth >= self.max_depth:
            return 0

        if node.untried_actions:
            # Expansion
            # action = node.untried_actions.pop()
            action = random.choice(node.untried_actions)
            node.untried_actions.remove(action)  

            # Execute action in the environment
            env_copy = copy.deepcopy(env)
            reward, next_state = env_copy.execute_agent_action(action)

            # Create new child node
            if next_state not in self.node_dict:
                actions = self.actions  # Assuming all actions are available
                child_node = Node(next_state, parent=node, actions=actions)
                self.node_dict[next_state] = child_node
            else:
                child_node = self.node_dict[next_state]
                child_node.parents.add(node)

            node.children[action] = child_node
            node.Nsa[action] = 0
            node.Qsa[action] = 0

            # Perform rollout from child node
            value = reward + self.gamma*self.rollout(child_node, env_copy, depth + 1)
        else:
            # Selection
            action = self.select_uct_action(node)
            # Execute action in the environment
            env_copy = copy.deepcopy(env)
            reward, next_state = env_copy.execute_agent_action(action)

            # Get or create child node
            if next_state not in self.node_dict:
                actions = self.actions  # Assuming all actions are available
                child_node = Node(next_state, parent=node, actions=actions)
                self.node_dict[next_state] = child_node
            else:
                child_node = self.node_dict[next_state]
                child_node.parents.add(node)

            node.children[action] = child_node

            # Recursively simulate from the child node
            value = reward + self.gamma*self.simulate(child_node, env_copy, depth + 1)

        # Backpropagation
        node.N += 1
        node.Nsa[action] = node.Nsa.get(action, 0) + 1
        # node.Qsa[action] = node.Qsa.get(action, 0) + (value - node.Qsa.get(action, 0)) / node.Nsa[action]
        old_q = node.Qsa.get(action, 0.0)
        node.Qsa[action] = old_q + (value - old_q) / node.Nsa[action]

        # Update U matrix
        self.U[node.state][action] = node.Qsa[action]

        return value

    def rollout(self, node, env, depth):
        """
        Perform a rollout (simulation) from the given node.
        :param node: Node to start rollout from
        :param env: Environment to use for simulation (need to create copies)
        :param depth: current depth
        :return: value estimate of the rollout
        """
        if node.state.is_terminal() or depth >= self.max_depth:
            return 0

        env_copy = copy.deepcopy(env)
        state = node.state
        total_reward = 0
        current_depth = depth
        discount = 1

        rollout_limit = self.max_depth - depth  # 或者一个固定常数5
        steps = 0
        while not state.is_terminal() and current_depth < self.max_depth and steps < rollout_limit:
            action = random.choice(self.actions)
            reward, state = env_copy.execute_agent_action(action)
            total_reward += reward*discount
            discount *= self.gamma
            current_depth += 1
        return total_reward


    def select_uct_action(self, node):
        """
        Select an action based on UCT formula.
        :param node: Node from which to select action
        :return: action to take
        """
        # UCT formula: UCB1 applied to MCTS
        # action maximizing Q + c * sqrt(log(N) / Nsa)
        c = 1.41  # Exploration constant, can be tuned

        total_N = node.N
        best_value = float('-inf')
        best_action = None
        for action in node.actions:
            Q = node.Qsa.get(action, 0)
            Nsa = node.Nsa.get(action, 0)
            if Nsa == 0:
                uct_value = float('inf')
                uct_value = min(uct_value, Q + (1/(1-self.gamma)) * self.dis )
            else:
                uct_value = Q + self.exploration_constant * math.sqrt(math.log(total_N) / Nsa)
                uct_value = min(uct_value, Q + (1/(1-self.gamma)) * self.dis + ((2*self.r_max)/(1-self.gamma)) * math.sqrt(math.log(2/0.9) / Nsa)) 

            if uct_value > best_value:
                best_value = uct_value
                best_action = action
        return best_action

    def select_action(self, node):
        """
        Select the action with the highest estimated Q value.
        :param node: Node to select action from
        :return: action to take
        """
        best_value = float('-inf')
        best_visits = float('-inf')
        best_action = None
        for action in node.actions:
            Q = node.Qsa.get(action, 0)
            if Q > best_value:
                best_value = Q
                best_action = action
            # visits = node.Nsa.get(action, 0)  # 访问次数
            # if visits > best_visits:
            #     best_visits = visits
            #     best_action = action
        return best_action

    def update_U_matrix(self):
        """
        Update the U matrix with Q values from all nodes.
        """
        for state, node in self.node_dict.items():
            for action in node.Qsa:
                self.U[state][action] = node.Qsa[action]

    def get_U_matrix(self):
        """
        Returns the U matrix.
        :return: U matrix where U[state][action] = Q value
        """
        self.update_U_matrix()
        return self.U
    
    def init_task(self,mdp):
        states = mdp.get_states()
        actions = mdp.get_actions()
        num = len(states)*len(actions)
        transitions = mdp.get_transtions()
        self.tmp_trans = transitions
        self.distances = []
        for trans in self.trans:
            dis = 0
            for state in states:
                for action in actions:
                    if (state, action) in trans:
                        if (state, action) in trans and (state, action) in transitions:
                            for s_p, r, prob in trans[(state, action)]:
                                for s_p_, r_, prob_ in transitions[(state, action)]:
                                    if s_p == s_p_:
                                        dis += abs(r - r_) + ((self.gamma*self.r_max)/(1-self.gamma)) * abs(prob - prob_)


            self.distances.append(dis/num)
        if len(self.distances) != 0:
            self.dis = min(self.distances)
        self.dis = min(self.dis, self.r_max/(1-self.gamma))
        # print(self.dis)
        return self.dis


        

    def set_task(self):
        self.trans.append(self.tmp_trans)
        pass
