import time

import torch
import torch.nn as nn

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

from model_utils import *
from llama_utils import get_llama, llama_fuse_rms_single_layer, llama_fuse_rotation_single_layer, replace_llama_layer

from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR

from transformers import AutoTokenizer

try:
    import wandb
    has_wandb = True
except:
    has_wandb = False

import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import gc

@torch.no_grad()
def llama_sequential(model, dataloader, dev):
    print("Starting...")

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.rotary_emb = model.model.rotary_emb.to(dev)
    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    model.model.norm = model.model.norm.to(dev)
    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0, "attention_mask": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            cache["position_embeddings"] = kwargs["position_embeddings"]
            raise ValueError

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    model.model.norm = model.model.norm.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache["attention_mask"]
    position_embeddings = cache["position_embeddings"]

    print("Ready.")
    batch_size = 16
    quantizers = {}
    for i in range(len(layers)):
        print(i)
        
        layer = layers[i].to(dev)

        if i < args.start_layer_id:
            for j in tqdm(range(0, args.nsamples, batch_size)):
                batch_inp = inps[j:j+batch_size]
                outs[j:j+batch_size] = layer(batch_inp, attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                inps, outs = outs, inps
                layers[i] = None
                del layer
                torch.cuda.empty_cache()
                continue
        elif i > args.end_layer_id:
            break
        else:
            llama_fuse_rms_single_layer(layer)
            if args.rotate == "True":
                for param in layer.parameters():
                    param.requires_grad_(False)
                full = find_layers(layer)
                names = list(full.keys())
                subset = {n: full[n] for n in names}
                gpts = {}
                for name in subset:
                    gpts[name] = LayerWrapper(subset[name])
                def add_batch(name):
                    def tmp(_, inp, out):
                        gpts[name].update_hessian_mean(inp[0].data, out.data)
                    return tmp
                handles = []
                name_list = ["self_attn.q_proj", "self_attn.o_proj", "mlp.up_proj"]
                for name in name_list:
                    handles.append(subset[name].register_forward_hook(add_batch(name)))
                batch_size = 16
                for j in tqdm(range(0, args.nsamples, batch_size)):
                    batch_inp = inps[j:j+batch_size]
                    outs[j:j+batch_size] = layer(batch_inp, attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                R1 = None
                R2_list = None

                weight_dict_list = [find_layers(layer)]
                hessian_dict = {}
                for name in subset:
                    hessian_dict[name] = gpts[name].H
                hessian_dict_list = [hessian_dict]
                R_model = RotatorOptimizer(weight_dict_list, model.config.hidden_size, num_key_value_heads=model.config.num_key_value_heads, head_dim=model.config.head_dim, device=dev, positive=True, hessian_dict_list=hessian_dict_list, dtype=torch.float32, with_weight=False)
                R_model = R_model.to(dev)
                R_model.compute_cov_norm_squared()

                h1_cov_norm_squared = R_model.h1_cov_norm_squared.item()
                h2_cov_norm_squared = R_model.h2_cov_norm_squared.item()
                h3_cov_norm_squared = R_model.h3_cov_norm_squared.item()

                if args.initialize_with_pca == "True":
                    R_model.initialize_with_pca()

                scaler = GradScaler()
                opt_R1 = torch.optim.Adam(R_model.parameters_R1(), lr=args.rotate_lr, maximize=True)
                opt_R2 = torch.optim.Adam(R_model.parameters_R2(), lr=args.rotate_lr, maximize=True)
                
                best_R1 = R_model.get_orthogonal_matrix()
                best_R2_list = R_model.get_orthogonal_matrix_R2_list_list()[0]
                best_h1_h3_loss = -float('inf')
                best_h2_loss = -float('inf')

                for epoch in tqdm(range(args.rotation_epoch)):
                    epoch_start_time = time.time()
                    for step_id in range(1):
                        with torch.enable_grad():
                            opt_R1.zero_grad()
                            h1_loss = R_model.forward_h1()
                            h3_loss = R_model.forward_h3()
                            h1_h3_loss = h1_loss + h3_loss
                            scaler.scale(h1_h3_loss).backward()
                            scaler.unscale_(opt_R1)
                            torch.nn.utils.clip_grad_norm_(R_model.parameters_R1(), 1.0)
                            scaler.step(opt_R1)
                            scaler.update()

                            if h1_h3_loss.item() > best_h1_h3_loss:
                                best_h1_h3_loss = h1_h3_loss.item()
                                best_R1 = R_model.get_orthogonal_matrix()
                                best_h1_loss = h1_loss.item()
                                best_h3_loss = h3_loss.item()

                            opt_R2.zero_grad()
                            h2_loss = R_model.forward_h2()
                            scaler.scale(h2_loss).backward()
                            scaler.unscale_(opt_R2)
                            torch.nn.utils.clip_grad_norm_(R_model.parameters_R2(), 1.0)
                            scaler.step(opt_R2)
                            scaler.update()

                            if h2_loss.item() > best_h2_loss:
                                best_h2_loss = h2_loss.item()
                                best_R2_list = R_model.get_orthogonal_matrix_R2_list_list()[0]


                            if epoch == 0 and step_id == 0:
                                tqdm.write(f"Initial loss R1: {h1_h3_loss.item()} R2: {h2_loss.item()}")
                                tqdm.write(f"Initial ratio h1: {abs(best_h1_loss) / h1_cov_norm_squared} h2: {abs(best_h2_loss) / h2_cov_norm_squared} h3: {abs(best_h3_loss) / h3_cov_norm_squared}")

                    tqdm.write(f"R1: {h1_h3_loss.item()} R2: {h2_loss.item()}")
                    epoch_end_time = time.time()

                print(f"End ratio h1: {abs(best_h1_loss) / h1_cov_norm_squared} h2: {abs(best_h2_loss) / h2_cov_norm_squared} h3: {abs(best_h3_loss) / h3_cov_norm_squared}")

                del weight_dict_list
                for h in handles:
                    h.remove()
                torch.cuda.empty_cache()

                R1 = best_R1.to(dtype)

                R2_list = best_R2_list
                for k in range(len(R2_list)):
                    R2_list[k] = R2_list[k].to(dtype)
                
                R1_path = os.path.join(args.rotation_dir, f"layer_{i}_R1.pt")
                R2_list_path = os.path.join(args.rotation_dir, f"layer_{i}_R2_list.pt")
                if args.rotate_R1 == "True":
                    torch.save(R1, R1_path)
                else:
                    R1 = None
                if args.rotate_R2 == "True":
                    torch.save(R2_list, R2_list_path)
                else:
                    R2_list = None

                llama_fuse_rotation_single_layer(layer, R1, R2_list)

                if args.rotate_R1 == "True":
                    layers[i] = layer
                    replace_llama_layer(model, i, R1)

                    layer = layers[i].to(dev)

                del gpts
                torch.cuda.empty_cache()

            name_list = ["self_attn.q_proj", "self_attn.o_proj", "mlp.up_proj", "mlp.down_proj"]
            full = find_layers(layer)
            names = list(full.keys())
            subset = {n: full[n] for n in names}
            gpts = {}
            for name in name_list:
                gpts[name] = LayerWrapper(subset[name])
            def add_batch(name):
                def tmp(_, inp, out):
                    gpts[name].set_input(inp[0].data, out.data)
                return tmp
            handles = []
            for name in name_list:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            input_dataset = ActivationDataset(name_list, torch.device("cuda"))
            batch_size = 16
            for j in tqdm(range(0, args.nsamples, batch_size)):
                batch_inp = inps[j:j+batch_size]
                outs[j:j+batch_size] = layer(batch_inp, attention_mask=attention_mask, position_embeddings=position_embeddings)[0]
                input_dict = {}
                for name in name_list:
                    input_dict[name] = gpts[name].input.squeeze(0)
                    gpts[name].reset_input()
                input_dataset.update(input_dict)
                del input_dict
                torch.cuda.empty_cache()
            for h in handles:
                h.remove()
            layers[i] = None
            del layer
            torch.cuda.empty_cache()

            input_dataset.concat_lists_to_tensors()
            input_dataset.normalize_and_sort()

            torch.save(input_dataset.input_normalized_sorted_dict,  os.path.join(args.input_normalized_sorted_dir, f"layer_{i}_input_normalized_sorted.pt"))
            
            del gpts
            del input_dataset
            torch.cuda.empty_cache()

            inps, outs = outs, inps

    model.config.use_cache = use_cache

    return quantizers



if __name__ == "__main__":

    import argparse
    from data_utils import *

    parser = argparse.ArgumentParser()

    parser.add_argument("model", type=str, help="LlaMA model to load")
    parser.add_argument(
        "dataset",
        type=str,
        choices=["wikitext2", "ptb", "c4", "alpaca"],
        help="Where to extract calibration data from.",
    )
    parser.add_argument("--seed", type=int, default=0, help="Seed for sampling the calibration data.")
    parser.add_argument("--nsamples", type=int, default=16, help="Number of calibration data samples.")
    parser.add_argument("--wbits", type=int, default=16, help="Whether to quantize as well.")
    parser.add_argument("--log_wandb", action="store_true", help="Whether to log to wandb.")
    parser.add_argument("--rotation_dir", type=str, default="", help="Path to saved rotation matrix.")
    parser.add_argument("--rotation_epoch", type=int, default=0)
    parser.add_argument("--input_normalized_sorted_dir", type=str, default="None")
    parser.add_argument("--rotate", type=str, default="True")
    parser.add_argument("--rotate_R1", type=str, default="True")
    parser.add_argument("--rotate_R2", type=str, default="True")
    parser.add_argument("--rotate_lr", type=float, default=0.000001)
    parser.add_argument("--initialize_with_pca", type=str, default="True")
    parser.add_argument("--start_layer_id", type=int, default=0)
    parser.add_argument("--end_layer_id", type=int, default=0)

    args = parser.parse_args()

    if args.log_wandb:
        assert has_wandb, "wandb not installed try `pip install wandb`"
        wandb.init(config=args)

    model = get_llama(args.model)
    model.eval()

    for param in model.parameters():
        param.requires_grad_(False)

    dataloader, testloader = get_loaders(
        args.dataset, nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen
    )

    DEV = torch.device("cuda:0")
    tick = time.time()
    llama_sequential(model, dataloader, DEV)
    torch.cuda.empty_cache()
    for n, p in model.named_parameters():
        print(n, torch.mean((p == 0).float()))
        if 'down_proj' in n:
            break
    print(time.time() - tick)