from collections import defaultdict
from datetime import datetime
import warnings
#warnings.filterwarnings("error")

import numpy as np
from scipy.signal import fftconvolve
from scipy.special import logsumexp
from tqdm import trange

from rayuela.base.semiring import Real
from rayuela.base.state import State
from rayuela.base.symbol import EOS, Sym
from rayuela.fsa.fsa import FSA
from rayuela.fsa.pathsum import Pathsum, Strategy
from rayuela.fsa.sampler import Sampler
from rayuela.occ_semiring.occ_semiring import OccurrenceWeight
from stopit import threading_timeoutable as timeoutable

from .fft_conv import compute_powers


def get_num_occs_old(Z, K, N, k=0, m=0, Z_povs=None):
    power = K - k - 1

    # should_run = True
    # if Z_povs is not None:
    #    if power in Z_povs:
    #        should_run = False

    # if should_run:
    #    if Z_povs is None:
    #        Z_povs = {}

    Z_povs = {}
    Z_pov = OccurrenceWeight.one
    Z_povs[0] = Z_pov
    for pov_idx in range(1, power + 1):
        Z_pov = Z_pov * Z
        Z_povs[pov_idx] = Z_pov

    Z_pov = Z_povs[power]

    probs = np.zeros(N + 1)
    for n in range(N + 1):
        pov_key = N - n - m
        if pov_key < 0:
            probs[n] = 0
            continue
        Z_n = Z.x[1][n]
        Z_pov_n = Z_pov.x[1][pov_key]
        probs[n] = Z_n * Z_pov_n
        # probs[n] = np.exp(np.log(Z_n) + np.log(Z_pov_n))

    n_probs = probs / np.sum(probs)
    try:
        np.random.seed(0)
        num = np.random.choice(len(n_probs), p=n_probs)
    except:
        breakpoint()
        return None, None
    return num, Z_povs


def logsumexp_normalize(log_probs, is_log=True, clip_min=-1e100, clip_max=1e100):
    if not is_log:
        log_probs = np.log(np.clip(log_probs, 1e-100, None))
    
    # Clip the log_probs to avoid numerical issues but be cautious with extreme values
    log_probs_clipped = np.clip(log_probs, clip_min, clip_max)
    log_sum_exp_probs = logsumexp(log_probs_clipped)
    normalized_log_probs = log_probs_clipped - log_sum_exp_probs
    probs = np.exp(normalized_log_probs)
    
    # Ensure the probabilities sum to 1
    probs /= np.sum(probs)
    return probs


def get_num_occs_log(Z, K, N, k=0, m=0, Z_povs=None):
    power = K - k - 1
    assert Z_povs is not None, "Z_povs must be provided"
    log_Z_pov = Z_povs[power]
    log_probs = -np.inf * np.ones(N + 1)

    for n in range(N + 1):
        pov_key = N - n - m
        if pov_key >= 0:
            log_Z_n = np.log(Z.x[1][n]) if Z.x[1][n] > 0 else -np.inf
            log_Z_pov_n = log_Z_pov.x[1][pov_key]
            log_probs[n] = log_Z_n + log_Z_pov_n

    probs = logsumexp_normalize(log_probs)

    try:
        np.random.seed(0)
        num = np.random.choice(len(probs), p=probs)
    except Exception as e:
        print(f"Error during random choice: {e}")
        return None, None

    return num, Z_povs


def get_num_occs(Z, K, N, k=0, m=0, Z_povs=None):
    power = K - k - 1
    assert Z_povs is not None, "Z_povs must be provided"
    Z_pov = Z_povs[power]
    probs = np.zeros(N + 1)

    for n in range(N + 1):
        pov_key = N - n - m
        if pov_key >= 0:
            # Should we normalize?
            Z_norm = logsumexp_normalize(Z.x[1], is_log=False)
            #Z_n = Z.x[1][n]
            Z_n = Z_norm[n]
            Z_pov_n = Z_pov.x[1][pov_key]
            # Should we normalize??
            Z_pov_n = logsumexp_normalize(Z_pov_n, is_log=False)
            probs[n] = Z_n * Z_pov_n

    #try:
    #    assert abs(1-sum(probs)) < 0.1, f"Probabilities must sum to near 1, got {sum(probs)}"
    #except:
    #    breakpoint()
    # Normalize the probabilities, can happen that they are close to but not 1
    probs_norm = logsumexp_normalize(probs, is_log=False)

    try:
        num = np.random.choice(len(probs_norm), p=probs_norm)
    except Exception as e:
        breakpoint()
        print(f"Error during random choice: {e}")
        return None, None

    return num, Z_povs


def get_weight_prob(weight, symbol, target, beta_q, n, N):
    beta_qa = weight * beta_q
    
    idx = N - n

    if idx < 0:
        breakpoint

    if symbol.value == target:
        idx = N - n - 1
    beta_qa_j = beta_qa.x[1][idx]

    new_weight = weight.zero.copy()
    new_weight.x[0] = beta_qa_j

    if symbol.value == target:
        new_weight.x[1][1] = beta_qa_j
    else:
        new_weight.x[1][0] = beta_qa_j

    return new_weight


def get_weight_prob_alt(weight, symbol, target, beta_q, n, N):
    idx = N - n
    if symbol == target:
        idx = N - n - 1

    probabilities = np.zeros(N + 1)
    new_weight = weight.zero.copy()
    beta_qa_j = weight.x[0] * beta_q.x[1][idx]

    new_weight.x[0] = beta_qa_j
    if symbol.value == target:
        new_weight.x[1][1] = beta_qa_j
    else:
        new_weight.x[1][0] = beta_qa_j

    return new_weight


