import argparse
import gc
import os
import time
import json
from pathlib import Path

import torch
from scp_dr.cp_model import CPModel
from examples.config import DATA_DIR_CACHE
from examples.data import DatasetLoader, create_freq2can, get_canaries
from examples.metrics import compute_statistics, plot_and_analyze_data
from examples.utils import keep_question_answer, print_gpu_utilization, split_text
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from torch.utils.data import DataLoader
from transformers import GenerationConfig, BatchEncoding
from tqdm import tqdm
from datasets import DatasetDict
import numpy as np
import scipy


def parse_arguments():
    """
    Parses command-line arguments for model evaluation configuration.
    """
    parser = argparse.ArgumentParser(description="Evaluation script for scp_dr models")
    parser.add_argument("--attack_type", type=str, help="attack type (simple / complex / canary)")
    parser.add_argument("--model_checkpoint1", type=str, help="Path to the first model checkpoint")
    parser.add_argument("--model_checkpoint2", type=str, help="Path to the second model checkpoint")
    parser.add_argument("--model_ref_checkpoint1", type=str, help="Path to the reference model checkpoint")
    parser.add_argument("--base_model_checkpoint", type=str, help="Path to the base model checkpoint for base smoothing")
    parser.add_argument("--base_model_const_logits", type=str, help="Path to the base model const logits for base smoothing")
    parser.add_argument("--dataset_name", type=str, required=True, default="MathAbstracts", help="Name of the dataset")
    parser.add_argument("--n_test_samples", type=int, default=500, help="Number of test samples")
    parser.add_argument("--sample_start_index", type=int, default=0, help="sample start index")
    parser.add_argument("--output_dir", type=str, default="./eval", help="Directory to save evaluation results")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size for evaluation")
    parser.add_argument("--grid_size", type=int, default=10, help="Grid size for grid search in CPModel")
    parser.add_argument("--fixed_coef", type=float, default=None, help="ficed coef for the first model")
    parser.add_argument("--k", type=int, default=10, help="Base smoothing parameter k")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output during evaluation")
    parser.add_argument("--use_relative_probs", action="store_true", help="use relative probabilities when aggregating")
    parser.add_argument("--use_minimum", action="store_true", help="use minimum as the aggregation function instead of weighted average")
    parser.add_argument("--remove_sep", action="store_true", help="remove [SEP] token from the input")
    return parser.parse_args()


def init_tokenizer(model_checkpoint):
    tokenizer = AutoTokenizer.from_pretrained(
        model_checkpoint,
        padding_side="left",
        trust_remote_code=True
    )
    # Do NOT add special tokens if the checkpoint was trained without them
    # tokenizer.add_special_tokens({"sep_token": "[SEP]", "pad_token": "[PAD]"})
    print("Tokenizer vocab size:", len(tokenizer))
    return tokenizer

def load_model(path, tokenizer):
    path = Path(path)
    if (path / "adapter_config.json").exists():
        print(f"Loading PEFT model from {path}")

        with open(path / "adapter_config.json") as f:
            cfg = json.load(f)

        base_model_name = cfg["base_model_name_or_path"]
        print("base_model:", base_model_name)
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_name, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).half()
        base_model.resize_token_embeddings(len(tokenizer))
        model = PeftModel.from_pretrained(base_model, path)
    else:
        print(f"Loading model from {path}")
        model = AutoModelForCausalLM.from_pretrained(
            path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True).half()
    model.resize_token_embeddings(len(tokenizer))
    return model

