#!/usr/bin/env python3
import argparse
import os
import json
import torch
import dill
import tqdm
from counting_sampling_v2 import to_rayuela


def tensor_to_item(obj):
    if isinstance(obj, torch.Tensor):
        if obj.ndim == 0:
            return obj.item()
        else:
            return obj.tolist()


class TorchItemEncoder(json.JSONEncoder):
    def default(self, obj):
        try:
            return super().default(obj)
        except TypeError:
            return tensor_to_item(obj)


out_file_name_map = {
    "text": "main.tok",
    "arcs": "arcs.txt",
    "log_probs": "log_probs.txt"
}


def get_output_files(args):
    # space for train/test
    output_dir = args.output_dir+"/{experiment}/{dataset}/{target}/{intervention_count}/"

    if args.automaton_name:
        output_exp = f"{args.automaton_name}"
    else:
        output_exp = f"samp_st{args.num_states}_sym{args.num_symbols}"
        
    if args.target_transition:
        interv_tgt = f"arc_{args.target_transition}"
    elif args.target_state:
        interv_tgt = f"state_{args.target_state}" 
    elif args.target_symbol:
        interv_tgt = f"symbol_{args.target_symbol}" 
    else:
        interv_tgt = "vanilla"

    output_dir = output_dir.format(experiment=output_exp, intervention_count=args.intervention_count)

    for subf in ["train", "test", "machine"]:
        if not os.path.exists(output_dir.format(dataset=subf, target=interv_tgt)):
            os.makedirs(output_dir.format(dataset=subf, target=interv_tgt))

    machine_path = os.path.join(output_dir.format(dataset="machine", target=interv_tgt), "machine.pkl")
    if args.output_type == "json":
        return {
            "train": os.path.join(output_dir.format(dataset="train", target=interv_tgt), "main.json"),
            "test": os.path.join(output_dir.format(dataset="test", target=interv_tgt), "main.json"),
            "machine": machine_path
        }

    files = {"machine": machine_path}
  
    for dfile in ["arcs", "text", "log_probs"]:
        for dset in ["train", "test"]:
            files[(dfile, dset)] = os.path.join(output_dir.format(dataset=dset, target=interv_tgt), out_file_name_map[dfile])
    return files


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Arguments for sampling and preparing data")
    parser.add_argument("--input_dir", type=str, default="data")
    parser.add_argument("--dataset_size", type=int, default=100)
    parser.add_argument("--num_val", type=int, default=100)
    parser.add_argument("--num_test", type=int, default=100)
    parser.add_argument("--automaton_name", type=str, default=None,
                        help="Lookup a specific automaton from an external register")
    parser.add_argument("--num_states", type=int, default=20)
    parser.add_argument("--num_symbols", type=int, default=10)
    parser.add_argument("--sampling_seed", type=int, default=None,
                        help="Seed for sampling process")
    parser.add_argument("--seed", type=int, default=42,
                        help="General seed (for backwards compatibility)")
    parser.add_argument("--target_state", type=int, default=None)
    parser.add_argument("--target_transition", type=int, default=None)
    parser.add_argument("--target_symbol", type=int, default=None)
    parser.add_argument("--intervention_type", type=str, default="arc")
    parser.add_argument("--intervention_count", type=int, default=100)
    parser.add_argument("--output_type", type=str, default="text")
    parser.add_argument("--save_automaton", action="store_true", default=True)
    parser.add_argument("--training_output", type=str, default="data/train/main.tok")
    parser.add_argument("--validation_output", type=str, default="data/val/main.tok")
    parser.add_argument("--test_output", type=str, default="data/test/main.tok")
    parser.add_argument("--validation-num-occurrences", type=int, default=None)
    return parser.parse_args()


def construct_jsons(samples, delim=" "):
    dataset = []
    for samp in samples:
        samp_string = delim.join(str(c) for c in samp["sampled_string"])
        log_probs = samp["log_probs"]
        if isinstance(log_probs, torch.Tensor):
            log_probs = log_probs.tolist()

        dataset.append(
            {
                "text": samp_string,
                "log_probs": log_probs,
                "arcs": samp["transitions"],
                "states": samp["states"],
                "perplexity": sum(samp["log_probs"]),
                "length": len(samp_string),
                "sampled_string_raw": samp["sampled_string"]
            }
        )
    return dataset


