import argparse
import os
import dill
import numpy as np
import pandas as pd
import torch
import tqdm
from glob import glob
from collections import defaultdict
from math import log, exp
import copy
from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer, GPT2Tokenizer
import wandb
from stopit import threading_timeoutable as timeoutable

# Rayuela / training imports
from lrayuela.fsa.fsa import State

from lrayuela.base.semiring import Entropy
from training.train_rnn import train_rnn, get_logits
from training.train_toyformer import train_transformer

# Counting sampling functions
from counting_sampling import sample_symbol, to_rayuela, sample_arc, sample_state, sample_vanilla

# Constants
RANKS = [1, 2, 4, 6, 8, 10, 12, 16]
OUTPUT_DIMS = [2, 4, 6, 8, 10, 12, 16]
MAX_LENGTH = 512
MAX_EXPECTED_LENGTH = 85

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")


def get_arcs(A):
    arcs = []
    for state in A.Q:
        for symbol in A.Sigma:
            for arc in A.a_out_arcs(state, symbol):
                target, weight = arc
                if weight.value == 0:
                    continue
                arcs.append((state, symbol, target, weight))
    return arcs

#############################
# Configuration and Setup   #
#############################

def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Train a transformer")
    parser.add_argument("save_dir", type=str, help="Directory to save results/models")
    # Model / automaton saving flags
    parser.add_argument("--save_automaton", action="store_true", default=False)
    parser.add_argument("--save_dataset", action="store_true", default=False)
    parser.add_argument("--save_model", action="store_true", default=False)
    # Training toggles
    parser.add_argument("--train_rnn", action="store_true", default=False)
    parser.add_argument("--train_transformer", action="store_true", default=False)
    parser.add_argument("--load_run_id", type=int, default=None)
    parser.add_argument("--dataset_size", type=int, default=5000)
    parser.add_argument("--num_val", type=int, default=5000)
    parser.add_argument("--output_dims", type=int, nargs="+", default=[-1])
    # Neural LM hyperparameters
    parser.add_argument("--rnn_type", type=str, default="LSTM")
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--num_epochs", type=int, default=2)
    parser.add_argument("--rnn_learning_rate", type=float, default=0.001)
    parser.add_argument("--transformer_learning_rate", type=float, default=0.001)
    parser.add_argument("--num_layers", type=int, default=4)
    parser.add_argument("--embedding_size", type=int, default=64)
    parser.add_argument("--output_size", type=int, default=None)
    parser.add_argument("--rnn_n_inner", type=int, default=64)
    parser.add_argument("--transformer_n_inner", type=int, default=128)
    parser.add_argument("--optim_type", type=str, default="Adam")
    parser.add_argument("--clip_grads", type=float, default=0.25)
    parser.add_argument("--dropout", type=float, default=0.2)
    parser.add_argument("--direction", type=str, default="left2right")
    parser.add_argument("--tie_weights", action="store_true", default=False)
    parser.add_argument("--init_range", type=float, default=0.1)
    parser.add_argument("--num_heads", type=int, default=4)
    parser.add_argument("--tie_word_embeddings", action="store_true", default=False)
    # PFSA parameters
    parser.add_argument("--automaton_name", type=str, default=None,
                        help="Lookup a specific automaton from an external register")
    parser.add_argument("--num_states", type=int, default=20)
    parser.add_argument("--num_symbols", type=int, default=10)
    parser.add_argument("--accept_prob", type=float, default=0.2)
    parser.add_argument("--rank", type=int, default=-1)
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--target_state", type=int, default=None)
    parser.add_argument("--target_transition", type=int, default=None)
    parser.add_argument("--target_symbol", type=int, default=None)

    # Intervention
    parser.add_argument("--intervention_type", type=str, default="arc",
                        help="arc, symbol, state, or vanilla")
    parser.add_argument("--start_intervention_range", type=int, default=0)
    parser.add_argument("--end_intervention_range", type=int, default=1000)
    parser.add_argument("--intervention_step", type=int, default=100)
    parser.add_argument("--use_dupl_symbol", action="store_true", default=False)
    parser.add_argument("--delete_arcs", action="store_true", default=False)
    # W&B
    parser.add_argument("--wandb_name", type=str, default=None)

    # Decomposed KL

    return parser.parse_args()


def init_wandb(args: argparse.Namespace):
    wandb_name = args.wandb_name
    wandb.init(
        project=wandb_name,
        config=args,
        entity="TODO",
        tags=[f"Q={args.num_states}", f"|Σ|={args.num_symbols}"],
    )


