import argparse
import os

import torch
from tqdm import tqdm

from code_demeanor.logger import logger
from code_demeanor.reading.scanner import (
    AttentionScanner,
    LayerWiseScanner,
    MicrosaccadesScanner,
    ScanningType,
)
from code_demeanor.utils import (
    get_dataset,
    get_device,
    load_list_jsonl,
    read_yaml_config,
    save_list_jsonl,
    save_tensor_jsonl,
    set_seed,
)
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

DEVICE = get_device()
set_seed(42)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Load and process a dataset.")
    parser.add_argument(
        "--samples_path",
        type=str,
        help="Path to secure samples JSONL file.",
        default="artifacts/backdoor/sleeper_agent.jsonl",
    )
    parser.add_argument(
        "--test_samples_path",
        type=str,
        help="Path to secure test samples JSONL file.",
        default=None,
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="Model Name",
        default="gpt2",
    )
    parser.add_argument(
        "--hidden_layers",
        type=int,
        nargs="+",
        help="Hidden layers to analyze",
    )
    parser.add_argument(
        "--heads",
        type=int,
        nargs="+",
        help="Attention heads to analyze",
    )
    parser.add_argument(
        "--seq_length",
        type=int,
        help="Sequence length",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        help="Batch size",
    )
    parser.add_argument(
        "--max_samples",
        type=int,
        help="Maximum number of samples to process.",
        default=None,
    )
    parser.add_argument(
        "--combine_causal_effects",
        type=bool,
        help="Whether to combine causal effects.",
        default=True,
    )
    parser.add_argument(
        "--results_dir",
        type=str,
        help="Directory to save the results.",
        default="results/experiment",
    )
    parser.add_argument(
        "--use_layer_intervention",
        type=bool,
        help="Whether to only intervene on the specified layers.",
        default=False,
    )
    parser.add_argument(
        "--use_attention_intervention",
        type=bool,
        help="Whether to only intervene on the specified attention heads.",
        default=False,
    )
    parser.add_argument(
        "--use_microsaccade_intervention",
        type=bool,
        help="Whether to only use microsaccade intervention.",
        default=False,
    )

    parser.add_argument(
        "--shadow_run_flops",
        type=bool,
        help="Whether to do a shadow run to estimate FLOPS.",
        default=False,
    )

    parser.add_argument(
        "--use_gaussian_noise",
        type=bool,
        help="Whether to add gaussian noise to the embeddings.",
        default=False,
    )
    parser.add_argument(
        "--use_random_noise",
        type=bool,
        help="Whether to add random noise to the embeddings.",
        default=False,
    )

    parser.add_argument(
        "--config_path",
        type=str,
        help="Path to the YAML configuration file.",
        default="examples/probe_all_scan_backdoor_sleeper_config.yaml",
    )

    args = parser.parse_args()
    config = read_yaml_config(args.config_path)

    samples_path = config.get("samples_path", args.samples_path)
    test_samples_path = config.get("test_samples_path", args.test_samples_path)
    model_name = config.get("model_name", args.model_name)
    hidden_layers = config.get("hidden_layers", args.hidden_layers)
    seq_length = config.get("seq_length", args.seq_length)
    heads = config.get("heads", args.heads)
    batch_size = config.get("batch_size", args.batch_size)
    max_samples = config.get("max_samples", args.max_samples)
    combine_causal_effects = config.get(
        "combine_causal_effects", args.combine_causal_effects
    )
    use_layer_intervention = config.get(
        "use_layer_intervention", args.use_layer_intervention
    )
    use_attention_intervention = config.get(
        "use_attention_intervention", args.use_attention_intervention
    )
    use_microsaccade_intervention = config.get(
        "use_microsaccade_intervention", args.use_microsaccade_intervention
    )
    shadow_run_flops = config.get("shadow_run_flops", args.shadow_run_flops)
    use_gaussian_noise = config.get("use_gaussian_noise", args.use_gaussian_noise)
    use_random_noise = config.get("use_random_noise", args.use_random_noise)

    results_dir = config.get("results_dir", args.results_dir)
    os.makedirs(results_dir, exist_ok=True)
    model_name_formatted = model_name.replace("/", "_")
    hidden_layers_formatted = "_".join(map(str, hidden_layers))
    heads_formatted = "_".join(map(str, heads))

    def get_default_file_name():
        path = f"scan_{model_name_formatted}_layers_{hidden_layers_formatted}_heads_{heads_formatted}_seqLen_{seq_length}_useLayerIntervention_{use_layer_intervention}_useAttentionIntervention_{use_attention_intervention}_useMicrosaccadeIntervention_{use_microsaccade_intervention}"
        if shadow_run_flops:
            path = "FLOPS" + path

        if use_gaussian_noise:
            path += "_gaussian"
        elif use_random_noise:
            path += "_random"

        return path

    train_data_path = os.path.join(
        results_dir,
        f"train_data_{get_default_file_name()}.jsonl",
    )
    train_labels_path = os.path.join(
        results_dir,
        f"train_labels_{get_default_file_name()}.jsonl",
    )
    test_data_path = os.path.join(
        results_dir,
        f"test_data_{get_default_file_name()}.jsonl",
    )
    test_labels_path = os.path.join(
        results_dir,
        f"test_labels_{get_default_file_name()}.jsonl",
    )
    suffix = str(max_samples) if test_samples_path is not None else "all"

    if (
        not os.path.exists(results_dir + "/train_inputs_" + suffix + ".jsonl")
        or not os.path.exists(results_dir + "/train_labels_" + suffix + ".jsonl")
        or not os.path.exists(results_dir + "/test_inputs_" + suffix + ".jsonl")
        or not os.path.exists(results_dir + "/test_labels_" + suffix + ".jsonl")
    ):
        train_inputs, train_labels, test_inputs, test_labels = get_dataset(
            samples_path, test_samples_path=test_samples_path, max_samples=max_samples
        )
        save_list_jsonl(
            train_inputs, results_dir + "/train_inputs_" + suffix + ".jsonl"
        )
        save_list_jsonl(
            train_labels, results_dir + "/train_labels_" + suffix + ".jsonl"
        )
        save_list_jsonl(test_inputs, results_dir + "/test_inputs_" + suffix + ".jsonl")
        save_list_jsonl(test_labels, results_dir + "/test_labels_" + suffix + ".jsonl")

    train_inputs = load_list_jsonl(results_dir + "/train_inputs_" + suffix + ".jsonl")
    train_labels = load_list_jsonl(results_dir + "/train_labels_" + suffix + ".jsonl")
    test_inputs = load_list_jsonl(results_dir + "/test_inputs_" + suffix + ".jsonl")
    test_labels = load_list_jsonl(results_dir + "/test_labels_" + suffix + ".jsonl")
    if (
        not os.path.exists(train_data_path)
        or not os.path.exists(train_labels_path)
        or not os.path.exists(test_data_path)
        or not os.path.exists(test_labels_path)
    ):
        logger.info("Starting the scanning process...")
        if DEVICE == "cuda":
            bnb_config = BitsAndBytesConfig(
                # Load the model with 4-bit quantization
                load_in_4bit=True,
                # Use double quantization
                bnb_4bit_use_double_quant=True,
                # Use 4-bit Normal Float for storing the base model weights in GPU memory
                bnb_4bit_quant_type="nf4",
                # De-quantize the weights to 16-bit (Brain) float before the forward/backward pass
                bnb_4bit_compute_dtype=torch.bfloat16,
            )
        else:
            bnb_config = None

        # Load the model and tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            quantization_config=bnb_config,
            output_hidden_states=True,
            output_attentions=True,
            attn_implementation="eager",  # <- crucial for output_attentions
        ).to(DEVICE)
        tokenizer.pad_token_id = tokenizer.eos_token_id

        token_scanner = AttentionScanner(
            model,
            tokenizer,
            hidden_layers=hidden_layers,
            scanning_type=ScanningType.TOKEN_WISE,
            heads=heads,
            device=DEVICE,
            batch_size=batch_size,
            combine_causal_effects=combine_causal_effects,
            shadow_run=shadow_run_flops,
        )

        layer_scanner = LayerWiseScanner(
            model,
            tokenizer,
            scanning_type=ScanningType.LAYER_WISE,
            device=DEVICE,
            batch_size=batch_size,
            combine_causal_effects=combine_causal_effects,
            shadow_run=shadow_run_flops,
        )

        microsaccade_scanner = MicrosaccadesScanner(
            model,
            tokenizer,
            scanning_type=ScanningType.MICROSACCADES,
            device=DEVICE,
            batch_size=batch_size,
            combine_causal_effects=combine_causal_effects,
            shadow_run=shadow_run_flops,
            random_positional_encoding=use_random_noise,
            gaussian_positional_encoding=use_gaussian_noise,
        )

        logger.info("lets go")

        from math import ceil

        num_batches = ceil(len(train_inputs) / batch_size)

        with tqdm(
            total=num_batches,
            desc="Processing batches",
            dynamic_ncols=True,  # adapt to terminal width
            mininterval=0.0,  # render immediately on first update
            leave=True,
        ) as pbar:
            for batch_idx in range(num_batches):
                s = batch_idx * batch_size
                e = s + batch_size

                batch_inputs = train_inputs[s:e]
                batch_labels = train_labels[s:e]

                combined = None  # ?

                if use_microsaccade_intervention:
                    _, causal_effects = microsaccade_scanner.scan(batch_inputs)

                    causal_effects = causal_effects.cpu()
                    combined = causal_effects

                if use_layer_intervention:
                    _, causal_effects = layer_scanner.scan(batch_inputs)
                    causal_effects = causal_effects.cpu()

                    if combined is not None:
                        combined = torch.cat([combined, causal_effects], dim=1)
                    else:
                        combined = causal_effects

                if use_attention_intervention:
                    stats, _ = token_scanner.scan(batch_inputs)
                    stats = stats.cpu()

                    if combined is not None:
                        combined = torch.cat([combined, stats], dim=1)
                    else:
                        combined = stats

                save_tensor_jsonl(combined, train_data_path, append=True)
                save_tensor_jsonl(
                    torch.tensor(batch_labels),
                    train_labels_path,
                    append=True,
                )

                pbar.update(1)  # ensure the bar updates every batch
                pbar.set_postfix({"Batch": batch_idx + 1, "Size": len(batch_inputs)})

        # Process the test set
        num_test_batches = ceil(len(test_inputs) / batch_size)
        with tqdm(
            total=num_test_batches,
            desc="Processing test batches",
            dynamic_ncols=True,  # adapt to terminal width
            mininterval=0.0,  # render immediately on first update
            leave=True,
        ) as pbar:
            for batch_idx in range(num_test_batches):
                s = batch_idx * batch_size
                e = s + batch_size

                batch_inputs = test_inputs[s:e]
                batch_labels = test_labels[s:e]

                combined = None
                if use_microsaccade_intervention:
                    _, causal_effects = microsaccade_scanner.scan(batch_inputs)

                    causal_effects = causal_effects.cpu()
                    combined = causal_effects
                if use_layer_intervention:
                    _, causal_effects = layer_scanner.scan(batch_inputs)
                    causal_effects = causal_effects.cpu()

                    if combined is not None:
                        combined = torch.cat([combined, causal_effects], dim=1)
                    else:
                        combined = causal_effects
                if use_attention_intervention:
                    stats, _ = token_scanner.scan(batch_inputs)
                    stats = stats.cpu()

                    if combined is not None:
                        combined = torch.cat([combined, stats], dim=1)
                    else:
                        combined = stats

                save_tensor_jsonl(combined, test_data_path, append=True)
                save_tensor_jsonl(
                    torch.tensor(batch_labels),
                    test_labels_path,
                    append=True,
                )

                pbar.update(1)  # ensure the bar updates every batch
                pbar.set_postfix({"Batch": batch_idx + 1, "Size": len(batch_inputs)})
