import argparse
import os
import json
from counting_sampling import sample_symbol, to_rayuela, sample_arc, sample_state, sample_vanilla
import torch
import dill


def tensor_to_item(obj):
    if isinstance(obj, torch.Tensor):
        if obj.ndim == 0:
            return obj.item()
        else:
            return obj.tolist()


class TorchItemEncoder(json.JSONEncoder):
    def default(self, obj):
        try:
            return super().default(obj)
        except TypeError:
            return tensor_to_item(obj)

out_file_name_map = {
    "text": "main.tok",
    "arcs": "arcs.txt",
    "log_probs": "log_probs.txt"
} 

def get_output_files(args):
    # space for train/test
    output_dir = args.output_dir+"/{experiment}/{dataset}/{target}/{args.intervention_count}/"

    output_exp = f"{args.intervention_type}_seed{args.seed}"

    # Construct output file path, train and val
    if args.automaton_name:
        output_exp = f"{args.automaton_name}"
    else:
        output_exp = f"samp_st{args.num_states}_sym{args.num_symbols}"
        
    if args.target_transition:
        interv_tgt = f"arc_{args.target_transition}"
    elif args.target_state:
        interv_tgt = f"state_{args.target_state}" 
    elif args.target_symbol:
        interv_tgt = f"symbol_{args.target_symbol}" 

    output_dir = output_dir.format(experiment=output_exp)
   
    for subf in ["train", "test", "machine"]:
        if not os.path.exists(output_dir.format(dataset=subf)):
            os.makedirs(output_dir.format(dataset=subf))

    machine_path = os.path.join(output_dir.format(dataset="machine"), "machine.pkl")
    if args.output_type == "json":
        return {
            "train": os.path.join(output_dir.format(dataset="train"), "main.json"),
            "test": os.path.join(output_dir.format(dataset="test"), "main.json"),
            "machine": machine_path
        }

    files = {"machine": machine_path}
  
    for dfile in ["arcs", "text", "log_probs"]:
        for dset in ["train", "test"]:
            files[(dfile, dset)] = os.path.join(output_dir.format(dataset=dset), out_file_name_map[dfile])
    return files


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Arguments for data generation") 
    parser.add_argument("--dataset_size", type=int, default=100)
    parser.add_argument("--num_val", type=int, default=100)
    parser.add_argument("--num_test", type=int, default=100)
    parser.add_argument("--validation_num_occurrences", type=int, default=100)
    parser.add_argument("--automaton_name", type=str,
                        help="Lookup a specific automaton from an external register")
    parser.add_argument("--num_states", type=int, default=20)
    parser.add_argument("--num_symbols", type=int, default=10)
    parser.add_argument("--accept_prob", type=float, default=0.2)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--target_state", type=int, default=None)
    parser.add_argument("--target_transition", type=int, default=None)
    parser.add_argument("--target_symbol", type=int, default=None)
    parser.add_argument("--intervention_type", type=str, default="arc",
                        help="arc, symbol, state, or vanilla")
    parser.add_argument("--intervention_count", type=int, default=100)
    parser.add_argument("--output_dir", type=str, default="data")
    parser.add_argument("--output_prefix", type=str, default=None)
    parser.add_argument("--output_type", type=str, default="text")
    parser.add_argument("--save_automaton", action="store_true", default=True)
    parser.add_argument("--training_output", type=str, default="data/train/main.tok")
    parser.add_argument("--validation_output", type=str, default="data/val/main.tok")
    parser.add_argument("--test_output", type=str, default="data/test/main.tok")
    return parser.parse_args()


def construct_jsons(samples, delim=" "):
    dataset = []
    for samp in samples:
        samp_string = delim.join(str(c) for c in samp["sampled_string"])
        log_probs = samp["log_probs"]
        if isinstance(log_probs, torch.Tensor):
            log_probs = log_probs.tolist()

        dataset.append(
            {
                "text": samp_string,
                "log_probs": log_probs,
                "arcs": samp["transitions"],
                "states": samp["states"],
                "perplexity": sum(samp["log_probs"]),
                "length": len(samp_string),
                "sampled_string_raw": samp["sampled_string"]
            }
        )
    return dataset


def get_num_tgts(samples, arcs=False, states=False, tgt_arc=None, tgt_state=None):
    """
    Count how many samples in 'samples' contain the specified arc or state.
    If arcs=True, 'tgt_arc' should be a (src_idx, tgt_idx, symbol_value) tuple.
    If states=True, 'tgt_state' should be a State or integer representing the target state.
    """
    count = 0
    for samp in samples:
        transitions = samp["transitions"]
        samp_states = samp["states"]
        if arcs:
            if tgt_arc in transitions:
                count += 1
        if states:
            # If the sample's states contain the target state
            if any(s == tgt_state for s in samp_states):
                count += 1
    return count


