import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, concatenate_datasets
import functools
import copy
# from resnet_new import *
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import argparse
import os

import random
import torch
import numpy as np
import re
#from utils import *
import math
import numpy as np
import cv2
import matplotlib.pyplot as plt
#from datautils import *

from torch.utils.data import DataLoader
from transformers import default_data_collator
from tqdm import tqdm

from sklearn.cluster import SpectralClustering
import numpy as np
import random

parser = argparse.ArgumentParser(description='PyTorch Switch-base-8 SLT')
parser.add_argument('--model_dir', type=str, required=True, help="Name of the model")
parser.add_argument('--seed', type=int, default=0, help="Seed used to initialize p-rng")
parser.add_argument('--teal_path', type=str, required=True, help="Path to TEAL")
parser.add_argument('--cur_dir', type=str, required=True, help="Path to the current directory")
parser.add_argument('--teal_sparsity_type', type=str, default="greedy", help="Type of sparsity for TEAL")
parser.add_argument('--layer_name', type=str, default="up_proj", help="Name of the layer to be pruned")
parser.add_argument('--layers', type=int, nargs='+', default=None, help="Layers to be pruned")
args = parser.parse_args()

torch.manual_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
torch.backends.cudnn.deterministic = True

def display_block(new_mask, show_separation = True, square_size = 4096, save=True, filename=None):
    image = new_mask.type(torch.float32).cpu().numpy()
    rescaled_image = cv2.resize(image, (square_size, square_size), interpolation=cv2.INTER_NEAREST)
    plt.imshow(rescaled_image, interpolation="nearest")
    plt.xlabel("Features", fontsize=72)
    plt.ylabel("Samples", fontsize=72)
    plt.xticks([])
    plt.yticks([])
    cbar = plt.colorbar()
    cbar.ax.tick_params(labelsize=48)
    if show_separation:
        count = 0
        for i in range(num_blocks-1):
            count+=int(features_per_cluster[cluster_ordering[i]]*(4096/len(membership)))
            plt.axvline(x=count, color="r")
        count = 0
        total_samples = np.sum(samples_per_cluster)
        for i in range(num_blocks):
            count+=int(samples_per_cluster[cluster_ordering[i]]*(4096/total_samples))
            plt.axhline(y=count, color="r")
    plt.show()
    if(save and (filename is None)):
        raise Exception
    plt.savefig(filename+("_separations" if show_separation else "")+".png")

def rsetattr(obj, attr, val):
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)

def rgetattr(obj, attr, *args):
    def _getattr(obj, attr):
        return getattr(obj, attr, *args)
    return functools.reduce(_getattr, [obj] + attr.split('.'))

class CatcherExit(Exception):
    pass

class Catcher(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
        
    def forward(self, inp, **kwargs):
        inps.append(inp.detach().cpu())
        raise CatcherExit()

from TEAL.utils.utils import get_tokenizer, get_sparse_model, SparsifyFn, get_layer_greedy_sparsities
import os
import torch

sparsities = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
threshold_file_name = "_".join([args.teal_sparsity_type, "thresholds", args.model_dir.split("/")[-1]])+".pt"

teal_results_dir = os.path.join(args.cur_dir, "teal_results")
if(not os.path.exists(teal_results_dir)):
    os.mkdir(teal_results_dir)

if(args.teal_sparsity_type == "uniform"):
    if(os.path.exists(os.path.join(teal_results_dir, threshold_file_name))):
        uniform_thresholds = torch.load(os.path.join(teal_results_dir, threshold_file_name))
        print("Loaded uniform thresholds from existing file.")
    else:
        model = get_sparse_model(args.model_dir, device="auto", histogram_path=os.path.join(args.teal_path, "histograms"))
        uniform_thresholds = {}
        for name, module in model.named_modules():
            if("SparsifyFn" in str(type(module))):
                uniform_thresholds[name] = {}
                for sp in sparsities:
                    module.set_threshold(sparsity = sp)
                    uniform_thresholds[name][sp] = module.get_threshold()
        torch.save(uniform_thresholds, os.path.join(teal_results_dir, threshold_file_name))
elif(args.teal_sparsity_type == "greedy"):
    if(os.path.exists(os.path.join(teal_results_dir, threshold_file_name))):
        greedy_thresholds = torch.load(os.path.join(teal_results_dir, threshold_file_name))
        print("Loaded greedy thresholds from existing file.")
    else:
        model = get_sparse_model(args.model_dir, device="auto", histogram_path=os.path.join(args.teal_path, "histograms"))
        greedy_thresholds = {}
        for name, module in model.named_modules():
            if("SparsifyFn" in str(type(module))):
                greedy_thresholds[name] = {}
        for sp in sparsities:
            greedy_sps = get_layer_greedy_sparsities([sp]*32, os.path.join(args.teal_path, "lookup"))
            model.set_sparsities(greedy_sps)
            for name, module in model.named_modules():
                if("SparsifyFn" in str(type(module))):
                    greedy_thresholds[name][sp] = module.get_threshold()
        torch.save(greedy_thresholds, os.path.join(teal_results_dir, threshold_file_name))

model_id = args.model_dir
tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# os.environ['CUDA_VISIBLE_DEVICES'] = '2'
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float, device_map="auto", local_files_only=True)

