import json
import shutil
from itertools import islice
import time
from typing import Tuple, Union

import torch
while True:
    try:
        torch._C._cuda_init()
        break
    except Exception:
        print("CUDA init failed, retrying...")
        time.sleep(1)
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama import LlamaTokenizerFast

from baselines.ft import FTHyperParams, apply_ft_to_model
from baselines.memit import MEMITHyperParams, apply_memit_to_model
from baselines.rome import ROMEHyperParams, apply_rome_to_model
from dsets import (
    AttributeSnippets,
    MENDQADataset,
    UNKEDataset,
    AKEWDataset,
    MquakeDataset,
    get_tfidf_vectorizer,
)
from util import nethook
from util.globals import *

from PTE import PTEHyperParams, apply_pte_to_model
from experiments.py.eval_utils_unke import compute_unke
from experiments.py.eval_utils_akew import compute_akew
from experiments.py.eval_utils_mquake import compute_mquake

from baselines.memit_ARE import MEMITAREHyperParams, apply_memit_ARE_to_model
from baselines.lora import LORAHyperParams, apply_lora_to_model
from baselines.alphaEdit import AlphaEditHyperParams, apply_AlphaEdit_to_model
from baselines.unke import unkeHyperParams, apply_unke_to_model

ALG_DICT = {
    "MEMIT": (MEMITHyperParams, apply_memit_to_model),
    "ROME": (ROMEHyperParams, apply_rome_to_model),
    "FT": (FTHyperParams, apply_ft_to_model),
    "PTE": (PTEHyperParams, apply_pte_to_model),
    "MEMIT_ARE": (MEMITAREHyperParams, apply_memit_ARE_to_model),
    "LORA": (LORAHyperParams, apply_lora_to_model),
    "AlphaEdit": (AlphaEditHyperParams, apply_AlphaEdit_to_model),
    "UNKE": (unkeHyperParams, apply_unke_to_model)
}

DS_DICT = {
    "mquake": (MquakeDataset, compute_mquake),
    "unke": (UNKEDataset, compute_unke),
    "akew": (AKEWDataset, compute_akew)
}


