import math
import time
import tqdm
import torch
import torch.nn as nn
from vlmq.quantization.utils import utils, quant_utils, model_utils
import logging
import functools

import vlmq.quantization.vlmq.modeling_qwen2_vl_utils as modeling_qwen2_vl_utils
import vlmq.quantization.vlmq.modeling_qwen2_utils as modeling_qwen2_utils

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

class VLMQ:

    def __init__(self, layer):
        self.layer = layer
        self.dev = self.layer.weight.device
        W = layer.weight.data.clone()
        self.rows = W.shape[0]
        self.columns = W.shape[1]
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.dXXT = torch.zeros((self.columns, self.columns), device=self.dev)
        self.n_samples = 0
        self.fp_inp = []
        # TODO: after careful verification, we found the grad is collected in the reversed order with samples
        self.grad = []

    def add_batch(self, inp, out):

        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if len(inp.shape) == 3:
            inp = inp.reshape((-1, inp.shape[-1]))

        inp = inp.t()
        
        if len(self.grad) != 0:
            self.grad[0] = self.grad[0].expand_as(inp).float()
            inp *= self.grad[0]
            self.fp_inp[0] *= self.grad[0]

        self.H *= self.n_samples / (self.n_samples + tmp)
        self.dXXT *= self.n_samples / (self.n_samples + tmp)
        self.n_samples += tmp
        inp = math.sqrt(2 / self.n_samples) * inp.float()
        self.H += inp.matmul(inp.t())
        dX = self.fp_inp[0].float() * math.sqrt(2 / self.n_samples) - inp
        self.dXXT += dX.matmul(inp.t())

        if len(self.grad) != 0:
            del self.fp_inp[0], self.grad[0]
        else:
            del self.fp_inp[0]

    def fasterquant(
            self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False, static_groups=False, alpha=0.25
    ):
        W = self.layer.weight.data.clone()
        W = W.float()

        if not self.quantizer.ready():
            self.quantizer.find_params(W)

        H = self.H
        del self.H
        dead = torch.diag(H) == 0
        H[dead, dead] = 1
        W[:, dead] = 0
        self.dXXT[:, dead] = 0

        if static_groups:
            import copy
            groups = []
            for i in range(0, self.columns, groupsize):
                quantizer = copy.deepcopy(self.quantizer)
                quantizer.find_params(W[:, i:(i + groupsize)])
                groups.append(quantizer)

        if actorder:
            perm = torch.argsort(torch.diag(H), descending=True)
            W = W[:, perm]
            H = H[perm][:, perm]
            self.dXXT = self.dXXT[perm][:, perm]
            invperm = torch.argsort(perm)

        Losses = torch.zeros_like(W)
        Q = torch.zeros_like(W)


        damp = percdamp * torch.mean(torch.diag(H))
        diag = torch.arange(self.columns, device=self.dev)
        H[diag, diag] += damp
        Hinv = torch.linalg.cholesky(H)
        Hinv = torch.cholesky_inverse(Hinv)
        Hinv = torch.linalg.cholesky(Hinv, upper=True)
            
        # scale it by alpha due to collection of dXXT axnd H
        P = alpha * ((self.dXXT @ Hinv.T).triu_(diagonal=1)) @ Hinv
        del self.dXXT

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            count = i2 - i1

            W1 = W[:, i1:i2].clone()
            Q1 = torch.zeros_like(W1)
            Err1 = torch.zeros_like(W1)
            Losses1 = torch.zeros_like(W1)
            Hinv1 = Hinv[i1:i2, i1:i2]
            P1 = P[i1:i2, i1:i2]

            for i in range(count):
                w = W1[:, i]
                d = Hinv1[i, i]

                if groupsize != -1:
                    if not static_groups:
                        if (i1 + i) % groupsize == 0:
                            self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)])
                    else:
                        idx = i1 + i
                        if actorder:
                            idx = perm[idx]
                        self.quantizer = groups[idx // groupsize]

                q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
                Q1[:, i] = q
                Losses1[:, i] = (w - q) ** 2 / d ** 2

                err1 = (w - q) / d
                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0))
                Err1[:, i] = err1

            Q[:, i1:i2] = Q1
            Losses[:, i1:i2] = Losses1 / 2

            W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:])

        torch.cuda.synchronize()

        if actorder:
            Q = Q[:, invperm]

        self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
        if torch.any(torch.isnan(self.layer.weight.data)):
            logging.warning('NaN in weights')
            import pprint
            pprint.pprint(self.quantizer.bits, self.quantizer.scale, self.quantizer.zero_point)
            raise ValueError('NaN in weights')

    def free(self):
        self.H = None
        self.Losses = None
        self.Trace = None
        self.dXXT = None
        torch.cuda.empty_cache()
        utils.cleanup_memory(verbos=False)


