import argparse
import json
from collections import defaultdict
import dill
import numpy as np
from neural_networks.data import load_output_vocabulary_from_file
import pickle
from math import log, exp
import torch


def parse_args():
    parser = argparse.ArgumentParser(description="evaluate model")
    parser.add_argument("--model_logprobs", type=str)
    parser.add_argument("--automaton")
    parser.add_argument("--arcs", type=str)
    parser.add_argument("--vocab_file", type=str)
    parser.add_argument("--debug", type=bool, default=False)
    args = parser.parse_args()
    return args


def get_machine_probs(input, machine):
    data = []
    return data


def load_arcs(path):
    data = []
    with open(path) as infile:
        for line in infile.readlines():
            lprobs = eval(line)
            data.append(lprobs)
    return data


def load_vocab(vocab_file):
    vocab = load_output_vocabulary_from_file(vocab_file)
    return vocab


def load_model_logprobs(path, vocab):
    data = []
    neglogprobs = torch.load(path, weights_only=False)
    for record in neglogprobs:
        probs = np.exp(-record)
        data.append(probs)
    return data


def calculate_weighted_decompositions(arcs, pdfa, modelprobs, vocab):
    # Setup symbol mappings
    symbol_val_to_obj = {int(str(sym)): sym for sym in pdfa.Sigma}
    symbol_dict = {}
    for sym in pdfa.Sigma:
        try:
            symbol_dict[sym] = vocab._first._string_list.index(str(sym)) 
        except:
            # if we use a different voicab that doesnt cover the whole machine
            pass

    eos_token = "<EOS>"
    symbol_dict[eos_token] = vocab.eos_index
    
    # Initialize contribution trackers
    total_kl = 0.0
    state_contributions = defaultdict(float)
    symbol_contributions = defaultdict(float)
    transition_contributions = defaultdict(float)

    pdfa_entropy = pdfa.calc_entropy()  # Get the entropy of the PDFA
    total_tokens = 0
    cross_entropy_sum = 0.0
    
    # Helper function to calculate local KL for a state
    def get_local_kl(pdfa_state, model_probs):
        """Calculate KL divergence for all outgoing transitions from a state"""
        local_kl = 0.0
        local_contributions = {}
        
        # Get all possible transitions from this state
        transitions = pdfa.δ[pdfa_state]
        
        # Calculate KL for regular symbols
        for symbol, next_states in transitions.items():
            # Skip if symbol not in our vocabulary mapping
            if symbol not in symbol_dict:
                continue
                
            # Get PDFA probability (sum across all possible next states)
            pdfa_prob = sum(prob.value for prob in next_states.values())
            
            # Get model probability for this symbol
            token_id = symbol_dict[symbol]
            model_prob = model_probs[token_id].item()
            
            # Skip if either probability is zero
            if pdfa_prob <= 0 or model_prob <= 0:
                continue
                
            # Calculate KL contribution
            kl_contrib = pdfa_prob * log(pdfa_prob / model_prob)
            local_kl += kl_contrib
            
            # Store individual contributions
            sym_str = str(symbol)
            local_contributions[sym_str] = kl_contrib
            
        # Handle EOS token if this state has acceptance probability
        accept_prob = pdfa.ρ[pdfa_state].value
        if accept_prob > 0:
            eos_id = vocab.eos_index
            eos_model_prob = model_probs[eos_id].item()
            
            if eos_model_prob > 0:
                kl_contrib = accept_prob * log(accept_prob / eos_model_prob)
                local_kl += kl_contrib
                local_contributions[eos_token] = kl_contrib
                
        return local_kl, local_contributions
    
    # Calculate total sequence probabilities for each arc
    # these are the prefix weights
    pref_weights = []
    by_state_weights = []

    pathsums = pdfa.forward()
    pathsums = {k: v.value for k, v in pathsums.items()}
    pathsums = {k: v / sum(pathsums.values()) for k, v in pathsums.items()}

    for arc_idx, record in enumerate(arcs):
        # Calculate full sequence probability under PDFA
        seq_prob = [1.0]
        seq_state_weights = [1.0]
        
        # Process each transition to get sequence probability
        for trans_idx, (src, tgt, sym) in enumerate(record):
            pdfa_state = next((q for q in pdfa.Q if q.idx == src), None)
            if pdfa_state is None:
                continue
                
            symbol = symbol_val_to_obj[sym]
            transitions = pdfa.δ[pdfa_state][symbol]
            pdfa_prob = sum(prob.value for prob in transitions.values())
            new_seq_prob = seq_prob[-1] * pdfa_prob
            seq_prob.append(new_seq_prob)
            
            # the local kl weight
            seq_state_weights.append(new_seq_prob / pathsums[pdfa_state])
        
        # Handle EOS
        if record:
            final_state = next((q for q in pdfa.Q if q.idx == record[-1][1]), None)
        else:
            final_state = next((q for q in pdfa.Q if q.idx == 0), None)
            
        if final_state is not None:
            accept_prob = pdfa.ρ[final_state].value
            new_seq_prob = seq_prob[-1] * accept_prob
            seq_prob.append(new_seq_prob)
        
        pref_weights.append(seq_prob)
        by_state_weights.append(seq_state_weights)
    
    # Normalize weights to sum to 1
    # total_weight = sum(sum(sub) for sub in pref_weights)
    # pref_weights = [[w / total_weight for w in sub] for sub in pref_weights]
    
    # Now calculate weighted KL decompositions
    for arc_idx, (record, weights) in enumerate(zip(arcs, by_state_weights)):
        # Process each transition
        for trans_idx, (src, tgt, sym) in enumerate(record):
            pdfa_state = next((q for q in pdfa.Q if q.idx == src), None)
            if pdfa_state is None:
                continue
            
            # Get model probabilities for this step
            try:
                model_probs = modelprobs[arc_idx][trans_idx]
            except:
                breakpoint()

            # Calculate state-level KL using the helper function
            state_kl, local_contributions = get_local_kl(pdfa_state, model_probs)
            
            # Weight the KL and add to state contributions
            weighted_state_kl = weights[trans_idx] * state_kl
            state_contributions[src] += weighted_state_kl
            
            # Add symbol contributions from local_contributions
            for sym_str, contrib in local_contributions.items():
                weighted_contrib = weights[trans_idx] * contrib
                symbol_contributions[sym_str] += weighted_contrib
                total_kl += weighted_contrib
            
            # Handle specific transition contribution
            symbol = symbol_val_to_obj[sym]
            sym_str = str(sym)
            
            # Get transition-specific probability
            transition_prob = 0.0
            transitions = pdfa.δ[pdfa_state][symbol]
            for target_state, prob in transitions.items():
                if target_state.idx == tgt:
                    transition_prob = prob.value
                    break
            
            # Only record transition contribution if it has probability
            if transition_prob > 0 and sym_str in local_contributions:
                transition_key = f"{src}-{sym_str}->{tgt}"
                
                # Scale the symbol contribution by the portion that goes to this specific transition
                symbol_contrib = local_contributions[sym_str]
                total_symbol_prob = sum(prob.value for prob in transitions.values())
                
                if total_symbol_prob > 0:
                    transition_portion = transition_prob / total_symbol_prob
                    transition_contrib = symbol_contrib * transition_portion
                    transition_contributions[transition_key] += weights[trans_idx] * transition_contrib
        
        # Handle the final state (for EOS token)
        if record:
            final_state_idx = record[-1][1]
            final_trans_idx = len(record) - 1
        else:
            final_state_idx = 0
            final_trans_idx = 0
            
        final_state = next((q for q in pdfa.Q if q.idx == final_state_idx), None)
        if final_state is not None and final_trans_idx < len(modelprobs[arc_idx]):
            # Calculate state-level KL for final state
            if record:
                final_model_probs = modelprobs[arc_idx][final_trans_idx + 1]
            else:
                final_model_probs = modelprobs[arc_idx][final_trans_idx] 
            
            state_kl, local_contributions = get_local_kl(final_state, final_model_probs)
            
            # Weight the KL and add to final state contributions
            if final_trans_idx + 1 < len(weights):
                weight = weights[final_trans_idx + 1]
                weighted_state_kl = weight * state_kl
                state_contributions[final_state_idx] += weighted_state_kl
                
                # Add EOS transition specifically
                if eos_token in local_contributions:
                    transition_key = f"{final_state_idx}-{eos_token}->None"
                    transition_contributions[transition_key] += weight * local_contributions[eos_token]
                    symbol_contributions[eos_token] += weight * local_contributions[eos_token]
    
    visit_counts = defaultdict(int)
    visit_counts_symbols = defaultdict(int)
    visit_counts_arcs = defaultdict(int)
    transition_key = "{src}-{sym}->{tgt}"

    for record in arcs:
        for src, tgt, sym in record:
            visit_counts[src] += 1
            symb_str = str(symbol_val_to_obj[sym])
            visit_counts_symbols[symb_str] += 1
            visit_counts_arcs[transition_key.format(src=src, sym=symb_str, tgt=tgt)] += 1
        # add eos
        src = tgt
        sym = eos_token
        tgt = None
        visit_counts_symbols[sym] += 1
        visit_counts_arcs[transition_key.format(src=src, sym=sym, tgt=tgt)] += 1

    total_counts = sum(visit_counts.values())

    mean_state_kl = {
        state: contribution / visit_counts[state]
        for state, contribution in state_contributions.items()
        if visit_counts[state] > 0
    }

    mean_symbol_kl = {
        symbol: contribution / visit_counts_symbols[symbol]
        for symbol, contribution in symbol_contributions.items()
        if visit_counts_symbols[symbol] > 0
    }

    mean_transition_kl = {
        transition: contribution / visit_counts_arcs[transition]
        for transition, contribution in transition_contributions.items()
        if visit_counts_arcs[transition] > 0
    }


    # Empirical KL calculation approach (KL(pdfa||trained))
    kl_sum = 0.0
    total_prob_mass = 0.0

    for arc_idx, record in enumerate(arcs):
        for trans_idx, (src, tgt, sym) in enumerate(record):
            pdfa_state = next((q for q in pdfa.Q if q.idx == src), None)
            if pdfa_state is None:
                continue
                
            symbol = symbol_val_to_obj[sym]
            transitions = pdfa.δ[pdfa_state][symbol]
            pdfa_prob = sum(prob.value for prob in transitions.values())
            
            # Get model probability for this token
            model_probs = modelprobs[arc_idx][trans_idx]
            token_id = symbol_dict[symbol]
            model_prob = model_probs[token_id].item()
            
            # Skip if probabilities are too small
            if pdfa_prob <= 0 or model_prob <= 0:
                continue
                
            # Direct KL contribution: pdfa_prob * log(pdfa_prob / model_prob)
            kl_contribution = pdfa_prob * log(pdfa_prob / model_prob)
            kl_sum += kl_contribution
            
            # Keep track of total probability mass for normalization
            total_prob_mass += pdfa_prob
                
        # Handle EOS token similarly
        if record:
            final_state_idx = record[-1][1]
            final_trans_idx = len(record) - 1
        else:
            final_state_idx = 0
            final_trans_idx = 0
            
        final_state = next((q for q in pdfa.Q if q.idx == final_state_idx), None)
        if final_state is not None and final_trans_idx < len(modelprobs[arc_idx]):
            accept_prob = pdfa.ρ[final_state].value
            
            # Get model probability for EOS
            if record:
                final_model_probs = modelprobs[arc_idx][final_trans_idx + 1]
            else:
                final_model_probs = modelprobs[arc_idx][final_trans_idx]
            
            eos_id = vocab.eos_index
            eos_model_prob = final_model_probs[eos_id].item()
            
            # Add KL contribution for EOS
            if accept_prob > 0 and eos_model_prob > 0:
                kl_contribution = accept_prob * log(accept_prob / eos_model_prob)
                kl_sum += kl_contribution
                total_prob_mass += accept_prob

    # Normalize KL divergence by total probability mass
    if total_prob_mass > 0:
        empirical_kl = kl_sum / total_prob_mass
    else:
        empirical_kl = float('inf')

    # Initialize tracking variables for the new KL calculation
    model_nll_sum = 0.0  # Sum of -log p_model(w)
    pdfa_nll_sum = 0.0   # Sum of -log p_pfsa(w)
    total_token_count = 0  # Sum of |w| + 1 for all strings
    
    # Process each sequence (string)
    for arc_idx, record in enumerate(arcs):
        # Initialize log probabilities for this string
        string_model_logprob = 0.0
        string_pdfa_logprob = 0.0
        string_length = len(record) + 1  # +1 for EOS token
        
        # Process each token in sequence
        for trans_idx, (src, tgt, sym) in enumerate(record):
            pdfa_state = next((q for q in pdfa.Q if q.idx == src), None)
            if pdfa_state is None:
                continue
            
            # Get symbol and PDFA probability
            symbol = symbol_val_to_obj[sym]
            transitions = pdfa.δ[pdfa_state][symbol]
            pdfa_prob = sum(prob.value for prob in transitions.values())
            
            # Get model probability
            model_probs = modelprobs[arc_idx][trans_idx]
            token_id = symbol_dict[symbol]
            model_prob = model_probs[token_id].item()
            
            # Skip if probabilities are invalid
            if pdfa_prob <= 0 or model_prob <= 0:
                continue
                
            # Add to log probabilities for this string
            string_model_logprob += log(model_prob)
            string_pdfa_logprob += log(pdfa_prob)
        
        # Handle EOS token
        if record:
            final_state_idx = record[-1][1]
            final_trans_idx = len(record) - 1
        else:
            final_state_idx = 0
            final_trans_idx = 0
            
        final_state = next((q for q in pdfa.Q if q.idx == final_state_idx), None)
        if final_state is not None and final_trans_idx < len(modelprobs[arc_idx]):
            # Get probabilities
            accept_prob = pdfa.ρ[final_state].value
            
            if record:
                final_model_probs = modelprobs[arc_idx][final_trans_idx + 1]
            else:
                final_model_probs = modelprobs[arc_idx][final_trans_idx]
            
            eos_id = vocab.eos_index
            eos_model_prob = final_model_probs[eos_id].item()
            
            # Skip if probabilities are invalid
            if accept_prob <= 0 or eos_model_prob <= 0:
                continue
                
            # Add EOS probabilities to string probabilities
            string_model_logprob += log(eos_model_prob)
            string_pdfa_logprob += log(accept_prob)
        
        # Add negative log probabilities to sums
        model_nll_sum -= string_model_logprob
        pdfa_nll_sum -= string_pdfa_logprob
        total_token_count += string_length
    
    # Calculate the final metric
    if total_token_count > 0:
        nll_diff_per_token = (model_nll_sum - pdfa_nll_sum) / total_token_count
    else:
        nll_diff_per_token = float('inf')
    
    result = {
        'total_kl': nll_diff_per_token,
        'old_empirical_kl': empirical_kl,
        #'pdfa_entropy': pdfa_entropy,
        #'cross_entropy': cross_entropy,
        'total_tokens': total_tokens,
        'state_contributions': mean_state_kl, #dict(state_contributions),  #
        'symbol_contributions': mean_symbol_kl, #dict(symbol_contributions),
        'transition_contributions': mean_transition_kl #dict(transition_contributions)
    }
    return result