def main(
    alg_name: str,
    model_name: Union[str, Tuple],
    hparams_fname: str,
    ds_name: str,
    dataset_size_limit: int,
    continue_from_run: str,
    skip_generation_tests: bool,
    generation_test_interval: int,
    conserve_memory: bool,
    dir_name: str,
    eval_type: str,
    num_edits: int = 1,
    use_cache: bool = False,
):
    # Set algorithm-specific variables
    params_class, apply_algo = ALG_DICT[alg_name]

    # Determine run directory
    # Create new dir if not continuing from prev run OR prev run doesn't exist
    if (
        continue_from_run is None
        or not (run_dir := RESULTS_DIR / dir_name / continue_from_run).exists()
    ):
        continue_from_run = None
    if continue_from_run is None:
        alg_dir = RESULTS_DIR / dir_name
        if alg_dir.exists():
            id_list = [
                int(str(x).split("_")[-1])
                for x in alg_dir.iterdir()
                if str(x).split("_")[-1].isnumeric()
            ]
            run_id = 0 if not id_list else max(id_list) + 1
        else:
            run_id = 0
        run_dir = RESULTS_DIR / dir_name / f"run_{str(run_id).zfill(3)}"
        run_dir.mkdir(parents=True, exist_ok=True)
    print(f"Results will be stored at {run_dir}")

    # Get run hyperparameters
    params_path = (
        run_dir / "params.json"
        if continue_from_run is not None
        else HPARAMS_DIR / alg_name / hparams_fname
    )
    hparams = params_class.from_json(params_path)
    if not (run_dir / "params.json").exists():
        shutil.copyfile(params_path, run_dir / "params.json")
    print(f"Executing {alg_name} with parameters {hparams}")

    # Instantiate vanilla model
    if type(model_name) is str:
        print("Instantiating model")
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto')

        if 'llama' in model.name_or_path.lower():
            tok = LlamaTokenizerFast.from_pretrained(model_name, add_bos_token=False)
            tok.padding_side = 'right'
        else:
            tok = AutoTokenizer.from_pretrained(model_name)
        tok.pad_token = tok.eos_token
    else:
        model, tok = model_name
        model_name = model.config._name_or_path
    
    # Load data
    print("Loading dataset, attribute snippets, tf-idf data")
    snips = AttributeSnippets(DATA_DIR) if not skip_generation_tests else None
    vec = get_tfidf_vectorizer(DATA_DIR) if not skip_generation_tests else None

    ds_class, ds_eval_method = DS_DICT[ds_name]
    ds = ds_class(DATA_DIR, tok=tok, size=dataset_size_limit, eval_type=eval_type)

    # Get cache templates
    cache_template = None
    if use_cache:
        cache_template = (
            KV_DIR
            / f"{model_name.replace('/', '_')}_{alg_name}"
            / f"{ds_name}_layer_{{}}_clamp_{{}}_case_{{}}.npz"
        )
        print(f"Will load cache from {cache_template}")

    # Iterate through dataset
    group_id = 0
    for record_chunks in chunks(ds, num_edits):
        case_result_template = str(run_dir / "{}_edits-case_{}.json")

        # Is the chunk already done?
        already_finished = True
        for record in record_chunks:
            if not Path(
                case_result_template.format(num_edits, record["case_id"])
            ).exists():
                already_finished = False
                break
        if already_finished:
            continue

        # Compute weight changes + record weights that changed
        args_conserve_memory = (
            dict(return_orig_weights_device=("cpu" if conserve_memory else "cuda"))
            if conserve_memory
            else dict()
        )
        etc_args = dict(cache_template=cache_template) if any(alg in alg_name for alg in ["ROME", "MEMIT", "PTE", "FT", "MEMIT_ARE", "AlphaEdit", "UNKE"]) else dict()

        start = time.time()
        edited_model, ori_weight_or_model = apply_algo(
            model,
            tok,
            [
                {"case_id": record["case_id"], **record}
                for record in record_chunks
            ],
            hparams,
            copy=False,
            return_orig_weights=True,
            **args_conserve_memory,
            **etc_args,
        )
        # edited_model, ori_weight_or_model = model, {}
        exec_time = time.time() - start
        print("Execution took", exec_time)

        # Evaluate new model
        start = time.time()
        gen_test_vars = [snips, vec]
        for record in record_chunks:
            out_file = Path(case_result_template.format(num_edits, record["case_id"]))
            if out_file.exists():
                print(f"Skipping {out_file}; already exists")
                continue

            eval_result = ds_eval_method(
                edited_model,
                tok,
                record,
                *(
                    gen_test_vars
                    if record["case_id"] % generation_test_interval == 0
                    else [None, None]
                ),  # Only test generation every generation_test_interval cases
            )

            if ds_name in ['cf', 'mquake']:
                metrics = {
                    "case_id": record["case_id"],
                    "num_edits": num_edits,
                    "requested_rewrite": record["requested_rewrite"],
                    "time": exec_time,
                    "post": eval_result
                }
            else:
                metrics = {
                    "case_id": record["case_id"],
                    "group_id": group_id,
                    "num_edits": num_edits,
                    "text": record["text"],
                    "generation_result": [qa + [-1, g] for qa, g in zip(record['eval_data'], eval_result['generation_result'])],
                    "neighborhood_result": eval_result['neighborhood_result'],
                    "mmlu_result": eval_result['mmlu_result'],
                }

            # eval_result = compute_probs(
            #     edited_model,
            #     tok,
            #     record,
            #     *(
            #         gen_test_vars
            #         if record["case_id"] % generation_test_interval == 0
            #         else [None, None]
            #     ),  # Only test generation every generation_test_interval cases
            # )
            # metrics = {
            #     "case_id": record["case_id"],
            #     "num_edits": num_edits,
            #     "requested_rewrite": record["requested_rewrite"],
            #     "time": exec_time,
            #     "pre_probs": eval_result["pre_probs"],
            #     "post_probs": eval_result["post_probs"]
            # }

            # Dump metrics in .json
            with open(out_file, "w") as f:
                json.dump(metrics, f, indent=1)

        # Restore original weights
        if isinstance(ori_weight_or_model, dict):
            if num_edits == 256:
                torch.save(
                    {k: nethook.get_parameter(model, k).cpu().detach().clone() for k in ori_weight_or_model.keys()},
                    run_dir / f'group_{group_id}_delta.pth'
                )
            with torch.no_grad():
                for k, v in ori_weight_or_model.items():
                    nethook.get_parameter(model, k)[...] = v.to("cuda")
        else:
            from peft.utils import get_peft_model_state_dict
            if num_edits == 256:
                torch.save(get_peft_model_state_dict(model), run_dir / f'group_{group_id}_delta.pth')
            # if num_edits == 256:
            #     from glue_eval.glue_eval import GLUEEval
            #     out_file = str(run_dir / f'group_{group_id}_delta' / f'{time.time_ns()}.json')
            #     (run_dir / f'group_{group_id}_delta').mkdir(exist_ok=True)
            #     model.config._name_or_path = 'meta-l3'
            #     glue_results = {'edit_num': -1}
            #     glue_eval = GLUEEval(model, tok, number_of_tests = 100)
            #     glue_results = glue_eval.evaluate(glue_results, out_file, nli_flag=True, sst_flag=True, cola_flag=True, rte_flag=True, mmlu_flag=True, mrpc_flag=True)
            #     output_filename = out_file.replace('.json', '_glue.json')
            #     with open(output_filename, "w") as f:
            #         json.dump(glue_results, f, indent=4)
            
            model = ori_weight_or_model
        group_id += 1

        print("Evaluation took", time.time() - start)