def create_dirs():
    run_name = wandb.run.name
    cfg = wandb.config
    base_dir = cfg.save_dir
    os.makedirs(f"{base_dir}/results/{run_name}", exist_ok=True)
    if cfg.save_automaton:
        os.makedirs(f"{base_dir}/pfsas/{run_name}", exist_ok=True)
    if cfg.save_dataset:
        os.makedirs(f"{base_dir}/data/{run_name}", exist_ok=True)
    if cfg.save_model:
        os.makedirs(f"{base_dir}/rnns/{run_name}", exist_ok=True)
        os.makedirs(f"{base_dir}/transformers/{run_name}", exist_ok=True)
        os.makedirs(f"{base_dir}/tokenizers/{run_name}", exist_ok=True)


def update_config():
    cfg = wandb.config
    if cfg.output_dims[0] == -1 and cfg.output_size is None:
        cfg.update({"output_dims": OUTPUT_DIMS}, allow_val_change=True)
    elif cfg.output_size is not None:
        cfg.update({"output_dims": [cfg.output_size]}, allow_val_change=True)
    if cfg.load_run_id is not None:
        all_runs = sorted([run.split("/")[-2] for run in glob(f"{cfg.save_dir}/results/*/results.csv")])
        cfg.update({"load_run_name": all_runs[cfg.load_run_id]}, allow_val_change=True)


#############################
# Data and Automata Helpers #
#############################

def construct_tokenizer(automaton) -> GPT2Tokenizer:
    base_tokenizer = AutoTokenizer.from_pretrained("gpt2")
    symbols = sorted(str(s.value) for s in automaton.Sigma)
    symbols += ["X"]
    new_tokenizer = base_tokenizer.train_new_from_iterator(
        [[]],
        vocab_size=0,
        min_frequency=1,
        new_special_tokens=symbols + ["<|pad|>", "<|bos|>", "<|eos|>"],
    )
    new_tokenizer.pad_token = "<|pad|>"
    new_tokenizer.bos_token = "<|bos|>"
    new_tokenizer.eos_token = "<|eos|>"
    return new_tokenizer


def construct_dataset(tokenizer: GPT2Tokenizer, train_samples, test_samples):
    samples, train_transitions = [], []
    samples_test, samples_test_probs = [], []
    samples_test_arcs, samples_test_arc_states = [], []
    samples_test_ppl = []

    for samp in train_samples:
        if not isinstance(samp, dict):
            samp_string = "".join(str(c) for c in samp["sampled_string"])
            train_transitions.append([])
        else:
            samp_string = "".join(str(c) for c in samp["sampled_string"])
            train_transitions.append(samp["transitions"])
        samples.append(samp_string)

    for samp in test_samples:
        samp_string = "".join(str(c) for c in samp["sampled_string"])
        samples_test.append(samp_string)
        samples_test_probs.append(samp["log_probs"])
        samples_test_arcs.append(samp["transitions"])
        samples_test_arc_states.append(samp["states"])
        samples_test_ppl.append(sum(samp["log_probs"]))

    average_test_length = np.mean([len(x) for x in samples_test])
    train_dataset = Dataset.from_dict({"symbols": samples, "transitions": train_transitions})
    test_dataset = Dataset.from_dict({
        "symbols": samples_test,
        "log_probs": samples_test_probs,
        "arcs": samples_test_arcs,
        "states": samples_test_arc_states,
        "average_test_perplexity": samples_test_ppl
    })
    average_test_perplexity = -np.mean(samples_test_ppl)

    dataset = DatasetDict({"train": train_dataset, "test": test_dataset})
    dataset = dataset.map(lambda x: {"text": tokenizer.bos_token + "".join(x["symbols"]) + tokenizer.eos_token})
    dataset = dataset.map(
        lambda x: tokenizer(
            x["text"],
            truncation=True,
            padding="max_length",
            max_length=MAX_LENGTH,
            return_tensors="pt",
        ),
        batched=True,
    )
    return dataset, average_test_length, average_test_perplexity


###############################
# Helper for Test-Set Filtering
###############################
def get_num_tgts(samples, arcs=False, states=False, tgt_arc=None, tgt_state=None):
    """
    Count how many samples in 'samples' contain the specified arc or state.
    If arcs=True, 'tgt_arc' should be a (src_idx, tgt_idx, symbol_value) tuple.
    If states=True, 'tgt_state' should be a State or integer representing the target state.
    """
    count = 0
    for samp in samples:
        transitions = samp["transitions"]
        samp_states = samp["states"]
        if arcs:
            if tgt_arc in transitions:
                count += 1
        if states:
            # If the sample's states contain the target state
            if any(s == tgt_state for s in samp_states):
                count += 1
    return count


