import time

import torch
import torch.nn as nn

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from model_utils import *
from mistral_utils import get_mistral, mistral_fuse_rms_single_layer, mistral_fuse_rotation_single_layer, replace_mistral_layer

from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR

from transformers import AutoTokenizer
import copy

try:
    import wandb
    has_wandb = True
except:
    has_wandb = False

import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import gc
import random

@torch.no_grad()
def mistral_sequential(model, dataloader, dev):
    print("Starting...")

    R1_permute_indices = create_max_variance_permutation_index(model.config.hidden_size, 32)
    R2_permute_indices = create_max_variance_permutation_index(model.config.head_dim, 32)

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.rotary_emb = model.model.rotary_emb.to(dev)
    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    model.model.norm = model.model.norm.to(dev)
    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0, "attention_mask": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            cache["position_embeddings"] = kwargs["position_embeddings"]
            raise ValueError

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    model.model.norm = model.model.norm.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    batch_size = 16
    attention_mask = cache["attention_mask"]
    position_embeddings = cache["position_embeddings"]

    print("Ready.")

    quantizers = {}
    for i in range(len(layers)):
        print(i)
        
        layer = layers[i].to(dev)

        mistral_fuse_rms_single_layer(layer)

        full = find_layers(layer)
        names = list(full.keys())

        if args.rotate == "True":
            if args.rotation_dir != None:
                R1_path = os.path.join(args.rotation_dir, f"layer_{i}_R1.pt")
                R2_list_path = os.path.join(args.rotation_dir, f"layer_{i}_R2_list.pt")
                if args.rotate_R1 == "True":
                    R1 = torch.load(R1_path)
                else:
                    R1 = None
                if args.rotate_R2 == "True":
                    R2_list = torch.load(R2_list_path)
                else:
                    R2_list = None
            else:
                for param in layer.parameters():
                    param.requires_grad_(False)
                full = find_layers(layer)
                names = list(full.keys())
                subset = {n: full[n] for n in names}
                gpts = {}
                for name in subset:
                    gpts[name] = LayerWrapper(subset[name])

                def add_batch(name):
                    def tmp(_, inp, out):
                        gpts[name].update_hessian_mean(inp[0].data, out.data)
                    return tmp

                handles = []
                name_list = ["self_attn.q_proj", "self_attn.o_proj", "mlp.up_proj"]
                for name in name_list:
                    handles.append(subset[name].register_forward_hook(add_batch(name)))

                for j in tqdm(range(0, args.nsamples, batch_size)):
                    batch_inp = inps[j:j+batch_size]
                    outs[j:j+batch_size] = layer(batch_inp, attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                
                R1 = None
                R2_list = None
                hessian = torch.zeros((model.config.hidden_size, model.config.hidden_size), dtype=torch.float32, device=dev)
                for name in ["self_attn.q_proj", "mlp.up_proj"]:
                    hessian += gpts[name].H
                eigenvalues, eigenvectors = torch.linalg.eigh(hessian)
                R1 = eigenvectors.to(dtype)

                weight_dict_list = [find_layers(layer)]

                hessian_dict = {}
                for name in subset:
                    hessian_dict[name] = gpts[name].H
                hessian_dict_list = [hessian_dict]
                del gpts
                torch.cuda.empty_cache()

                R_model = RotatorOptimizer(weight_dict_list, model.config.hidden_size, num_key_value_heads=model.config.num_key_value_heads, head_dim=model.config.head_dim, device=dev, positive=True, hessian_dict_list=hessian_dict_list, dtype=torch.float32, with_weight=False)
                R_model = R_model.to(dev)
                R_model.A_list[0].data = eigenvectors

                best_R1 = R_model.get_orthogonal_matrix()
                best_R2_list = R_model.get_orthogonal_matrix_R2_list_list()[0]

                scaler = GradScaler()
                opt_R1 = torch.optim.Adam(R_model.parameters_R1(), lr=0.000001, maximize=False)
                opt_R2 = torch.optim.Adam(R_model.parameters_R2(), lr=0.000001, maximize=False)
                def warmup_lr(epoch):
                    if epoch < 5:
                        return epoch / 5
                    else:
                        return 1
                warmup_scheduler_R1 = LambdaLR(opt_R1, lr_lambda=warmup_lr)
                warmup_scheduler_R2 = LambdaLR(opt_R2, lr_lambda=warmup_lr)
                cosine_scheduler_R1 = CosineAnnealingLR(opt_R1, T_max=10)
                cosine_scheduler_R2 = CosineAnnealingLR(opt_R2, T_max=10)

                best_h1_h3_loss = float('inf')
                best_h2_loss = float('inf')

                for epoch in tqdm(range(args.rotation_epoch)):
                    epoch_start_time = time.time()
                    for step_id in range(1):
                        with torch.enable_grad():
                            opt_R1.zero_grad()
                            h1_h3_loss = R_model.forward_h1_h3()
                            scaler.scale(h1_h3_loss).backward()
                            scaler.unscale_(opt_R1)
                            torch.nn.utils.clip_grad_norm_(R_model.parameters_R1(), 1.0)
                            scaler.step(opt_R1)
                            scaler.update()

                            if h1_h3_loss.item() < best_h1_h3_loss:
                                best_h1_h3_loss = h1_h3_loss.item()
                                best_R1 = R_model.get_orthogonal_matrix()

                            opt_R2.zero_grad()
                            h2_loss = R_model.forward_h2()
                            scaler.scale(h2_loss).backward()
                            scaler.unscale_(opt_R2)
                            torch.nn.utils.clip_grad_norm_(R_model.parameters_R2(), 1.0)
                            scaler.step(opt_R2)
                            scaler.update()

                            if h2_loss.item() < best_h2_loss:
                                best_h2_loss = h2_loss.item()
                                best_R2_list = R_model.get_orthogonal_matrix_R2_list_list()[0]

                            if epoch == 0 and step_id == 0:
                                tqdm.write(f"Initial loss R1: {h1_h3_loss.item()} R2: {h2_loss.item()}")

                    tqdm.write(f"R1: {h1_h3_loss.item()} R2: {h2_loss.item()}")
                    if epoch < 5:
                        warmup_scheduler_R1.step()
                        warmup_scheduler_R2.step()
                    elif epoch > 10:
                        cosine_scheduler_R1.step()
                        cosine_scheduler_R2.step()
                    tqdm.write(f"Epoch {epoch+1}, Learning Rate R1: {opt_R1.param_groups[0]['lr']}, Learning Rate R2: {opt_R2.param_groups[0]['lr']}")
                    epoch_end_time = time.time()

                del weight_dict_list
                for h in handles:
                    h.remove()
                torch.cuda.empty_cache()
                
                if args.rotate_R1 == "True":
                    R1 = best_R1.to(dtype)
                    R1 = R1[:, R1_permute_indices]
                else:
                    R1 = None
                if args.rotate_R2 == "True":
                    R2_list = best_R2_list
                    for k in range(len(R2_list)):
                        R2_list[k] = R2_list[k].to(dtype)
                        R2_list[k] = R2_list[k][:, R2_permute_indices]
                else:
                    R2_list = None
                
            mistral_fuse_rotation_single_layer(layer, R1, R2_list)

            if args.rotate_R1 == "True":
                R1 = R1.to(dev)
                layers[i] = layer
                replace_mistral_layer(model, i, R1)
                layer = layers[i].to(dev)

        if args.input_normalized_sorted_dir != "None":
            input_normalized_sorted_dict_path = os.path.join(args.input_normalized_sorted_dir, f"layer_{i}_input_normalized_sorted.pt")
            input_normalized_sorted_dict = torch.load(input_normalized_sorted_dict_path)
        else:
            full = find_layers(layer)
            subset = {n: full[n] for n in names}
            name_list = ["self_attn.q_proj", "self_attn.o_proj", "mlp.up_proj", "mlp.down_proj"]
            gpts = {}
            for name in name_list:
                gpts[name] = LayerWrapper(subset[name])
            def add_batch(name):
                def tmp(_, inp, out):
                    gpts[name].set_input(inp[0].data, out.data)
                return tmp
            handles = []
            for name in name_list:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            input_dataset = ActivationDataset(name_list, torch.device("cuda"))
            for j in tqdm(range(0, 16, batch_size)):
                batch_inp = inps[j:j+batch_size]
                outs[j:j+batch_size] = layer(batch_inp, attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                input_dict = {}
                for name in name_list:
                    input_dict[name] = gpts[name].input.squeeze(0)
                    gpts[name].reset_input()
                input_dataset.update(input_dict)
                del input_dict
                torch.cuda.empty_cache()
            for h in handles:
                h.remove()

            input_dataset.concat_lists_to_tensors()
            input_dataset.normalize_and_sort()
            input_normalized_sorted_dict = {}
            for key, value in input_dataset.input_normalized_sorted_dict.items():
                if isinstance(value, torch.Tensor):
                    input_normalized_sorted_dict[key] = value.detach().clone()
                else:
                    input_normalized_sorted_dict[key] = value
            del gpts
            del input_dataset
            torch.cuda.empty_cache()

        sparsity_dict = {"self_attn.q_proj": {"sparsity": 0, "spectrum_threshold": 0}, 
                         "self_attn.k_proj": {"sparsity": 0, "spectrum_threshold": 0},
                         "self_attn.v_proj": {"sparsity": 0, "spectrum_threshold": 0},
                         "self_attn.o_proj": {"sparsity": 0, "spectrum_threshold": 0}, 
                         "mlp.up_proj":   {"sparsity": 0, "spectrum_threshold": 0},
                         "mlp.gate_proj": {"sparsity": 0, "spectrum_threshold": 0},
                         "mlp.down_proj": {"sparsity": 0, "spectrum_threshold": 0}}

        def compute_spectrum_threshold(layer_input, sparsity_ratio):
            with torch.no_grad():
                k = max(1, int(layer_input.numel() * sparsity_ratio))
                threshold_value = layer_input[k-1].item()
                return threshold_value

        name_map = {"self_attn.q_proj": "self_attn.q_proj",
                    "self_attn.k_proj": "self_attn.q_proj",
                    "self_attn.v_proj": "self_attn.q_proj",
                    "self_attn.o_proj": "self_attn.o_proj",
                    "mlp.up_proj": "mlp.up_proj",
                    "mlp.gate_proj": "mlp.up_proj",
                    "mlp.down_proj": "mlp.down_proj",}

        for name in name_map.keys():
            layer_input = input_normalized_sorted_dict[name_map[name]]
            sparsity_dict[name]["spectrum_threshold"] = compute_spectrum_threshold(
                layer_input, sparsity_dict[name]["sparsity"]
            )
            del layer_input
            gc.collect()
            torch.cuda.empty_cache()

        def convert_to_sparse_layer(layer):
            layer.self_attn.q_proj = convert_to_sparse_linear(layer.self_attn.q_proj)
            layer.self_attn.k_proj = convert_to_sparse_linear(layer.self_attn.k_proj)
            layer.self_attn.v_proj = convert_to_sparse_linear(layer.self_attn.v_proj)
            layer.self_attn.o_proj = convert_to_sparse_linear(layer.self_attn.o_proj)
            layer.mlp.up_proj = convert_to_sparse_linear(layer.mlp.up_proj)
            layer.mlp.gate_proj = convert_to_sparse_linear(layer.mlp.gate_proj)
            layer.mlp.down_proj = convert_to_sparse_linear(layer.mlp.down_proj)
        
        def set_layer_spectrum_threshold(layer, sparsity_dict=None):
            if sparsity_dict != None:
                layer.self_attn.q_proj.sparsifier.spectrum_threshold = sparsity_dict["self_attn.q_proj"]["spectrum_threshold"]
                layer.self_attn.k_proj.sparsifier.spectrum_threshold = sparsity_dict["self_attn.k_proj"]["spectrum_threshold"]
                layer.self_attn.v_proj.sparsifier.spectrum_threshold = sparsity_dict["self_attn.v_proj"]["spectrum_threshold"]
                layer.self_attn.o_proj.sparsifier.spectrum_threshold = sparsity_dict["self_attn.o_proj"]["spectrum_threshold"]
                layer.mlp.up_proj.sparsifier.spectrum_threshold = sparsity_dict["mlp.up_proj"]["spectrum_threshold"]
                layer.mlp.gate_proj.sparsifier.spectrum_threshold = sparsity_dict["mlp.gate_proj"]["spectrum_threshold"]
                layer.mlp.down_proj.sparsifier.spectrum_threshold = sparsity_dict["mlp.down_proj"]["spectrum_threshold"]

        convert_to_sparse_layer(layer)
        set_layer_spectrum_threshold(layer, sparsity_dict)

        if args.greedy_search_dir != "None":
            import yaml

            weight_ratio_dict = {name: full[name].weight.numel() / (model.config.hidden_size ** 2) for name in names}
            weight_ratio_sum = 0
            for name in names:
                sparsity_dict[name]["sparsity"] = 0
                layer_input = input_normalized_sorted_dict[name_map[name]]
                sparsity_dict[name]["spectrum_threshold"] = compute_spectrum_threshold(layer_input, sparsity_dict[name]["sparsity"])
                weight_ratio_sum += weight_ratio_dict[name]
            for j in tqdm(range(0, args.nsamples, batch_size)):
                batch_inp = inps[j:j+batch_size]
                outs[j:j+batch_size] = layer(batch_inp, attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
            layer_ratio_step = 1
            weight_ratio_increase_ratio_dict = {}
            step_num = args.greedy_search_step_num
            for name in names:
                weight_ratio_increase_ratio_dict[name] = float(layer_ratio_step) / 100 * weight_ratio_sum / weight_ratio_dict[name] / step_num

            def compute_reconstruction_error(target, layer_output):
                with torch.no_grad():
                    re_error = torch.norm(target - layer_output, dim=1).mean()
                return re_error.detach().clone()

            def compute_layer_sparsity(weight_ratio_dict, sparsity_dict):
                weighted_sparse_ratio_sum = 0
                for name in names:
                    weighted_sparse_ratio_sum += weight_ratio_dict[name] * sparsity_dict[name]["sparsity"]
                return weighted_sparse_ratio_sum / weight_ratio_sum

            layer_sparsity = compute_layer_sparsity(weight_ratio_dict, sparsity_dict)
            best_layer_sparsity_config_dict = {}
            for layer_sparsity in tqdm(range(0, 60, layer_ratio_step)):
                for step in range(step_num):
                    best_layer_sparsity_config = None
                    best_reconstruction_error = None
                    # random.shuffle(names)
                    for name in names:
                        weight_ratio_increase_ratio = weight_ratio_increase_ratio_dict[name]
                        sparsity_dict[name]["sparsity"] += weight_ratio_increase_ratio
                        if sparsity_dict[name]["sparsity"] >= 0.8:
                            sparsity_dict[name]["sparsity"] -= weight_ratio_increase_ratio
                            continue
                        layer_input = input_normalized_sorted_dict[name_map[name]]
                        sparsity_dict[name]["spectrum_threshold"] = compute_spectrum_threshold(layer_input, sparsity_dict[name]["sparsity"])
                        set_layer_spectrum_threshold(layer, sparsity_dict)
                        current_layer_output = torch.zeros_like(outs)
                        for j in range(0, args.nsamples, batch_size):
                            batch_inp = inps[j:j+batch_size]
                            current_layer_output[j:j+batch_size] = layer(batch_inp, attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                        current_reconstruction_error = compute_reconstruction_error(outs, current_layer_output)
                        # print(current_reconstruction_error)
                        if best_layer_sparsity_config is None or current_reconstruction_error <= best_reconstruction_error:
                            best_layer_sparsity_config = copy.deepcopy(sparsity_dict)
                            best_reconstruction_error = current_reconstruction_error
                        sparsity_dict[name]["sparsity"] -= weight_ratio_increase_ratio
                        sparsity_dict[name]["spectrum_threshold"] = compute_spectrum_threshold(layer_input, sparsity_dict[name]["sparsity"])
                    sparsity_dict = copy.deepcopy(best_layer_sparsity_config)
                best_layer_sparsity_config_dict[layer_sparsity+1] = copy.deepcopy(best_layer_sparsity_config)
                best_layer_sparsity_config_dict_path = os.path.join(args.greedy_search_dir, f"layer_{i}_mixed_sparsity_config_dict.yaml")
            with open(best_layer_sparsity_config_dict_path, 'w', encoding='utf-8') as f:
                yaml.dump(best_layer_sparsity_config_dict, f, sort_keys=False, allow_unicode=True, default_flow_style=False, indent=2)
                
        del sparsity_dict, input_normalized_sorted_dict
        layers[i] = None
        del layer
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache

    return quantizers


@torch.no_grad()
def mistral_eval_ppl(model, testenc, dev,  dataset: str, log_wandb: bool = False):
    print("Evaluating ...")

    testenc = testenc.input_ids
    nsamples = testenc.numel() // model.seqlen

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.embed_tokens = model.model.embed_tokens.to(dev)

    model.model.rotary_emb = model.model.rotary_emb.to(dev)

    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0, "attention_mask": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            cache["position_embeddings"] = kwargs["position_embeddings"]
            raise ValueError

    layers[0] = Catcher(layers[0])
    for i in range(nsamples):
        batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev)
        try:
            model(batch)
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache["attention_mask"]
    position_embeddings = cache["position_embeddings"]

    for i in range(len(layers)):
        print(i)

        layer = layers[i].to(dev)

        for j in range(nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()
        inps, outs = outs, inps

    if model.model.norm is not None:
        model.model.norm = model.model.norm.to(dev)
    model.lm_head = model.lm_head.to(dev)

    testenc = testenc.to(dev)
    nlls = []
    for i in range(nsamples):
        hidden_states = inps[i].unsqueeze(0)
        if model.model.norm is not None:
            hidden_states = model.model.norm(hidden_states)
        lm_logits = model.lm_head(hidden_states)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        neg_log_likelihood = loss.float() * model.seqlen
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
    print(f"Perplexity: {ppl.item():3f}")
    if log_wandb:
        wandb.log({f"{dataset}/perplexity": ppl.item()})

    model.config.use_cache = use_cache

if __name__ == "__main__":

    import argparse
    from data_utils import *

    parser = argparse.ArgumentParser()

    parser.add_argument(
        "model", type=str, help="mistral model to load")
    parser.add_argument(
        "dataset",
        type=str,
        choices=["wikitext2", "ptb", "c4", "alpaca"],
        help="Where to extract calibration data from.")
    parser.add_argument("--seed", type=int, default=0, help="Seed for sampling the calibration data.")
    parser.add_argument("--nsamples", type=int, default=128, help="Number of calibration data samples.")
    parser.add_argument("--log_wandb", action="store_true", help="Whether to log to wandb.")
    parser.add_argument("--rotate", type=str, default="True")
    parser.add_argument("--rotate_R1", type=str, default="True")
    parser.add_argument("--rotate_R2", type=str, default="True")
    parser.add_argument("--rotation_dir", type=str, default="", help="Path to saved rotation matrix.")
    parser.add_argument("--rotation_epoch", type=int, default=100)
    parser.add_argument("--distribute_model", type=str, default="False")
    parser.add_argument("--input_normalized_sorted_dir", type=str, default="None")
    parser.add_argument("--greedy_search_dir", type=str, default="None")
    parser.add_argument("--greedy_search_step_num", type=int, default=1)

    args = parser.parse_args()

    # init W&B logging
    if args.log_wandb:
        assert has_wandb, "wandb not installed try `pip install wandb`"
        wandb.init(config=args)

    model = get_mistral(args.model)
    model.eval()

    for param in model.parameters():
        param.requires_grad_(False)

    dataloader, testloader = get_loaders(
        args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
    )

    DEV = torch.device("cuda:0")
    tick = time.time()
    mistral_sequential(model, dataloader, DEV)
    torch.cuda.empty_cache()
    for n, p in model.named_parameters():
        print(n, torch.mean((p == 0).float()))
        if 'down_proj' in n:
            break
    print(time.time() - tick)