import os
import re
import gc
import time
import json
import pickle
import torch
import transformers
# import lm_eval
import numpy as np
import pandas as pd
from tqdm import tqdm
from enum import Enum
from datasets import load_dataset
from typing import Any, Union, Dict
from transformers import OPTPreTrainedModel, LlamaPreTrainedModel, MistralPreTrainedModel
from optim import Optimizer
import copy
import random
import matplotlib.pyplot as plt
from gsm8k_helper import *


def get_wikitext2(nsamples, seed, seqlen, tokenizer, data_name, dataset_cache_dir=None):
    import random
    random.seed(seed)
    
    if data_name == 'wiki':
        traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    else:
        data_files = {"train": "en/c4-validation.*.json.gz"}
        traindata = load_dataset("allenai/c4", data_files=data_files, split="train")
    tot_text = "\n\n".join(traindata["text"])
    
    traindataset = []
    for s in range(nsamples + 1):
        i = random.randint(0, len(tot_text) - seqlen - 1)
        j = i + seqlen * 10
        trainenc = tokenizer(tot_text[i:j], return_tensors="pt")
        if trainenc.input_ids.shape[1] < seqlen:
            s = s - 1
            continue
        if s != 0:
            traindataset.append(original_text)
        inp = trainenc.input_ids[:, :seqlen]
        original_text = tokenizer.decode(inp[0].tolist(), skip_special_tokens=True)
    
    return traindataset


def infer_device() -> torch.device:
    if not torch.cuda.is_available():
        return torch.device("cpu")
    max_free_memory = -1
    best_device_index = -1
    for i in range(torch.cuda.device_count()):
        current_device = torch.device(f"cuda:{i}")
        torch.cuda.set_device(current_device)
        free_memory = torch.cuda.get_device_properties(i).total_memory - torch.cuda.memory_allocated()
        if free_memory > max_free_memory:
            max_free_memory = free_memory
            best_device_index = i
    if best_device_index == -1:
        return torch.device("cpu")
    else:
        return torch.device(f"cuda:{best_device_index}")

def clear_torch_cache() -> None:
    gc.collect()
    torch.cuda.empty_cache()

def get_uv(u, s, v, k):
    svd_u = u[:, :k]
    svd_s = s[:k]
    svd_v = v[:k, :]
    sqrt_s = torch.diag(torch.sqrt(svd_s))
    if svd_u.device != sqrt_s.device:
        print('svd u s device: ', svd_u.device, sqrt_s.device)
        svd_u = svd_u.to(sqrt_s.device)
    if sqrt_s.device != svd_v.device:
        print('svd s v device: ', sqrt_s.device, svd_v.device)
        svd_v = svd_v.to(sqrt_s.device)
    
    u=(svd_u @ sqrt_s).T
    v=(sqrt_s @ svd_v).T
    
    svd_u = sqrt_s = svd_v = None
    clear_torch_cache()
    return u, v

def get_calib_data_fisher(data_name, tokenizer, model_name, batch_size, nsamples, seqlen=4096, seed=3):
    cache_file = (
        f"/data/{data_name}_{model_name}_{batch_size}_{nsamples}_{seqlen}_{seed}.pt"
    )
    random.seed(seed)
    if os.path.exists(cache_file):
        traindataset = torch.load(cache_file)
        print(f"[Calib data] Load from {cache_file}")
        return traindataset
    if data_name == "c4":
        traindata = load_dataset(
            "allenai/c4",
            data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
            revision="607bd4c8450a42878aa9ddc051a65a055450ef87",
            split="train",
        )
        tot_text = "\n\n".join(traindata["text"])
    elif data_name == "wikitext2":
        traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
        tot_text = "\n\n".join(traindata["text"])
    else:
        raise NotImplementedError
    print(f"len tot_text = {len(tot_text)}")
    traindataset = []
    batch_traindataset = []
    for cnt in tqdm(range(1, nsamples+1)):
        i = random.randint(0, len(tot_text) - seqlen - 1)
        j = i + seqlen * 10
        trainenc = tokenizer(tot_text[i:j], return_tensors="pt")
        inp = trainenc.input_ids[:, :seqlen]
        batch_traindataset.append(inp)
        if cnt % batch_size == 0:
            traindataset.append({"batch_input_ids": batch_traindataset})
            batch_traindataset = []
    torch.save(traindataset, cache_file)
    print(f"{cache_file} saved.")
    return traindataset

