import math, random, itertools, heapq

import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

class problem_base:
    def __init__(self, beta):
        self.beta = beta

    def predecessors(self, z):
        raise NotImplementedError

    def get_forward_probs(self, x, g):
        raise NotImplementedError

    def state_prior(self, x):
        return 1

    def sample_start_state(self):
        raise NotImplementedError

    def done(self, x, g):
        raise NotImplementedError


class classical_problem(problem_base):
    def __init__(self, beta):
        super().__init__(beta)
        self.ias = {}

    def neighbors(self, z):
        raise NotImplementedError

    def heuristic(self, x, end):
        raise NotImplementedError

    def predecessors(self, z):
        return self.neighbors(z)

    def get_forward_probs(self, x, g, a=None):
        if g not in self.ias:
            self.ias[g] = incremental_a_star([g], self)
            next(self.ias[g])
        ias = self.ias[g]

        ys = [y for y, _ in self.neighbors(x)]
        cs = [ias.send(y) for y in ys]
        cs = [math.exp(-self.beta * c) for c in cs]
        cs = [c / sum(cs) for c in cs]
        return ys, cs

    def done(self, x, g):
        return x == g


def incremental_a_star(starts, problem):
    g = {start: 0 for start in starts}
    counter = itertools.count()

    h_cache = {}
    def get_h(x, end):
        if (x, end) not in h_cache:
            h_cache[(x, end)] = problem.heuristic(x, end)
        return h_cache[(x, end)]

    S_open = [(0, next(counter), start) for start in starts]
    dist = None
    while True:
        # Get a new goal from the caller
        end = yield dist
        if end in g:
            dist = g[end]
            continue

        # Recompute f's for new end
        S_open = [(g[x] + get_h(x, end), next(counter), x) for _, _, x in S_open]
        heapq.heapify(S_open)

        while len(S_open) > 0:
            f_x, _, x = heapq.heappop(S_open)
            if x == end:
                dist = g[x]
                heapq.heappush(S_open, (f_x, next(counter), x))
                break
            for y, cost in problem.neighbors(x):
                g_y = g[y] if y in g else float('+inf')
                new = g[x] + cost
                if new < g_y:
                    g[y] = new
                    heapq.heappush(S_open, (g[y] + get_h(y, end), next(counter), y))
        else:
            dist = -1  # unreachable


def compute_bdpt_cache(g, problem, num_samples=10, depth=10):
    cache = {}
    for _ in range(num_samples):
        x_ = problem.sample_start_state()
        pi = [x_]
        weight = 1.
        while True:
            if problem.done(x_, g):
                if x_ not in cache: cache[x_] = []
                cache[x_].append( (pi[:], weight) )
                break

            if random.random() < 1 / depth:
                weight *= 1/depth
                if x_ not in cache: cache[x_] = []
                cache[x_].append( (pi[:], weight) )
                break
            weight *= 1 - 1/depth

            ys, cs = problem.get_forward_probs(x_, g)
            x_ = random.choices(ys, weights=cs)[0]
            pi.append(x_)
    return cache


def sample_likelihood_rejection(
    x, g, problem, num_samples=1, save_paths=False, next_action=None, size_principle=True,
    bdpt_cache=None, depth=None, alpha=None  # ignored, just for compatibility with sample_likelihood
):
    samples = []
    if save_paths:
        paths = []
    for _ in range(num_samples):
        x_ = problem.sample_start_state()
        pi = [x_]
        while True:
            if problem.done(x_, g):
                break
            ys, cs = problem.get_forward_probs(x_, g, next_action if x_ == x else None)
            x_ = random.choices(ys, weights=cs)[0]
            pi.append(x_)

        if x in pi:
            samples.append(pi.count(x) * (1 / len(pi) if size_principle else 1))
        else:
            samples.append(0)
        if save_paths:
            if x in pi:
                i = pi.index(x)
                pi_prev = pi[:i + 1]
                pi_post = pi[i + 1:]
                paths.append((pi_prev, pi_post))
            else:
                paths.append(([], pi))
    if save_paths:
        return samples, paths
    else:
        return sum(samples) / len(samples)


def sample_likelihood(
    x, g, problem, num_samples=1, depth=20, alpha=1.0, save_paths=False,
    next_action=None, bdpt_cache=None, size_principle=True
):
    if bdpt_cache is not None:
        cache_size = sum(len(c) for c in bdpt_cache.values())

    samples = []
    if save_paths:
        paths = []

    for _ in range(num_samples):
        # sample forward trace
        t_next = 0
        x_ = x
        if save_paths:
            pi_s = []
            pi_g = []
        if next_action is not None:
            t_next += 1
            ys, cs = problem.get_forward_probs(x_, g, next_action)
            x_ = random.choices(ys, weights=cs)[0]
            if save_paths:
                pi_g.append(x_)
        while not problem.done(x_, g):
            t_next += 1
            ys, cs = problem.get_forward_probs(x_, g)
            x_ = random.choices(ys, weights=cs)[0]
            if save_paths:
                pi_g.append(x_)

        # sample reverse traces
        t_prev = 0
        x_ = x
        p_pi = 1.

        def bdpt_samples(x_, p_pi, t_prev):
            if bdpt_cache is None or x_ not in bdpt_cache:
                return None

            samples_this_time = []
            for pi, pi_w in bdpt_cache[x_]:
                samples_this_time.append(
                    (1 / pi_w)
                    * (1 / (cache_size / len(bdpt_cache[x_])))
                    * p_pi / ((len(pi) + t_prev + t_next) if size_principle else 1)
                )
                if save_paths:
                    paths.append((pi[:-1] + pi_s[::-1], pi_g))
            return [s / len(samples_this_time) for s in samples_this_time]

        def rr_samples(x_, p_pi, t_prev, q):
            t_prev += 1

            if save_paths:
                pi_s.append(x_)

            if random.random() < q:
                p_pi /= q
                if save_paths:
                    paths.append((pi_s[::-1], pi_g))
                return [problem.state_prior(x_) * p_pi / ((t_prev + t_next) if size_principle else 1)]
            p_pi /= 1 - q

            ns = [n[0] for n in problem.predecessors(x_) if not problem.done(n[0], g)]
            if len(ns) == 0:  # no predecessors found!
                return 0
            ps = []
            for i, n in enumerate(ns):
                ys, cs = problem.get_forward_probs(n, g)
                ps.append(cs[ys.index(x_)])
            ps_ = [math.exp(p * alpha) for p in ps]
            x_i = random.choices(list(range(len(ps))), weights=ps_)[0]
            x_ = ns[x_i]
            p_pi *= ps[x_i] / (ps_[x_i] / sum(ps_))

            return mis_samples(x_, p_pi, t_prev)

        def mis_samples(x_, p_pi, t_prev):
            s_bdpt = bdpt_samples(x_, p_pi, t_prev)
            s_rr_1 = rr_samples(x_, p_pi, t_prev, 1 / depth)

            if s_bdpt is None:
                return s_rr_1
            return [s * 0.5 for s in s_bdpt] + [s * 0.5 for s in s_rr_1]

        samples.extend(mis_samples(x_, p_pi, t_prev))

    if save_paths:
        return samples, paths
    else:
        return sum(samples) / num_samples