def window(seq, n=2):
    "Returns a sliding window (of width n) over data from the iterable"
    "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
    it = iter(seq)
    result = tuple(islice(it, n))
    if len(result) == n:
        yield result
    for elem in it:
        result = result[1:] + (elem,)
        yield result


def chunks(arr, n):
    """Yield successive n-sized chunks from arr."""
    for i in range(0, len(arr), n):
        yield arr[i : i + n]


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--alg_name",
        choices=["PTE", "MEMIT_ARE", "LORA", "FT", "MEMIT", "AlphaEdit", "UNKE", "ROME"],
        help="Editing algorithm to use. Results are saved in results/<alg_name>/<run_id>, "
        "where a new run_id is generated on each run. "
        "If continuing from previous run, specify the run_id in --continue_from_run.",
        required=True,
    )
    parser.add_argument(
        "--model_name",
        help="Model to edit.",
        required=True,
    )
    parser.add_argument(
        "--hparams_fname",
        type=str,
        default="gpt2-xl.json",
        help="Name of hyperparameters file, located in the hparams/<alg_name> folder.",
        required=True,
    )
    parser.add_argument(
        "--ds_name",
        choices=["unke", "akew", "mquake"],
        default="mcf",
        help="Dataset to perform evaluations on. Either CounterFact (cf), MultiCounterFact (mcf), or zsRE (zsre).",
    )
    parser.add_argument(
        "--continue_from_run",
        type=str,
        default=None,
        help="If continuing from previous run, set to run_id. Otherwise, leave as None.",
    )
    parser.add_argument(
        "--dataset_size_limit",
        type=int,
        default=None,
        help="Truncate CounterFact to first n records.",
    )
    parser.add_argument(
        "--skip_generation_tests",
        dest="skip_generation_tests",
        action="store_true",
        help="Only run fast probability-based tests without slow generation tests. "
        "Useful for quick debugging and hyperparameter sweeps.",
    )
    parser.add_argument(
        "--generation_test_interval",
        type=int,
        default=1,
        help="One generation test is performed every [flag_value] iterations. If -1, generation tests are skipped.",
    )
    parser.add_argument(
        "--conserve_memory",
        dest="conserve_memory",
        action="store_true",
        help="Reduce memory usage during evaluation at the cost of a minor slowdown. "
        "Backs up model weights on CPU instead of GPU.",
    )
    parser.add_argument(
        "--num_edits",
        type=int,
        default=1,
        help="Number of rewrites to perform simultaneously.",
    )
    parser.add_argument(
        "--use_cache",
        dest="use_cache",
        action="store_true",
        help="Use cached k/v pairs",
    )
    parser.add_argument(
        "--only_evaluate_mmlu",
        dest="only_evaluate_mmlu",
        action="store_true"
    )
    parser.add_argument(
        "--eval_type",
        choices=["QA", "completion", "QA_completion"],
        type=str,
        default=None,
    )
    parser.set_defaults(skip_generation_tests=False, conserve_memory=False)
    args = parser.parse_args()

    main(
        args.alg_name,
        args.model_name,
        args.hparams_fname,
        args.ds_name,
        args.dataset_size_limit,
        args.continue_from_run,
        args.skip_generation_tests,
        args.generation_test_interval,
        args.conserve_memory,
        dir_name=args.alg_name,
        num_edits=args.num_edits,
        use_cache=args.use_cache,
        eval_type=args.eval_type
    )