def calib_fisher_info(model, target_modules, skipped_layers, calib_loader, model_name, use_cache=True):
    model.gradient_checkpointing_enable()
    model.half()
    if 'opt-6.7b' in model_name:
        cache_file = f"/data/opt_6dot7b/{model_name}_calib_fisher_info_wiki_2048_2048.pt"
    if 'opt-13b' in model_name:
        cache_file = f"/data/opt_13b/{model_name}_calib_fisher_info_wiki_2048_2048.pt"
    if 'Llama-2-13b' in model_name:
        cache_file = f"/data/13b/{model_name}_calib_fisher_info_wiki_2048_4096.pt"
    elif 'Llama-2-7b' in model_name:
        cache_file = f"/data/7b/{model_name}_calib_fisher_info_wiki_2048_4096.pt"
    elif 'Llama-3-8b' in model_name:
        cache_file = f"/data/8b/llama-3-8B_calib_fisher_info_wiki_2048_4096.pt"
    
    print(f"[Fisher] Search cache_file={cache_file}")
    
    if os.path.exists(cache_file) and use_cache:
        print(f"[Fisher] File {cache_file} exist.")
        print(f"[Fisher] Load cache_file={cache_file}")
        all_fisher_info = torch.load(cache_file, map_location="cpu")
        for name, module in model.named_modules():
            suffix = name.split(".")[-1]
            if suffix not in target_modules:
                continue
            if suffix in ['fc1', 'fc2']:
                layer_idx = int(name.split(".")[-2])
            else:
                layer_idx = int(name.split(".")[-3])
            if layer_idx in skipped_layers:
                continue
            module.fisher_info = all_fisher_info[name].to(module.weight.device)
        return
    model.eval()

    print(f"[Fisher] No cache_file={cache_file}")
    print(f"[Fisher] Create fisher info list...")

    for name, module in model.named_modules():
        suffix = name.split(".")[-1]
        if suffix not in target_modules:
            continue
        if suffix in ['fc1', 'fc2']:
            layer_idx = int(name.split(".")[-2])
        else:
            layer_idx = int(name.split(".")[-3])
        if layer_idx in skipped_layers:
            continue
        module.fisher_info = 0

    for batch in tqdm(calib_loader):
        input_ids = batch["input_ids"][:, :-1].to(model.device)
        labels = batch["input_ids"][:, 1:].to(model.device)
        out = model(input_ids=input_ids, labels=labels)
        out[0].backward()
        for name, module in model.named_modules():
            suffix = name.split(".")[-1]
            if suffix not in target_modules:
                continue
            if suffix in ['fc1', 'fc2']:
                layer_idx = int(name.split(".")[-2])
            else:
                layer_idx = int(name.split(".")[-3])
            if layer_idx in skipped_layers:
                continue
            module.fisher_info += module.weight.grad.detach().to(torch.float32).pow(2)

        model.zero_grad()

    for name, module in model.named_modules():
        suffix = name.split(".")[-1]
        if suffix not in target_modules:
            continue
        if suffix in ['fc1', 'fc2']:
            layer_idx = int(name.split(".")[-2])
        else:
            layer_idx = int(name.split(".")[-3])
        if layer_idx in skipped_layers:
            continue
        module.fisher_info = module.fisher_info.div(len(calib_loader)).sqrt()
       
    all_fisher_info = {}
    for name, module in model.named_modules():
        suffix = name.split(".")[-1]
        if suffix not in target_modules:
            continue
        if suffix in ['fc1', 'fc2']:
            layer_idx = int(name.split(".")[-2])
        else:
            layer_idx = int(name.split(".")[-3])
        if layer_idx in skipped_layers:
            continue
        module._forward_hooks.clear()
        all_fisher_info[name] = module.fisher_info

    torch.save(all_fisher_info, cache_file)
    print(f"[Fisher] Save the fisher info list to:  {cache_file}")

