# from utils import monkeypatch
import functools
import os

from torch import nn
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
)
from transformers.models.llama.modeling_llama import (
    apply_rotary_pos_emb as pos_emb_llama,
)
from tqdm import tqdm
import torch
import random
from typing import Any, Dict

import datasets
import transformers
from tqdm.auto import tqdm



def get_wikitext2(
    nsamples=128,
    seed=0,
    seqlen=2048,
    model="",
    tokenizer=None,
    eval_mode=False,
    vision=False,
):
    print("get_wikitext2")

    if tokenizer is None:
        if vision:
            tokenizer = transformers.AutoProcessor.from_pretrained(model)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                model, use_fast=False
            )

    if eval_mode:
        testdata = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[
            "test"
        ]
        testenc = tokenizer(text="\n\n".join(testdata["text"]), return_tensors="pt")
        return testenc
    else:
        traindata = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")[
            "train"
        ]
        trainenc = tokenizer(text="\n\n".join(traindata["text"]), return_tensors="pt")
        random.seed(seed)
        trainloader = []
        for _ in range(nsamples):
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            tar = inp.clone()
            tar[:, :-1] = -100
            trainloader.append((inp, tar))
        return trainloader


def get_c4(
    nsamples=128,
    seed=0,
    seqlen=2048,
    model="",
    tokenizer=None,
    eval_mode=False,
    vision=False,
):
    print("get_c4")

    if tokenizer is None:
        if vision:
            tokenizer = transformers.AutoProcessor.from_pretrained(model)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                model, use_fast=False
            )

    if eval_mode:
        valdata = datasets.load_dataset(
            "allenai/c4",
            data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
            split="validation",
        )
        random.seed(0)
        valenc = []
        for _ in range(256):
            while True:
                i = random.randint(0, len(valdata) - 1)
                tmp = tokenizer(text=valdata[i]["text"], return_tensors="pt")
                if tmp.input_ids.shape[1] > seqlen:
                    break
            i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            valenc.append(tmp.input_ids[:, i:j])
        valenc = torch.hstack(valenc)
        return valenc
    else:
        random.seed(seed)
        traindata = datasets.load_dataset(
            "allenai/c4",
            data_files={"train": "en/c4-train.00000-of-01024.json.gz"},
            split="train",
        )
        trainloader = []
        for _ in range(nsamples):
            while True:
                i = random.randint(0, len(traindata) - 1)
                trainenc = tokenizer(text=traindata[i]["text"], return_tensors="pt")
                if trainenc.input_ids.shape[1] > seqlen:
                    break
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            tar = inp.clone()
            tar[:, :-1] = -100
            trainloader.append((inp, tar))
        return trainloader


def get_ptb(
    nsamples=128,
    seed=0,
    seqlen=2048,
    model="",
    tokenizer=None,
    eval_mode=False,
    vision=False,
):
    print("get_ptb")

    if tokenizer is None:
        if vision:
            tokenizer = transformers.AutoProcessor.from_pretrained(model)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                model, use_fast=False
            )

    if eval_mode:
        testdata = datasets.load_dataset("ptb_text_only", "penn_treebank", split="test")
        testenc = tokenizer(text="\n\n".join(testdata["sentence"]), return_tensors="pt")
        return testenc
    else:
        traindata = datasets.load_dataset(
            "ptb_text_only", "penn_treebank", split="train", trust_remote_code=True
        )
        trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt")
        random.seed(seed)
        trainloader = []
        for _ in range(nsamples):
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            tar = inp.clone()
            tar[:, :-1] = -100
            trainloader.append((inp, tar))
        return trainloader


def get_alpaca(
    nsamples=128,
    seed=0,
    seqlen=2048,
    model="",
    tokenizer=None,
    eval_mode=False,
    vision=False,
):
    print("get_alpaca")

    if tokenizer is None:
        if vision:
            tokenizer = transformers.AutoProcessor.from_pretrained(model)
        else:
            tokenizer = transformers.AutoTokenizer.from_pretrained(
                model, use_fast=False
            )

    if eval_mode:
        raise ValueError  # no eval set for alpaca
    else:
        ds = datasets.load_dataset("tatsu-lab/alpaca")
        ds = ds.remove_columns(["input", "output", "instruction"])
        traindata = ds["train"]
        trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt")
        random.seed(seed)
        trainloader = []
        for _ in range(nsamples):
            i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = trainenc.input_ids[:, i:j]
            tar = inp.clone()
            tar[:, :-1] = -100
            trainloader.append((inp, tar))

        return trainloader


