import numpy as np
from gymnasium import Env
from gymnasium.spaces import Discrete, MultiDiscrete
import torch
from tree_exp.utils import solve_contract, correct_delta
from config import DTYPE, DEVICE


class Node:
    def __init__(self, rewards, F, costs, children=None):
        self.rewards = rewards
        self.F = F
        self.costs = costs
        self.children = children
        self.n_actions, self.n_outcomes = F.shape

    def step(self, action):
        cost = self.costs[action].item()
        dist = self.F[action]
        outcome = np.random.choice(np.arange(self.n_outcomes), p=dist).item()
        reward = self.rewards[outcome].item()
        new_node = self.children[outcome] if self.children is not None else None
        return cost, outcome, reward, new_node


class TreeContractEnv(Env):
    def __init__(self, root: Node, depth: int, *args, **kwargs):
        super(TreeContractEnv, self).__init__(*args, **kwargs)
        self.root = root
        self.depth = depth
        self.n_actions, self.n_outcomes = root.n_actions, root.n_outcomes
        self.state_space = MultiDiscrete([self.n_outcomes + 1] * self.depth)
        self.action_space = Discrete(self.n_actions)
        self.reset()
        self.root_state, self.state_node_map = self._create_state_node_map(self.root)
        self.n_states = len(self.state_node_map)
        self.seed = None

    def _create_state_node_map(self, node, state=None, state_node_map=None, current_depth=0):
        if state is None:
            state = np.full((self.depth,), -1, dtype=int)
        if state_node_map is None:
            state_node_map = {}
        state_node_map[tuple(state)] = node
        node.state = state
        if node.children is not None:
            for i, child in enumerate(node.children):
                if child is not None:
                    new_state = np.copy(state)
                    new_state[current_depth] = i
                    self._create_state_node_map(child, new_state, state_node_map, current_depth=current_depth + 1)
        return state, state_node_map

    def get_node_info(self, states):
        if isinstance(states, torch.Tensor):
            states = states.detach().cpu().int().numpy()
        if states.ndim == 1:
            states = np.array([states])
        n_states = states.shape[0]
        F = np.zeros((n_states, self.n_actions, self.n_outcomes))
        costs = np.zeros((n_states, self.n_actions))
        rewards = np.zeros((n_states, self.n_outcomes))
        actions = np.zeros((n_states,), dtype=int)
        contracts = np.zeros((n_states, self.n_outcomes))
        utility_p = np.zeros((n_states,))
        utility_a = np.zeros((n_states,))
        for i in range(n_states):
            node = self.state_node_map.get(tuple(states[i]), None)
            if node is not None:
                F[i] = node.F
                costs[i] = node.costs
                rewards[i] = node.rewards
                actions[i] = node.opt_action
                contracts[i] = node.opt_contract
                utility_p[i] = node.utility_principal
                utility_a[i] = node.utility_agent
        F, costs, rewards = torch.from_numpy(F).to(DTYPE).to(DEVICE), torch.from_numpy(costs).to(DTYPE).to(DEVICE), torch.from_numpy(rewards).to(DTYPE).to(DEVICE)
        actions, contracts = torch.from_numpy(actions).to(DTYPE).to(DEVICE), torch.from_numpy(contracts).to(DTYPE).to(DEVICE)
        utility_p, utility_a = torch.from_numpy(utility_p).to(DTYPE).to(DEVICE), torch.from_numpy(utility_a).to(DTYPE).to(DEVICE)
        return F, costs, rewards, actions, contracts, utility_p, utility_a

    def get_root_state(self):
        return torch.from_numpy(self.root_state).unsqueeze(0).to(DTYPE).to(DEVICE)

    def get_all_states(self):
        return torch.from_numpy(np.stack(self.state_node_map.keys())).to(DTYPE).to(DEVICE)

    def step(self, action):
        assert self.current_depth < self.depth and self.current_node is not None, "step is called after a terminal node"
        cost, outcome, reward, new_node = self.current_node.step(action)
        info = {'outcome': outcome, 'cost': cost}

        self.state[self.current_depth] = outcome
        self.current_node = new_node
        self.current_depth += 1

        terminated = True if self.current_node is None else False
        truncated = True if self.current_depth == self.depth else False

        return self.state, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        self.state = np.full((self.depth,), -1)
        self.current_depth, self.current_node = 0, self.root
        return self.state, {}

    def render(self):
        pass

    def close(self):
        pass


def get_random_tree_env(n_actions=2, n_outcomes=2, node_prob=1, max_depth=2, seed=42):
    np.random.seed(seed)
    base_F = np.full((n_actions, n_outcomes), 0.1)
    base_F[np.arange(min(n_actions, n_outcomes)), np.arange(min(n_actions, n_outcomes))] += 1 - 0.1 * n_outcomes

    def _get_node(depth=1):
        F = base_F
        costs = np.random.uniform(0, 1, (n_actions,))
        costs -= costs.min()
        costs = np.sort(costs)[::-1]
        rewards = np.random.uniform(0, 2, (n_outcomes,))
        rewards -= rewards.min()
        rewards = np.sort(rewards)[::-1]

        children = [None] * n_outcomes
        max_depth_so_far = depth
        if depth < max_depth:
            for i in range(n_outcomes):
                if np.random.random() <= node_prob:
                    children[i], child_depth = _get_node(depth + 1)
                    max_depth_so_far = max(max_depth_so_far, child_depth)
        root = Node(rewards, F, costs, children)
        return root, max_depth_so_far

    root, depth = _get_node()
    env = TreeContractEnv(root, depth)
    env.seed = seed
    return env


def solve_tree(root: Node, gamma=1, delta=0):
    root.utility_principal = root.utility_agent = np.NINF
    root.q_values_p = np.zeros((root.n_outcomes,), dtype=np.float32)
    root.q_values_a = np.zeros((root.n_outcomes,), dtype=np.float32)
    root.actions_count = np.zeros((root.n_actions,), dtype=np.float32)

    if root.children is not None:
        for i, child in enumerate(root.children):
            if child is not None:
                solve_tree(child, gamma, delta)
                root.q_values_p[i] = child.utility_principal
                root.q_values_a[i] = child.utility_agent
                root.actions_count += child.actions_count
    expected_rewards = (root.F * np.expand_dims(root.rewards, axis=0)).sum(-1)
    root.q_values_p = gamma * (root.F * np.expand_dims(root.q_values_p, axis=0)).sum(-1) + expected_rewards
    root.q_values_a = gamma * (root.F * np.expand_dims(root.q_values_a, axis=0)).sum(-1) - root.costs

    for action in range(root.n_actions):
        corrected_delta = correct_delta(delta, action, root.q_values_p)
        contract = solve_contract(root.F, root.q_values_a, action, delta=corrected_delta).x
        expected_payment = (root.F[action] * contract).sum()
        utility_principal = -expected_payment + root.q_values_p[action]
        if utility_principal > root.utility_principal:
            root.opt_action = action
            root.opt_contract = contract
            root.utility_principal = utility_principal
            root.utility_agent = expected_payment + root.q_values_a[action]

    root.actions_count[root.opt_action] += 1