def get_num_tgts(samples, arcs=False, states=False, tgt_arc=None, tgt_state=None):
    """
    Count how many samples in 'samples' contain the specified arc or state.
    If arcs=True, 'tgt_arc' should be a (src_idx, tgt_idx, symbol_value) tuple.
    If states=True, 'tgt_state' should be a State or integer representing the target state.
    """
    count = 0
    for samp in samples:
        transitions = samp["transitions"]
        samp_states = samp["states"]
        if arcs:
            if tgt_arc in transitions:
                count += 1
        if states:
            # If the sample's states contain the target state
            if any(s == tgt_state for s in samp_states):
                count += 1
    return count


def main(args=None):
    if args is None:
        args = parse_args()
    print(args)

    assert args.output_type in ["text", "json"], "wrong output type"

    # Load the sampler from the previous step
    sampler_path = os.path.join(args.input_dir, "sampler.pkl")
    with open(sampler_path, "rb") as f:
        sampler = dill.load(f)

    # resize sampler, this is one ugly conditional
    if "Overflow" not in str(sampler.A_l.semiring()):
        sampler.resize(args.intervention_count)

    # Load sampling parameters
    sampling_params_path = os.path.join(args.input_dir, "sampling_params.json")
    with open(sampling_params_path, "r") as f:
        sampling_params = json.load(f)
    
    # Load seed information from previous step
    seed_info_path = os.path.join(args.input_dir, "seed_info.json")
    with open(seed_info_path, "r") as f:
        seed_info = json.load(f)
    
    # Determine the seed to use for sampling
    sampling_seed = None
    if args.sampling_seed is not None:
        sampling_seed = args.sampling_seed
    elif args.seed is not None:
        sampling_seed = args.seed
    
    # If we have a specific sampling seed, set it
    sampler.set_seed(sampling_seed)
    
    # Update seed information with sampling seed
    seed_info["sampling_seed"] = sampling_seed
    
    # Get parameters from sampling_params
    intervention_type = sampling_params["intervention_type"]
    intervention_count = args.intervention_count
    validation_num_occurrences = args.validation_num_occurrences
    if validation_num_occurrences is None:
        validation_num_occurrences = int(args.intervention_count * args.num_val / args.dataset_size)
    target_state = sampling_params.get("target_state")
    target_transition = sampling_params.get("target_transition")
    target_symbol = sampling_params.get("target_symbol")
    
    # Generate training samples
    print(f"Generating {args.dataset_size} training samples...")
    if intervention_type == "vanilla":
        train_samples = sampler.sample_original(num_samples=args.dataset_size)
        val_samples = sampler.sample_original(num_samples=args.num_val)
    else:
        train_samples, train_counts = sampler.sample_interventions(args.dataset_size, intervention_count, pbar=True)
    
        # Generate validation samples
        print(f"Generating {args.num_val} validation samples...")
        val_samples, val_counts = sampler.sample_interventions(args.num_val, validation_num_occurrences, pbar=True)

    # Convert to rayuela format
    pdfa = to_rayuela(sampler)

    tgt_symbol_orig = sampler.tgt_symbol if intervention_type == "symbol" else None
    tgt_arc_orig = sampler.A_l.target_transition if intervention_type == "arc" else None
    tgt_state_orig = sampler.A_l.target_state if intervention_type == "state" else None

    # Helpers for ensuring enough target arcs / states in test set
    def ensure_test_tgts_for_arc(test_size, tgt_arc_orig, min_count=5):
        tries = 0
        found_tgts = 0
        while found_tgts < min_count:
            potential_test = sampler.sample_original(test_size, tgt_arc=tgt_arc_orig)
            found_tgts = get_num_tgts(potential_test, arcs=True, tgt_arc=(tgt_arc_orig.state_from, 
                                                                          tgt_arc_orig.state_to,
                                                                          tgt_arc_orig.symbol))
            tries += 1
            if tries > 5:
                raise ValueError("Could not find enough target arcs in the test set after multiple tries.")
            if found_tgts >= min_count:
                return potential_test
        return potential_test

    def ensure_test_tgts_for_state(test_size, tgt_state_orig, min_count=5):
        tries = 0
        found_tgts = 0
        while found_tgts < min_count:
            potential_test = sampler.sample_original(test_size, tgt_state=tgt_state_orig)
            found_tgts = get_num_tgts(potential_test, states=True, tgt_state=tgt_state_orig)
            tries += 1
            if tries > 5:
                raise ValueError("Could not find enough target states in the test set after multiple tries.")
            if found_tgts >= min_count:
                return potential_test
        return potential_test

    # TEST DATA SAMPLING
    print(f"Generating {args.num_test} test samples...")
    val_min_count = 10
    if intervention_type == "symbol" or intervention_type == "vanilla":
        test_samples = sampler.sample_original(args.num_test)
    elif intervention_type == "arc":
        if tgt_arc_orig is not None:
            test_samples = ensure_test_tgts_for_arc(args.num_test, tgt_arc_orig, min_count=val_min_count)
        else:
            test_samples = sampler.sample_original(args.num_test)
    elif intervention_type == "state":
        if tgt_state_orig is not None:
            test_samples = ensure_test_tgts_for_state(args.num_test, tgt_state_orig, min_count=val_min_count)
        else:
            test_samples = sampler.sample_original(args.num_test)

        test_samp_found = 0
        for samp in test_samples:
            for arc in samp["transitions"]:
                _, tgt, sym = arc
                if tgt == tgt_state_orig:
                    test_samp_found += 1

    # Construct the dataset objects
    datasets = {
        "train": construct_jsons(train_samples),
        "test": construct_jsons(test_samples),
        "validation": construct_jsons(val_samples),
        "train_folder": "/".join(args.training_output.split("/")[:-1]),
        "validation_folder": "/".join(args.validation_output.split("/")[:-1]),
        "test_folder": "/".join(args.test_output.split("/")[:-1]), 
        "train_out": args.training_output,
        "validation_out": args.validation_output,
        "test_out": args.test_output
    }
    os.makedirs(datasets["train_folder"], exist_ok=True)
    os.makedirs(datasets["validation_folder"], exist_ok=True)
    os.makedirs(datasets["test_folder"], exist_ok=True) 

    machine_out = datasets["train_folder"] + "/machine.pkl"

    if args.save_automaton:    
        with open(machine_out, "wb") as f:
            dill.dump(pdfa, f)
    
    # Save the final seed information
    final_seed_info_path = os.path.join(datasets["train_folder"], "seed_info.json")
    with open(final_seed_info_path, "w") as f:
        json.dump(seed_info, f)

    # Write the datasets to files
    if args.output_type == "json":
        train_file = args.training_output
        val_file = args.validation_output 
        test_file = args.test_output
        for data, file in zip([datasets["train"], datasets["validation"], datasets["test"]], 
                              [train_file, val_file, test_file]):
            with open(file, "w") as f:
                for d in data:
                    f.write(json.dumps(d, cls=TorchItemEncoder) + "\n")
    else:
        for dfile in ["arcs", "text", "log_probs"]:
            for dset in ["train", "test", "validation"]:
                data = datasets[dset]
                if dfile == "text":
                    out_file = datasets[f"{dset}_out"]
                else:
                    out_file = datasets[f"{dset}_folder"] + f"/{dfile}.txt"
                with open(out_file, "w") as f:
                    for d in data:
                        f.write(str(d[dfile]) + "\n")

    print(f"Data sampling and preparation complete. Outputs saved to:")
    print(f"  Training: {args.training_output}")
    print(f"  Validation: {args.validation_output}")
    print(f"  Test: {args.test_output}")


if __name__ == "__main__":
    main()