def load_models(args, tokenizer):
    """
    Loads the primary model or CPModel based on provided arguments.
    """
    if args.model_checkpoint2:
        model1 = load_model(args.model_checkpoint1, tokenizer)
        model2 = load_model(args.model_checkpoint2, tokenizer)
        base_model = None
        uniform_smooth = False
        base_model_const_logits = None
        if args.base_model_checkpoint:
            if args.base_model_checkpoint == "default":
                uniform_smooth = True
            else:
                base_model = load_model(args.base_model_checkpoint, tokenizer)
        elif args.base_model_const_logits:
            print(f"Loading base model const logits from {args.base_model_const_logits}")
            file = os.listdir(args.base_model_const_logits)[0]
            assert file.endswith("const_logit_vec.pt"), f"Expected const_logit_vec.pt in {args.base_model_const_logits}, but found {file}"
            base_model_const_logits = torch.load(os.path.join(args.base_model_const_logits, file))
            

        model_name = f"{args.dataset_name}_cp_model"
        return (
            CPModel(
                model1=model1,
                model2=model2,
                base_model=base_model,
                base_model_const_logits=base_model_const_logits,
                smooth_k=args.k,
                grid_size=args.grid_size,
                verbose=args.verbose,
                fixed_coef=args.fixed_coef,
                use_relative_probs=args.use_relative_probs,
                use_minimum=args.use_minimum,
                uniform_smooth=uniform_smooth,
            ),
            model_name,
        )
    else:
        model = load_model(args.model_checkpoint1, tokenizer)
        model_name = f"{args.dataset_name}_single_model"
        return model, model_name

def remove_sep(dataset):
    """
    Removes [SEP] tokens from the dataset.
    """
    def filter_sep(example):
        return example.replace("[SEP]", " ")
    
    dataset = dataset.map(lambda x: {"text": filter_sep(x["text"])})
    return dataset

def get_model_generations(model, q_and_a, tokenizer, batch_size, max_new_tokens=256, 
                          force_labels=False, is_ids=False):
    if is_ids:
        dataloader = q_and_a
    else:
        dataloader = DataLoader(
            q_and_a,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            collate_fn=keep_question_answer,
        )
    generation_config = GenerationConfig(
        max_new_tokens=max_new_tokens,
        num_return_sequences=1,
        eos_token_id=model.config.eos_token_id,
        pad_token_id=model.config.pad_token_id,
        do_sample=False,
        num_beams=1,
        temperature=0.7,
    )
    # print(generation_config.to_dict())
    res = []
    log_probs = []
    with torch.no_grad():
        counter = 1
        for batch in dataloader:

            if is_ids:
                batch_texts = batch_tensors = batch["question"].to("cuda")
            else:
                batch_texts = batch["question"]
                batch_tensors = tokenizer(batch_texts, return_tensors="pt", padding=True).to("cuda")
            if force_labels:
                labels_texts = batch["answer"]
                labels = tokenizer(labels_texts, return_tensors="pt", padding=True, add_special_tokens=False)["input_ids"].to("cuda")
                eos_column = torch.full((batch_size, 1), tokenizer.eos_token_id, device=labels.device, dtype=labels.dtype)
                labels = torch.cat([labels, eos_column], dim=-1)
                # tokens = labels[0].tolist()
                # print(tokenizer.decode(tokens))
                labels = list(torch.unbind(labels, dim=-1))
                outputs = model.generate(
                    **batch_tensors,
                    generation_config=generation_config,
                    return_dict_in_generate=True,
                    output_logits=True,
                    force_labels=labels
                )
            else:
                outputs = model.generate(
                    **batch_tensors,
                    generation_config=generation_config,
                    return_dict_in_generate=True,
                    output_logits=True,
                    # parallelize=False
                )
            # generations = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
            
            if force_labels:
                log_probs.append(outputs.scores.cpu())
            else:
                input_ids = batch_tensors["input_ids"]
                gen_ids = outputs.sequences
                gen_ids_start = gen_ids[:, :input_ids.shape[1]]
                assert torch.all(gen_ids_start == input_ids)
                if is_ids:
                    for i in range(len(input_ids)):
                        answer_ids = gen_ids[i, input_ids.shape[1]:]
                        res.append({
                            "question": input_ids[i],
                            "answer": answer_ids,
                            "full": gen_ids[i]
                        })
                else:
                    generations = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
                    for gen, input_text in zip(generations, batch_texts):
                        input_text_no_sep = input_text.replace("[SEP]", " ")
                        assert(gen.startswith(input_text_no_sep))
                        res.append({"question":input_text, "answer":gen[len(input_text_no_sep):]})


            del outputs, batch_texts, batch_tensors
            torch.cuda.empty_cache()
            gc.collect()

            counter = counter + 1
    if force_labels:
        return log_probs
    else:
        return res