new_dataset = load_dataset("wikitext", "wikitext-2-raw-v1") 
def tokenize_function(examples):
    tokens = tokenizer(examples["text"])
    return tokens

# Create train/test split
tokenized_datasets = new_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])
# print(tokenized_datasets["train"][1])

block_size = 1024
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

valloader = DataLoader(
    lm_datasets["validation"],
    batch_size=8,
    shuffle=False,  # Do NOT shuffle during evaluation
    collate_fn=default_data_collator,  # Handles padding and tensor collation
)

if(not os.path.exists(os.path.join(teal_results_dir, "Block_Imgs"))):
    os.mkdir(os.path.join(teal_results_dir, "Block_Imgs"))

assert args.layer_name in ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
block_name = "self_attn" if args.layer_name.split("_")[0] in ['q','k','v','o'] else "mlp"
if(args.layers is None):
    args.layers = list(range(model.config.num_hidden_layers))

for lay_num in args.layers:
    layer_name = '.'.join(["model.layers", str(lay_num), block_name, args.layer_name])
    sparsity_module_name = '.'.join(["model.layers", str(lay_num), block_name, "sparse_fns", args.layer_name.split("_")[0]])
    ###### Initialize the inps list to gather the input tensors and setup the Catcher module
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float, device_map="auto", local_files_only=True)
    rsetattr(model, layer_name, Catcher(rgetattr(model, layer_name)))
    inps = []
    with torch.no_grad():
        for batch in tqdm(valloader):
            try:
                batch = {k: v.to(model.device) for k, v in batch.items()}
                _ = model(**batch)
            except CatcherExit:
                pass
    inps = torch.vstack(inps)
    inps = inps.contiguous().view(-1, inps.shape[-1])

    for SPARSITY in [0.0]+sparsities:
        MASKED = True
        GREEDY = (args.teal_sparsity_type == "greedy")
        NORMALIZED = True
        COSINE_SIM = True ### False calculates distances
        if(MASKED):
            if(SPARSITY==0.0):
                threshold = 0.0
            elif(GREEDY):
                threshold = greedy_thresholds[sparsity_module_name][SPARSITY]
            else:
                threshold = uniform_thresholds[sparsity_module_name][SPARSITY]
            new_acts = inps*(inps.abs()>threshold)
        else:
            new_acts = inps
        if(NORMALIZED and COSINE_SIM):
            normalized = new_acts*torch.rsqrt((new_acts**2).sum(dim=0)+1e-9).view(1,-1)
        else:
            normalized = new_acts

        activation_sparsity = (inps.abs()>threshold).sum() / inps.numel()
        if(COSINE_SIM):
            transpose = normalized.T
            feature_similarity = torch.zeros(normalized.shape[1], normalized.shape[1]).to(model.device)
            for i in range(math.ceil(normalized.shape[0]/1000)):
                if(i%10 == 0):
                    print(i)
                feature_similarity += transpose[:, 1000*i:1000*(i+1)].to(model.device)@normalized[1000*i:1000*(i+1),:].to(model.device)
        else:
            col_norms_squared = (new_acts ** 2).sum(dim=0)
            G = new_acts.T @ new_acts  # shape: (d, d)
            D_squared = (
                col_norms_squared.unsqueeze(1)   # (d, 1)
                + col_norms_squared.unsqueeze(0) # (1, d)
                - 2 * G
            )
            D_squared = D_squared.clamp(min=0.0)
            feature_similarity = torch.sqrt(D_squared)
            feature_similarity = feature_similarity.max() - feature_similarity

        SEED = 0
        NUM_BLOCKS = 20
        
        torch.manual_seed(SEED)
        random.seed(SEED)
        np.random.seed(SEED)
        torch.backends.cudnn.deterministic = True
        
        sc = SpectralClustering(
            n_clusters=NUM_BLOCKS,  # Choose number of blocks you expect
            affinity='precomputed'
        )
        
        labels = sc.fit_predict((feature_similarity-feature_similarity.min()).cpu().numpy())

        num_features_per_cluster = [(labels==i).sum() for i in range(20)]
        print(num_features_per_cluster)

        #######  FIND CDF  #######
        ##### IMP NOTE: Set normalized to either the normalized activation matrix or the unnormalized activation matrix
        # normalized = acts[args.layer]
        # normalized = acts[args.layer]*torch.rsqrt((acts[args.layer]**2).sum(dim=0)).view(1,-1)
        num_bins = 100000
        NORMALIZED_FOR_SCORES = False
        if(NORMALIZED_FOR_SCORES):
            abs_normalized = normalized.abs()
        else:
            abs_normalized = new_acts.abs()
        
        act_binning = (torch.floor((abs_normalized - abs_normalized.min())*num_bins/((abs_normalized.max()+1e-8) - abs_normalized.min()))).to(dtype=torch.int32)
        unique, counts = torch.unique(act_binning, return_counts=True)

        prev = 0
        ctr = 0
        cdf = []
        for i in range(num_bins):
            if(unique[ctr]==i):
                prev += counts[ctr]
                ctr+=1
            cdf.append(copy.deepcopy(prev))
        cdf.append(1)
        cdf = torch.Tensor(cdf)/counts.sum()

        scoring = cdf[act_binning.to(dtype=torch.int64)]

        plt.rcParams['figure.figsize'] = [20, 20]

        DISPLAY_MODE = 1
        SHOW_MASKED = False
        
        membership = np.array(labels)
        num_blocks = len(np.unique(membership))
        features_per_cluster = [(membership==i).sum() for i in range(num_blocks)]
        print("Number of features per cluster are: ")
        print(features_per_cluster)
        sample_membership = []
        cluster_wise_features = []
        for i in range(num_blocks):
            # Find the features that belong to the current Block id 'i'
            features_cluster = torch.where(torch.Tensor(membership)==i)[0]
            cluster_wise_features.append(features_cluster)
            # Find how many of them are active per sample to help with ideal Block membership
            num_cluster_features_selected = scoring[:, features_cluster].sum(dim=1)
            sample_membership.append(num_cluster_features_selected/features_per_cluster[i])
        # A sample is a part of a Block if the maximum number of activations correspond to features of the Block
        ideal_cluster = torch.vstack(sample_membership).argmax(dim=0)
        print("Number of samples assigned to each cluster are: ")
        samples_per_cluster = [(ideal_cluster==i).sum().item() for i in range(num_blocks)]
        print([(ideal_cluster==i).sum().item() for i in range(num_blocks)])
        cluster_ordering = torch.Tensor(samples_per_cluster).argsort(descending=True).tolist()
        cluster_wise_features_ordered = []
        for i in cluster_ordering:
            # Isolate the samples and features that belong to a Block and find # of samples activated per feature
            num_samples_per_feature = scoring[torch.where(ideal_cluster==i)[0]][:, cluster_wise_features[i]].sum(dim=0)
            # Based on the activation frequency, order the cluster-wise features in descending format for visualization
            descending_order_of_features = num_samples_per_feature.argsort(descending=True)
            #print(num_samples_per_feature[descending_order_of_features])
            cluster_wise_features_ordered.append(cluster_wise_features[i][descending_order_of_features])
        
        samples_ranked = []
        for i in cluster_ordering:
            # Find samples that belong to a Block
            sample_ids_of_cluster = torch.where(ideal_cluster==i)[0]
            # Isolate the block mask
            cluster_sample_mask = scoring[sample_ids_of_cluster][:,cluster_wise_features_ordered[i]]
            # Rank the samples such that the feature with higher frequency are activated. As tie-break, 
            # we give priority to the sample with the higher # of Block activations
            rank = (torch.max(cluster_sample_mask, dim=1)[1])*len(cluster_wise_features_ordered[i]) + cluster_sample_mask.sum(dim=1)
            # Sort the samples according to the rank.
            samples_ranked.append(sample_ids_of_cluster[torch.sort(rank)[1]])
        
        
        samples_sequence = torch.hstack(samples_ranked)
        features_sequence = torch.hstack(cluster_wise_features_ordered)
        #########   WHAT TO DISPLAY   #############
        ### 1. scores
        if(DISPLAY_MODE==1):
            new_mask = scoring[samples_sequence][:, features_sequence]
        ### 2. normalized activations
        elif(DISPLAY_MODE==2):
            new_mask = normalized[samples_sequence][:, features_sequence]
        ### 3. activations
        elif(DISPLAY_MODE==3):
            new_mask = acts_new[samples_sequence][:, features_sequence]

        # fig_filename = os.path.join(teal_results_dir, "Block_Imgs", "_".join([layer_name, args.teal_sparsity_type, "sparsity", str(SPARSITY), str(activation_sparsity)]))
        fig_filename = os.path.join(teal_results_dir, "Block_Imgs", "_".join([args.model_dir.split("/")[-1], layer_name, args.teal_sparsity_type, "sparsity", str(SPARSITY), str(activation_sparsity)]))
        display_block(new_mask, show_separation=False, save=True, filename=fig_filename)
        plt.close()
        display_block(new_mask, show_separation=True, save=True, filename=fig_filename)
        plt.close()
