import utils
import torch
import model_utils
import data_utils
import transformers
import quant_utils
import eval_utils
import os

import numpy as np
import torch.nn as nn
from tqdm import tqdm
import gc

#! The implemenations are modified from the TEAL's github repo:
#! https://github.com/FasterDecoding/TEAL

def interp(x, xp, fp):
    """Custom interpolation function for PyTorch tensors."""
    i = torch.searchsorted(xp, x)
    i = torch.clamp(i, 1, len(xp) - 1)
    
    xp_left = xp[i - 1]
    xp_right = xp[i]
    fp_left = fp[i - 1]
    fp_right = fp[i]
    
    t = (x - xp_left) / (xp_right - xp_left)
    return fp_left + t * (fp_right - fp_left)

#! The class to get distribution and get thresholds.
class Distribution:
    def __init__(self, file_path):

        self.file_path = file_path

        histogram = torch.load(f"{self.file_path}/histograms.pt")

        self.bin_centers, self.counts = histogram['centers'], histogram['counts']
        self.total_count = self.counts.sum()
        self.cumulative_counts = torch.cumsum(self.counts, dim=0)
    
    #! The one used by authors, with Gaussian kernel smoothing.
    def pdf(self, x, bandwidth=None):
        if bandwidth is None:
            bandwidth = 1.06 * torch.std(self.bin_centers[1:-1]) * (self.total_count-2)**(-1/5) #! 0.15 was used in the issue, in Github, 1.06 is used.
        bin_centers = self.bin_centers.unsqueeze(1)
        
        if isinstance(x, float) or isinstance(x, int):
            x = torch.tensor([x])
        else:
            x = x.unsqueeze(0)
        
        kernel = torch.exp(-0.5 * ((x - bin_centers) / bandwidth)**2) / (bandwidth * torch.sqrt(torch.tensor(2 * torch.pi)))
        pdf = torch.sum(kernel * self.counts.unsqueeze(1), dim=0) / self.total_count
        
        return pdf
    
    #! Below is the running average version, the author provided this as a simpler version
    # def pdf(self, x, window_size=50):
    #     # Convert input to tensor if it's a single number
    #     if isinstance(x, (float, int)):
    #         x = torch.tensor([x])
    #     # Compute moving average of counts
    #     smoothed_counts = torch.nn.functional.avg_pool1d(
    #         self.counts.unsqueeze(0).unsqueeze(0),
    #         kernel_size=window_size,
    #         stride=1,
    #         padding=window_size // 2
    #     ).squeeze(0).squeeze(0)
    #     # Interpolate smoothed counts at input x
    #     pdf = interp(x, self.bin_centers, smoothed_counts)
    #     return pdf
    
    def cdf(self, x):
        return interp(x, self.bin_centers, self.cumulative_counts / self.total_count)
    
    
    def icdf(self, q):
        # if q < 0.01 or q > 0.99:
        #     print(f"WARNING: All outliers clip to the most extreme bin")

        target_count = q * self.total_count
        idx = torch.searchsorted(self.cumulative_counts, target_count)
        
        if idx == 0:
            return self.bin_centers[0]
        elif idx == len(self.bin_centers):
            return self.bin_centers[-1]
        else:
            lower_count = self.cumulative_counts[idx - 1]
            upper_count = self.cumulative_counts[idx]
            lower_value = self.bin_centers[idx - 1]
            upper_value = self.bin_centers[idx]
            
            fraction = (target_count - lower_count) / (upper_count - lower_count)
            return lower_value + fraction * (upper_value - lower_value)


