import json
import shutil
import pickle
from itertools import islice
from time import time
from typing import Tuple, Union
from sentence_transformers import SentenceTransformer, util

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer

from baselines.ft import FTHyperParams, apply_ft_to_model
from baselines.mend import MENDHyperParams, MendRewriteExecutor
from baselines.lora import LoRAHyperParams, apply_lora_to_model
from baselines.ike import IKEHyperParams, apply_ike_to_model
from dsets import (
    AttributeSnippets,
    CounterFactDataset,
    MENDQADataset,
    MultiCounterFactDataset,
    KGDataset,
    EVOKEMainDataset,
    get_tfidf_vectorizer,
)
from experiments.py.eval_utils_counterfact import compute_rewrite_quality_counterfact
from experiments.py.eval_utils_zsre import compute_rewrite_quality_zsre
from memit_lti import MEMITLTIHyperParams, apply_memit_lti_to_model
from rome_lti import ROMELTIHyperParams, apply_rome_lti_to_model
from glame import GLAMEHyperParams, apply_glame_to_model, GNN
from util import nethook
from util.globals import *

ALG_DICT = {
    "IKE": (IKEHyperParams, apply_ike_to_model),
    "LORA": (LoRAHyperParams, apply_lora_to_model),
    "MEMIT-LTI": (MEMITLTIHyperParams, apply_memit_lti_to_model),
    "ROME-LTI": (ROMELTIHyperParams, apply_rome_lti_to_model),
    "FT": (FTHyperParams, apply_ft_to_model),
    "MEND": (MENDHyperParams, MendRewriteExecutor().apply_to_model),
    "GLAME": (GLAMEHyperParams, apply_glame_to_model),
}

DS_DICT = {
    "mcf": (MultiCounterFactDataset, compute_rewrite_quality_counterfact),
    "cf": (CounterFactDataset, compute_rewrite_quality_counterfact),
    "zsre": (MENDQADataset, compute_rewrite_quality_zsre),
}


