from tensorboard.backend.event_processing import event_accumulator
from argparse import ArgumentParser
import numpy as np
np.random.seed(101)
import random
random.seed(101)
import os
import glob
import torch
torch.use_deterministic_algorithms(True)
import torch.nn as nn
import math
from itertools import chain
from torch.linalg import svdvals


def parse_args():
    parser = ArgumentParser()
    parser.add_argument('--task', required=True, default='cola')
    parser.add_argument('--method', required=True, choices=['lora', 'lora_ortho'])
    parser.add_argument('--r', type=int, default=8)
    parser.add_argument('--alpha', type=int, default=16)
    return parser.parse_args()


def _load_scalars(path):
    event_acc = event_accumulator.EventAccumulator(path)
    event_acc.Reload()
    data = {}

    for tag in sorted(event_acc.Tags()['scalars']):
        x, y = [], []

        for scalar_event in event_acc.Scalars(tag):
            x.append(scalar_event.step)
            y.append(scalar_event.value)

            data[tag] = (np.asarray(x), np.asarray(y))
    return data


def initialize_matrix(mat, seed):
    torch.manual_seed(seed)
    nn.init.kaiming_uniform_(mat, a=math.sqrt(5))


def main():
    args = parse_args()
    device = 'cpu'
    print(f"Method: {args.method}")
    setting = f"r_{args.r}_alpha_{args.alpha}"
    files = [f for f in glob.glob(os.path.join(args.task, args.method, setting, '*', '*', '**')) if 'pytorch_model.bin' in f]
    deltas = []
    for f in files:
        seed = int(f.split('/')[-3])
        params = torch.load(f, map_location=device)
        keys = list(params.keys())
        lora_inds = [i for i, k in enumerate(keys) if 'lora' in k]
        lora_keys = np.array(keys)[lora_inds]
        weights = list(params.values())
        lora_weights = [weights[idx] for idx in lora_inds]
        initial_weights = []
        lens = []
        for weight, name in zip(lora_weights, lora_keys):
            if 'lora_B' in name:
                mat = torch.zeros(weight.shape).to(device)
            else:
                mat = torch.zeros(weight.shape).to(device)
                initialize_matrix(mat, seed)
            initial_weights.append(mat.flatten())
            lens.append(len(initial_weights[-1]))
        initial_weights = torch.cat(initial_weights)
        final_weights = torch.cat([m.flatten() for m in lora_weights])
        names = [[key]*len for key, len in zip(lora_keys, lens)]
        names = list(chain(*names))
        deltas.append(final_weights - initial_weights)

    import pdb
    pdb.set_trace()
    delta = torch.stack(deltas)
    A = delta.T @ delta
    U, svals, V = torch.svd_lowrank(A, q=1024)
    import pdb
    pdb.set_trace()




if __name__ == '__main__':
    main()