import json
import random as pyrandom
from collections import defaultdict
from datetime import datetime
import string
import numpy as np
from stopit import threading_timeoutable as timeoutable
from tqdm import tqdm

from rayuela.base.semiring import Real
from rayuela.base.state import State
from rayuela.base.symbol import Sym
from rayuela.fsa import random, sampler
from rayuela.fsa.fsa import FSA
from rayuela.fsa.pathsum import Pathsum, Strategy
from rayuela.occ_semiring.arc_interventions import get_arcs, sample_from_machine_with_arcs_interventions
from rayuela.occ_semiring.fft_conv import compute_powers
from rayuela.occ_semiring.occ_semiring import OccurrenceWeight
from rayuela.occ_semiring.sampler import OccSampler, get_num_occs, get_num_occs_old, lift_for_occ_sampling


def lift(machine, R, tgt_symbol, N):
    def lifter(w, symbol=None):
        new_w = OccurrenceWeight.lift_weight(tgt_symbol, w, N, symbol=symbol)
        return new_w

    A = FSA(R)
    for q, w in machine.I:
        # inital
        A.set_I(q, lifter(w))
    for q, w in machine.F:
        # final
        A.set_F(q, lifter(w))
    for q in machine.Q:
        # states
        for a, j, w in machine.arcs(q):
            A.add_arc(q, a, j, lifter(w, symbol=a))
    return A


def sample_tgt(A):
    if len(A.Sigma) == 0:
        return None
    return pyrandom.sample(list(A.Sigma), 1)[0].value


def get_random_machine(
    tgt_symbol,
    Sigma="abc",
    N=10,
    num_states=4,
    seed=422,
    lift_machine=True,
    min_size=1,
    A=None,
    loop_while_none=True,
    force_number_of_states=False,
):
    if A is None:
        A = random.random_pfsa(Sigma=Sigma, num_states=num_states, seed=seed)
        #A = A.trim()
        #A = A.normalize()

        if force_number_of_states and len(A.Q) != num_states:
            return get_random_machine(
                tgt_symbol,
                Sigma,
                N,
                num_states,
                seed + 1,
                lift_machine,
                min_size,
                A,
                loop_while_none,
            )

        if A is None and loop_while_none:
            print("Machine is not suitable,", seed+1)
            return get_random_machine(
                tgt_symbol,
                Sigma,
                N,
                num_states,
                seed + 1,
                lift_machine,
                min_size,
                A,
                loop_while_none,
            )

        if len(A.Q) < min_size:
            return None, None, None, None, None

        # Sample from machine
        # A_sampler = sampler.Sampler(A, seed=0)
        # A_sample = A_sampler.sample(5)

        if not lift_machine:
            return A

    OccurrenceWeight.zero = OccurrenceWeight.get_zero(N)
    OccurrenceWeight.one = OccurrenceWeight.get_one(N)

    if tgt_symbol is None:
        tgt_symbol = sample_tgt(A)
        if tgt_symbol is None:
            return None, None, None, None, None

    Al = lift(A, OccurrenceWeight, tgt_symbol, N)
    
    start = datetime.now()
    print("Starting bw-calc .. : ", start)
    W_lehm = Pathsum(Al).lehmann()
    
    #W_ps = Pathsum(Al).allpairs_pathsum(W_lehm)
    #W_s = Pathsum(Al).backward(Strategy.LEHMANN)

    W = Pathsum(Al).lehmann_bw_scc()
    start_state = list(Al.λ.keys())[0]
    Z = W[start_state]
    end = datetime.now()
    print(f"-- It took {end-start} to calc bw probs w. lehmann")

    return Al, Z, W, start_state, tgt_symbol


def construct_pdfa(num_states, seed, num_symbols, force_number_of_states=False):
    """
    Constructs a random PDFA with the given parametersß
    """
    expected_length_valid = False
    counter = 1
    apbt = string.ascii_lowercase
    sigma = apbt[: num_symbols]

    pdfa = get_random_machine(
        tgt_symbol=None,
        Sigma=sigma,
        N=None,
        num_states=num_states,
        seed=seed * counter,
        lift_machine=False,
    )
    
    assert pdfa is not None

    return pdfa