def main(
    alg_name: str,
    model_name: Union[str, Tuple],
    hparams_fname: str,
    ds_name: str,
    dataset_size_limit: int,
    continue_from_run: str,
    skip_generation_tests: bool,
    generation_test_interval: int,
    conserve_memory: bool,
    dir_name: str,
    num_edits: int = 1,
    use_cache: bool = False,
):
    # Set algorithm-specific variables
    params_class, apply_algo = ALG_DICT[alg_name]

    # Determine run directory
    # Create new dir if not continuing from prev run OR prev run doesn't exist
    if (
        continue_from_run is None
        or not (run_dir := RESULTS_DIR / dir_name / continue_from_run).exists()
    ):
        continue_from_run = None
    if continue_from_run is None:
        alg_dir = RESULTS_DIR / dir_name
        if alg_dir.exists():
            id_list = [
                int(str(x).split("_")[-1])
                for x in alg_dir.iterdir()
                if str(x).split("_")[-1].isnumeric()
            ]
            run_id = 0 if not id_list else max(id_list) + 1
        else:
            run_id = 0
        run_dir = RESULTS_DIR / dir_name / f"run_{str(run_id).zfill(3)}"
        run_dir.mkdir(parents=True, exist_ok=True)
    print(f"Results will be stored at {run_dir}")

    # Get run hyperparameters
    params_path = (
        run_dir / "params.json"
        if continue_from_run is not None
        else HPARAMS_DIR / alg_name / hparams_fname
    )

    hparams = params_class.from_json(params_path)
    if not (run_dir / "params.json").exists():
        shutil.copyfile(params_path, run_dir / "params.json")
    print(f"Executing {alg_name} with parameters {hparams}")

    # Instantiate vanilla model
    if type(model_name) is str:
        print("Instantiating model")
        if 'llama' in model_name.lower():
            model = LlamaForCausalLM.from_pretrained(model_name).cuda()
            tok = LlamaTokenizer.from_pretrained(model_name)
        # elif 'j' in model_name.lower():
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name, device_map="auto")
            tok = AutoTokenizer.from_pretrained(model_name)
        # else:
        #     model = AutoModelForCausalLM.from_pretrained(
        #         model_name).to("cuda:1" if alg_name == "LORA" else "cuda")
        #     tok = AutoTokenizer.from_pretrained(model_name)
        tok.pad_token = tok.eos_token
    else:
        model, tok = model_name
        model_name = model.config._name_or_path

    gnn_model = None

    # Load data
    print("Loading dataset, attribute snippets, tf-idf data")
    snips = AttributeSnippets(DATA_DIR) if not skip_generation_tests else None
    vec = get_tfidf_vectorizer(DATA_DIR) if not skip_generation_tests else None

    if num_edits > 1:
        assert ds_name != "cf", f"{ds_name} does not support multiple edits"

    ds_class, ds_eval_method = DS_DICT[ds_name]
    ds = ds_class(DATA_DIR, size=dataset_size_limit, tok=tok)
    graphds = KGDataset(DATA_DIR, size=dataset_size_limit,
                        tok=tok, ds_name=ds_name)

    if alg_name == "GLAME":
        n_embed = model.config.n_embd if hasattr(
            model.config, "n_embed") else model.config.hidden_size
        print("Initializing RGCN")
        gnn_model = GNN(
            num_rels=1000,
            num_nodes=1000,
            h_dim=int(n_embed*hparams.gnn_dim_factor),
            out_dim=n_embed,
            num_bases=100,
            num_basis=100,
            num_hidden_layers=2,
            dropout=hparams.gnn_feat_drop,
            self_loop=True,
            skip_connect=False,
            encoder_name='uvrgcn',
            opn='sub',
            use_cuda=True,
            analysis=False)

    if alg_name == "IKE":
        sentence_model = SentenceTransformer(
            hparams.sentence_model_name).to("cuda")

        safe_model_name = hparams.sentence_model_name.rsplit('/', 1)[-1]
        with open(f'{EMBEDDING_DIR}/{safe_model_name}_{type(ds).__name__}.pkl', "rb") as fIn:
            stored_data = pickle.load(fIn)

    # Get cache templates
    cache_template = None
    if use_cache:
        tmp_ds_name = ds_name
        if ds_name == "cf-one-hop":
            tmp_ds_name = "cf"
        cache_template = (
            KV_DIR
            / f"{model_name.replace('/', '_')}_{alg_name}"
            / f"{tmp_ds_name}_layer_{{}}_clamp_{{}}_case_{{}}.npz"
        )
        print(f"Will load cache from {cache_template}")

    # Iterate through dataset
    # for (record_chunks, graph_chunks) in chunks(zip(ds, graphds), num_edits):

    zipped_lists = list(zip(ds, graphds))

    # chunk contains multiple (record, graph) pairs now
    for chunk in chunks(zipped_lists, num_edits):
        record_chunks = [item[0] for item in chunk]
        graph_chunks = [item[1] for item in chunk]

        case_result_template = str(run_dir / "{}_edits-case_{}.json")

        # Is the chunk already done?
        already_finished = True
        for record in record_chunks:
            if not Path(
                case_result_template.format(num_edits, record["case_id"])
            ).exists():
                already_finished = False
                break
        if already_finished:
            continue

        # Compute weight changes + record weights that changed
        case_ids = [record["case_id"] for record in record_chunks]

        # args for specific methods
        args_conserve_memory = (
            dict(return_orig_weights_device=(
                "cpu" if conserve_memory else "cuda"))
            if conserve_memory
            else dict()
        )
        etc_args = dict(cache_template=cache_template) if any(
            alg in alg_name for alg in ["ROME", "MEMIT", "GLAME"]) else dict()
        glame_args = dict(
            graph_input=[
                graph_case['triples']
                for graph_case in graph_chunks
            ],
            gnn_model=gnn_model
        ) if any(
            alg in alg_name for alg in ["GLAME"]) else dict()
        icl_args = dict(
            dataset_name=type(ds).__name__
        ) if any(alg in alg_name for alg in ["LORA", "FT"]) else dict()  # For In-Context Learning Methods
        ike_args = dict(
            sentence_model=sentence_model,
            stored_data=stored_data,
        ) if any(alg in alg_name for alg in ["IKE"]) else dict()

        start = time()
        edited_model, weights_copy = apply_algo(
            model,
            tok,
            [
                {"case_id": record["case_id"], **record["requested_rewrite"]}
                for record in record_chunks
            ],
            hparams,
            copy=False,
            return_orig_weights=True,
            **glame_args,
            **icl_args,
            **ike_args,
            **args_conserve_memory,
            **etc_args,
        )
        exec_time = time() - start
        print("Execution took", exec_time)

        # Evaluate new model
        start = time()
        gen_test_vars = [snips, vec]
        for record in record_chunks:
            out_file = Path(case_result_template.format(
                num_edits, record["case_id"]))
            if out_file.exists():
                print(f"Skipping {out_file}; already exists")
                continue

            if "IKE" in alg_name:
                context = edited_model
                record = process_IKE_record(record=record, context=context)


            metrics = {
                "case_id": record["case_id"],
                "grouped_case_ids": case_ids,
                "num_edits": num_edits,
                "requested_rewrite": record["requested_rewrite"],
                "time": exec_time,
                "post": ds_eval_method(
                    edited_model,
                    tok,
                    record,
                    *(
                        gen_test_vars
                        if record["case_id"] % generation_test_interval == 0
                        else [None, None]
                    ),
                ),
            }

            # Dump metrics in .json
            with open(out_file, "w") as f:
                json.dump(metrics, f, indent=1)

        # Restore original weights
        with torch.no_grad():
            for k, v in weights_copy.items():
                nethook.get_parameter(model, k)[...] = v.to("cuda")

        print("Evaluation took", time() - start)