#############################
# Extra Metrics Calculation #
#############################

def expected_symbol_freq(pdfa, tgt_symbol, arcs):
    fw_weights = pdfa.forward()
    bw_weights = pdfa.backward()
    pathsum = pdfa.pathsum()
    try:
        fw_weights = {k: v / pathsum for k, v in fw_weights.items()}
    except:
        breakpoint()

    symbol_fw_weights = []
    others = []
    for arc in arcs:
        if arc[1].value == tgt_symbol:
            arc_contrib = fw_weights[arc[0]].value * arc[3].value * bw_weights[arc[2]].value
            symbol_fw_weights.append(arc_contrib)
        else:
            arc_contrib = fw_weights[arc[0]].value * arc[3].value * bw_weights[arc[2]].value
            others.append(arc_contrib)
    return sum(symbol_fw_weights) / sum(others) if sum(others) > 0 else 0.0


def compute_extra_metrics(
    intervention_type, cfg, pdfa, sampler, samples, arcs,
    tgt_symbol=None, tgt_arc_orig=None, tgt_state_orig=None
):
    """
    Depending on the intervention type, compute extra metrics:
    - Symbol freq, arc-based, or state-based metrics.
    - We pass the "orig" arc/state from the sampler so we can find it within the PDFa arcs or states.
    """
    extra_metrics = {}
    if intervention_type == "symbol":
        exp_freq = expected_symbol_freq(pdfa, tgt_symbol, arcs)
        symbol_occs = sum(
            (
                "".join(samp["sampled_string"]) if isinstance(samp, dict) else "".join(str(samp[0]))
            ).count(str(tgt_symbol))
            for samp in samples
        )
        extra_metrics = {
            "symbol_occs": symbol_occs,
            "dataset_size": cfg.dataset_size,
            "expected_symbol_freq": exp_freq,
            "real_num_samples": len(samples),
        }
    elif intervention_type == "arc":
        # We'll find the matching arc in the PDFa arcs
        forward = pdfa.forward()
        bward = pdfa.backward()
        sum_probs = [x.value for x in forward.values()]
        epdfa = pdfa.lift(Entropy, lambda w: Entropy(float(w), -log(float(w))))
        fw_entropy = epdfa.forward()
        total_pathsum_arcs = sum(
            forward[arc[0]].value * arc[3].value * bward[arc[2]].value
            for arc in arcs
        )
        tgt_arc_found = None
        for arc in arcs:
            if (
                arc[0].idx == tgt_arc_orig.state_from
                and arc[2].idx == tgt_arc_orig.state_to
                and arc[1].value == tgt_arc_orig.symbol
            ):
                tgt_arc_found = arc
                break

        if tgt_arc_found is not None:
            src_state = tgt_arc_found[0]
            tgt_state_local = tgt_arc_found[2]
            state_fw_prob = forward[src_state].value
            state_fw_prob_norm = (
                state_fw_prob / sum(sum_probs) if sum(sum_probs) > 0 else 0
            )
            src_local_entropy = sum(a[-1].value[1] for a in epdfa.arcs(src_state))
            tgt_local_entropy = sum(a[-1].value[1] for a in epdfa.arcs(tgt_state_local))
            pathsum_arc = (
                state_fw_prob * tgt_arc_found[3].value * bward[tgt_arc_found[2]].value
            )
            pathsum_arc_norm = (
                pathsum_arc / total_pathsum_arcs if total_pathsum_arcs > 0 else 0
            )
            extra_metrics = {
                "arc_occs": sum(sampler.counts) if hasattr(sampler, "counts") else 0,
                "arc": str(tgt_arc_found),
                "arc_weight": tgt_arc_found[3].value,
                "dataset_size": cfg.dataset_size,
                "state_fw_prob": state_fw_prob,
                "state_fw_prob_norm": state_fw_prob_norm,
                "src_local_entropy": src_local_entropy,
                "tgt_local_entropy": tgt_local_entropy,
                "pathsum_arc": pathsum_arc,
                "pathsum_arc_norm": pathsum_arc_norm,
                "real_num_samples": len(samples),
            }
    elif intervention_type == "state":
        # We'll carefully wrap the target state if needed
        forward = pdfa.forward()
        bward = pdfa.backward()
        sum_probs = [x.value for x in forward.values()]
        epdfa = pdfa.lift(Entropy, lambda w: Entropy(float(w), -log(float(w))))
        fw_entropy = epdfa.forward()
        total_pathsum_states = sum(
            forward[s].value * bward[s].value for s in pdfa.Q
        )

        # Avoid double-wrapping
        if isinstance(tgt_state_orig, State):
            tgt_state_obj = tgt_state_orig
        else:
            tgt_state_obj = State(tgt_state_orig)

        if tgt_state_obj not in forward:
            # If the PDFa doesn't have that state for some reason, skip
            return extra_metrics

        state_fw_prob = forward[tgt_state_obj].value
        state_fw_prob_norm = (
            state_fw_prob / sum(sum_probs) if sum(sum_probs) > 0 else 0
        )
        local_entropy = sum(
            a[-1].value[1] for a in epdfa.arcs(tgt_state_obj)
        )
        pathsum_state = state_fw_prob * bward[tgt_state_obj].value
        pathsum_state_norm = (
            pathsum_state / total_pathsum_states if total_pathsum_states > 0 else 0
        )
        extra_metrics = {
            "state_occs": sum(sampler.counts) if hasattr(sampler, "counts") else 0,
            "state": str(tgt_state_obj),
            "dataset_size": cfg.dataset_size,
            "state_fw_prob": state_fw_prob,
            "state_fw_prob_norm": state_fw_prob_norm,
            "fw_entropy": fw_entropy[tgt_state_obj].value[1],
            "local_entropy": local_entropy,
            "pathsum_state": pathsum_state,
            "pathsum_state_norm": pathsum_state_norm,
            "real_num_samples": len(samples),
        }
    else:  # vanilla
        extra_metrics = {"real_num_samples": len(samples)}

    return extra_metrics


