import heapq
import pickle
from typing import List, Tuple, Optional

import networkx as nx
import torch

from core.data import AbstractState, Operator
import numpy as np


class AbstractMDP:

    def __init__(self, states: List[AbstractState], action_names: List[str], rounding_rule: float = 0.0):
        self.states = states
        self.transitions = dict()
        self.adjacency_graph = dict()
        self._action_names = action_names
        self.actions = set()
        self.rounding_rule = rounding_rule
        self.q_func = None
        self.sample_prob = None
        self.reset_visit_counts()

    def add_transition(self, state: AbstractState, action: int, next_state: AbstractState, prob: float, reward: float,
                       step: float):

        if prob > 1 - self.rounding_rule:
            prob = 1
        elif prob < self.rounding_rule:
            prob = 0
        if prob > 0:
            self.transitions[(state, action, next_state)] = prob, reward, step
            self.actions.add(action)

            if state not in self.adjacency_graph:
                self.adjacency_graph[state] = {}

            if action not in self.adjacency_graph[state]:
                self.adjacency_graph[state][action] = []

            self.adjacency_graph[state][action].append((next_state, prob, reward, step))

    def print_states(self, filename):
        avg_rewards = dict()
        for (s, a, s_prime), (prob, reward, step) in self.transitions.items():
            if s not in avg_rewards:
                avg_rewards[s] = 0
            avg_rewards[s] += reward

        with open(filename, "w") as f:
            for state in self.states:
                factor_str = " ".join([str(factor) for factor in state.factors])
                if state in avg_rewards:
                    reward = avg_rewards[state]
                else:
                    reward = 0
                f.write(f"State={state}, reward={reward}\nfactors: {factor_str}\n")

    def reset_visit_counts(self):
        self.visit_counts = {s: {} for s in self.states}
        for s in self.states:
            for i, c in enumerate(s.action_counts):
                if s.initiation[i] == 1:
                    self.visit_counts[s][i] = max(int(c), 1)
                else:
                    self.visit_counts[s][i] = 0

    def increment_visit_count(self, state, action):
        assert self.visit_counts[state][action] != 0
        self.visit_counts[state][action] += 1

    def traverse_graph(self, abstract_state, deterministic=False) -> dict[AbstractState, tuple[float, list]]:
        if abstract_state not in self.adjacency_graph:
            return {}

        queue = []
        heapq.heappush(queue, (0, abstract_state, None, None))
        parent = {}
        negloglikelihoods = {}
        while len(queue) > 0:
            total_nll, state, parent_state, parent_action = heapq.heappop(queue)
            if state in negloglikelihoods:
                continue

            negloglikelihoods[state] = total_nll
            parent[state] = (parent_state, parent_action)

            if state in self.adjacency_graph:
                for action in self.adjacency_graph[state]:
                    max_prob = 0
                    max_likely_state = None
                    for outcome in self.adjacency_graph[state][action]:
                        next_state, prob, _, _ = outcome
                        if deterministic:
                            if prob > max_prob:
                                max_prob = prob
                                max_likely_state = next_state
                        else:
                            nll_child = -np.log(prob) + total_nll

                            if next_state in negloglikelihoods:
                                continue
                            else:
                                heapq.heappush(queue, (nll_child, next_state, state, action))
                    if deterministic and max_likely_state is not None:
                        nll_child = -np.log(max_prob) + total_nll
                        if max_likely_state not in negloglikelihoods:
                            heapq.heappush(queue, (nll_child, max_likely_state, state, action))

        paths = {}
        for state, nll in negloglikelihoods.items():
            paths[state] = []
            parent_state, parent_action = parent[state]
            while parent_state is not None:
                paths[state].insert(0, (parent_state, parent_action))
                parent_state, parent_action = parent[parent_state]
            paths[state] = (nll, paths[state])
        return paths

    def __str__(self):
        rep = list()

        for state in self.states:
            rep.append("State at {}".format(state))
        rep.append("")
        for (s, a, s_prime), (prob, reward, step) in self.transitions.items():
            rep.append("Executing {} at {} leads to {} w.p {} and r={}, t={}".format(
                self._action_names[a],
                s,
                s_prime,
                round(prob, 2),
                round(reward, 2),
                round(step, 2)
            ))
        return '\n'.join(rep)

    def render(self):
        G = nx.DiGraph()
        for i, state in enumerate(self.states):
            G.add_node(str(state))

        positions = dict()

        for (s, a, s_prime), (prob, reward, step) in self.transitions.items():
            G.add_edge(str(s), str(s_prime),
                       label="{}, p={}".format(self._action_names[a], round(prob, 2)))

            if s not in positions and s.info:
                positions[str(s)] = np.mean(np.array([(x["position"][0],
                                                       -x["position"][2]) for x in s.info]), axis=0)
            if s_prime not in positions and s_prime.info:
                positions[str(s_prime)] = np.mean(np.array([(x["position"][0],
                                                             -x["position"][2]) for x in s_prime.info]), axis=0)

        if len(positions) == 0:
            positions = nx.spring_layout(G)
        nx.draw(
            G,
            positions,
            with_labels=True,
            font_size=8,
            # node_size=1000,
            node_color="lightblue",
            edge_color="black"
        )

        edge_labels = nx.get_edge_attributes(G, "label")
        nx.draw_networkx_edge_labels(G, positions, edge_labels=edge_labels)

        # Draw nodes
        import matplotlib.pyplot as plt
        plt.show()

    def save(self, path):
        with open(path, "wb") as f:
            pickle.dump(self, f)

    def to_pddl(self, probabilistic=False, include_reward=False):
        predicates = set()
        operators = []
        transitions = {}
        for (s, a, s_prime), (prob, reward, _) in self.transitions.items():
            for p in s.factors:
                p_str = str(p)
                if p_str not in predicates:
                    predicates.add(f"({p_str})")
            for p in s_prime.factors:
                p_str = str(p)
                if p_str not in predicates:
                    predicates.add(f"({p_str})")

            if (s, a) not in transitions:
                transitions[(s, a)] = []
            transitions[(s, a)].append((s_prime, prob, reward))
        for i, (s, a) in enumerate(transitions):
            effects = []
            probs = []
            rewards = []
            for s_prime, prob, reward in transitions[(s, a)]:
                effects.append(s_prime)
                probs.append(prob)
                rewards.append(reward)
            if not probabilistic:
                idx = np.argmax(probs)
                effects = [effects[idx]]
                probs = None
            if not include_reward:
                rewards = None
            op = Operator(self._action_names[a]+f"_{i}", s, effects, probs, rewards)
            operators.append(op)

        pddl = "(define (domain abstract_mdp)\n"
        pddl += "    (:requirements :strips"
        if probabilistic:
            pddl += " :probabilistic-effects"
        if include_reward:
            pddl += " :rewards"
        pddl += ")\n"
        pddl += "\n    (:predicates\n        "
        pddl += "\n        ".join(predicates) + ")"
        for op in operators:
            pddl += "\n\n" + op.to_pddl()
        pddl += "\n)"
        return pddl

    def compute_q_values(self, gamma=0.99, epsilon=1e-6, count_based=False, count_and_reward=False, use_previous=False):
        if use_previous:
            q_func = self.q_func
        else:
            q_func = dict()
            for state in self.states:
                q_func[state] = np.zeros(len(self._action_names))

        sas_dict, is_terminal = self._get_sas_dict()
        it = 0
        while True:
            delta = 0
            for state in self.states:
                q_values = np.zeros(len(self._action_names))
                for action in self.actions:
                    if state.initiation[action]:
                        q_values[action] = self._compute_q_value(state, action, q_func, sas_dict,
                                                                 is_terminal, gamma, count_based,
                                                                 count_and_reward=count_and_reward)
                delta = max(delta, np.abs(q_func[state] - q_values).max())
                q_func[state] = q_values
            if delta < epsilon:
                break
            it += 1
        self.q_func = q_func
        self.v_func = {state: np.max(q_values) for state, q_values in q_func.items()}

    def _compute_q_value(self, state, action, q_func, sas_dict,
                         is_terminal, gamma, count_based=False,
                         count_and_reward=False):
        q_value = 0
        if is_terminal[state]:
            return q_value

        if (state, action) not in sas_dict:
            return q_value

        for s_prime, prob, r, _ in sas_dict[(state, action)]:
            if count_based:
                reward = 1 / np.sqrt(self.visit_counts[state][action])
            elif count_and_reward:
                reward = r + 1 / np.sqrt(self.visit_counts[state][action])
            else:
                reward = r

            if is_terminal[s_prime]:
                target = reward
            else:
                # target = reward + (np.power(gamma, steps)) * np.max(q_func[s_prime])
                target = reward + gamma * np.max(q_func[s_prime])
            q_value += prob * target
        return q_value

    def _get_sas_dict(self):
        sas_dict = dict()
        is_terminal = dict()
        for state in self.states:
            is_terminal[state] = True
        for (s, a, s_prime), (prob, reward, steps) in self.transitions.items():
            if (s, a) not in sas_dict:
                sas_dict[(s, a)] = []
            sas_dict[(s, a)].append((s_prime, prob, reward, steps))
            is_terminal[s] = False
        return sas_dict, is_terminal

    def get_factor_groundings(self, init_vec: Optional[Tuple[int, ...]] = None, max_sample: int = 100):
        factors = {}
        for s in self.states:
            if init_vec is not None and s.initiation != init_vec:
                continue

            for f in s.factors:
                if f.factor not in factors:
                    factors[f.factor] = {}

                if f not in factors[f.factor]:
                    factors[f.factor][f] = []

                n = s.data.shape[0]
                n_sample = min(n, max_sample)
                r = torch.randperm(n)[:n_sample]
                factors[f.factor][f].append(s.data[..., f.factor.variables][r])
        for f in factors:
            for fval in factors[f]:
                factors[f][fval] = torch.cat(factors[f][fval])
        return factors

    def get_factor_marginals(self, init_vec: Optional[Tuple[int, ...]] = None, max_sample: int = 100):
        factor_groundings = self.get_factor_groundings(init_vec=init_vec, max_sample=max_sample)
        factor_marginals = {}
        for f in factor_groundings:
            factor_marginals[f] = {}
            N = sum([len(factor_groundings[f][fval]) for fval in factor_groundings[f]])
            logN = torch.log(torch.tensor(N, dtype=torch.float32))
            for fval in factor_groundings[f]:
                factor_marginals[f][fval] = torch.log(torch.tensor(factor_groundings[f][fval].shape[0])) - logN
        return factor_marginals

    def get_grounding_prob(self, x, init_vec: Optional[Tuple[int, ...]] = None):
        if x.ndim == 1:
            x = x.unsqueeze(0)
        factor_groundings = self.get_factor_groundings(init_vec=init_vec)
        factor_marginals = {}
        factor_likelihoods = {}
        for f in factor_groundings:
            factor_marginals[f] = {}
            factor_likelihoods[f] = {}
            N = sum([len(factor_groundings[f][fval]) for fval in factor_groundings[f]])
            logN = torch.log(torch.tensor(N, dtype=torch.float32))
            for fval in factor_groundings[f]:
                factor_marginals[f][fval] = torch.log(torch.tensor(factor_groundings[f][fval].shape[0])) - logN
                k = min(10, factor_groundings[f][fval].shape[0])
                factor_likelihoods[f][fval] = (-torch.cdist(x[..., f.variables], factor_groundings[f][fval])).topk(k=k, dim=-1)[0].mean(dim=-1)
            Z = torch.logsumexp(torch.cat([x.clone() for x in factor_likelihoods[f].values()]), 0)
            for fval in factor_likelihoods[f]:
                factor_likelihoods[f][fval] -= Z
        factor_posteriors = {}
        for f in factor_groundings:
            factor_posteriors[f] = {}
            for fval in factor_groundings[f]:
                # factor_posteriors[f][fval] = factor_marginals[f][fval] + factor_likelihoods[f][fval]
                factor_posteriors[f][fval] = factor_likelihoods[f][fval]

        state_logl = torch.zeros(len(self.states))
        all_zero = True
        for i, s in enumerate(self.states):
            if init_vec is not None and s.initiation != init_vec:
                state_logl[i] = -1e12
                continue
            all_zero = False

            for f in s.factors:
                state_logl[i] += factor_posteriors[f.factor][f][0]

        if all_zero:
            return None

        state_logl -= torch.logsumexp(state_logl, 0)
        return state_logl.exp()