#! Below is for historgram creating, copied from TEAL's repository.
class ActivationModule(torch.nn.Module):
    def __init__(self, module:torch.nn.Linear, name, grab_mode = False, exp_name='dobby'):
        super(ActivationModule, self).__init__()
        self.name = name
        self.module = module #! Original Linear layer, do not touch
        
        self.grab_mode = grab_mode #! This is the mode to turn on/off the activation grab
        self.exp_name = exp_name
        #? We don't want to put stress on storage
        #? So, we only store 4 cases, instead of full 7 cases.
        if name in ['q_proj', 'k_proj', 'v_proj']:
            self.file_path = f'./histogram/{exp_name}/attn_h1'
        elif name in ['o_proj']:
            self.file_path = f'./histogram/{exp_name}/attn_h2'
        elif name in ['up_proj', 'gate_proj']:
            self.file_path = f'./histogram/{exp_name}/mlp_h1'
        elif name in ['down_proj']:
            self.file_path = f'./histogram/{exp_name}/mlp_h2'
        else:
            raise ValueError('Not support other models now.')

        self.activations = []
        self.histograms = None
        self.distrs = None
        self.sparsity_level = 0.0
        
        # store is to store stuff like position_ids in attn (for convinience, is bad code)
        self.store = {}

        self.act_stats = {
            'min': float('inf'),
            'max': 0.0,
            'total': 0.0,
            'count': 0,
            'moving_avg': 0.0
        }

    def grab_activations(self, x):
        #! In the current settings, the training calibration data has the shape of [1, seqlen (2048), hidden]
        bsz, seqlen, hdim = x.shape
        self.activations.append(x.detach().reshape(-1, hdim).cpu().float()) #! This is more general, even works with batch > 1
        #! Original implementation as below, it's the same with ours since the batch = 1.
        #! self.activations[key].append(x.detach().squeeze(0).cpu().float())
        # print(self.activations)
    def save_activations(self):
        self.activations = self.combine_activations()
        torch.save(self.activations, f"{self.file_path}/activations.pt")

    def load_activations(self):
        self.activations = torch.load(f"{self.file_path}/activations.pt")

    # NOTE: This doesn't store outlier activation values
    def find_histogram(self, num_bins=10000, outlier_threshold=0.01):
        if self.histograms is None:
            # for fine-grained analysis, do not combine activations
            self.activations = self.combine_activations()
            self.histograms = {}
        else:
            return self.histograms
        
        torch.cuda.empty_cache()

        acts = self.activations
        acts = acts.flatten().detach().to('cuda')
        acts = torch.sort(acts)[0]

        lower_bound = acts[int(outlier_threshold * len(acts))]
        upper_bound = acts[-int(outlier_threshold * len(acts))]

        acts = acts.cpu()

        main_bins = torch.linspace(lower_bound, upper_bound, num_bins - 1)
        bins = torch.cat([torch.tensor([acts[0]]), main_bins, torch.tensor([acts[-1]])])

        counts, _ = torch.histogram(acts, bins=bins)

        bin_centers = (bins[:-1] + bins[1:]) / 2

        self.histograms['counts'] = counts.float().cpu()
        self.histograms[f'centers'] = bin_centers.float().cpu()
        return self.histograms
    
    def save_histogram(self):
        os.makedirs(self.file_path, exist_ok=True)
        torch.save(self.histograms, f"{self.file_path}/histograms.pt")

    def combine_activations(self):
        combined_activations = torch.cat(self.activations, dim=0)
        return combined_activations

    def set_threshold(self, sparsity):
        self.distrs = Distribution(self.file_path)
        self.threshold = self.distrs.icdf(0.5 + sparsity/2).item() if sparsity != 0.0 else 0.0
        self.sparsity_level = sparsity

    def th_prune(self, x):
        #! Original implementation, not for per hidden.
        return x.abs().gt(self.threshold) * x

    
    def forward(self, x):
        if self.grab_mode:
            self.grab_activations(x)
        else:
            x = self.th_prune(x)
            x_flat = x.view(-1, x.size(-1))
            sparsity_tmp = (1-x_flat.count_nonzero(dim=-1)/x.size(-1))

            self.act_stats['max'] = max(self.act_stats['max'], torch.max(sparsity_tmp).item())
            self.act_stats['min'] = min(self.act_stats['min'], torch.min(sparsity_tmp).item())

            self.act_stats['total'] += torch.mean(sparsity_tmp).item()
            self.act_stats['count'] += 1
            self.act_stats['moving_avg'] = self.act_stats['total'] / self.act_stats['count']

        x = self.module(x)
        return x