def split_batchencoding_at_index(batch: BatchEncoding, k: int) -> tuple[BatchEncoding, BatchEncoding]:
    """
    Split a BatchEncoding with tensors shaped (B, L) into two BatchEncodings at token index k:
      left  = [:k]
      right = [k:]
    """
    left_dict = {}
    right_dict = {}

    for key, val in batch.items():
        # Only split tensor fields shaped (B, L). Pass through other metadata untouched.
        if torch.is_tensor(val) and val.dim() >= 2:
            left_dict[key] = val[:, :k]
            right_dict[key] = val[:, k:]
        else:
            # optional: copy over non-tensor fields
            left_dict[key] = val
            right_dict[key] = val

    return BatchEncoding(left_dict), BatchEncoding(right_dict)

def split_batch_tensors(batch_tensors, index_after_sep=0, split_index=None, tokenizer=None):
    assert len(batch_tensors["input_ids"]) == 1
    token_ids = batch_tensors["input_ids"][0]
    if split_index is None:
        # Splitting the text on '[SEP]'
        sep_id = tokenizer.convert_tokens_to_ids("[SEP]")
        sep_indices = (token_ids == sep_id).nonzero(as_tuple=True)[0]
        assert(len(sep_indices) == 1)
        sep_index = sep_indices[0].item()
        split_index = sep_index + 1 + index_after_sep
    split_parts = split_batchencoding_at_index(batch_tensors, split_index)
    res = {
        "question": split_parts[0],
        "answer": split_parts[1],
    }
    return res

def split_text(example, index_after_sep=0, split_index=None, tokenizer=None):
    if split_index is None:
        # Splitting the text on '[SEP]'
        sep = "[SEP]"
        assert(example["text"].count(sep) == 1)
        if tokenizer is None:
            split_index = example["text"].index(sep) + len(sep) + index_after_sep
        else:
            split_index = example["text"].index(sep) + len(sep)
            tokenized_text = tokenizer.tokenize(example["text"])
            sep_ind = tokenized_text.index(sep)
            for i in range(sep_ind+1, min(sep_ind+index_after_sep+1, len(tokenized_text))):
                split_index += len(tokenized_text[i])
            token_split_index = tokenized_text.index(sep) + 1 + index_after_sep
    split_parts = [example["text"][:split_index], example["text"][split_index:]]

    res = {
        "question": split_parts[0],
        "answer": split_parts[1].strip(),
    }

    if tokenizer is None:
        return res
    else:
        return res, token_split_index

def accuracy_attack(model, data, tokenizer, batch_size):
    q_and_a = [split_text(example) for example in data]
    model.accuracy = []
    _ = get_model_generations(model, q_and_a, tokenizer, batch_size, max_new_tokens=256, force_labels=True)
    res = {"accuracy" : np.mean(model.accuracy)}
    print(res)
    return res

def canary_attack(model, data_dict, tokenizer, batch_size):
    random_canaries = get_canaries(100, can_len=3, seed=27412586)
    q_and_a = [split_text(canary, split_index=0) for canary in random_canaries]
    log_probs_ref = get_model_generations(model, q_and_a, tokenizer, batch_size, max_new_tokens=256, force_labels=True)
    # print(log_probs_ref)
    ref_std = np.std(log_probs_ref)
    ref_mean = np.mean(log_probs_ref)
    print("ref mean and std:", ref_mean, ref_std)
    res = {}
    for key, data in data_dict.items():
        q_and_a = [split_text(canary, split_index=0) for canary in data]
        log_probs = get_model_generations(model, q_and_a, tokenizer, batch_size, max_new_tokens=256, force_labels=True)
        exposures = []
        for lp in log_probs:
            error = (lp - ref_mean) / ref_std
            exposure = -np.log2(scipy.stats.norm(0, 1).cdf(-error))
            exposures.append(exposure)
        print(f"canary exposures for key {key}:")
        res[key] = {"mean" : np.mean(exposures),
                    "median" : np.median(exposures),
                    "95%" : np.percentile(exposures,95),
                    "max" : np.max(exposures)}
        print(res[key])
    return res

