from lbrayuela.automata.finite_automaton import FiniteAutomatonContainer, FiniteAutomatonTransition, WeightedFiniteAutomatonContainer
from lbrayuela.automata.semiring import Semiring

import torch


def mv(p):
    if torch.cuda.is_available():
        return torch.log(torch.tensor([p])).cuda()
    else:
        return torch.log(torch.tensor([p]))


def trans(M, q, a, r):
    arc = FiniteAutomatonTransition(
        q, a, r
    )
    M.add_transition(
        arc
    )
    return arc


def make_parity_automaton(eta=0.05, accept_prob=0.2):
    M = WeightedFiniteAutomatonContainer(num_states=4, alphabet_size=2, semiring=Semiring)
    
    # States
    q0 = 0
    q1 = 1
    q2 = 2
    r0 = 3

    state_map = {
        q0: "q0",
        q1: "q1",
        q2: "q2",
        r0: "r0"
    }
    M.state_map = state_map

    # Symbols
    a = 0
    b = 1

    symbol_map = {
        a: "a",
        b: "b"
    }
    M.symbol_map = symbol_map

    t = trans(M, q0, a, r0)
    M._transitions[t] = eta

    t = trans(M, r0, a, r0)
    M._transitions[t] = 1 - accept_prob
    M.set_accept_weight(r0, accept_prob)

    t = trans(M, q0, b, q1)
    M._transitions[t] = 1 - eta - accept_prob
    M.set_accept_weight(q0, accept_prob)

    t = trans(M, q1, b, q2)
    M._transitions[t] = 1

    t = trans(M, q2, b, q1)
    M._transitions[t] = 1 - accept_prob
    M.set_accept_weight(q2, accept_prob)
    
    return M


def make_starfree_automaton():
    M = WeightedFiniteAutomatonContainer(num_states=2, alphabet_size=2, semiring=Semiring)

    # States
    start = 0

    M.state_map = {
        start: "start",
    }

    # Symbols
    a = 0  # input symbol "0"

    M.symbol_map = {
        a: "0",
    }

    # Transitions
    M._transitions[trans(M, start, a, start)] = mv(0.95)

    # Only accept from even parity state
    M._accept_weights[start] = mv(0.05)

    return M    


def make_parity_automaton():
    M = WeightedFiniteAutomatonContainer(num_states=2, alphabet_size=2, semiring=Semiring)

    # States
    q_even = 0
    q_odd = 1

    M.state_map = {
        q_even: "q_even",
        q_odd: "q_odd"
    }

    # Symbols
    a = 0  # input symbol "0"
    b = 1  # input symbol "1"

    M.symbol_map = {
        a: "0",
        b: "1"
    }

   

    # Transitions
    M._transitions[trans(M, q_even, a, q_even)] = mv(0.33)  # reading 0 doesn't change parity
    M._transitions[trans(M, q_even, b, q_odd)] = mv(0.33)  # reading 1 flips parity
    M._transitions[trans(M, q_odd, a, q_odd)] = mv(0.5)
    M._transitions[trans(M, q_odd, b, q_even)] = mv(0.5)

    # Only accept from even parity state
    M._accept_weights[q_even] = mv(0.34)
    M._accept_weights[q_odd] = mv(0)

    return M


