import time

import torch
import torch.nn as nn

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from model_utils import *
from mistral_utils import get_mistral, mistral_fuse_rms_single_layer, mistral_fuse_rotation_single_layer, replace_mistral_layer

from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR

from gpu_utils import distribute_model
from lm_eval_utils import *
from transformers import AutoTokenizer
import copy

try:
    import wandb
    has_wandb = True
except:
    has_wandb = False

import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import gc
import yaml
@torch.no_grad()
def mistral_sequential(model, dev):
    print("Starting...")

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    quantizers = {}
    for i in tqdm(range(len(layers))):
        layer = layers[i].to(dev)

        mistral_fuse_rms_single_layer(layer)
        full = find_layers(layer)
        names = list(full.keys())

        if args.rotate == "True":
            if args.rotation_dir != None:
                R1_path = os.path.join(args.rotation_dir, f"layer_{i}_R1.pt")
                R2_list_path = os.path.join(args.rotation_dir, f"layer_{i}_R2_list.pt")
                if args.rotate_R1 == "True":
                    R1 = torch.load(R1_path)
                else:
                    R1 = None
                if args.rotate_R2 == "True":
                    R2_list = torch.load(R2_list_path)
                else:
                    R2_list = None
                mistral_fuse_rotation_single_layer(layer, R1, R2_list)
                if args.rotate_R1 == "True":
                    layers[i] = layer
                    replace_mistral_layer(model, i, R1)
                    layer = layers[i].to(dev)
                torch.cuda.empty_cache()
        if args.sparsity != 0:
            if args.greedy_search_dir != "None":
                mixed_sparsity_config_dict_path = os.path.join(args.greedy_search_dir, f"layer_{i}_mixed_sparsity_config_dict.yaml")
                with open(mixed_sparsity_config_dict_path, 'r', encoding='utf-8') as f:
                    mixed_sparsity_config_dict = yaml.safe_load(f)
                sparsity_dict = mixed_sparsity_config_dict[args.sparsity]
                def convert_to_sparse_layer(layer):
                    layer.self_attn.q_proj = convert_to_sparse_linear(layer.self_attn.q_proj)
                    layer.self_attn.k_proj = convert_to_sparse_linear(layer.self_attn.k_proj)
                    layer.self_attn.v_proj = convert_to_sparse_linear(layer.self_attn.v_proj)
                    layer.self_attn.o_proj = convert_to_sparse_linear(layer.self_attn.o_proj)
                    layer.mlp.up_proj = convert_to_sparse_linear(layer.mlp.up_proj)
                    layer.mlp.gate_proj = convert_to_sparse_linear(layer.mlp.gate_proj)
                    layer.mlp.down_proj = convert_to_sparse_linear(layer.mlp.down_proj)
                def set_layer_spectrum_threshold(layer, sparsity_dict=None):
                    if sparsity_dict != None:
                        layer.self_attn.q_proj.sparsifier.spectrum_threshold = sparsity_dict["self_attn.q_proj"]["spectrum_threshold"]
                        layer.self_attn.k_proj.sparsifier.spectrum_threshold = sparsity_dict["self_attn.k_proj"]["spectrum_threshold"]
                        layer.self_attn.v_proj.sparsifier.spectrum_threshold = sparsity_dict["self_attn.v_proj"]["spectrum_threshold"]
                        layer.self_attn.o_proj.sparsifier.spectrum_threshold = sparsity_dict["self_attn.o_proj"]["spectrum_threshold"]
                        layer.mlp.up_proj.sparsifier.spectrum_threshold = sparsity_dict["mlp.up_proj"]["spectrum_threshold"]
                        layer.mlp.gate_proj.sparsifier.spectrum_threshold = sparsity_dict["mlp.gate_proj"]["spectrum_threshold"]
                        layer.mlp.down_proj.sparsifier.spectrum_threshold = sparsity_dict["mlp.down_proj"]["spectrum_threshold"]

                convert_to_sparse_layer(layer)
                set_layer_spectrum_threshold(layer, sparsity_dict)

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()

    model.config.use_cache = use_cache

    return quantizers


@torch.no_grad()
def mistral_eval_ppl(model, testenc, dev,  dataset: str, log_wandb: bool = False):
    print("Evaluating ...")

    testenc = testenc.input_ids
    nsamples = testenc.numel() // model.seqlen

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.embed_tokens = model.model.embed_tokens.to(dev)

    model.model.rotary_emb = model.model.rotary_emb.to(dev)

    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0, "attention_mask": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            cache["position_embeddings"] = kwargs["position_embeddings"]
            raise ValueError

    layers[0] = Catcher(layers[0])
    for i in range(nsamples):
        batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev)
        try:
            model(batch)
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache["attention_mask"]
    position_embeddings = cache["position_embeddings"]

    for i in range(len(layers)):
        print(i)

        layer = layers[i].to(dev)

        for j in range(nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()
        inps, outs = outs, inps

    if model.model.norm is not None:
        model.model.norm = model.model.norm.to(dev)
    model.lm_head = model.lm_head.to(dev)

    testenc = testenc.to(dev)
    nlls = []
    for i in range(nsamples):
        hidden_states = inps[i].unsqueeze(0)
        if model.model.norm is not None:
            hidden_states = model.model.norm(hidden_states)
        lm_logits = model.lm_head(hidden_states)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        neg_log_likelihood = loss.float() * model.seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
    print(f"Perplexity: {ppl.item():3f}")
    if log_wandb:
        wandb.log({f"{dataset}/perplexity": ppl.item()})

    model.config.use_cache = use_cache

if __name__ == "__main__":

    import argparse
    from data_utils import *

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "model", type=str, help="mistral model to load")
    parser.add_argument("--seed", type=int, default=0, help="Seed for sampling the calibration data.")
    parser.add_argument("--nsamples", type=int, default=128, help="Number of calibration data samples.")
    parser.add_argument("--log_wandb", action="store_true", help="Whether to log to wandb.")
    parser.add_argument("--rotate", type=str, default="True")
    parser.add_argument("--rotate_R1", type=str, default="True")
    parser.add_argument("--rotate_R2", type=str, default="True")
    parser.add_argument("--rotation_dir", type=str, default="", help="Path to saved rotation matrix.")
    parser.add_argument("--rotation_epoch", type=int, default=100)
    parser.add_argument("--distribute_model", type=str, default="False")
    parser.add_argument("--greedy_search_dir", type=str, default="None")
    parser.add_argument("--sparsity", type=int, default=0)

    args = parser.parse_args()

    # init W&B logging
    if args.log_wandb:
        assert has_wandb, "wandb not installed try `pip install wandb`"
        wandb.init(config=args)

    model = get_mistral(args.model)
    model.eval()

    for param in model.parameters():
        param.requires_grad_(False)

    DEV = torch.device("cuda:0")
    tick = time.time()
    mistral_sequential(model, DEV)
    torch.cuda.empty_cache()
    for n, p in model.named_parameters():
        print(n, torch.mean((p == 0).float()))
        if 'down_proj' in n:
            break
    print(time.time() - tick)

    for dataset in ["wikitext2"]:
        dataloader, testloader = get_loaders(
            dataset, seed=args.seed, model=args.model, seqlen=model.seqlen
        )
        print("Dataset:", dataset)
        mistral_eval_ppl(model, testloader, DEV, dataset, args.log_wandb)

    if args.distribute_model == "True":
        distribute_model(model)
    else:
        model.to(DEV)
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    evaluate_zero_shot(model, tokenizer, batch_size=16, num_fewshot=0)
    evaluate_few_shot_mmlu(model, tokenizer, batch_size=16, num_fewshot=5)