#############################
# Training and Evaluation   #
#############################

def train_instance(pdfa, dataset, tokenizer, identifier=None, tgt_arc=None, tgt_symbol=None):
    run_name = wandb.run.name
    all_results = {}
    cfg = wandb.config

    if cfg.train_rnn:
        cfg.update({"learning_rate": cfg.rnn_learning_rate,
                    "n_inner": cfg.rnn_n_inner}, allow_val_change=True)
        rnn_results, rnn_model = train_rnn(pdfa, dataset, tokenizer, identifier=identifier)
        rnn_results = {f"rnn_{k}": v for k, v in rnn_results.items()}
        all_results.update(rnn_results)
        if cfg.save_model:
            save_path = f"{cfg.save_dir}/rnns/{run_name}/rank_{cfg.rank}/out_size_{cfg.output_size}"
            os.makedirs(save_path, exist_ok=True)
            torch.save(rnn_model.state_dict(), f"{save_path}/weights.pt")

    if cfg.train_transformer:
        cfg.update({"learning_rate": cfg.transformer_learning_rate,
                    "n_inner": cfg.transformer_n_inner}, allow_val_change=True)
        transformer_results, transformer_trainer = train_transformer(pdfa, dataset, tokenizer)
        transformer_results = {f"transformer_{k}": v for k, v in transformer_results.items()}
        all_results.update(transformer_results)
        if cfg.save_model:
            save_path = f"{cfg.save_dir}/transformers/{run_name}/rank_{cfg.rank}/out_size_{cfg.output_size}"
            os.makedirs(save_path, exist_ok=True)
            transformer_trainer.save_model(save_path)
        return all_results, transformer_trainer.model

    # If only an RNN was trained, return that
    return all_results, rnn_model if cfg.train_rnn else None