def make_canonical_parity_weights(eta=0.05, accept_prob=0.2):
     # Create an automaton with 5 states and 3 symbols.
    M = WeightedFiniteAutomatonContainer(num_states=5, alphabet_size=3, semiring=Semiring)
    
    # State indices
    q0 = 0  # initial and accepting
    q1 = 1
    q2 = 2  # accepting
    q3 = 3
    r1 = 4  # accepting
    
    state_map = {
        q0: "q0",
        q1: "q1",
        q2: "q2",
        q3: "q3",
        r1: "r1"
    }
    M.state_map = state_map

    # Symbol indices
    a = 0
    b = 1
    c = 2
    
    symbol_map = {
        a: "a",
        b: "b",
        c: "c"
    }
    M.symbol_map = symbol_map

    # Transitions from q0:
    # q0 -- b/(1-eta)/2 --> q1
    t = trans(M, q0, b, q1)
    M._transitions[t] = (1 - eta) / 2

    # q0 -- c/(1-eta)/2 --> q3
    t = trans(M, q0, c, q3)
    M._transitions[t] = (1 - eta) / 2

    # q0 -- a/eta --> r1
    t = trans(M, q0, a, r1)
    M._transitions[t] = eta

    # Transitions from q3:
    # q3 -- b --> q1
    t = trans(M, q3, b, q1)
    M._transitions[t] = 0.5

    # q3 self-loop on c
    t = trans(M, q3, c, q3)
    M._transitions[t] = 0.5

    # Parity branch transitions:
    # q1 -- b --> q2
    t = trans(M, q1, b, q2)
    M._transitions[t] = 0.5

    # q2 -- b --> q1
    t = trans(M, q2, b, q1)
    M._transitions[t] = 0.5

    # q1 self-loop on c
    t = trans(M, q1, c, q1)
    M._transitions[t] = 0.5

    # q2 self-loop on c
    t = trans(M, q2, c, q2)
    M._transitions[t] = 0.5

    # r1 self-loop on a (for the {a}* branch)
    t = trans(M, r1, a, r1)
    M._transitions[t] = 1

    # Set accepting weights (final weights) for the accepting states.
    M.set_accept_weight(q0, accept_prob)
    M.set_accept_weight(q2, accept_prob)
    M.set_accept_weight(r1, accept_prob)
    
    return M


def make_canonical_parity_automaton():
    # Create an automaton with 5 states and 3 symbols
    M = FiniteAutomatonContainer(num_states=5, alphabet_size=3)
    
    # State indices
    q0 = 0  # initial and accepting
    q1 = 1
    q2 = 2  # accepting
    q3 = 3
    r1 = 4  # accepting
    
    state_map = {
        q0: "q0",
        q1: "q1",
        q2: "q2",
        q3: "q3",
        r1: "r1"
    }
    M.state_map = state_map 
    M.state_map_rev = {v: k for k, v in state_map.items()}
    
    # Symbol indices
    a = 0
    b = 1
    c = 2
    
    symbol_map = {
        a: "a",
        b: "b",
        c: "c"
    }
    M.symbol_map = symbol_map
    
    # Transitions from q0
    trans(M, q0, b, q1)
    trans(M, q0, c, q3)
    trans(M, q0, a, r1)
    
    # Transitions from q3
    trans(M, q3, b, q1)
    trans(M, q3, c, q3)
    
    # Parity branch transitions
    trans(M, q1, b, q2)
    trans(M, q2, b, q1)
    trans(M, q1, c, q1)
    trans(M, q2, c, q2)
    
    # r1 branch transition
    trans(M, r1, a, r1)
    
    # Set accepting states
    accept = [q0, q2, r1]
    for state in accept:
        M.add_accept_state(state)
    
    return M


def parity_free_automaton_highacc():
    M = WeightedFiniteAutomatonContainer(num_states=4, alphabet_size=3, semiring=Semiring)

    # States
    start = 0
    # star free state
    free = 1
    # parity States
    q_even = 2
    q_odd = 3

    M.state_map = {
        start: "start",
        free: "free",
        q_even: "even",
        q_odd: "odd"
    }

    # Symbols
    a = 0
    b = 1
    c = 2

    M.symbol_map = {
        a: "0",
        b: "1",
        c: "2"
    }

    acc_prob = 0.5
    half_prob = (1-acc_prob)/2

    # Start transitions
    M._transitions[trans(M, start, a, free)] = mv(0.5)
    M._transitions[trans(M, start, b, q_even)] = mv(0.5) 

    # Parity transitions
    M._transitions[trans(M, q_even, b, q_even)] = mv(half_prob)  
    M._transitions[trans(M, q_even, c, q_odd)] = mv(half_prob)  
    M._transitions[trans(M, q_odd, b, q_odd)] = mv(0.5)
    M._transitions[trans(M, q_odd, c, q_even)] = mv(0.5)

    # star free transition
    M._transitions[trans(M, free, a, free)] = mv(1-acc_prob) 

    # Only accept from even parity state
    M._accept_weights[q_even] = mv(acc_prob)
    # M._accept_weights[q_odd] = mv(0)

    # and from the star free state
    M._accept_weights[free] = mv(acc_prob)

    return M    