@timeoutable(default=None)
def _construct_large_scc_pdfa(num_states, seed, num_symbols, scc_size=10, debug=False):
    counter = 1
    apbt = string.ascii_lowercase
    sigma = apbt[: num_symbols]

    pdfas = []
    interim_pdfas = []
    total_states = 0

    while total_states < num_states:
        # sample size of scc from gaussian with mean scc_size and variance scc_size / 2
        samp_size = 0
        while samp_size < 5:
            samp_size = int(np.random.normal(scc_size, 0)) #scc_size / 2))
        pdfa = construct_pdfa(samp_size, seed + total_states + counter, num_symbols)
        pdfas.append(pdfa)
        total_states += len(pdfa.Q)
        print(f"Constructed {len(pdfas)} pdfas with {total_states} states")

    print(f"Done: Constructed {len(pdfas)} pdfas with {total_states} states")

    new_pdfa = pdfas[0]
    while len(new_pdfa.Q) == 0:
        pdfas = pdfas[1:]
        new_pdfa = pdfas[0]

    # now we connect the PDFAs to ensure they are SCC
    for idx, pdfa in enumerate(pdfas):
        interim_pdfas.append(new_pdfa.copy())

        # if first pdfa, do nothing
        if idx == 0:
            continue

        def nS(state):
            return State(state.idx + seen_states + 1)

        # max seen idx
        # THE IDXs MAY NEED TO BE CONTIGUOUS!
        seen_states = max([s.idx for s in new_pdfa.Q])
        print(f"Seen states: {seen_states}") 

        # connect last state of pdfa to first state of next pdfa
        start_state = list(pdfa.λ.keys())[0]
        symbols = list(pdfa.Sigma.union(new_pdfa.Sigma))
        if not symbols:
            continue
        random_symbol = Sym(pyrandom.choice(symbols))

        new_start_state = nS(start_state)
        end_weight = list(new_pdfa.ρ.values())[0]
        end_state = list(new_pdfa.ρ.keys())[0]

        print(f"Mapping end to start: Connecting {end_state} to {new_start_state} with {random_symbol} and {end_weight}")
        # map the end state of the current pdfa to the start state of the next pdfa
        new_pdfa.add_arc(end_state, random_symbol, new_start_state, end_weight)
        
        for arc in get_arcs(pdfa):
            # populate new machine with arcs from the current pdfa
            i, a, j, w = arc
            new_state = nS(i)
            new_target_state = nS(j)
            new_pdfa.add_arc(new_state, a, new_target_state, w)
            
    
        # TODO: add more connections from earlier pdfa to the new pdfa

        new_end_state = list(pdfa.ρ.keys())[0]
        new_end_state = nS(new_end_state)
        # make the final state of the current pdfa the final state of the new pdfa
        end_prob = pdfa.ρ[list(pdfa.ρ.keys())[0]]
        new_pdfa.ρ = defaultdict(lambda: pdfa.R.zero)
        new_pdfa.ρ[new_end_state] = end_prob
        print(f"New pdfa with {len(new_pdfa.Q)} states")
        assert len(new_pdfa.Q) != 0

    if debug:
        return new_pdfa, pdfas, interim_pdfas

    return new_pdfa


def construct_large_scc_pdfa(num_states, seed, num_symbols, scc_size=10):
    A = None
    counter = 0
    while A is None:
        A = _construct_large_scc_pdfa(num_states, seed + counter, num_symbols, scc_size=scc_size)
        counter += 1
        print(f"Counter: {counter}")
    return A


@timeoutable(default="timeout")
def sample_occurrence_strings(A, Z, K, N, tgt_symbol, W):
    Al = lift_for_occ_sampling(A, 0, N, tgt_symbol, W)
    A_sampler = OccSampler(Al, seed=0)
    A_sample = A_sampler.sample(K, Z=Z, N=N, beta=W, tgt_symbol=tgt_symbol)
    counts = 0
    samps = []
    for samp in A_sample:
        if samp is None:
            return samps
        samps.append(samp)
        for l in samp:
            if l == tgt_symbol:
                counts += 1
    return samps


def check_if_tgt_symbol_in_machine(A, tgt):
    ss = [s.value for s in A.Sigma]
    if tgt not in ss:
        return False
    return True


