from tensorboard.backend.event_processing import event_accumulator
from argparse import ArgumentParser
import numpy as np
np.random.seed(101)
import random
random.seed(101)
import os
import glob
import torch
torch.use_deterministic_algorithms(True)
import torch.nn as nn
import math
from itertools import chain
from torch.linalg import svdvals, matrix_rank
import matplotlib.pyplot as plt
import seaborn as sns
import re


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--task', required=True, default='cola')
    parser.add_argument('--method', required=True, choices=['lora', 'lora_ortho'])
    parser.add_argument('--r', type=int, default=8)
    parser.add_argument('--alpha', type=int, default=16)
    parser.add_argument('--deltas', action='store_true')
    parser.add_argument('--lora-weights', action='store_true')
    return parser.parse_args()


def _load_scalars(path):
    event_acc = event_accumulator.EventAccumulator(path)
    event_acc.Reload()
    data = {}

    for tag in sorted(event_acc.Tags()['scalars']):
        x, y = [], []

        for scalar_event in event_acc.Scalars(tag):
            x.append(scalar_event.step)
            y.append(scalar_event.value)

            data[tag] = (np.asarray(x), np.asarray(y))
    return data


def initialize_matrix(mat, seed):
    torch.manual_seed(seed)
    nn.init.kaiming_uniform_(mat, a=math.sqrt(5))


def main():
    args = parse_args()
    device = 'cpu'
    print(f"Method: {args.method}")
    setting = f"r_{args.r}_alpha_{args.alpha}"
    files = [f for f in glob.glob(os.path.join(args.task, args.method, setting, '*', '*', '**')) if 'pytorch_model.bin' in f]

    if args.deltas:
        dir = "deltas"
    elif args.lora_weights:
        dir = "lora_weights"
    else:
        dir = "orig_weights"

    os.makedirs(os.path.join('plots', args.task, args.method, setting, dir), exist_ok=True)
    for f in files:
        types = ['orig_weight'] if not args.deltas and not args.lora_weights else ['lora_A', 'lora_B']
        for type in types:
            seed = int(f.split('/')[-3])
            if seed:
                continue
            params = torch.load(f, map_location=device)
            keys = list(params.keys())
            if type == 'orig_weight':
                lora_inds = [i for i, k in enumerate(keys) if 'lora' in k]
            else:
                lora_inds = [i for i, k in enumerate(keys) if type in k]
            lora_keys = np.array(keys)[lora_inds]
            _, idx = np.unique([re.sub('lora_.', 'weight', k) for k in lora_keys], return_index=True)
            weight_keys = [re.sub('lora_.', 'weight', k) for k in lora_keys[np.sort(idx)]]
            pruned_lora_keys = np.array([k.split('.')[4] for k in lora_keys])
            pruned_weight_keys = np.array([k.split('.')[4] for k in weight_keys])
            weights = list(params.values())
            lora_weights = [weights[idx] for idx in lora_inds]
            if args.deltas:
                print("SVD on deltas")
                deltas = []
                for weight, name in zip(lora_weights, lora_keys):
                    if 'lora_B' in name:
                        mat = torch.zeros(weight.shape).to(device)
                    else:
                        mat = torch.zeros(weight.shape).to(device)
                        initialize_matrix(mat, seed)
                    deltas.append(weight - mat)

                if type == 'lora_A':
                    interm_inds = np.nonzero(pruned_lora_keys == 'output')[0]
                else:
                    interm_inds = np.nonzero(pruned_lora_keys == 'intermediate')[0]
                other_inds = np.setxor1d(np.arange(len(lora_keys)), interm_inds)

                svals = []
                ranks = []
                for index in [interm_inds, other_inds]:
                    svals.append(svdvals(torch.stack([deltas[ind] for ind in index])))
                    ranks.append(matrix_rank(torch.stack([deltas[ind] for ind in index])))
                svals = torch.cat(svals)
                ranks = torch.cat(ranks)
                normed_svals = svals / svals.abs().sum(-1).reshape(-1, 1)
                eranks = torch.exp(torch.distributions.Categorical(probs=normed_svals).entropy())
            elif args.lora_weights:
                print("SVD on lora weights")
                if type == 'lora_A':
                    interm_inds = np.nonzero(pruned_lora_keys == 'output')[0]
                else:
                    interm_inds = np.nonzero(pruned_lora_keys == 'intermediate')[0]
                other_inds = np.setxor1d(np.arange(len(lora_keys)), interm_inds)
                svals = []
                ranks = []
                for index in [interm_inds, other_inds]:
                    svals.append(svdvals(torch.stack([lora_weights[ind] for ind in index])))
                    ranks.append(matrix_rank(torch.stack([lora_weights[ind] for ind in index])))
                svals = torch.cat(svals)
                ranks = torch.cat(ranks)
                normed_svals = svals / svals.abs().sum(-1).reshape(-1, 1)
                eranks = torch.exp(torch.distributions.Categorical(probs=normed_svals).entropy())
            else:
                print("SVD on original weights")
                first_interm_inds = np.nonzero(pruned_weight_keys == 'output')[0]
                sec_interm_inds = np.nonzero(pruned_weight_keys == 'intermediate')[0]
                interm_inds = np.concatenate([first_interm_inds, sec_interm_inds])
                other_inds = np.setxor1d(np.arange(len(pruned_weight_keys)), interm_inds)
                interm_shape = params[weight_keys[interm_inds[0]]].shape
                other_shape = params[weight_keys[other_inds[0]]].shape
                svals = []
                ranks = []
                for index, shape in zip([interm_inds, other_inds], [interm_shape, other_shape]):
                    svals.append(svdvals(torch.stack([params[weight_keys[ind]].view(shape) for ind in index])))
                    ranks.append(matrix_rank(torch.stack([params[weight_keys[ind]].view(shape) for ind in index])))
                svals = torch.cat(svals)[:, :args.r]
                ranks = torch.cat(ranks)
                normed_svals = svals / svals.abs().sum(-1).reshape(-1, 1)
                eranks = torch.exp(torch.distributions.Categorical(probs=normed_svals).entropy())

            plt.close()
            plt.figure(figsize=(10,5))
            sns.heatmap(svals.numpy().T)
            if args.deltas or args.lora_weights:
                ticks = [f"layer_{k.split('.')[3]}_{k.split('.')[-2]}" for k in lora_keys]
            else:
                ticks = [f"layer_{k.split('.')[3]}_{k.split('.')[-2]}" for k in weight_keys]
            plt.xticks(np.arange(.5, len(svals)+.5, 1), ticks, rotation=90)
            plt.title(f"Singular Values of {type} Seed: {seed}")
            plt.tight_layout()
            plt.gcf().savefig(os.path.join('plots', args.task, args.method, setting, dir, f'{type}_{seed}_svals.png'))

            plt.close()
            plt.figure(figsize=(10, 5))
            plt.plot(eranks, label="erank")
            plt.plot(ranks, label="rank")
            plt.legend(loc='lower right')
            plt.xticks(np.arange(len(svals)), ticks, rotation=90)
            plt.title(f"Effective ranks of {type} Seed: {seed}")
            plt.grid()
            plt.ylim([0., 8.1])
            plt.tight_layout()
            plt.gcf().savefig(os.path.join('plots', args.task, args.method, setting, dir, f'{type}_{seed}_eranks_and_ranks.png'))


if __name__ == '__main__':
    main()