def compute_decomposed_kl(test_data, model, pdfa, tokenizer, intervention_type,
                          tgt_arc=None, tgt_symbol=None, tgt_state=None):
    def get_model_probs(model, samp):
        input_ids = torch.tensor(samp["input_ids"]).unsqueeze(0).to(DEVICE)
        attn_mask = torch.tensor(samp["attention_mask"]).to(DEVICE)
        if wandb.config.train_transformer:
            logits = model(input_ids=input_ids, attention_mask=attn_mask).logits
        else:
            logits = get_logits(model, input_ids, attn_mask)
        return logits.softmax(dim=-1)[0]

    # machine symbols
    symbols = pdfa.Sigma
    # tokenizer symbols
    symbol_dict = {}
    for sym in symbols:
        token = tokenizer.encode(str(sym))[0]
        symbol_dict[sym] = token

    if tgt_state is not None:
        seen = 0
        for samp in test_data:
            for arc in samp["arcs"]:
                src, tgt, sym = arc
                if tgt_state == tgt:
                    seen += 1 

    decomposed_kl = {"trans": defaultdict(list), "states": defaultdict(list), "symbols": defaultdict(list)}
    for _, samp in tqdm.tqdm(enumerate(test_data), desc="Decomposing KL", total=len(test_data)):
        probs = get_model_probs(model, samp)

        def get_local_kl(cur_idx):
            cur_arc = samp["arcs"][cur_idx]
            _, tgt_state, _ = cur_arc

            # the state the arc targets
            pdfa_state = [q for q in pdfa.Q if q.idx.idx == tgt_state][0]

            local_kl = 0
            # skip bos since no transition
            
            # Probs at idx are the probs for the first symbol given BOS
            # To get the probs at the pdfa_state , the target state
            # we need to +1
            model_probs = probs[cur_idx + 1]

            # Transtions from the pdfa_state
            transitions = pdfa.δ[pdfa_state]

            # We check the transitions at the state, we only
            # need to loop over these since the others
            # have zero probability. And then check for eos below
            trans_sym = set()
            for symb in transitions.keys():
                trans_sym.add(symb.value.value)

            for symb, token_id in symbol_dict.items(): 
                if symb.value.value not in trans_sym:
                    continue 

                # this is dict with tgt state as key and weight as value
                pdfa_probs = transitions[symb].values()
                pdfa_prob = 0
                for prob in pdfa_probs:
                    pdfa_prob += prob.value

                m_prob = model_probs[token_id].item()
                local_kl += pdfa_prob * log(pdfa_prob / m_prob)
            
            # We also need the EOS in the local KL
            accept_prob = pdfa.ρ[pdfa_state].value

            # if tgt_state == 4:
            #    # check the star free
            #    breakpoint()

            if accept_prob:
                eos_id = tokenizer.eos_token_id
                eos_prob = model_probs[eos_id].item()
                local_kl += accept_prob * log(accept_prob / eos_prob)

            if False and tgt_state == 4:
                # check the star free
                try:
                    assert abs(local_kl - 0) < 0.01
                except:
                    print("Seen state as tgt in test set: ", seen)
                    print(f"tgt_r1\t{samp["symbols"][:cur_idx+1]}\tlocal_kl\t{local_kl}")
                    breakpoint()           

            return local_kl

        for jdx, arc in enumerate(samp["arcs"]):
        
            src, tgt, sym = arc
            arc_prob = exp(samp["log_probs"][jdx])
            token = tokenizer.encode(str(sym))[0]

            try:
                model_prob = probs[jdx][token].item()
            except Exception:
                # Possibly an out-of-bounds index for the last token if we sample too long, skip
                print("THIS SHOULD NOT HAPPEN")
                print(f"Shape of sample['arcs']: {len(samp["arcs"])} -- jdx: {jdx}")
                break

            # Todo: reconsider for not states
            loc_kl = log(arc_prob / model_prob)
            decomposed_kl["trans"][tuple(arc)].append(loc_kl)

            # Check if first state..
            try:
                state_kl = get_local_kl(jdx)
            except:
                print("TODO: fix, OUT OF BOUNDS ... model does not accept this length...")
                break
 
            decomposed_kl["states"][tgt].append(state_kl)
            decomposed_kl["symbols"][sym].append(loc_kl)

        if not len(samp["arcs"]):
            continue

        # ensure we did not have some eos thing in the samp["arcs"]!

        # We now handle EOS probs
        #final_state = tgt
        # eos_idx = len(samp["arcs"])
        # eos_token_id = tokenizer.encode(tokenizer.eos_token)[0]

        # try:
        #     assert samp["input_ids"][1:][eos_idx] == eos_token_id

        #     eos_machine_prob = exp(samp["log_probs"][eos_idx])
        #     eos_model_prob = probs[eos_idx][eos_token_id].item()
        #     eos_kl = log(eos_machine_prob / eos_model_prob)
        #     decomposed_kl["states"][final_state].append(eos_kl)
        # except:
        #     print("THIS SHOULD NOT HAPPEN")
        #     break

    # Intervention-specific grouping
    if intervention_type == "symbol":
        intervention_kl = decomposed_kl["symbols"][tgt_symbol]
        rest_kl = np.mean([np.mean(v) for k, v in decomposed_kl["symbols"].items() if k != tgt_symbol])
    elif intervention_type == "arc":
        #lookup_arc = (tgt_arc[0].idx, tgt_arc[2].idx, tgt_arc[1].value)
        lookup_arc = (tgt_arc.state_from, tgt_arc.state_to, tgt_arc.symbol)
        intervention_kl = decomposed_kl["trans"][lookup_arc]
        rest_kl = np.mean([np.mean(v) for k, v in decomposed_kl["trans"].items() if k != lookup_arc])
    elif intervention_type == "state":
        intervention_kl = decomposed_kl["states"][tgt_state]
        rest_kl = np.mean([np.mean(v) for k, v in decomposed_kl["states"].items() if k != tgt_state])
    else:  # "vanilla"
        # For vanilla, treat all symbols as "intervention"
        intervention_kl = [np.mean(v) for v in decomposed_kl["symbols"].values()]
        rest_kl = intervention_kl

    #if tgt_state == 4:
    #    # check the star free
    #    local_kl = np.mean(intervention_kl)
    #    #try:
    #    #    assert abs(local_kl - 0) < 0.01
    #    #except:
    #    #    #print(f"tgt_r1\tlocal_kl\t{local_kl}")
    #    #    #breakpoint()

    return np.mean(intervention_kl), np.mean(rest_kl), decomposed_kl