class HelperState(Enum):
    KEY = 10000

    Collecting = 0
    Inference = 1

    Invalid = 9999


HelperState.KEY.label = "HelperState"
HelperState.Collecting.label = "Helper-Data-Collection"  # hook forward() to collect data
HelperState.Inference.label = "Helper-Ready-Inference"    # with updated forward()

class HelperState(Enum):
    KEY = 10000

    Collecting = 0
    Inference = 1

    Invalid = 9999


HelperState.KEY.label = "HelperState"
HelperState.Collecting.label = "Helper-Data-Collection"  # hook forward() to collect data
HelperState.Inference.label = "Helper-Ready-Inference"    # with updated forward()


class HelperCollectState(Enum):
    KEY = 10001

    Pre = 0
    Post = 1
    End = 2

    Invalid = 9999


HelperCollectState.KEY.label = "HelperCollectState"
HelperCollectState.Pre.label = "HelperCollectState-Pre"
HelperCollectState.Post.label = "HelperCollectState-Post"
HelperCollectState.End.label = "HelperCollectState-End"

def set_helper_state(model, state: HelperState) -> None:
    setattr(model, HelperState.KEY.label, state)


HELPER_SUPPORT_MODEL_LIST = (LlamaPreTrainedModel, OPTPreTrainedModel)
HELPER_SUPPORT_MODEL_TYPES = Union[LlamaPreTrainedModel, OPTPreTrainedModel]


# https://pypi.org/project/lm-eval/0.0.1/
TASK_METRIC_MAP = {
    "piqa": "acc_norm,none",
    "arc_challenge": "acc_norm,none",
    "arc_easy": "acc_norm,none",
    "hellaswag": "acc_norm,none",
    "winogrande": "acc,none",
    "boolq": "acc,none",
    'wsc': 'acc,none',
    "openbookqa": "acc_norm,none"
}

def calculate_avg_accuracy(task_names: str, results: dict) -> float:
    n_tasks = len(task_names)
    acc_cumul = sum(result.get(TASK_METRIC_MAP[task]) for task, result in results.items() if 'mmlu' not in task)

    questions_per_mmlu_task = {
        task_name: lm_eval.tasks.get_task_dict([task_name])[task_name].dataset["test"].num_rows
        for task_name in task_names
        if 'mmlu' in task_name
    }

    if not questions_per_mmlu_task:
        return acc_cumul / n_tasks

    # Calculate average accuracy for mmlu tasks, weighted by number of questions in each task
    acc_mmlu = sum(
        result.get(TASK_METRIC_MAP[task]) * questions_per_mmlu_task[task]
        for task, result in results.items()
        if 'mmlu' in task
    )
    acc_mmlu_avg = acc_mmlu / sum(questions_per_mmlu_task.values())

    return (acc_cumul + acc_mmlu_avg) / (n_tasks - len(questions_per_mmlu_task) + 1)


def easy_dump(obj, dest, label):
    with open(os.path.join(dest, f"{label}.pkl"), "wb") as f:
        pickle.dump(obj, f)

    # also dump as json if it is a dict
    if isinstance(obj, dict):
        with open(os.path.join(dest, f"{label}.json"), "w") as f:
            f.write(json.dumps(obj, indent=4))
            