def main(args=None):
    if args is None:
        args = parse_args()
    print(args)

    assert args.output_type in ["text", "json"], "wrong output type"

    intervention_type = args.intervention_type
    if intervention_type not in ["arc", "symbol", "state", "vanilla"]:
        raise ValueError("Intervention must be 'arc', 'symbol', 'state' or 'vanilla'")

    # Helpers for ensuring enough target arcs / states in test set
    # might also need owrry about symbols
    def ensure_test_tgts_for_arc(test_size, tgt_arc_orig, min_count=5):
        tries = 0
        found_tgts = 0
        while found_tgts < min_count:
            potential_test = sampler.sample_original(test_size, tgt_arc=tgt_arc_orig)
            found_tgts = get_num_tgts(potential_test, arcs=True, tgt_arc=(tgt_arc_orig.state_from, 
                                                                          tgt_arc_orig.state_to,
                                                                          tgt_arc_orig.symbol))
            tries += 1
            if tries > 5:
                raise ValueError("Could not find enough target arcs in the test set after multiple tries.")
            if found_tgts >= min_count:
                return potential_test
        return potential_test

    def ensure_test_tgts_for_state(test_size, tgt_state_orig, min_count=5):
        tries = 0
        found_tgts = 0
        while found_tgts < min_count:
            potential_test = sampler.sample_original(test_size, tgt_state=tgt_state_orig)
            found_tgts = get_num_tgts(potential_test, states=True, tgt_state=tgt_state_orig)
            tries += 1
            if tries > 5:
                raise ValueError("Could not find enough target states in the test set after multiple tries.")
            if found_tgts >= min_count:
                return potential_test
        return potential_test

    sample_args = {
        "seed": args.seed,
        "num_states": args.num_states,
        "accept_prob": args.accept_prob,
        "name": args.automaton_name,
        "alphabet_size": args.num_symbols
    }
    
    if intervention_type == "symbol":
        sample_function = sample_symbol
        sample_args["tgt_symbol"] = args.target_symbol
    elif intervention_type == "arc":
        sample_function = sample_arc
        sample_args["tgt_transition"] = args.target_transition
    elif intervention_type == "state":
        sample_function = sample_state
        sample_args["tgt_state"] = args.target_state
    else:
        sample_function = sample_vanilla
        
    train_samples, sampler = sample_function(args.dataset_size, args.intervention_count, **sample_args)[:2]

    val_samples, val_counts = sampler.sample_interventions(args.num_val, args.validation_num_occurrences, pbar=True)

    pdfa = to_rayuela(sampler)

    tgt_symbol_orig = sampler.tgt_symbol if intervention_type == "symbol" else None
    tgt_arc_orig = sampler.A_l.target_transition if intervention_type == "arc" else None
    tgt_state_orig = sampler.A_l.target_state if intervention_type == "state" else None

    # TEST DATA SAMPLING
    # This is a bit of a hack, but it's to ensure that the test set contains the
    # target arc or state, we could also use the semiring to ensure this
    val_min_count = 10
    if intervention_type == "symbol" or intervention_type == "vanilla":
        test_samples = sampler.sample_original(args.num_test)
    elif intervention_type == "arc":
        if tgt_arc_orig is not None:
            test_samples = ensure_test_tgts_for_arc(args.num_test, tgt_arc_orig, min_count=val_min_count)
        else:
            test_samples = sampler.sample_original(args.num_test)
    elif intervention_type == "state":
        if tgt_state_orig is not None:
            test_samples = ensure_test_tgts_for_state(args.num_test, tgt_state_orig, min_count=val_min_count)
        else:
            test_samples = sampler.sample_original(args.num_test)

        test_samp_found = 0
        for samp in test_samples:
            for arc in samp["transitions"]:
                _, tgt, sym = arc
                if tgt == tgt_state_orig:
                    test_samp_found += 1

    datasets = {
        "train": construct_jsons(train_samples),
        "test": construct_jsons(test_samples),
        "validation": construct_jsons(val_samples),
        "train_folder": "/".join(args.training_output.split("/")[:-1]),
        "validation_folder": "/".join(args.validation_output.split("/")[:-1]),
        "test_folder": "/".join(args.test_output.split("/")[:-1]), 
        "train_out": args.training_output,
        "validation_out": args.validation_output,
        "test_out": args.test_output
    }
    os.makedirs(datasets["train_folder"], exist_ok=True)
    os.makedirs(datasets["validation_folder"], exist_ok=True)
    os.makedirs(datasets["test_folder"], exist_ok=True) 

    machine_out = datasets["train_folder"] + "/machine.pkl"

    if args.save_automaton:    
        with open(machine_out, "wb") as f:
            dill.dump(pdfa, f)

    if args.output_type == "json":
        train_file, val_file = args.training_output, validation_output
        for data, file in zip([train_dataset, val_dataset], [train_file, val_file]):
            with open(file, "w") as f:
                for d in data:
                    f.write(json.dumps(d, cls=TorchItemEncoder) + "\n")
    else:
        for dfile in ["arcs", "text", "log_probs"]:
            for dset in ["train", "test", "validation"]:
                data = datasets[dset]
                if dfile == "text":
                    out_file = datasets[f"{dset}_out"]
                else:
                    out_file = datasets[f"{dset}_folder"] + f"/{dfile}.txt"
                with open(out_file, "w") as f:
                    for d in data:
                        f.write(str(d[dfile]) + "\n")


if __name__ == "__main__":
    main()