def aggregate_decomposed_kl(decomposed_kl, train_dataset):
    averaged_kls = {}
    averaged_kls_counts = {}
    for key, val in decomposed_kl.items():
        if not val:
            continue
        for k, v in val.items():
            averaged_kls[f"{key}_{k}"] = np.mean(v)
            averaged_kls_counts[f"{key}_{k}"] = len(v)

    counts_in_train = {"symbols": defaultdict(int), "trans": defaultdict(int), "states": defaultdict(int)}
    for transitions in train_dataset["transitions"]:
        for transition in transitions:
            src, tgt, sym = transition
            counts_in_train["symbols"][sym] += 1
            counts_in_train["trans"][str(tuple(transition))] += 1
            counts_in_train["states"][tgt] += 1
    counts_in_train = {k: dict(v) for k, v in counts_in_train.items()}
    return averaged_kls, averaged_kls_counts, counts_in_train


def log_and_save_metrics(metrics, decomposed_data, occ_count, extra_tables=None):
    cfg = wandb.config
    run_name = wandb.run.name
    metrics.update(cfg)
    metrics["wandb_run_name"] = run_name
    results_dir = f"{cfg.save_dir}/results/{run_name}"
    os.makedirs(results_dir, exist_ok=True)

    df = pd.DataFrame([metrics])
    df.to_csv(f"{results_dir}/results_{occ_count}.csv", index=False)
    with open(f"{results_dir}/results_{occ_count}.pkl", "wb") as f:
        dill.dump(metrics, f)

    wandb_update = {"results": wandb.Table(dataframe=df)}
    for key, val in decomposed_data.items():
        wandb_update[f"decomp_kl/{key}"] = val

    if extra_tables:
        wandb_update.update(extra_tables)

    wandb.log(wandb_update, step=occ_count)
    return df


#############################
# Main Experiment Loop      #
#############################