#! We turn on the grab mode since this is supposed to add at beginning with data distribution collection.
def add_actprofile(module, name='', layers=[torch.nn.Linear,
                                          ActivationModule], grab_mode = True, exp_name='dobby'):
    if isinstance(module, ActivationModule):
        return
    for attr in dir(module):
        tmp = getattr(module, attr)
        if type(tmp) in layers:
            setattr(module, attr, ActivationModule(tmp, name=attr, grab_mode=grab_mode, exp_name=exp_name))
        if type(tmp) == torch.nn.Sequential:
            raise ValueError('Should not enter this case in this implementation.')
            replaced = []
            for i, child in enumerate(tmp.children()):
                if type(child) in layers:
                    replaced.append(ActivationModule(child)) #! This case seems will not be used.
                else:
                    replaced.append(child)
            setattr(module, attr, torch.nn.Sequential(*replaced))
        if type(tmp) == torch.nn.ModuleList:
            raise ValueError('Should not enter this case in this implementation.')
            replaced = []
            for i, child in enumerate(tmp.children()):
                if type(child) in layers:
                    replaced.append(ActivationModule(child)) #! This case seems will not be used.
                else:
                    replaced.append(child)
            setattr(module, attr, torch.nn.ModuleList(replaced))
    for name1, child in module.named_children():
        add_actprofile(child, name + '.' + name1 if name != '' else name1, layers, grab_mode, exp_name)

def disable_act_grab(model):
    idx = 0
    for layer in model.model.layers:
        for name, m in layer.named_modules():
            if isinstance(m, ActivationModule):
                    m.grab_mode = False
                    if f'/Layer{idx}' not in m.file_path:
                        m.file_path += f'/Layer{idx}' #! To align with the collection phase
        idx += 1

def set_model_sparsity(model, sparsity=0.0):
    for layer in model.model.layers:
        for name, m in layer.named_modules():
            if isinstance(m, ActivationModule):
                    m.set_threshold(sparsity) #! This sets the threshold