def parity_free_automaton():
    M = WeightedFiniteAutomatonContainer(num_states=4, alphabet_size=3, semiring=Semiring)

    # States
    start = 0
    # star free state
    free = 1
    # parity States
    q_even = 2
    q_odd = 3

    M.state_map = {
        start: "start",
        free: "free",
        q_even: "even",
        q_odd: "odd"
    }

    # Symbols
    a = 0
    b = 1
    c = 2

    M.symbol_map = {
        a: "0",
        b: "1",
        c: "2"
    }

    # Start transitions
    M._transitions[trans(M, start, a, free)] = mv(0.5)
    M._transitions[trans(M, start, b, q_even)] = mv(0.5) 

    # Parity transitions
    M._transitions[trans(M, q_even, b, q_even)] = mv(0.45)  
    M._transitions[trans(M, q_even, c, q_odd)] = mv(0.45)  
    M._transitions[trans(M, q_odd, b, q_odd)] = mv(0.5)
    M._transitions[trans(M, q_odd, c, q_even)] = mv(0.5)

    # star free transition
    M._transitions[trans(M, free, a, free)] = mv(0.95) 

    # Only accept from even parity state
    M._accept_weights[q_even] = mv(0.1)
    # M._accept_weights[q_odd] = mv(0)

    # and from the star free state
    M._accept_weights[free] = mv(0.05)

    return M    


def parity_free_only_free():
    M = WeightedFiniteAutomatonContainer(num_states=4, alphabet_size=3, semiring=Semiring)

    # States
    start = 0
    # star free state
    free = 1
    # parity States
    q_even = 2
    q_odd = 3

    M.state_map = {
        start: "start",
        free: "free",
        q_even: "even",
        q_odd: "odd"
    }

    # Symbols
    a = 0
    b = 1
    c = 2

    M.symbol_map = {
        a: "0",
        b: "1",
        c: "2"
    }

    # Start transitions
    M._transitions[trans(M, start, a, free)] = mv(1)

    # star free transition
    M._transitions[trans(M, free, a, free)] = mv(0.95) 

    # and from the star free state
    M._accept_weights[free] = mv(0.05)

    return M    


def parity_free_only_parity():
    M = WeightedFiniteAutomatonContainer(num_states=4, alphabet_size=3, semiring=Semiring)

    # States
    start = 0
    # star free state
    free = 1
    # parity States
    q_even = 2
    q_odd = 3

    M.state_map = {
        start: "start",
        free: "free",
        q_even: "even",
        q_odd: "odd"
    }

    # Symbols
    a = 0
    b = 1
    c = 2

    M.symbol_map = {
        a: "0",
        b: "1",
        c: "2"
    }

    # Start transitions
    M._transitions[trans(M, start, b, q_even)] = mv(1) 

    # Parity transitions
    M._transitions[trans(M, q_even, b, q_even)] = mv(0.45)  
    M._transitions[trans(M, q_even, c, q_odd)] = mv(0.45)  
    M._transitions[trans(M, q_odd, b, q_odd)] = mv(0.5)
    M._transitions[trans(M, q_odd, c, q_even)] = mv(0.5)

    # Only accept from even parity state
    M._accept_weights[q_even] = mv(0.1)
    M._accept_weights[q_odd] = mv(0)

    return M    