def get_data(
    calib_dataset="wikitext",
    nsamples=128,
    seed=0,
    seqlen=2048,
    model="",
    tokenizer=None,
    eval_mode=False,
    vision=False,
):
    if "wikitext" in calib_dataset:
        return get_wikitext2(
            nsamples, seed, seqlen, model, tokenizer, eval_mode, vision
        )
    elif "c4" in calib_dataset:
        return get_c4(nsamples, seed, seqlen, model, tokenizer, eval_mode, vision)
    elif "ptb" in calib_dataset:
        return get_ptb(nsamples, seed, seqlen, model, tokenizer, eval_mode, vision)
    elif "alpaca" in calib_dataset:
        return get_alpaca(nsamples, seed, seqlen, model, tokenizer, eval_mode, vision)
    else:
        raise NotImplementedError

@torch.no_grad()
def get_basis(args) -> None:
    if not os.path.exists(args.rotation_path):
        os.makedirs(os.path.dirname(args.rotation_path), exist_ok=True)

        config = AutoConfig.from_pretrained(
            args.pretrained_model_name_or_path,
        )

        config.use_flash = True
        dtype = torch.bfloat16 if args.bf16 else torch.float16

        llm = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=args.pretrained_model_name_or_path,
            torch_dtype=dtype,
            config=config,
            attn_implementation=args.attn_implementation,
        )

        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path=args.pretrained_model_name_or_path,
            padding_side="right",
            use_fast=True,
            add_eos_token=False,
            add_bos_token=False,
        )

        llm.eval()

        llm.config.use_cache = False
        seqlen = args.calib_seqlen

        train_data = get_data(
            seed=args.seed,
            tokenizer=tokenizer,
            eval_mode=False,
            nsamples=args.calib_samples,
            calib_dataset=args.calib_dataset,
            vision=False,
        )
        nbatches = len(train_data)
        layers = llm.model.layers
        llm.model.embed_tokens = llm.model.embed_tokens.to("cuda")
        if hasattr(llm.model, "rotary_emb"):
            llm.model.rotary_emb = llm.model.rotary_emb.to("cuda")

        layers[0] = layers[0].to("cuda")

        dtype = next(iter(llm.parameters())).dtype

        # The input of the first decoder layer.
        inps = torch.zeros(
            (nbatches, seqlen, llm.config.hidden_size),
            dtype=dtype,
            device="cuda",
        )
        inps = [0] * nbatches
        cache = {"i": 0, "attention_mask": None}

        class Catcher(torch.nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
                if hasattr(module, "attention_type"):
                    self.attention_type = module.attention_type

            def forward(self, inp, **kwargs):
                inps[cache["i"]] = inp
                cache["i"] += 1
                cache["attention_mask"] = kwargs["attention_mask"]
                cache["position_ids"] = kwargs["position_ids"]
                if "position_embeddings" in kwargs:
                    cache["position_embeddings"] = kwargs["position_embeddings"]
                else:
                    cache["position_embeddings"] = None
                cache["cross_attention_states"] = (
                    kwargs["cross_attention_states"]
                    if "cross_attention_states" in kwargs.keys()
                    else None
                )
                cache["cross_attention_mask"] = (
                    kwargs["cross_attention_mask"]
                    if "cross_attention_mask" in kwargs.keys()
                    else None
                )
                cache["full_text_row_masked_out_mask"] = (
                    kwargs["full_text_row_masked_out_mask"]
                    if "full_text_row_masked_out_mask" in kwargs.keys()
                    else None
                )
                cache["cache_position"] = (
                    kwargs["cache_position"]
                    if "cache_position" in kwargs.keys()
                    else None
                )
                raise ValueError

        layers[0] = Catcher(layers[0])
        for i in range(nbatches):
            batch = train_data[i][0].to("cuda")
            try:
                llm(batch)
            except ValueError:
                pass
        layers[0] = layers[0].module
        layers[0] = layers[0].cpu()

        llm.model.embed_tokens = llm.model.embed_tokens.cpu()
        if hasattr(llm.model, "rotary_emb"):
            llm.model.rotary_emb = llm.model.rotary_emb.cpu()
        position_ids = cache["position_ids"]
        position_embeddings = cache["position_embeddings"]
        attention_mask = cache["attention_mask"]

        torch.cuda.empty_cache()
        outs = [0] * nbatches

        basis_dict = {}
        eval_dict = {}

        hidden_dim = llm.config.hidden_size
        num_attention_heads = llm.config.num_attention_heads
        head_dim = llm.model.layers[0].self_attn.head_dim
        kv_heads = llm.config.num_key_value_heads
        nlayers = len(layers)
        cov_device = "cuda"
        
        # random matrix for mixed precision
        low_prec = int(head_dim*(1-args.high_frac))
        high_prec = int(head_dim*args.high_frac)
        rand_matrix = torch.block_diag(random_orthogonal_matrix(low_prec, "cuda"), 
                                        torch.eye(high_prec, device="cuda", dtype=torch.float64))

        for i in tqdm(range(nlayers), desc="(Calibration) Layers"):
            layer = layers[i].to("cuda")

            hooks = []

            def hook_fn_vproj(module, input, output):
                global output_vproj
                output_vproj = output

            def hook_fn_kproj(module, input, output):
                global output_kproj
                output_kproj = output

            def hook_fn_qproj(module, input, output):
                global output_qproj
                output_qproj = output

            hooks.append(layer.self_attn.v_proj.register_forward_hook(hook_fn_vproj))
            hooks.append(layer.self_attn.k_proj.register_forward_hook(hook_fn_kproj))
            hooks.append(layer.self_attn.q_proj.register_forward_hook(hook_fn_qproj))

            H_value = 0.0
            H_key_pos = 0.0
            for j in range(nbatches):
                # 1 sample at a time
                input_shape = inps[j].shape[:-1]
                hidden_shape = (*input_shape, -1, head_dim)
                outs[j] = layer(
                    inps[j],
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    position_embeddings=position_embeddings,
                )
                # reshape to get value states per head
                value_states = output_vproj.view(hidden_shape).transpose(1, 2)

                # rope cos, sin
                if position_embeddings is not None:
                    cos, sin = position_embeddings
                else:
                    if "llama" in args.pretrained_model_name_or_path.lower():
                        cos, sin = layer.self_attn.rotary_emb(
                            value_states, position_ids
                        )
                    elif "qwen" in args.pretrained_model_name_or_path.lower():
                        cos, sin = layer.self_attn.rotary_emb(value_states, seqlen)
                
                if hasattr(layer.self_attn, "q_norm"): #qwen3
                    # apply norm and reshape to get query,key states per head
                    key_states = layer.self_attn.k_norm(output_kproj.view(hidden_shape)).transpose(1,2)
                    query_states = layer.self_attn.q_norm(output_qproj.view(hidden_shape)).transpose(1,2)
                else:
                    # reshape to get query,key states per head
                    key_states = output_kproj.view(hidden_shape).transpose(1, 2)
                    query_states = output_qproj.view(hidden_shape).transpose(1, 2)

                # apply rotary embedding
                query_states_pos, key_states_pos = pos_emb_llama(
                    query_states, key_states, cos, sin
                )
                if i == 16:
                    had = hadamard_matrix(head_dim, "cuda").to(dtype)
                    key_pos_had = torch.matmul(key_states_pos, had)
                    value_had = torch.matmul(value_states, had)

                    torch.save(value_states[0,0].cpu().numpy(), "qwen2p5-3b-l17-head0-values.pt")
                    torch.save(value_had[0,0].cpu().numpy(), "qwen2p5-3b-l17-head0-values-hadamard.pt")
                    torch.save(key_states_pos[0,0].cpu().numpy(), "qwen2p5-3b-l17-head0-keys.pt")
                    torch.save(key_pos_had[0,0].cpu().numpy(), "qwen2p5-3b-l17-head0-keys-hadamard.pt")
                    # breakpoint()
                H_value += torch.sum(
                    value_states.double().mT @ value_states.double(), dim=0
                ).to(
                    cov_device
                )  # shape : [num_heads, head_dim, head_dim]

                H_key_pos += torch.sum(
                    key_states_pos.double().mT @ key_states_pos.double(),
                    dim=(0),
                ).to(
                    cov_device
                )  # shape : [num_kv_heads, head_dim, head_dim]

            # eigen decomposition of value states and multiplication with random matrix for ResQ projection
            eval_value, evec_value = perform_eigen_decomp(
                (H_value / (seqlen * nbatches)),
                per_head=True,
                num_heads=kv_heads,
            )
            evec_value = torch.matmul(evec_value, rand_matrix)

            # eigen decomposition of key states after rope embedding and multiplication with random matrix for ResQ projection
            eval_k_pos, evec_k_pos = perform_eigen_decomp(
                H_key_pos.sum(0) / (kv_heads * nbatches * seqlen),
            )
            evec_k_pos = torch.matmul(evec_k_pos, rand_matrix)

            basis_dict["U-layer." + str(i) + ".self_attn.value"] = evec_value.cpu()
            basis_dict["U-layer." + str(i) + ".self_attn.key_pos"] = evec_k_pos.cpu()
            print(evec_k_pos.shape[-1])
            for hook in hooks:
                hook.remove()

            layers[i] = layers[i].cpu()

            torch.cuda.empty_cache()

            inps, outs = outs, inps

        torch.save(
            basis_dict,
            args.rotation_path,
        )

    else:
        print(f"Basis rotations already exist at {args.rotation_path}")


def perform_eigen_decomp(Cov_matrix, per_head=False, num_heads=0):
    # performs eigen decomposition and returns
    # the sorted eigen values and eigen vectors
    Cov_matrix = Cov_matrix.to("cuda")
    if per_head:
        assert num_heads != 0  # cannot use per head and not pass num_heads
        eval = []
        evec = []
        for hd in range(num_heads):
            H = Cov_matrix[hd]
            damp = 0.01 * torch.mean(torch.diag(H))
            diag = torch.arange(H.shape[-1]).to(device=H.device)
            H[diag, diag] = H[diag, diag] + damp
            X = torch.linalg.eigh(H.to(torch.float64))
            index = torch.argsort(X[0])
            eval.append(X[0][index])
            evec.append(X[1][:, index])
        eval = torch.stack(eval)
        evec = torch.stack(evec)
    else:
        H = Cov_matrix
        damp = 0.01 * torch.mean(torch.diag(H))
        diag = torch.arange(H.shape[-1]).to(device=H.device)
        H[diag, diag] = H[diag, diag] + damp
        X = torch.linalg.eigh(H.to(torch.float64))
        index = torch.argsort(X[0])
        eval = X[0][index]
        evec = X[1][:, index]

    return eval, evec


def rotate_ov_proj(layer, v_rot, o_rot):
    head_dim = v_rot.shape[-1]

    W_ = layer.self_attn.v_proj.weight.data
    B_ = layer.self_attn.v_proj.bias
    W_ = W_.t()
    device = W_.device
    dtype = W_.dtype
    transposed_shape = W_.shape
    temp = W_.reshape(-1, transposed_shape[-1] // head_dim, head_dim).to(
        device="cuda", dtype=torch.float64
    )
    if B_ is not None:
        bias_shape = B_.shape
        temp_bias = B_.reshape(transposed_shape[-1] // head_dim, head_dim).to(
            device="cuda", dtype=torch.float64
        )
    num_kv_heads = transposed_shape[-1] // head_dim
    for i in range(num_kv_heads):
        temp[:, i, :] = temp[:, i, :] @ v_rot[i]
        if B_ is not None:
            temp_bias[i] = temp_bias[i] @ v_rot[i]

    W_ = temp.reshape(transposed_shape).t()
    if B_ is not None:
        B_ = temp_bias.reshape(bias_shape)

    layer.self_attn.v_proj.weight.data = W_.to(device=device, dtype=dtype)
    if B_ is not None:
        layer.self_attn.v_proj.bias.data = B_.to(device=device, dtype=dtype)

    W_ = layer.self_attn.o_proj.weight.data
    init_shape = W_.shape
    device = W_.device
    dtype = W_.dtype
    temp = W_.reshape(-1, init_shape[-1] // head_dim, head_dim).to(
        device="cuda", dtype=torch.float64
    )
    num_kv_groups = init_shape[1] // (head_dim * num_kv_heads)
    for i in range(num_kv_heads):
        for j in range(num_kv_groups):
            idx = j + num_kv_groups * i
            temp[:, idx, :] = temp[:, idx, :] @ o_rot[i]

    W_ = temp.reshape(init_shape).to(device=device, dtype=dtype)
    layer.self_attn.o_proj.weight.data = W_
    return


def add_rotations(model, args):
    # R_dict = torch.load(args.rotation_path)
    config = model.config
    num_heads = config.num_attention_heads
    model_dim = config.hidden_size
    head_dim = model.model.layers[0].self_attn.head_dim
    # head_dim = model_dim // num_heads
    num_kv_heads = config.num_key_value_heads
    if "resq" in args.apply_rot:
        R_dict = torch.load(args.rotation_path)
    if "value" in args.apply_rot:
        layers = model.model.layers
        for idx, layer in enumerate(tqdm(layers, desc="Rotating Values")):
            if "resq" in args.apply_rot:
                name = f"U-layer.{idx}.self_attn.value"
                v_rotation = R_dict[name].cuda()
                o_rotation = R_dict[name].cuda()
            else:
                # if idx <= 13 :
                # if idx > 13 :
                # if idx <= 9 :
                # if idx <= 18 and idx >=10 :
                # if idx >= 19 :
                value_rand = hadamard_matrix(head_dim, "cuda")
                v_rotation = value_rand.unsqueeze(0).repeat(num_kv_heads,1,1)
                o_rotation = value_rand.unsqueeze(0).repeat(num_kv_heads,1,1)
                # else:
                # continue

            rotate_ov_proj(layer, v_rotation, o_rotation)
    if "key" in args.apply_rot:
        layers = model.model.layers
        dtype = layers[0].self_attn.q_proj.weight.dtype
        for idx, layer in enumerate(tqdm(layers, desc="Rotating Keys")):
            if "resq" in args.apply_rot:
                name = f"U-layer.{idx}.self_attn.key_pos"
                q_rotation = R_dict[name].cuda().to(dtype)
                k_rotation = R_dict[name].cuda().to(dtype)
            else: 
                q_rotation = hadamard_matrix(head_dim, "cuda").to(dtype)
                k_rotation = hadamard_matrix(head_dim, "cuda").to(dtype)
            layer.self_attn.q_rotation = q_rotation
            layer.self_attn.k_rotation = k_rotation
    return model


def random_orthogonal_matrix(size, device):
    """
    Generate a random orthogonal matrix of the specified size.
    First, we generate a random matrix with entries from a standard distribution.
    Then, we use QR decomposition to obtain an orthogonal matrix.
    Finally, we multiply by a diagonal matrix with diag r to adjust the signs.

    Args:
    size (int): The size of the matrix (size x size).

    Returns:
    torch.Tensor: An orthogonal matrix of the specified size.
    """
    torch.cuda.empty_cache()
    random_matrix = torch.randn(size, size, dtype=torch.float64).to(device)
    q, r = torch.linalg.qr(random_matrix)
    q *= torch.sign(torch.diag(r)).unsqueeze(0)
    return q


def random_hadamard_matrix(size, device):
    # See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
    Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
    Q = Q * 2 - 1
    Q = torch.diag(Q)
    return matmul_hadU(Q).to(device)

def hadamard_matrix(size, device):
    Q = torch.eye(size).to(torch.float64)
    return matmul_hadU(Q).to(device)

def matmul_hadU(X, transpose=False):
    n = X.shape[-1]
    input = X.clone().view(-1, n, 1)
    output = input.clone()
    while input.shape[1] > 1:
        input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
        output = output.view(input.shape)
        output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
        output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
        output = output.view(input.shape[0], input.shape[1], -1)
        (input, output) = (output, input)
    del output

    return input.view(X.shape) / torch.tensor(n).sqrt()


class QKRotationWrapper(torch.nn.Module):
    def __init__(self, func, k_rotation, q_rotation):
        super().__init__()
        self.func = func
        self.k_rotation = k_rotation
        self.q_rotation = q_rotation

    def forward(self, *args, **kwargs):
        q, k = self.func(*args, **kwargs)
        q = torch.matmul(q, self.q_rotation.to(q))
        k = torch.matmul(k, self.k_rotation.to(k))

        return q, k


def add_qk_rotation_wrapper_after_function_call_in_forward(
    module,
    function_name,
    *args,
    **kwargs,
):
    """
    This function adds a rotation wrapper after the output of a function call in forward.
    Only calls directly in the forward function are affected. calls by other functions called in forward are not affected.
    """

    attr_name = f"{function_name}_qk_rotation_wrapper"
    assert not hasattr(module, attr_name)
    wrapper = monkeypatch.add_wrapper_after_function_call_in_method(
        module,
        "forward",
        function_name,
        functools.partial(QKRotationWrapper, *args, **kwargs),
    )
    setattr(module, attr_name, wrapper)
