import os
import argparse
import transformers
import torch
from contextlib import ExitStack

from transformers.models.llama import modeling_llama
from transformers.models.qwen2 import modeling_qwen2
from transformers.models.gemma3 import modeling_gemma3
from lxt.efficient import monkey_patch

from .dataset_processors import get_dataset_processor
from .utils import build_results_filepath, write_config, get_experiment_dir, partition_dataset, set_seeds
from .patches import (
    IxG_map,
    get_llama_eager,
    get_llama_flash,
    get_qwen_eager,
    get_qwen_flash,
    get_gemma_eager,
    get_gemma_flash,
)
from .cad_baseline import CAD


def parse_arguments():
    parser = argparse.ArgumentParser(description="Generate greedy and relevance responses.")
    parser.add_argument(
        "--generation_methods",
        type=str,
        default=["relevance"],
        nargs="+",
        choices=["relevance", "greedy", "random", "context_aware"],
        help="Methods to run and save. At least one method must be specified.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="meta-llama/Llama-3.2-1B-Instruct",
        help="Name of the model.",
    )
    parser.add_argument(
        "--entropy",
        type=float,
        default=0.0,
        help="Entropy threshold to determine if attribution is used at given generation step.",
    )
    parser.add_argument(
        "--explanation_method",
        choices=["attnLRP", "IxG", "attnLRP_gradcam"],
        type=str,
        default="attnLRP",
        help="Explanation method to use.",
    )
    parser.add_argument(
        "--attn_implementation",
        choices=["flash", "eager"],
        type=str,
        default="flash",
    )
    ##### if using attention heads instead of context
    parser.add_argument("--task_type", type=str, choices=["context", "heads", "heads_context"], default="context")
    parser.add_argument(
        "--heads_type", type=str, choices=["parametric", "in_context", "retrieval", "task"], default="parametric"
    )
    parser.add_argument(
        "--heads_path",
        type=str,
        default="multi_token_instruction_following/heads",
        help="Path to the heads file.",
    )
    parser.add_argument("--top_heads_num", type=int, default=20, help="Number of top heads to use. The maximum is 100.")
    parser.add_argument("--clamp", action="store_true", help="Clamp relevance scores to positive only.")
    #####

    parser.add_argument(
        "--alpha",
        type=float,
        default=0,
        help="Weighting factor between model logits and relevance scores when selecting the next token. "
        "Formula: alpha * logits + (1 - alpha) * relevance. "
        "Set to 0 for fully relevance-guided selection, 1 for greedy decoding. Range: [0, 1).",
    )

    parser.add_argument("--p", type=float, default=0.05, help="p in min_p strategy and 1-p in top_p strategy.")
    parser.add_argument("--rel_diff_threshold", type=float, default=0.0)
    parser.add_argument(
        "--agg_method",
        type=str,
        choices=["sum", "max"],
        default="sum",
        help="Aggregation method for relevance to compare.",
    )
    parser.add_argument(
        "--relevance_role",
        type=str,
        choices=["system", "user", "full"],
        default="system",
        help="Role to focus for relevance.",
    )
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="input_constraint_ifeval_augmented_fixed.jsonl",
        help="Path to the dataset file.",
    )
    parser.add_argument(
        "--filtering_strategy",
        type=str,
        choices=["min_p", "top_p", "min_p_scaled"],
        default="min_p",
        help="Strategy for top token selection to evaluate.",
    )
    parser.add_argument("--n_logits", type=int, default=5, help="Number of top logits to evaluate.")
    parser.add_argument("--response_len", type=int, default=1024, help="Response length.")
    parser.add_argument(
        "--cut_off_len",
        type=int,
        help="Cut off length in tokens after which tokens are greedily selected for efficiency.",
    )
    parser.add_argument(
        "--repetition_penalty",
        type=float,
        help="The parameter for repetition penalty. 1.0 means no penalty.",
    )
    parser.add_argument(
        "--repetition_penalty_ignore_prompt",
        action="store_true",
        help="The original input ids sequence length, which if provided, will not be used in the penalty calculation.",
    )
    parser.add_argument(
        "--date_block",
        choices=["remove", "keep", "keep_no_relevance"],
        default="remove",
        help="How to handle date block inserted in the system prompt. Applicable only for Llama-3.",
    )
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.")
    parser.add_argument(
        "--results_dir", type=str, default="results/lxt_token_selection", help="Directory to save results."
    )
    parser.add_argument(
        "--overwrite", action="store_true", help="Overwrite the results directory if it already exists."
    )
    parser.add_argument(
        "--partition",
        type=int,
        help="Partition index for distributed processing. Expected values are integers starting from 0.",
    )
    parser.add_argument("--num_partitions", type=int, help="Total number of partitions to divide the dataset into.")
    return parser.parse_args()