@timeoutable(default=None)
def main():
    cfg = wandb.config
    intervention_type = cfg.intervention_type
    if intervention_type not in ["arc", "symbol", "state", "vanilla"]:
        raise ValueError("Intervention must be 'arc', 'symbol', 'state' or 'vanilla'")

    # Setup logging name for interventions
    log_type = intervention_type
    if cfg.use_dupl_symbol:
        log_type = "dupl_symbol"
        if cfg.delete_arcs:
            log_type = f"{log_type}_del_arcs"
    cfg.update({"intervention": log_type})

    sampler = None
    test_samples = None
    tokenizer = None
    first_sha = None
    all_metrics = []

    # Decide on the steps for intervention
    start, end, interval = cfg.start_intervention_range, cfg.end_intervention_range, cfg.intervention_step
    if intervention_type in ["arc", "symbol", "state"]:
        occ_steps = list(range(start, end + 1, interval))
        if len(occ_steps) > 1:
            # Insert a mid-point between the first two
            # Higher granularity at the front
            occ_steps = occ_steps[:1] + [int((occ_steps[1] - occ_steps[0]) / 2)] + occ_steps[1:]
    else:
        occ_steps = [0]

    # Helper for ensuring enough target arcs / states in test set
    def ensure_test_tgts_for_arc(test_size, tgt_arc_orig, min_count=5):
        tries = 0
        found_tgts = 0
        while found_tgts < min_count:
            potential_test = sampler.sample_original(test_size, tgt_arc=tgt_arc_orig)
            found_tgts = get_num_tgts(potential_test, arcs=True, tgt_arc=(tgt_arc_orig.state_from, 
                                                                          tgt_arc_orig.state_to,
                                                                          tgt_arc_orig.symbol))
            tries += 1
            if tries > 5:
                raise ValueError("Could not find enough target arcs in the test set after multiple tries.")
            if found_tgts >= min_count:
                return potential_test
        return potential_test

    def ensure_test_tgts_for_state(test_size, tgt_state_orig, min_count=5):
        tries = 0
        found_tgts = 0
        while found_tgts < min_count:
            potential_test = sampler.sample_original(test_size, tgt_state=tgt_state_orig)
            found_tgts = get_num_tgts(potential_test, states=True, tgt_state=tgt_state_orig)
            tries += 1
            if tries > 5:
                raise ValueError("Could not find enough target states in the test set after multiple tries.")
            if found_tgts >= min_count:
                return potential_test
        return potential_test

    for occ_count in occ_steps:
        print(f"Starting training for intervention={intervention_type}, occ_count={occ_count}")
        # Sample training data for interventions
        if sampler is None:
            print("Setting up sampler/machine")
            if intervention_type == "symbol":
                train_samples, sampler, counts, stats = sample_symbol(cfg.dataset_size, occ_count,
                                                                      seed=cfg.seed, num_states=cfg.num_states,
                                                                     accept_prob=cfg.accept_prob, name=cfg.automaton_name,
                                                                      tgt_symbol=cfg.target_symbol, alphabet_size=cfg.num_symbols)
            elif intervention_type == "arc":
                train_samples, sampler, counts, stats = sample_arc(cfg.dataset_size, occ_count,
                                                                   seed=cfg.seed, num_states=cfg.num_states,
                                                                   accept_prob=cfg.accept_prob, name=cfg.automaton_name,
                                                                   tgt_transition=cfg.target_transition, alphabet_size=cfg.num_symbols)
            elif intervention_type == "state":
                train_samples, sampler, counts, stats = sample_state(cfg.dataset_size, occ_count,
                                                                     seed=cfg.seed, num_states=cfg.num_states,
                                                                     accept_prob=cfg.accept_prob, name=cfg.automaton_name,
                                                                     tgt_state=cfg.target_state, alphabet_size=cfg.num_symbols)
                count_in_train = 0
                for samp in train_samples:
                    for src, tgt, symb in samp["transitions"]:
                        if tgt == cfg.target_state:
                            count_in_train += 1
                assert count_in_train == occ_count

            else:  # vanilla
                train_samples, sampler, stats = sample_vanilla(cfg.dataset_size, occ_count,
                                                               seed=cfg.seed, num_states=cfg.num_states,
                                                               accept_prob=cfg.accept_prob, name=cfg.automaton_name, alphabet_size=cfg.num_symbols)
        else:
            sampler.resize(occ_count)
            #try:
            new_samples, counts = sampler.sample_interventions(cfg.dataset_size, occ_count)
            #except:
            #    print(f"ISSUE WITH {occ_count} -- continuing to next count")
            #    continue
            train_samples = new_samples

        if stats is not None:
            for k, v in stats.items():
                wandb.log({f"stats/{k}": v}, step=occ_count)

        # Convert the sampler to a PDFa
        pdfa = to_rayuela(sampler)

        pdfa_hash = pdfa.hash_from_values()
        if first_sha is None:
            first_sha = pdfa_hash
        elif pdfa_hash != first_sha:
            raise ValueError("The underlying automaton changed across samples!")

        # Extract "orig" references from the sampler
        tgt_symbol_orig = sampler.tgt_symbol if intervention_type == "symbol" else None
        tgt_arc_orig = sampler.A_l.target_transition if intervention_type == "arc" else None
        tgt_state_orig = sampler.A_l.target_state if intervention_type == "state" else None

        # Build or update the test set with enough arcs/states if needed
        if test_samples is None:
            # We pick a test_size large enough to ensure coverage
            test_size = cfg.num_val
            if intervention_type == "symbol" or intervention_type == "vanilla":
                test_samples = sampler.sample_original(test_size)
            elif intervention_type == "arc":
                if tgt_arc_orig is not None:
                    test_samples = ensure_test_tgts_for_arc(test_size, tgt_arc_orig, min_count=5)
                else:
                    test_samples = sampler.sample_original(test_size)
            elif intervention_type == "state":
                if tgt_state_orig is not None:
                    test_samples = ensure_test_tgts_for_state(test_size, tgt_state_orig, min_count=5)
                else:
                    test_samples = sampler.sample_original(test_size)

                test_samp_found = 0
                for samp in test_samples:
                    for arc in samp["transitions"]:
                        _, tgt, sym = arc
                        if tgt == tgt_state_orig:
                            test_samp_found += 1

                assert test_samp_found > 10

        if False:
            # Remove test_samples that are in train
            train_strings = set()
            for tr_samp in train_samples:
                tr_str = tr_samp["sampled_string"]
                train_strings.add(tr_str)
            new_test_samples = []
            for tst_samp in test_samples:
                if tst_samp["sampled_string"] in train_strings:
                    continue
                new_test_samples.append(tst_samp)
            print(f"Went from {len(test_samples)} test samples to {len(new_test_samples)} samples")
            test_samples = new_test_samples

        if tokenizer is None:
            tokenizer = construct_tokenizer(pdfa)
            if cfg.save_automaton:
                tokenizer.save_pretrained(f"{cfg.save_dir}/tokenizers/{wandb.run.name}")

        dataset, avg_test_len, avg_test_ppl = construct_dataset(tokenizer, train_samples, test_samples)

        # Train
        training_metrics, trained_model = train_instance(pdfa, dataset, tokenizer,
                                                         identifier=occ_count,
                                                         tgt_symbol=tgt_symbol_orig,
                                                         tgt_arc=tgt_arc_orig)

        # Decomposed KL
        kl_intervention, kl_rest, decomposed_kl = compute_decomposed_kl(
            dataset["test"], trained_model, pdfa, tokenizer, intervention_type,
            tgt_arc=tgt_arc_orig, tgt_symbol=tgt_symbol_orig, tgt_state=tgt_state_orig
        )
        training_metrics["decomp_KL_intervention"] = kl_intervention
        training_metrics["decomp_KL_rest"] = kl_rest

        print(f"Decomposed kl : {kl_intervention}")

        # Aggregate decomposed KL
        averaged_kls, averaged_kls_counts, counts_in_train = aggregate_decomposed_kl(
            decomposed_kl, dataset["train"].to_dict()
        )
        training_metrics["all_decomp_kl"] = averaged_kls
        training_metrics["all_decomp_kl_counts"] = averaged_kls_counts
        training_metrics["counts_in_train"] = counts_in_train

        # Extra metrics
        arcs = get_arcs(pdfa)
        extra_metrics = compute_extra_metrics(
            intervention_type, cfg, pdfa, sampler, train_samples, arcs,
            tgt_symbol=tgt_symbol_orig, tgt_arc_orig=tgt_arc_orig, tgt_state_orig=tgt_state_orig
        )
        training_metrics.update(extra_metrics)

        # Log final stuff
        # E.g., we can compute an empirical KL difference using the PDFa entropy
        pdfa_entropy = pdfa.calc_entropy()
        training_metrics["pdfa_entropy"] = pdfa_entropy
        training_metrics["pdfa_expected_length"] = pdfa.calc_expected_length()
        training_metrics["average_test_length"] = avg_test_len
        # Empirical perplexity difference, if desired
        if "rnn_loss" in training_metrics:
            training_metrics["rnn_KL_divergence"] = training_metrics["rnn_loss"] - pdfa_entropy
            training_metrics["rnn_KL_divergence_empirical"] = training_metrics["rnn_loss"] - avg_test_ppl
        if "transformer_loss" in training_metrics:
            training_metrics["transformer_KL_divergence"] = (
                training_metrics["transformer_loss"] - pdfa_entropy
            )
            training_metrics["transformer_KL_divergence_empirical"] = (
                training_metrics["transformer_loss"] - avg_test_ppl
            )

        wandb.config.update(training_metrics, allow_val_change=True)
        all_metrics.append(training_metrics.copy())

        df = log_and_save_metrics(
            training_metrics,
            {
                "target_decomp_kl/intervention": kl_intervention,
                "target_decomp_kl/rest": kl_rest
            },
            occ_count
        )

        torch.cuda.empty_cache()
        del dataset, pdfa

    # Save final results
    final_dir = f"{cfg.save_dir}/{wandb.run.name}"
    os.makedirs(final_dir, exist_ok=True)
    df.to_csv(f"{final_dir}/final_results.csv")
    wandb.finish()


if __name__ == "__main__":
    args = parse_args()
    init_wandb(args)
    create_dirs()
    update_config()
    assert args.embedding_size % args.num_heads == 0, (
        f"The embedding size ({args.embedding_size}) must be divisible by the number of heads ({args.num_heads})."
    )
    main(timeout=3600)