def make_run_dir(outdir: Union[str, os.PathLike], desc: str) -> str:
    """Reject modernity, return to automatically create the run dir."""
    # Pick output directory.
    prev_run_dirs = []
    if os.path.isdir(outdir):  # sanity check, but click.Path() should clear this one
        prev_run_dirs = [x for x in os.listdir(outdir) if os.path.isdir(os.path.join(outdir, x))]
    prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs]
    prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None]
    cur_run_id = max(prev_run_ids, default=-1) + 1  # start with 00000
    run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{desc}')
    os.makedirs(run_dir, exist_ok=False)  # make sure it doesn't already exist
    return run_dir

class EasyDict(dict):
    """Convenience class that behaves like a dict but allows access with the attribute syntax."""

    def __getattr__(self, name: str) -> Any:
        try:
            return self[name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name: str, value: Any) -> None:
        self[name] = value

    def __delattr__(self, name: str) -> None:
        del self[name]

DEFAULT_PAD_TOKEN = "[PAD]"
def smart_tokenizer_and_embedding_resize(
    special_tokens_dict: Dict,
    tokenizer: transformers.PreTrainedTokenizer,
    model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))
    
    if num_new_tokens > 0:
        input_embeddings_data = model.get_input_embeddings().weight.data
        output_embeddings_data = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings_data[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings_data[-num_new_tokens:] = input_embeddings_avg
        output_embeddings_data[-num_new_tokens:] = output_embeddings_avg

subcategories = {
    "abstract_algebra": ["math"],
    "anatomy": ["health"],
    "astronomy": ["physics"],
    "business_ethics": ["business"],
    "clinical_knowledge": ["health"],
    "college_biology": ["biology"],
    "college_chemistry": ["chemistry"],
    "college_computer_science": ["computer science"],
    "college_mathematics": ["math"],
    "college_medicine": ["health"],
    "college_physics": ["physics"],
    "computer_security": ["computer science"],
    "conceptual_physics": ["physics"],
    "econometrics": ["economics"],
    "electrical_engineering": ["engineering"],
    "elementary_mathematics": ["math"],
    "formal_logic": ["philosophy"],
    "global_facts": ["other"],
    "high_school_biology": ["biology"],
    "high_school_chemistry": ["chemistry"],
    "high_school_computer_science": ["computer science"],
    "high_school_european_history": ["history"],
    "high_school_geography": ["geography"],
    "high_school_government_and_politics": ["politics"],
    "high_school_macroeconomics": ["economics"],
    "high_school_mathematics": ["math"],
    "high_school_microeconomics": ["economics"],
    "high_school_physics": ["physics"],
    "high_school_psychology": ["psychology"],
    "high_school_statistics": ["math"],
    "high_school_us_history": ["history"],
    "high_school_world_history": ["history"],
    "human_aging": ["health"],
    "human_sexuality": ["culture"],
    "international_law": ["law"],
    "jurisprudence": ["law"],
    "logical_fallacies": ["philosophy"],
    "machine_learning": ["computer science"],
    "management": ["business"],
    "marketing": ["business"],
    "medical_genetics": ["health"],
    "miscellaneous": ["other"],
    "moral_disputes": ["philosophy"],
    "moral_scenarios": ["philosophy"],
    "nutrition": ["health"],
    "philosophy": ["philosophy"],
    "prehistory": ["history"],
    "professional_accounting": ["other"],
    "professional_law": ["law"],
    "professional_medicine": ["health"],
    "professional_psychology": ["psychology"],
    "public_relations": ["politics"],
    "security_studies": ["politics"],
    "sociology": ["culture"],
    "us_foreign_policy": ["politics"],
    "virology": ["health"],
    "world_religions": ["philosophy"],
}

categories = {
    "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
    "humanities": ["history", "philosophy", "law"],
    "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
    "other (business, health, misc.)": ["other", "business", "health"],
}

choices = ["A", "B", "C", "D"]


def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1])
    return prompt


def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)
    return prompt