def lift_for_occ_sampling(machine, n, N, target_symbol, pathsums, mutate=False):
    if mutate:
        A = machine
    else:
        A = FSA(OccurrenceWeight)
        for q, w in machine.I:
            # inital
            if not mutate:
                A.set_I(q, w.copy())
        for q, w in machine.F:
            # final
            A.set_F(q, w.copy())

    for q in machine.Q:
        for a, j, w in machine.arcs(q):
            beta_q = pathsums[j]
            weight = get_weight_prob(w, a, target_symbol, beta_q, n, N)
            if not mutate:
                A.add_arc(q, a, j, weight)
            else:
                A.set_arc(q, a, j, weight)

    # Normalize locally
    for q in machine.Q:
        # q_sum = 0
        values = [w.x[0] for a, j, w in machine.arcs(q)]
        if len(values) == 0:
            continue
        # for a, j, w in machine.arcs(q):
        #    q_sum += w.value
        probs = logsumexp_normalize(values, is_log=False)

        for idx, (a, j, w) in enumerate(machine.arcs(q)):
            # new_weight = w / q_sum
            new_prob = probs[idx]
            new_weight = OccurrenceWeight.zero.copy()
            new_weight.x[0] = new_prob
            if a.value == target_symbol:
                new_weight.x[1][1] = new_prob
            else:
                new_weight.x[1][0] = new_prob

            if not mutate:
                A.add_arc(q, a, j, new_weight)
            else:
                A.set_arc(q, a, j, new_weight)
    return A


class OccSampler(Sampler):
    seen_symbols = None

    def __init__(self, fsa, T=1, seed=None):
        self.T = T
        self.A = fsa.push() if not fsa.probabilistic else fsa
        self.rng = np.random.default_rng(seed)
        self.seen_symbols = 0
        self.A = self.A.copy()
        self.A_orig = self.A.copy()

    def sample(
        self,
        K: int = 1,
        to_string: bool = True,
        sep: str = " ",
        transform=lambda _, p: p,
        lm: bool = False,
        Z=None,
        N=None,
        beta=None,
        tgt_symbol=None,
    ):
        result = []
        # Todo: check if we can do without clipping
        #log_Z = np.log(np.clip(Z.x[1])) #, a_min=1e-10, a_max=None))
        #z_povs = compute_powers(log_Z, K + 1)
        z_povs = compute_powers(Z, K + 1)
        for i in trange(K):
            start = datetime.now()
            num_occ, z_povs = get_num_occs(
                Z, K, N, k=i, m=self.seen_symbols, Z_povs=z_povs
            )

            end = datetime.now()
            #print(f"Time to get num occs: {end - start}")

            if num_occ is None:
                # can't get num_occ
                yield None
            start = datetime.now()
            samp = self._ancestral(to_string, sep, num_occ, tgt_symbol, N, beta, timeout=300)
            if samp == "timeout":
                yield None, num_occ
            end = datetime.now()
            #print(f"Time to get ancestral: {end - start}")
            yield samp, num_occ

    def _draw(self, options):
        p = np.asarray(list(w for w in options.values()))
        return list(options.keys())[self.rng.choice(len(p), p=p)]

    @timeoutable(default="timeout")
    def _ancestral(
        self,
        to_string: bool,
        sep: str = " ",
        num_occ=None,
        tgt_symbol=None,
        N=None,
        beta=None,
    ):
        y = []
        q = self._draw({p: w.value for p, w in self.A.I})

        states = []
        perplexity = 0

        while q != 0:
            options = {(a, qʼ): w for a, qʼ, w in self.A.arcs(q)}
            num_sym = len([x for x in y if x == tgt_symbol])

            if num_sym >= num_occ:
                new_options = {}
                for option in options:
                    a, q_p = option
                    if a.value == tgt_symbol:
                        new_options[option] = self.A.R.zero
                    else:
                        new_options[option] = options[option]
                options = new_options

            if self.A.ρ[q] != self.A.R.zero and num_sym >= num_occ:
                options[(EOS, 0)] = self.A.ρ[q]  # Signals the end of generation

            # if self.A.ρ[q] != self.A.R.zero and num_sym >= num_occ:
            #    options[(EOS, 0)] = 1000 * self.A.R.one

            # We need to renormalize since we don't allow EOS until num_occ is reached
            prob_sum = sum([x.x[0] for x in options.values()])
            if prob_sum == 0:
                # can happen because we null probs for target after num_occ
                q = 0
                continue

            # new_probs = {}
            # for option, w in options.items():
            #    new_probs[option] = w.x[0] / prob_sum
            option_keys = []
            vvalues = []
            for k, v in options.items():
                option_keys.append(k)
                vvalues.append(v.x[0])

            new_vals = logsumexp_normalize(vvalues, is_log=False)
            new_probs = {option_keys[i]: new_vals[i] for i in range(len(option_keys))}

            (a, q) = self._draw(new_probs)
            if a.value == tgt_symbol:
                self.seen_symbols += 1

            if a != EOS:
                y.append(str(a))

            # beta = Pathsum(self.A).backward(Strategy.LEHMANN)
            if a.value == tgt_symbol:
                # We just passed the target symbol and need to update the probabilities
                self.A = lift_for_occ_sampling(
                    self.A_orig.copy(), self.seen_symbols, N, tgt_symbol, beta
                )

        if to_string:
            try:
                y = sep.join(y)
            except:
                breakpoint()

        num_sym = sum(1 for c in y.split() if c == tgt_symbol)

        if not num_sym == num_occ:
            print(
                f"Not correct number of symbol, saw {num_sym} instead of {num_occ}, tgt {tgt_symbol}, sampled {y}. Likely due to shape of machine."
            )
            # breakpoint()
        return y