def process_IKE_record(record: dict, context: str) -> dict:
    record['requested_rewrite']['prompt'] = context + \
        record['requested_rewrite']['prompt']

    for case in ['paraphrase_prompts', 'neighborhood_prompts']:
        for prompt in record[case]:
            prompt = context+prompt

    print("IKE record processed")

    return record


def window(seq, n=2):
    "Returns a sliding window (of width n) over data from the iterable"
    "   s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...                   "
    it = iter(seq)
    result = tuple(islice(it, n))
    if len(result) == n:
        yield result
    for elem in it:
        result = result[1:] + (elem,)
        yield result


def chunks(arr, n):
    """Yield successive n-sized chunks from arr."""
    for i in range(0, len(arr), n):
        yield arr[i: i + n]


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--alg_name",
        choices=["MEMIT", "ROME", "FT", "MEND", "GLAME", "LORA", "IKE"],
        default="GLAME",
        help="Editing algorithm to use. Results are saved in results/<alg_name>/<run_id>, "
        "where a new run_id is generated on each run. "
        "If continuing from previous run, specify the run_id in --continue_from_run.",
        required=True,
    )
    parser.add_argument(
        "--model_name",
        default="gpt2-xl",
        help="Model to edit.",
        required=True,
    )
    parser.add_argument(
        "--hparams_fname",
        type=str,
        default="gpt2-xl.json",
        help="Name of hyperparameters file, located in the hparams/<alg_name> folder.",
        required=True,
    )
    parser.add_argument(
        "--ds_name",
        choices=["mcf", "cf", "zsre"],
        default="mcf",
        help="Dataset to perform evaluations on. Either CounterFact (cf), MultiCounterFact (mcf), or zsRE (zsre).",
    )
    parser.add_argument(
        "--continue_from_run",
        type=str,
        default=None,
        help="If continuing from previous run, set to run_id. Otherwise, leave as None.",
    )
    parser.add_argument(
        "--dataset_size_limit",
        type=int,
        default=None,
        help="Truncate CounterFact to first n records.",
    )
    parser.add_argument(
        "--skip_generation_tests",
        dest="skip_generation_tests",
        action="store_true",
        help="Only run fast probability-based tests without slow generation tests. "
        "Useful for quick debugging and hyperparameter sweeps.",
    )
    parser.add_argument(
        "--generation_test_interval",
        type=int,
        default=1,
        help="One generation test is performed every [flag_value] iterations. If -1, generation tests are skipped.",
    )
    parser.add_argument(
        "--conserve_memory",
        dest="conserve_memory",
        action="store_true",
        help="Reduce memory usage during evaluation at the cost of a minor slowdown. "
        "Backs up model weights on CPU instead of GPU.",
    )
    parser.add_argument(
        "--num_edits",
        type=int,
        default=1,
        help="Number of rewrites to perform simultaneously.",
    )
    parser.add_argument(
        "--use_cache",
        dest="use_cache",
        action="store_true",
        help="Use cached k/v pairs",
    )
    parser.set_defaults(skip_generation_tests=True, conserve_memory=False)
    args = parser.parse_args()

    main(
        args.alg_name,
        args.model_name,
        args.hparams_fname,
        args.ds_name,
        args.dataset_size_limit,
        args.continue_from_run,
        args.skip_generation_tests,
        args.generation_test_interval,
        args.conserve_memory,
        dir_name=args.alg_name,
        num_edits=args.num_edits,
        use_cache=args.use_cache,
    )