def main():
    args = utils.parser_gen()

    transformers.set_seed(args.seed)
    model = model_utils.get_model(args.model, args.hf_token)
    model.eval()
    if args.act_teal:
        add_actprofile(model.model.layers, exp_name=args.wandb_name)#! This add act_prune wrappers to all linear layers in the model.
    else:
        quant_utils.add_actprune(model)#! This add act_prune wrappers to all linear layers in the model.


    if args.load_ckpt:
        load_path = os.path.join(args.load_pmodel_path, args.model, args.wandb_name)
        print("Load prunned model from ", load_path)
        save_dict = torch.load(os.path.join(load_path,"calibrated_model.pt"))
        model.load_state_dict(save_dict["model"], strict=False)

    if args.act_distr_catch:
        sample_rate = 32

        #! Below is the code for get distribution
        trainloader = data_utils.get_loaders(
                        'c4', nsamples=sample_rate,
                        seed=args.seed, model=args.model,
                        seqlen=model.seqlen, eval_mode=False)
        
        act_catcher(model, trainloader, utils.DEV)

    #! At this point, the calibration data should be already gathered and merged as the histogram.

    #! First, let's turn off the grab mode.
    if args.act_teal:
        disable_act_grab(model)
        #! The next is to set the thresholds for each activation function according to the sparsity.
        set_model_sparsity(model, args.a_sparsity)
    else:
        def add_ap(model, args):
            # Add Input Sparsity
            players = quant_utils.find_layers(model, layers=[quant_utils.ActPruneWrapper])
            for name in players:
                players[name].pruner.configure(sparsity=args.a_sparsity, annealing=args.enable_ap_anneal, annealer=args.nsamples-1)
        add_ap(model.model.layers, args)

    if args.lm_ppl:
        # Evaluating on dataset
        testloader = data_utils.get_loaders(
            args.eval_dataset,
            seed=args.seed,
            model=args.model,
            seqlen=model.seqlen,
            hf_token=args.hf_token,
            eval_mode=True
            )
        dataset_ppl = eval_utils.evaluator(model, testloader, utils.DEV, args)
        print(dataset_ppl)
        
        #! Below is the code to check the per layer sparsity ratio
        # if args.act_teal:
        #     idx = 0
        #     for layer in model.model.layers:
        #         print(f'Layer: {idx}')
        #         for name, m in layer.named_modules():
        #             if isinstance(m, ActivationModule):
        #                 minspa = m.act_stats['min']
        #                 maxspa = m.act_stats['max']
        #                 avgspa = m.act_stats['moving_avg']
        #                 print(f'{name}: min {minspa}, max {maxspa}, avg: {round(avgspa,3)}')
        #         idx+=1

    if not args.lm_eval:
        return
    else:
        # Import lm_eval utils
        import lm_eval
        from lm_eval import utils as lm_eval_utils
        from lm_eval.api.registry import ALL_TASKS
        from lm_eval.models.huggingface import HFLM

    if args.distribute:
        utils.distribute_model(model)
    else:
        model.to(utils.DEV)

    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, use_fast=False, use_auth_token=args.hf_token, cache_dir=args.hf_cache_path) 
    hflm = HFLM(pretrained=model, tokenizer=tokenizer, batch_size=args.lm_eval_batch_size)

    # commenting out this line as it will include two lambda sub-tasks
    # task_names = lm_eval_utils.pattern_match(args.tasks, ALL_TASKS)
    task_names = args.tasks
    results = lm_eval.simple_evaluate(hflm, tasks=task_names, batch_size=args.lm_eval_batch_size)['results']

    metric_vals = {task: round((result.get('acc_norm,none', result['acc,none'])), 4) for task, result in results.items()}
    metric_vals['acc_avg'] = round((sum(metric_vals.values()) / len(metric_vals.values())), 4)
    print(metric_vals)

                        


@torch.no_grad()
def act_catcher(model, trainenc, dev):
    # print(model)
    print('-----Teal Threshold Calibration Start-----')
    model.eval()

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers
    model.model.embed_tokens = model.model.embed_tokens.to(dev)
    model.model.rotary_emb = model.model.rotary_emb.to(dev)
    layers[0] = layers[0].to(dev)
    dtype = next(iter(model.parameters())).dtype

    nsamples = len(trainenc)

    inps = torch.zeros(
        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0, "attention_mask": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            cache['position_embeddings'] = kwargs['position_embeddings']
            raise ValueError

    layers[0] = Catcher(layers[0])

    for batch in trainenc:
        try:
            model(batch[0].to(dev))
        except ValueError:
            pass

    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()

    model.model.embed_tokens = model.model.embed_tokens.cpu()
    model.model.rotary_emb = model.model.rotary_emb.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache["attention_mask"]
    position_embeddings = cache['position_embeddings']

    for i in tqdm(range(len(layers))):
        layer = layers[i].to(dev)
        
        for j in range(nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_embeddings=position_embeddings)[0]

        layers[i] = layer.cpu()
        inps, outs = outs, inps
        
        #! Go through each actvation modules outside the linear to get histograms and delete the garbages to release memory.
        for m in layer.modules():
            if isinstance(m, ActivationModule):
                m.file_path += f'/Layer{i}'
                m.find_histogram()
                m.save_histogram()
                del m.activations
        del layer
        gc.collect()
        torch.cuda.empty_cache()
        utils.cleanup_memory(verbos=False)
        
    
    model.config.use_cache = use_cache
    utils.cleanup_memory(verbos=True)
    print('-----Teal Threshold Calibration Done-----')


if __name__ == '__main__':
    main()