@torch.no_grad()
def evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df):
    cors = []
    all_probs = []
    answers = choices[: test_df.shape[1] - 2]

    for i in tqdm(range(test_df.shape[0]), desc=subject):
        # get prompt and make sure it fits
        k = ntrain
        prompt_end = format_example(test_df, i, include_answer=False)
        train_prompt = gen_prompt(dev_df, subject, k)
        prompt = train_prompt + prompt_end
        
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

        while input_ids.shape[-1] > 2048:
            k -= 1
            train_prompt = gen_prompt(dev_df, subject, k)
            prompt = train_prompt + prompt_end
            input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
                model.device
            )

        label = test_df.iloc[i, test_df.shape[1] - 1]

        logits = model(input_ids=input_ids).logits[0, -1]

        probs = (
            torch.nn.functional.softmax(
                torch.tensor(
                    [
                        logits[tokenizer("A").input_ids[-1]],
                        logits[tokenizer("B").input_ids[-1]],
                        logits[tokenizer("C").input_ids[-1]],
                        logits[tokenizer("D").input_ids[-1]],
                    ]
                ).float(),
                dim=0,
            )
            .detach()
            .cpu()
            .numpy()
        )
        pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]

        cor = pred == label
        cors.append(cor)
        all_probs.append(probs)

    acc = np.mean(cors)
    cors = np.array(cors)

    all_probs = np.array(all_probs)
    print("Average accuracy {:.3f} - {}".format(acc, subject))

    return cors, acc, all_probs

def eval_mmlu(model, tokenizer, ntrain, data_dir):
    subjects = sorted(
        [
            f.split("_test.csv")[0]
            for f in os.listdir(os.path.join(data_dir, "test"))
            if "_test.csv" in f
        ]
    )

    all_cors = []
    subcat_cors = {
        subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists
    }
    cat_cors = {cat: [] for cat in categories}
    
    start_time = time.time()
    for subject in subjects:
        dev_df = pd.read_csv(
            os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None
        )[: ntrain]
        test_df = pd.read_csv(
            os.path.join(data_dir, "test", subject + "_test.csv"), header=None
        )

        cors, acc, probs = evaluate_subject(subject, model, tokenizer, ntrain, dev_df, test_df)
        subcats = subcategories[subject]
        for subcat in subcats:
            subcat_cors[subcat].append(cors)
            for key in categories.keys():
                if subcat in categories[key]:
                    cat_cors[key].append(cors)
        all_cors.append(cors)

    results = {"subcategories": {}, "categories": {}}
    for subcat in subcat_cors:
        subcat_acc = np.mean(np.concatenate(subcat_cors[subcat]))
        results["subcategories"][subcat] = subcat_acc
        print("Average accuracy {:.3f} - {}".format(subcat_acc, subcat))

    for cat in cat_cors:
        cat_acc = np.mean(np.concatenate(cat_cors[cat]))
        results["categories"][cat] = cat_acc
        print("Average accuracy {:.3f} - {}".format(cat_acc, cat))
    weighted_acc = np.mean(np.concatenate(all_cors))
    results["weighted_accuracy"] = weighted_acc
    print("Average accuracy: {:.3f}".format(weighted_acc))

    end_time = time.time()
    results["cost_time"] = end_time - start_time
    
    return results

def eval_ppl(model, tokenizer):
    model.eval()
    max_length = 2048   # model.config.max_position_embeddings
    stride = max_length
    test = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"]
    nlls = []
    encodings = tokenizer("\n\n".join(test), return_tensors="pt")
    seq_len = encodings.input_ids.size(1)
    prev_end_loc = 0
    for begin_loc in tqdm(range(0, seq_len, stride)):
        end_loc = min(begin_loc + max_length, seq_len)
        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(model.device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100
        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            # loss is calculated using CrossEntropyLoss which averages over valid labels
            # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
            # to the left by 1.
            neg_log_likelihood = outputs.loss
        nlls.append(neg_log_likelihood)
        prev_end_loc = end_loc
        if end_loc == seq_len:
            break
    ppl = torch.exp(torch.stack(nlls).mean())
    return ppl
