from argparse import ArgumentParser
import numpy as np
np.random.seed(101)
import random
random.seed(101)
import os
import torch
torch.use_deterministic_algorithms(True)
import torch.nn as nn
import math
from torch.linalg import svdvals
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoModel


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--model_name', type=str, default='roberta-base')
    parser.add_argument('--threshold', type=float, default=0.9)
    parser.add_argument('--budget', type=int, required=True, help="Rank budget as in adalora for decomposition")
    return parser.parse_args()


def initialize_matrix(mat, seed):
    torch.manual_seed(seed)
    nn.init.kaiming_uniform_(mat, a=math.sqrt(5))


def main():
    args = parse_args()
    model = AutoModel.from_pretrained(args.model_name)

    os.makedirs(os.path.join('plots', 'ranks', args.model_name), exist_ok=True)
    trainable_params = ['query', 'key', 'value', 'attention.output', 'intermediate', 'output.dense']
    trainable_state_dict = {}
    for name, param in model.named_parameters():
        for trainable_param in trainable_params:
            if trainable_param in name and 'weight' in name and not 'LayerNorm' in name:
                trainable_state_dict[name] = param

    keys = list(trainable_state_dict.keys())
    params = list(trainable_state_dict.values())
    print("SVD on original weights")
    param_type_dict = {}
    svals = []
    names = []
    abbrs = []
    with torch.no_grad():
        for param_name in trainable_params:
            if param_name != 'output.dense':
                cur_param_inds = [i for i, key in enumerate(keys) if param_name in key]
            else:
                cur_param_inds = [i for i, key in enumerate(keys) if param_name in key and not 'attention' in key]
            sval = svdvals(torch.stack([params[ind] for ind in cur_param_inds])).numpy()
            # sval = sval / sval.sum(-1).unsqueeze(1)
            # ranks = (torch.cumsum(sval, dim=-1) < args.threshold).sum(-1)
            svals.append(sval)
            names.append([[keys[ind]] * sval.shape[-1] for ind in cur_param_inds])
            abbrs.append(param_name)

    # import pdb
    # pdb.set_trace()
    # avg_svals = np.mean(svals, axis=1)
    # for n, s in zip(abbrs, avg_svals):
    #     sns.histplot(s, label=n)
    # plt.title("Histogram of singular values averaged over layers")
    # plt.xlabel("Singular Values")
    # plt.legend(loc="best")
    # plt.gcf().savefig(os.path.join('plots', 'ranks', args.model_name, f'sval_histogram.png'))


    svals = np.concatenate(svals).flatten()
    names = np.concatenate(names).flatten()
    sorted = np.argsort(svals)[::-1]
    # sorted singular values only required for thresholding
    # sorted_svals = svals[sorted][:args.budget]
    sorted_names = np.array(names)[sorted][:args.budget]
    for param_name in trainable_params:
        per_layer = []
        for layer in np.arange(12):
            per_layer.append(np.sum([1 for name in sorted_names if f".{layer}." in name and param_name in name]))
        param_type_dict[param_name] = per_layer

    plt.close()
    plt.figure(figsize=(10, 5))
    rank_matrix = np.array([param_type_dict[key] for key in param_type_dict.keys()], dtype=int)
    sns.heatmap(rank_matrix, annot=True, linewidths=.5, cmap='magma', fmt="d")
    ticks = [i for i in np.arange(rank_matrix.shape[1])]
    plt.xticks(np.arange(.5, len(ticks)+.5, 1), ticks)
    plt.xlabel("Layer")
    yticks = [k for k in param_type_dict.keys()]
    plt.yticks(np.arange(.5, len(param_type_dict)+.5, 1), yticks, rotation=0)
    plt.ylabel("Weights")
    plt.title(f"Ranks of low rank decompositions for budget={args.budget}")
    plt.tight_layout()
    plt.gcf().savefig(os.path.join('plots', 'ranks', args.model_name, f'ranks_budget_{args.budget}.png'))


if __name__ == '__main__':
    main()