# Import necessary modules
import time
import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm

# Import get_loaders function from data module within the same directory
from .data import get_loaders 

from collections import defaultdict
import fnmatch

from slicegpt import rotate, gpu_utils
from slicegpt.model_adapter import LayerAdapter, ModelAdapter, rot_mask_Linear

from copy import deepcopy
from lib.prune_opt import prune_magnitude, prune_wanda, prune_wanda_subset, prune_magnitude_subset, prune_slice_subset, check_sparsity, find_layers

import wandb
import logging
from accelerate import Accelerator

# Function to evaluate perplexity (ppl) on a specified model and tokenizer
def eval_ppl(args, model, testloader, device=torch.device("cuda:0")):
    # Set dataset
    dataset = "wikitext2"

    # Print status
    print(f"evaluating on {dataset}")

    # Evaluate ppl in no grad context to avoid updating the model
    with torch.no_grad():
        ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
    return ppl_test 

# Function to train on calibation set
def eval_ppl_wikitext_train(args, model_adapter: ModelAdapter, Qs, W_masks, trainloader, optimizer, bs=1, iters=None, device=None, mask_iter=None, dif_Q=True, num=1, accelerator:Accelerator=None):
    nsamples = len(trainloader)
    model = model_adapter.model
    model.train()
    
    if iters is None:
        iters = int(nsamples / bs)

    print(f"iters {iters}")
    loss_fct = nn.CrossEntropyLoss()
    
    # Loop through each batch
    for batch, gt_logits in tqdm(trainloader):
        
        with torch.cuda.amp.autocast():
            rotate.rotate_and_mask_implicit_sequential(model_adapter, Qs, W_masks, mask_shortcut=True, dif_Q=dif_Q, num=num)
            
            if args.distribute_model:
                inputs = batch['input_ids'].to(device=device)
                gt_logits = gt_logits[0].to(device=device)
            else:
                model.to(device=device)
                inputs = batch['input_ids'].to(device=device)
                gt_logits = gt_logits[0].to(device=device)
                

            # Forward pass through the model
            lm_logits = model(inputs).logits

            # Shift logits and labels for next token prediction
            shift_logits = lm_logits[:, :-1, :].contiguous()
            shift_labels = inputs[:, 1:]

            # Compute loss
            loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
            
            # distill loss
            # l2 0.0001
            # cos 10
            # js 1000
            # loss = sim(lm_logits, gt_logits, 'l2') / model.seqlen * 0.001
            # loss = loss + sim(lm_logits[:, -1, :], gt_logits[:, -1, :], 'js') * 1000
            loss = loss + sim(lm_logits, gt_logits, 'cos') * 10
        
            wandb.log({"loss": loss})
            
            # import pdb; pdb.set_trace()

            # Append to list of negative log likelihoods
            optimizer.zero_grad()
            if args.distribute_model:
                accelerator.backward(loss)
            else:
                loss.backward()
            optimizer.step()
            
            if args.distribute_model:
                accelerator.wait_for_everyone() 
                model = accelerator.unwrap_model(model)
                  
        torch.cuda.empty_cache()

def l1_norm_elementwise(W):
    result = 0
    for row in range(W.shape[0]):
        result += torch.norm(W[row], p=1)
        
    return result