def load_automaton(path):
    return dill.load(open(path, "rb"))


def main(args=None):
    if args is None:
        args = parse_args()

    m_lps=None

    out_folder = "/".join(args.model_logprobs.split("/")[:-1])
    out_file = f"{out_folder}/decomposed_kls.json"
    decomposed_kl_all = {}

    vocab = load_vocab(args.vocab_file)

    model_probs = load_model_logprobs(args.model_logprobs, vocab)

    arcs = load_arcs(args.arcs)
    machine = load_automaton(args.automaton)

    decomp_all = calculate_weighted_decompositions(
        arcs, machine, model_probs, vocab 
    )

    if args.debug:
        print(decomp_all)
        print("symbol_sum", sum(decomp_all["symbol_contributions"].values()))
        print("state_sum", sum(decomp_all["state_contributions"].values()))
        print("transition_sum", sum(decomp_all["transition_contributions"].values()))


    with open(out_file, "w") as out:
        out.writelines(
            json.dumps(decomp_all) + "\n"
        )


if __name__ == "__main__":

    # for testings
    if False:
        args = argparse.Namespace(
            model_logprobs="/dtu/p1/vestsn/wfsa_refactor/intervention_sampling_clean/experiments_mix/models/parity_free_only_free/transformer_9550/eval/token-negative-log-probabilities.pt",
            automaton="/dtu/p1/vestsn/wfsa_refactor/intervention_sampling_clean/experiments_mix/data/datasets/parity_free_only_free/train/9550/machine.pkl",
            arcs="/dtu/p1/vestsn/wfsa_refactor/intervention_sampling_clean/experiments_mix/data/datasets/parity_free_only_free/test/arcs.txt",
            vocab_file="/dtu/p1/vestsn/wfsa_refactor/intervention_sampling_clean/experiments_mix/data/datasets/parity_free_only_free/test/main.vocab",
            debug=True
        )
 
        main(args)
    else:
        main()

    