def main():
    args = parse_arguments()
    set_seeds(args.seed)

    assert (
        0 <= args.alpha < 1
    ), f"Alpha must be in the range [0, 1), but is {args.alpha}. For greedy decoding (value 1), use --generation_methods greedy."
    assert 0 < args.p < 1, f"p must be in the range (0, 1), but is {args.p}."
    assert args.generation_methods, "At least one method must be specified in --generation_methods."

    partition = args.partition
    dataset_path = args.dataset_path

    if partition is None:
        partition = os.environ.get("SLURM_ARRAY_TASK_ID")
        if partition is not None:
            partition = int(partition)

    if partition is not None and args.num_partitions is None:
        raise ValueError("Need to specify num_partitions when specifying partition")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    if "Llama-3" in args.model_name:
        model_family = "llama3"
        modeling_module = modeling_llama
    elif "Qwen" in args.model_name:
        model_family = "qwen2"
        modeling_module = modeling_qwen2
    elif "gemma-3" in args.model_name:
        model_family = "gemma3"
        modeling_module = modeling_gemma3
    else:
        raise NotImplementedError(f"Model {args.model_name} is not implemented.")

    args.heads_path = os.path.join(args.heads_path, model_family, f"{args.heads_type}.txt")
    print(f"Using heads from: {args.heads_path}")

    if args.explanation_method == "attnLRP":
        if args.task_type == "heads":
            explanation_method_patch = {
                "llama3": (get_llama_eager() if args.attn_implementation == "eager" else get_llama_flash()),
                "qwen2": (get_qwen_eager() if args.attn_implementation == "eager" else get_qwen_flash()),
                "gemma3": (get_gemma_eager() if args.attn_implementation == "eager" else get_gemma_flash()),
            }.get(model_family)
            if explanation_method_patch is None:
                raise ValueError(f"attnLRP not yet supported for model family: {model_family}")
        else:
            explanation_method_patch = None
    elif args.explanation_method == "IxG":
        explanation_method_patch = IxG_map
    else:
        raise ValueError(f"Unsupported explanation method: {args.explanation_method}")

    monkey_patch(
        modeling_module,
        explanation_method_patch,
        verbose=True,
    )

    if model_family == "llama3":
        model_cls = modeling_module.LlamaForCausalLM
    elif model_family == "qwen2":
        model_cls = modeling_module.Qwen2ForCausalLM
    elif model_family == "gemma3":
        model_cls = modeling_module.Gemma3ForCausalLM
    else:
        raise ValueError(f"Unsupported model family: {model_family}")

    if "context_aware" not in args.generation_methods:
        # If we are not using cad
        model = model_cls.from_pretrained(
            args.model_name,
            torch_dtype=torch.bfloat16,
            device_map=device,
            attn_implementation=(
                "eager" if args.task_type == "heads" and args.attn_implementation == "eager" else "flash_attention_2"
            ),
        )
        tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name)

        # To avoid further errors
        cad = None
    elif len(args.generation_methods) == 1 and args.generation_methods[0] == "context_aware":
        # If we only use cad model
        cad = CAD(model_name=args.model_name, device=device)
        model = cad.model
        tokenizer = cad.tokenizer
    else:
        # In this case, both cad and for example lxt models are defined
        # cad = CAD(model_name=args.model_name, device=device)
        pass  # TODO:(Implement)

    # optional gradient checkpointing to save memory (2x forward pass)
    model.train()
    model.gradient_checkpointing_enable()

    for param in model.parameters():
        param.requires_grad = False

    processor = get_dataset_processor(dataset_path, args, tokenizer, model, model_family, device)
    dataset = processor.load_dataset(dataset_path, args.results_dir)
    dataset = partition_dataset(dataset, partition, args.num_partitions)

    args_dict = vars(args)
    print(f"Args: {args_dict}")
    experiment_dir = get_experiment_dir(args)
    write_config(args_dict, experiment_dir)

    file_handles = {
        method: (open(build_results_filepath(method, experiment_dir, partition, args.overwrite), "w"))
        for method in args.generation_methods
    }
    failed_instructions = 0
    with ExitStack() as stack:
        open_files = {method: stack.enter_context(file) for method, file in file_handles.items()}
        for example_idx, example in enumerate(dataset):
            print(f"Processing example {example_idx} in partition {partition}")

            chats_info = processor.prepare_chats(example)
            if chats_info is None:
                continue
            chats, extra_info = chats_info
            responses = []
            for chat in chats:

                response = processor.process_chat_request(chat, extra_info, cad)
                if response is None:
                    failed_instructions += 1
                    break
                responses.append(response)
            else:
                processor.write_results(open_files, example, responses)

    print(f"Failed instructions: {failed_instructions}")


if __name__ == "__main__":
    main()