def train_layer_by_layer_l1(args, model_adapter: ModelAdapter, Qs, trainloader, testloader, optimizer_type, bs=1, iters=None, device=None):

    model = model_adapter.model
    layers = model_adapter.get_layers()
    
    for Q in Qs:
        Q.requires_grad = False
        
    input_names = ['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'fc1']
    output_names = ['self_attn.out_proj', 'fc2']
        
    # for index, Q in enumerate(Qs):        
    #     Q.requires_grad = True
    #     epochs = 200
    #     optimizer = optimizer_type([{'params':Q,'lr':args.lr,'stiefel':True}])
    #     scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=epochs)
        
    #     subset = rotate.get_rotate_matrix(model_adapter=model_adapter, i=index)
        
    #     # import pdb; pdb.set_trace()
    #     model.to(device=device)
        
    #     for _ in tqdm(range(epochs)):
                
    #         with torch.cuda.amp.autocast():
    #             loss = 0
    #             cnt = 0
    #             for name in subset:
    #                 cnt += 1
    #                 if name in input_names:
    #                     loss += l1_norm_elementwise(subset[name].weight.data @ Q)
    #                 elif name in output_names:
    #                     loss += l1_norm_elementwise(subset[name].weight.data.T @ Q)
    #                 else:
    #                     loss += l1_norm_elementwise(Qs[index-1].T @ Q)
    #             loss /= cnt
    #             wandb.log({"loss": loss})
                
    #             optimizer.zero_grad()
    #             loss.backward()
    #             optimizer.step()

    #         scheduler.step()
            
    #     Q.requires_grad = False
    
    # torch.save(Qs, 'l1_rot.pt')
    Qs = torch.load('l1_rot.pt')
    rotate.rotate_sequential(model_adapter, Qs)    
    
    sparsity_ratios = []
    sum = 0
    cnt = 0
    min_a = 1000
    max_a = 0
    lamda = 0.2
    for i in range(len(layers)):
        sparsity_ratios.append({})
        layer = layers[i].layer
        subset = find_layers(layer, layers=[rot_mask_Linear, nn.Linear])
        
        for name in subset:
            W = subset[name].weight.data
            sparsity_ratios[i][name] = 1 / (l1_norm_elementwise(W) / W.numel())
            sparsity_ratios[i][name] = sparsity_ratios[i][name].to(dtype=torch.float64)
            sum += sparsity_ratios[i][name]
            cnt += 1
            min_a = min_a if min_a <= sparsity_ratios[i][name] else sparsity_ratios[i][name]
            max_a = max_a if max_a >= sparsity_ratios[i][name] else sparsity_ratios[i][name]
    
    mean_a = sum / cnt
    
    for i in range(len(layers)):
        layer = layers[i].layer
        subset = find_layers(layer, layers=[rot_mask_Linear, nn.Linear])
        
        for name in subset:
            W = subset[name].weight.data
            W_metric = torch.abs(W)
            
            sparsity = 2 * lamda * (sparsity_ratios[i][name] - min_a) / (max_a - min_a)
            sparsity = sparsity - 2 * lamda * (mean_a - min_a) / (max_a - min_a) + args.sparsity
            # sparsity = 0.5
            
            try:
                thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu()
                W_mask = (W_metric<=thresh)
            except IndexError:
                import pdb; pdb.set_trace()
            
            W[W_mask] = 0
            
    dataset_ppl = gpu_utils.evaluate_ppl(model, model.config.pad_token_id, testloader)
    logging.info(f'After training Q: {dataset_ppl:.4f}')
    wandb.log({"ppl": dataset_ppl})

    print(check_sparsity(model))
        
        
def train_layer_by_layer(args, model_adapter: ModelAdapter, Qs, trainloader, testloader, optimizer_type, bs=1, iters=None, device=None):
    # nsamples = len(trainloader)
    model = model_adapter.model
    
    model_dense = deepcopy(model)
    
    # if iters is None:
    #     iters = int(nsamples / bs)
    
    # Loop through each Q
    for Q in Qs:
        Q.requires_grad = False
        
    for index, Q in enumerate(Qs):
        model_cp = deepcopy(model) # 已经训练好的固定
        Q.requires_grad = True
        epochs = 20
        optimizer = optimizer_type([{'params':Q,'lr':args.lr,'stiefel':True}])
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=epochs)
        
        # next_layer_index = rotate.get_next_layer(model, index) # next_layer 需要获取输入
        
        sparse_logits, dense_logits = [0], [0]
        
        class sparse_Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
            def forward(self, inp, **kwargs):
                sparse_logits[0] = inp
                raise ValueError
            
        class dense_Catcher(nn.Module):
            def __init__(self, module):
                super().__init__()
                self.module = module
            def forward(self, inp, **kwargs):
                dense_logits[0] = inp
                raise ValueError
        
        for t in range(epochs):
            k = 0
            for batch in tqdm(trainloader):
                model_adapter._model = deepcopy(model_cp)
                model = model_adapter.model
                
                if k % 200 == 0: # 更新 mask
                    subset = rotate.rotate_one(model_adapter, Qs, index) # subset 表示需要训练的层
                    if args.prune_method == 'wanda':
                        W_mask = prune_wanda_subset(args=args, model_adapter=model_adapter, subset=subset, dataloader=trainloader, sparsity_ratio=args.sparsity, device=device)
                    elif args.prune_method == 'magnitude':
                        W_mask = prune_magnitude_subset(args=args, model_adapter=model_adapter, subset=subset, dataloader=trainloader, sparsity_ratio=args.sparsity, device=device)
                    elif args.prune_method == 'slice':
                        W_mask = prune_slice_subset(args=args, model_adapter=model_adapter, subset=subset, dataloader=trainloader, sparsity_ratio=args.sparsity, device=device, last=(index==len(Qs)-1))
                    if t == 0:
                        dataset_ppl = gpu_utils.evaluate_ppl(model, model.config.pad_token_id, testloader)
                        logging.info(f'Before training {index}-th Q: {dataset_ppl:.4f}')
                    model_adapter._model = deepcopy(model_cp)
                    model = model_adapter.model
                    
                k += 1
                
                if index % 2: # catch at mlp output
                    model_dense.model.decoder.layers[index//2].fc2 = dense_Catcher(model_dense.model.decoder.layers[index//2].fc2)
                    model.model.decoder.layers[index//2].fc2 = sparse_Catcher(model.model.decoder.layers[index//2].fc2)
                elif index < len(Qs) - 1: # catch at attn output
                    model_dense.model.decoder.layers[index//2].self_attn.out_proj = dense_Catcher(model_dense.model.decoder.layers[index//2].self_attn.out_proj)
                    model.model.decoder.layers[index//2].self_attn.out_proj = sparse_Catcher(model.model.decoder.layers[index//2].self_attn.out_proj)
                else:
                    model_dense.lm_head = dense_Catcher(model_dense.lm_head)
                    model.lm_head = sparse_Catcher(model.lm_head)
                
                # if index < len(Qs) - 3:
                #     model_dense.model.decoder.layers[index//2 + 1] = dense_Catcher(model_dense.model.decoder.layers[index//2 + 1])
                #     model.model.decoder.layers[index//2 + 1] = sparse_Catcher(model.model.decoder.layers[index//2 + 1])
                # else:
                #     model_dense.lm_head = dense_Catcher(model_dense.lm_head)
                #     model.lm_head = sparse_Catcher(model.lm_head)
                
                # with torch.cuda.amp.autocast():
                rotate.rotate_and_mask_one_implicit(model_adapter, Qs, index, W_mask)
                
                model.to(device=device)
                inputs = batch['input_ids'].to(device=device)
                
                # if index < len(Qs) - 1:
                try:
                    model(inputs)
                except ValueError:
                    pass
                try:
                    with torch.no_grad():
                        model_dense(inputs)
                except ValueError:
                    pass
                # else:
                #     sparse_logits[0] = model(inputs).logits
                #     with torch.no_grad():
                #         dense_logits[0] = model_dense(inputs).logits
                    
                with torch.cuda.amp.autocast():
                    loss = sim(sparse_logits[0], dense_logits[0], 'l2') / 100
                    # loss = sim(sparse_logits[0], dense_logits[0], 'cos')
                    # loss = sim(sparse_logits[0], dense_logits[0], 'cos') + sim(sparse_logits[0], dense_logits[0], 'l2') / 1000
                    # loss = sim(sparse_logits[0], dense_logits[0], 'js') * 1000
                    
                    # if index % 2: 
                    #     loss = loss * 3

                    wandb.log({"loss": loss})
                    
                    # if index == 1:
                    #     import pdb; pdb.set_trace()

                    # Append to list of negative log likelihoods
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    # torch.cuda.empty_cache()
                
                if index % 2: # catch at mlp output
                    model_dense.model.decoder.layers[index//2].fc2 = model_dense.model.decoder.layers[index//2].fc2.module
                    model.model.decoder.layers[index//2].fc2 = model.model.decoder.layers[index//2].fc2.module
                elif index < len(Qs) - 1: # catch at attn output
                    model_dense.model.decoder.layers[index//2].self_attn.out_proj = model_dense.model.decoder.layers[index//2].self_attn.out_proj.module
                    model.model.decoder.layers[index//2].self_attn.out_proj = model.model.decoder.layers[index//2].self_attn.out_proj.module
                else:
                    model_dense.lm_head = model_dense.lm_head.module
                    model.lm_head = model.lm_head.module
                    
                # if index < len(Qs) - 3:
                #     model_dense.model.decoder.layers[index//2 + 1] = model_dense.model.decoder.layers[index//2 + 1].module
                #     model.model.decoder.layers[index//2 + 1] = model.model.decoder.layers[index//2 + 1].module
                # else:
                #     model_dense.lm_head = model_dense.lm_head.module
                #     model.lm_head = model.lm_head.module
            
            # dataset_ppl = gpu_utils.evaluate_ppl(model, model.config.pad_token_id, testloader)
            # logging.info(f'After epoch {t} in {index} Q: {dataset_ppl:.4f}')
            # wandb.log({"ppl": dataset_ppl})
            scheduler.step()
        
        # fix rotation and mask
        model_adapter._model = deepcopy(model_cp)
        model = model_adapter.model
        subset = rotate.rotate_and_mask_one(model_adapter, Qs, index, W_mask)
        # prune_slice_subset(args=args, model_adapter=model_adapter, subset=subset, dataloader=trainloader, sparsity_ratio=args.sparsity, device=device)
        
        Q.requires_grad = False
        
        dataset_ppl = gpu_utils.evaluate_ppl(model, model.config.pad_token_id, testloader)
        logging.info(f'After training {index}-th Q: {dataset_ppl:.4f}')
        wandb.log({"ppl": dataset_ppl})
    
        print(check_sparsity(model))

def sim(a, b, type):
    if type == "l2":
        return torch.norm(a-b)
    
    if type == 'cos':
        cos = torch.nn.CosineSimilarity(dim=-1)
        return (1 - cos(a, b)).mean()
    
    if type == 'js':
        # a = a.reshape(-1, a.shape[-1])
        a = F.log_softmax(a, dim=-1)
        # b = b.reshape(-1, b.shape[-1])
        b = F.log_softmax(b, dim=-1)
        kl = torch.nn.KLDivLoss(reduction='mean', log_target=True)
        return (kl(a, b) + kl(b, a))/2

def ppl_train(model, model_dense, trainloader, optimizer, bs=2, size=128, device=None):
    # Get input IDs
    # testenc = testenc.input_ids

    # Calculate number of samples
    # nsamples = testenc.numel() // model.seqlen
    nsamples = len(trainloader)
    
    small = [i[0] for i in trainloader]   
    idx = range(nsamples)
    chosen_idx = np.random.choice(idx, size, replace=False) 
    small = torch.from_numpy(np.array(small)[chosen_idx])

    # List to store negative log likelihoods
    print(f"nsamples {size}")
    
    # Loop through each batch
    for i in range(0,size,bs):
        # Calculate end index
        j = min(i+bs, size)
        # print(f"sample {i} to {j}")
        with torch.cuda.amp.autocast():
            # Prepare inputs and move to device
            # inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
            # import pdb; pdb.set_trace()
            inputs = small[i:j].to(device)
            inputs = inputs.reshape(j-i, model.seqlen)

            # Forward pass through the model
            lm_logits = model(inputs).logits
            with torch.no_grad():
                gt_logits = model_dense(inputs).logits

            # loss = sim(lm_logits, gt_logits, 'cos') * model.seqlen * (j-i)
            loss = sim(lm_logits, gt_logits, 'js') * model.seqlen * (j-i)
            loss = loss.float()
            if i == 0:
                print(loss.item())
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        torch.cuda.empty_cache()

# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext(model, testenc, bs=1, device=None):
    # Get input IDs
    testenc = testenc.input_ids

    # Calculate number of samples
    nsamples = testenc.numel() // model.seqlen

    # List to store negative log likelihoods
    nlls = []
    print(f"nsamples {nsamples}")

    # Loop through each batch
    for i in range(0,nsamples,bs):
        if i % 50 == 0:
            print(f"sample {i}")

        # Calculate end index
        j = min(i+bs, nsamples)

        # Prepare inputs and move to device
        inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device)
        inputs = inputs.reshape(j-i, model.seqlen)

        # Forward pass through the model
        lm_logits = model(inputs).logits

        # Shift logits and labels for next token prediction
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = inputs[:, 1:]

        # Compute loss
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))

        # Calculate negative log likelihood
        neg_log_likelihood = loss.float() * model.seqlen * (j-i)

        # Append to list of negative log likelihoods
        nlls.append(neg_log_likelihood)

    # Compute perplexity
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))

    # Empty CUDA cache to save memory
    torch.cuda.empty_cache()

    return ppl.item()


def eval_zero_shot(model_name, model, tokenizer, task_list=["boolq","rte","hellaswag","winogrande","arc_challenge","arc_easy","openbookqa"], 
        num_fewshot=0, use_accelerate=False, add_special_tokens=False):
    from lm_eval import tasks, evaluator 
    def pattern_match(patterns, source_list):
        task_names = set()
        for pattern in patterns:
            for matching in fnmatch.filter(source_list, pattern):
                task_names.add(matching)
        return list(task_names)
    task_names = pattern_match(task_list, tasks.ALL_TASKS)
    model_args = f"pretrained={model_name},cache_dir=./llm_weights"
    limit = None 
    if "70b" in model_name or "65b" in model_name:
        limit = 2000
    if use_accelerate:
        model_args = f"pretrained={model_name},cache_dir=./llm_weights,use_accelerate=True"
    results = evaluator.simple_evaluate(
        model="hf-causal-experimental",
        model_args=model_args,
        tasks=task_names,
        num_fewshot=num_fewshot,
        batch_size=None,
        device=None,
        no_cache=True,
        limit=limit,
        description_dict={},
        decontamination_ngrams_path=None,
        check_integrity=False,
        pretrained_model=model,
        tokenizer=tokenizer, 
        add_special_tokens=add_special_tokens
    )

    return results 