@torch.no_grad()
def vlmq_fwrd(model, dataloader, dev, args):
    '''
    From GPTQ repo
    TODO: Make this function general to support both OPT and LLaMA models
    '''
    logging.info('-----VLMQ Quantization-----')

    input_ids = dataloader['input_ids']
    inputs_embeds = dataloader['inputs_embeds']
    vision_mask = dataloader['vision_mask']
    image_grid_thw = dataloader['image_grid_thw']
    del dataloader
    utils.cleanup_memory(verbos=True)
    
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    model.model.norm = model.model.norm.to(dev)
    model.model.rotary_emb = model.model.rotary_emb.to(dev)
    layers[0] = layers[0].to(dev)

    if args.model == 'llava_onevision':
        # TODO: convert dtype to avoid nan in gradients
        model=model.to(torch.bfloat16)
        inputs_embeds = inputs_embeds.to(torch.bfloat16)
        
    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.n_samples, args.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )

    cache = {'i': 0, 'kwargs': []}
    if args.model == 'llava_onevision':
        pass
    else:
        # caching kwargs
        class KwargsCatcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
            def forward(self, inp, **kwargs):
                # kwargs.keys()
                # dict_keys(['attention_mask', 'position_ids', 'past_key_value', 'output_attentions', 'use_cache', 'cache_position', 'position_embeddings'])
                kwargs.pop('position_embeddings')
                kwargs.pop('past_key_value')
                cache['kwargs'].append(kwargs)
                raise ValueError
            
        layers[0] = KwargsCatcher(layers[0])
        for i in range(args.n_samples):
            try:
                model(input_ids=input_ids[i:i+1].to(dev), image_grid_thw=image_grid_thw[i:i+1].to(dev))
            except ValueError:
                pass
        layers[0] = layers[0].module
    
    # caching inputs to the 1st layer
    class InputsCatcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            cache['i'] += 1
            cache['kwargs'].append(kwargs)
            raise ValueError

    kwargs = cache['kwargs']
    cache['kwargs'] = []
    # kwargs.pop('position_embeddings')
    # kwargs.pop('past_key_value')
    layers[0] = InputsCatcher(layers[0])
    for i in range(args.n_samples):
        try:
            if args.model == 'llava_onevision':
                model(inputs_embeds=inputs_embeds[i:i+1].to(dev))
            else:
                model(inputs_embeds=inputs_embeds[i:i+1].to(dev), **kwargs[i])
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    model.model.embed_tokens = model.model.embed_tokens.cpu()
    model.model.norm = model.model.norm.cpu()
    model.model.rotary_emb = model.model.rotary_emb.cpu()
    del inputs_embeds
    torch.cuda.empty_cache()
    kwargs = cache['kwargs']

    outs = torch.zeros_like(inps)

    quantizers = {}
    sequential = [
        ['self_attn.k_proj.module', 'self_attn.v_proj.module', 'self_attn.q_proj.module'],
        ['self_attn.o_proj.module'],
        ['mlp.up_proj.module', 'mlp.gate_proj.module'],
        ['mlp.down_proj.module']
    ]

    fp_inputs_cache = model_utils.FPInputsCache(sequential)
    out_grad_cache = model_utils.OutputGradCache(
        sequential, 
        grad_acton=args.grad_acton, 
        grad_norm=args.grad_norm, 
        grad_clip=args.grad_clip,
        grad_temperature=args.grad_temperature,
        grad_clip_times=args.grad_clip_times,
        grad_clip_high_only=args.grad_clip_high_only
        )
    
    fp_inps = inps.clone()

    start_quant_time = time.time()
    for i in range(len(layers)):
        print(f'\nLayer {i}:', flush=True, end=' ')
        layer = layers[i].to(dev)
        full = quant_utils.find_qlayers(layer, layers=[torch.nn.Linear])
        
        # TODO: record activate MAE loss start
        err = (fp_inps - inps).abs().mean()
        logging.info(f'{i}: Input activate MAE loss: {err.item()}')
        # record activate MAE loss done

        bits_config = quant_utils.disable_act_quant(layer)
        
        # cache output gradients start
        if i >= args.grad_start_idx:  # skip 1st layer
            bsz, n_batch = cache_grad_bsz(args.n_samples, args)
            print(f'Cache output gradients: batch size {bsz}, n_batch {n_batch}')
            tmp_fp_outs = torch.zeros_like(inps)
            
            # prepare tmp fp outs
            for bch_idx in range(n_batch):
                for smp_idx in range(bsz):
                    j = bch_idx * bsz + smp_idx
                    if args.grad_from == 'block_out':
                        tmp_fp_outs[j] = layer(fp_inps[j].unsqueeze(0), **kwargs[j])[0]
                        
                    elif args.grad_from == 'attn_out':
                        tmp_fp_outs[j] = layer.self_attn(
                            layer.input_layernorm(fp_inps[j].unsqueeze(0)), 
                            **kwargs[j])[0] + fp_inps[j].unsqueeze(0)

                    elif args.grad_from == 'o_proj_in':
                        if args.model == 'llava_onevision':
                            tmp_fp_outs[j] = modeling_qwen2_utils.attn_forward(
                                layer, 
                                fp_inps[j].unsqueeze(0),
                                kwargs[j]['position_embeddings'], 
                                kwargs[j]['past_key_value'], 
                                kwargs[j]['attention_mask'])[0]
                        else:
                            tmp_fp_outs[j] = modeling_qwen2_vl_utils.attn_forward(
                                layer, 
                                fp_inps[j].unsqueeze(0),
                                kwargs[j]['position_embeddings'], 
                                kwargs[j]['past_key_value'], 
                                kwargs[j]['attention_mask'])[0]
                    else:
                        raise ValueError(f'Unsupported grad_from: {args.grad_from}')
                
            out_grad_cache.add_hook(full)
            with torch.enable_grad():
                for bch_idx in range(n_batch):
                    layer.zero_grad()
                    tmp_q_outs = torch.zeros(
                        (bsz, args.seqlen, model.config.hidden_size), dtype=dtype, device=dev
                    )
                    for smp_idx in range(bsz):
                        j = bch_idx * bsz + smp_idx
                        if args.grad_from == 'block_out':
                            tmp_q_outs[smp_idx] = layer(inps[j].unsqueeze(0), **kwargs[j])[0]
                            
                        elif args.grad_from == 'attn_out':
                            tmp_q_outs[smp_idx] = layer.self_attn(
                                layer.input_layernorm(inps[j].unsqueeze(0)), 
                                **kwargs[j])[0] + inps[j].unsqueeze(0)
                        
                        elif args.grad_from == 'o_proj_in':
                            if args.model == 'llava_onevision':
                                tmp_q_outs[smp_idx] = modeling_qwen2_utils.attn_forward(
                                    layer, 
                                    inps[j].unsqueeze(0),
                                    kwargs[j]['position_embeddings'], 
                                    kwargs[j]['past_key_value'], 
                                    kwargs[j]['attention_mask'])[0]
                            else:
                                tmp_q_outs[smp_idx] = modeling_qwen2_vl_utils.attn_forward(
                                    layer, 
                                    inps[j].unsqueeze(0),
                                    kwargs[j]['position_embeddings'], 
                                    kwargs[j]['past_key_value'], 
                                    kwargs[j]['attention_mask'])[0]
                        else:
                            raise ValueError(f'Unsupported grad_from: {args.grad_from}')
                    
                    loss = (tmp_fp_outs[bch_idx*bsz : (bch_idx+1)*bsz] - tmp_q_outs).pow(2).sum(-1).mean()
                    loss.backward()
                
            out_grad_cache.clear_hook()
            del tmp_fp_outs, tmp_q_outs
            utils.cleanup_memory(verbos=False)
        # cache output gradients ends
        
        # cache fp inps starts
        fp_inputs_cache.add_hook(full)
        for j in range(args.n_samples):
            fp_inps[j] = layer(fp_inps[j].unsqueeze(0), **kwargs[j])[0]
        fp_inputs_cache.clear_hook()
        # cache fp inps done
        
        quant_utils.enable_act_quant(layer, bits_config)

        for names in sequential:
            subset = {n: full[n] for n in names}

            vlmq = {}
            for name in subset:
                print(f'{name}', end='  ', flush=True)
                layer_weight_bits = args.w_bits
                layer_weight_sym = not (args.w_asym)
                if 'lm_head' in name:
                    layer_weight_bits = 16
                    continue
                if args.int8_down_proj and 'down_proj' in name:
                    layer_weight_bits = 8
                vlmq[name] = VLMQ(subset[name])
                vlmq[name].quantizer = quant_utils.WeightQuantizer()
                vlmq[name].quantizer.configure(
                    layer_weight_bits, perchannel=True, sym=layer_weight_sym, mse=args.w_clip
                )
                vlmq[name].fp_inp = fp_inputs_cache.fp_cache[name]
                if len(out_grad_cache.grad_cache[name]) == 0:
                    vlmq[name].grad = []
                else:
                    if args.random_drop_ratio < 0:
                        assert len(out_grad_cache.grad_cache[name]) == args.n_samples
                        vlmq[name].grad = reverse_lst(out_grad_cache.grad_cache[name], segment_len=bsz)
                    else:
                        vlmq[name].grad = []
                        for m in range(args.n_samples):
                            cur_mask = vision_mask[m,:]
                            true_idx = torch.nonzero(cur_mask, as_tuple=False).squeeze()
                            num_true = true_idx.numel()
                            k = int(num_true * args.random_drop_ratio + 0.5)  
                            perm = true_idx[torch.randperm(num_true, device=cur_mask.device)[:k]]
                            new_mask = torch.zeros_like(cur_mask, dtype=torch.bool)
                            new_mask[perm] = True
                            cur_score = torch.ones_like(cur_mask, dtype=torch.float32)
                            cur_score[new_mask] = 0.01
                            cur_score = cur_score.numel() * cur_score / cur_score.sum()
                            vlmq[name].grad.append(cur_score)
                            
                
            def add_batch(name):
                def tmp(_, inp, out):
                    vlmq[name].add_batch(inp[0].data, out.data)
                return tmp
            
            handles = []
            for name in subset:
                handles.append(subset[name].register_forward_hook(add_batch(name)))
            for j in range(args.n_samples):
                outs[j] = layer(inps[j].unsqueeze(0), **kwargs[j])[0]
            for h in handles:
                h.remove()

            for name in subset:
                layer_w_groupsize = args.w_groupsize
                vlmq[name].fasterquant(
                    percdamp=args.percdamp, groupsize=layer_w_groupsize, actorder=args.act_order,
                    static_groups=args.static_groups, alpha=args.residual_alpha
                )
                quantizers['model.layers.%d.%s' % (i, name)] = vlmq[name].quantizer
                vlmq[name].free()

        for j in range(args.n_samples):
            outs[j] = layer(inps[j].unsqueeze(0), **kwargs[j])[0]

        fp_inputs_cache.clear_cache()
        out_grad_cache.clear_cache()
        layers[i] = layer.cpu()
        del layer
        del vlmq
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    logging.info(f"Calib time cost: {(time.time()-start_quant_time)/3600.0} hours\n")
    model.config.use_cache = use_cache
    utils.cleanup_memory(verbos=True)
    logging.info('-----VLMQ Quantization Done-----\n')

    return quantizers


def cache_grad_bsz(n_samples, args):
    '''
    return bsz, n_batch
    '''
    is_large_model = "72B" in args.model_args
    is_block_out = args.grad_from == 'block_out'
    is_o_proj_in = args.grad_from == 'o_proj_in'
    is_llava = args.model == 'llava_onevision'
    if is_large_model:
        return 2, int(n_samples/2)
    if is_block_out:
        return 1, n_samples
    if is_llava and is_o_proj_in:
        return int(n_samples/2), 2  # 7B llava-ov: 90%*80G
    else:
        return n_samples, 1
    
    
from typing import List, Any

def reverse_lst(lst: List[Any], segment_len: int) -> List[List[Any]]:
    assert len(lst) % segment_len == 0, "Length of list must be divisible by segment_len"
    results = []
    for i in range(0, len(lst), segment_len):
        segment = lst[i:i+segment_len]
        results.extend(segment[::-1])  # reverse the sublist
    return results