def complex_attack(model, data, tokenizer, batch_size, max_length=1000):
    success_rates = []
    fail_inds = set()
    # tokenized_data = []
    # for example in data:
    #     tokenized_data.append(tokenizer.tokenize(example))
    for i in range(max_length): # id length
        q_and_a = []
        for e_index, example in enumerate(data):
            if e_index in fail_inds:
                continue
            # example = {"text": canonicalize_prompt(example['text'], tokenizer)}
            # example_splited, token_split_index = split_text(example, index_after_sep = i, tokenizer=tokenizer)
            batch_tensors = tokenizer([example['text']], return_tensors="pt", padding=True).to("cuda")
            example_splited = split_batch_tensors(batch_tensors, index_after_sep = i, tokenizer=tokenizer)
            if len(example_splited['answer']["input_ids"]) == 0 or \
                len(example_splited['answer']["input_ids"][0]) == 0:
                continue
            if len(q_and_a) == 0:
                # print(i, example_splited)
                ground_truth = example_splited['answer']["input_ids"][0][0].item()
                print("next token id:", ground_truth)
            q_and_a.append((e_index, example, example_splited))
        success_count = 0

        if len(q_and_a) == 0:
            break
        q_and_a_for_model = [qa[-1] for qa in q_and_a]
        model_q_and_a = get_model_generations(model, q_and_a_for_model, tokenizer, batch_size, max_new_tokens=1, is_ids=True)
        for j, (e_index, example, example_splited) in enumerate(q_and_a):
            # token_index = len(tokenizer.tokenize(example['question']))
            # print("example text:", example['text'])
            # print("model generation:", model_q_and_a[j]['full'])
            if len(model_q_and_a[j]['answer']) == 0 or \
                len(example_splited['answer']["input_ids"]) == 0 or \
                len(example_splited['answer']["input_ids"][0]) == 0:
                fail_inds.add(e_index)
                continue
            ground_truth = example_splited['answer']["input_ids"][0][0].item()
            extracted_token = model_q_and_a[j]['answer'][0].item()
            # print(f"ground_truth: {ground_truth}. extracted_token: {extracted_token}")
            if ground_truth == extracted_token:
                success_count += 1
            else:
                fail_inds.add(e_index)
        success_rate = success_count / len(data)
        success_rates.append(success_rate)
        print(f"token index {i}. success rate: {success_rate}")
    max_extract_length = max([i+1 for i in range(len(success_rates)) if success_rates[i] > 0])
    avg_extract_length = 0
    prev_rate = 0
    for i in range(len(success_rates)-1,-1,-1):
        rate = success_rates[i]
        avg_extract_length += (i+1) * (rate - prev_rate)
        prev_rate = rate
    single_token_success_rates = [success_rates[0]] + [success_rates[i] / success_rates[i-1] for i in range(1,len(success_rates))]
    print("max extract length", max_extract_length)
    print("average extract length:", avg_extract_length)
    return {"max_extract_length" : max_extract_length, "avg_extract_length" : avg_extract_length, 
            "first_token_success_rate" : success_rates[0], "success_rates" : success_rates,
            "single_token_success_rates" : single_token_success_rates}

def simple_attack(model, data, tokenizer, batch_size, remove_sep=False):
    initial_skip = 4 # skip "ID: "
    success_rates = []
    fail_inds = set()
    for i in range(10): # id length
        q_and_a = []
        for e_index, example in enumerate(data):
            if e_index in fail_inds:
                continue
            example = split_text(example, index_after_sep = initial_skip + i)
            if remove_sep:
                example['question'] = example['question'].replace("[SEP]", " ")
            if len(q_and_a) == 0:
                print(example)
            q_and_a.append((e_index, example))
        success_count = 0

        if len(q_and_a) > 0:
            q_and_a_for_model = [qa[1] for qa in q_and_a]
            model_q_and_a = get_model_generations(model, q_and_a_for_model, tokenizer, batch_size, max_new_tokens=1)
            for j, (e_index, example) in enumerate(q_and_a):
                ground_truth = example['answer'][0] # curr digit
                if len(model_q_and_a[j]['answer']) == 0:
                    extracted_digit = ""
                else:
                    extracted_digit = model_q_and_a[j]['answer'][0]
                # print(f"ground_truth: {ground_truth}. extracted_digit: {extracted_digit}")
                if ground_truth == extracted_digit:
                    success_count += 1
                else:
                    fail_inds.add(e_index)
        success_rate = success_count / len(data)
        success_rates.append(success_rate)
        print(f"digit index {i}. success rate: {success_rate}")
    full_success_rate = success_rates[9]
    avg_extract_length = 0
    prev_rate = 0
    for i in range(9,-1,-1):
        rate = success_rates[i]
        avg_extract_length += (i+1) * (rate - prev_rate)
        prev_rate = rate
    print("full id success rate:", full_success_rate)
    print("average extract length:", avg_extract_length)
    return {"full_success_rate" : full_success_rate, "avg_extract_length" : avg_extract_length, 
            "first_digit_success_rate" : success_rates[0]}