def supersimple_automaton():
    M = WeightedFiniteAutomatonContainer(num_states=3, alphabet_size=2, semiring=Semiring)

    # States
    start = 0
    sta = 1
    stb = 2
    
    M.state_map = {
        start: "start",
        sta: "a",
        stb: "b"
    }

    # Symbols
    a = 0
    b = 1

    M.symbol_map = {
        a: "0",
        b: "1",
    }

    mu = 10
    fmu = 1 - 1/(mu+1)

    # Start transitions
    M._transitions[trans(M, start, a, sta)] = mv(0.5 * fmu)
    M._transitions[trans(M, start, b, stb)] = mv(0.5 * fmu) 

    M._transitions[trans(M, sta, a, start)] = mv(1)  
    M._transitions[trans(M, stb, b, start)] = mv(1)  

    # Only accept from even parity state
    M._accept_weights[start] = mv(1 - fmu)
   
    return M    


def simplev2_automaton():
    M = WeightedFiniteAutomatonContainer(num_states=4, alphabet_size=4, semiring=Semiring)

    stone = 0
    sttwo = 1
    stthree = 2
    stfour = 3
        
    M.state_map = {
        stone: "1",
        sttwo: "2",
        stthree: "3",
        stfour: "4"
    }

    # Symbols
    a = 0
    b = 1
    c = 2
    d = 3

    M.symbol_map = {
        a: "0",
        b: "1",
        c: "2",
        d: "3"
    }

    p = 0.9

    M._transitions[trans(M, stone, a, stone)] = mv(0.5 * p)
    M._transitions[trans(M, stone, b, stone)] = mv(0.5 * p)

    M._transitions[trans(M, stone, c, sttwo)] = mv(0.5 * (1 - p))
    M._transitions[trans(M, stone, d, stthree)] = mv(0.5 * (1 - p))

    M._transitions[trans(M, sttwo, a, stfour)] = mv(0.5)
    M._transitions[trans(M, sttwo, b, stfour)] = mv(0.5)

    M._transitions[trans(M, stthree, a, stfour)] = mv(0.5)
    M._transitions[trans(M, stthree, b, stfour)] = mv(0.5) 
    
    M._transitions[trans(M, stfour, a, stfour)] = mv(0.5 * p)
    M._transitions[trans(M, stfour, b, stfour)] = mv(0.5 * p)
    
    M._accept_weights[stfour] = mv(1 - p)

    return M


AUTOMATA_REGISTER = {
    "parity": make_parity_automaton(),
    "starfree": make_starfree_automaton(),
    "canonical_parity": make_canonical_parity_automaton(),
    "parity_free": parity_free_automaton(),
    "parity_free_hp": parity_free_automaton_highacc(),
    "parity_free_only_free": parity_free_only_free(),
    "parity_free_only_parity": parity_free_only_parity(),
    "supersimple": supersimple_automaton(),
    "simple": simplev2_automaton(),
}

def get_arc_str(src, tgt, sym, machine_name):
    M = AUTOMATA_REGISTER[machine_name]
    src = M.state_map[src]
    tgt = M.state_map[tgt]
    sym = M.symbol_map[sym]
    return f"{src}-{sym}->{tgt}"


if __name__ == "__main__":
    M = make_canonical_parity_automaton()
    print(M)
    

    print("States:")
    for state_idx in M.states():
        state = M.state_map[state_idx]
        print(f"State {state_idx}: {state}")

    print("-----------------")
    print("Symbols:")
    for symbol_idx in range(M.alphabet_size()):
        symbol = M.symbol_map[symbol_idx]
        print(f"Symbol {symbol_idx}: {symbol}")

    print("-----------------")
    keys = sorted(M.transitions(), key=lambda t: (t.state_from, t.symbol, t.state_to))
    print("Transitions:")
    for idx, transition in enumerate(keys):
        state_from = M.state_map[transition.state_from]
        state_to = M.state_map[transition.state_to]
        symbol = M.symbol_map[transition.symbol]

        print(f"Transition {idx}: {state_from} --{symbol}--> {state_to}")
    
    breakpoint()