def evaluate_attack(attack_type, train_dataset, validation, model, tokenizer, 
                    eval_dir, batch_size, model_checkpoint1, remove_sep=False):
    """
    Evaluates multiple datasets and saves results to CSV files.
    """
    datasets = {"train": train_dataset, "validation": validation}

    for name, data in datasets.items():
        file_name = os.path.join(eval_dir, f"{name}.json")
        # if not os.path.isfile(file_name):
        print(f"attacking {name} set with the {attack_type} attack...")
        start_time = time.time()
        if attack_type == "simple":
            attack_res = simple_attack(model=model, data=data, tokenizer=tokenizer, batch_size=batch_size, remove_sep=remove_sep)
        elif attack_type == "complex":
            attack_res = complex_attack(model=model, data=data, tokenizer=tokenizer, batch_size=batch_size)
        elif attack_type == "canary":
            if name == "train":
                canaries = DatasetDict.load_from_disk(os.path.join(model_checkpoint1, "canaries_datasets"))
            else:
                canaries, _ = create_freq2can(100, can_len=3, can_freqs=[0], seed=239832673)
            attack_res = canary_attack(model=model, data_dict=canaries, tokenizer=tokenizer, batch_size=batch_size)
        elif attack_type == "accuracy":
            attack_res = accuracy_attack(model=model, data=data, tokenizer=tokenizer, batch_size=batch_size)
        else:
            print(f"attack type {attack_type} not supported. exiting")
            return
        with open(file_name, "w") as f:
            json.dump(attack_res, f, indent=4, sort_keys=True)
        print(f"{attack_type} attack of {name} completed in {time.time() - start_time:.2f} seconds.")
        del attack_res
        torch.cuda.empty_cache()
        gc.collect()

    print("Evaluation completed. Results saved in:", eval_dir)


def main():
    args = parse_arguments()
    print("Evaluation configuration:", args)
    
    # Set up evaluation directory
    avg_type = "minimum" if args.use_minimum else "geometric"
    prob_type = "relative" if args.use_relative_probs else "raw"
    coef_type = "adaptive" if args.fixed_coef is None else "fixed_" + str(args.fixed_coef).replace(".","") 
    eval_dir = os.path.join(args.output_dir, f"{args.attack_type}_attack_res_{args.dataset_name}_{prob_type}_{avg_type}_{coef_type}")
    os.makedirs(eval_dir, exist_ok=True)

    # Initialize tokenizer
    tokenizer = init_tokenizer(args.model_checkpoint1)

    # Load datasets
    dataloader = DatasetLoader()
    _, train_dataset, validation_dataset = dataloader.load_or_create_datasets(
        dataset_name=args.dataset_name,
        ntrain=args.n_test_samples,
        start_index=args.sample_start_index,
    )
    
    print("first training example:", train_dataset[0])
    print("first validation example:", validation_dataset[0])

    # Load model(s)
    model, _ = load_models(args, tokenizer)

    # Run evaluation
    evaluate_attack(
        args.attack_type,
        train_dataset=train_dataset,
        validation=validation_dataset,
        model=model,
        tokenizer=tokenizer,
        eval_dir=eval_dir,
        batch_size=args.batch_size,
        model_checkpoint1=args.model_checkpoint1,
        remove_sep=args.remove_sep,
    )


if __name__ == "__main__":
    main()
