from argparse import ArgumentParser
import torch
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

@torch.no_grad()
def compute_subspace_similarity(A, B):
    """Compute subspace similarity between two subspaces"""
    return (A @ B.T).norm() / A.norm(dim=-1).sum()


def save_heatmap(heatmap, args, name):
    plt.figure(figsize=(15, 5))
    r, n_comps = heatmap.shape
    sns.heatmap(heatmap, vmax=1.0, vmin=0.0)
    plt.xticks(np.arange(.5, n_comps+.5, 1), range(1, n_comps + 1), rotation='horizontal')
    if "weights" in name:
        plt.xlabel("Pretrained singular vectors")
    else:
        plt.xlabel("PCA components")
    plt.ylabel("Learned subspace")
    plt.yticks(np.arange(.5, r+.5, 1), range(1, r + 1), rotation='horizontal')
    plt.tight_layout()
    os.makedirs(f"plots/subspaces/{args.model_name}/{args.task_name}", exist_ok=True)
    plt.gcf().savefig(f"plots/subspaces/{args.model_name}/{args.task_name}/{name}")
    plt.close()


def main(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    principal_components = torch.load(os.path.join(args.cache_dir, f"{args.model_name}_{args.task_name}_r_64_pca.bin"))
    principal_components = {k: v for k, v in principal_components.items() if 'lora_A' in k}
    weight_directions = torch.load(os.path.join(args.cache_dir, f"{args.model_name}_r_64_weights.bin"))
    weight_directions = {k: v.to(device) for k, v in weight_directions.items() if 'lora_A' in k}
    trained_weights = {k: v for k, v in torch.load(args.trained_path).items() if 'lora_A' in k}

    trained_decomp = {}
    for name in trained_weights.keys():
        trained_decomp[name] = torch.linalg.svd(trained_weights[name], full_matrices=False)

    ident = args.trained_path.split('/')[1]
    for name in trained_decomp.keys():
        # u_trained_decomp = trained_decomp[name][-1]
        trained = trained_weights[name]
        # trained_normalized = trained / trained.norm(dim=-1, keepdim=True)
        r = trained.shape[0]
        pcs = principal_components[name]
        n_comps = pcs.shape[0]
        w_dir = weight_directions[name]

        heatmap_pca = []
        heatmap_weights = []
        for i in range(1, r+1):
            row_pca = []
            row_weights = []
            for j in range(1, n_comps+1):
                row_pca.append(compute_subspace_similarity(trained[:i], pcs[:j]).cpu().numpy())
                row_weights.append(compute_subspace_similarity(trained[:i], w_dir[:j]).cpu().numpy())
            heatmap_pca.append(row_pca)
            heatmap_weights.append(row_weights)

        save_heatmap(np.array(heatmap_pca), args, f"{name}_{ident}_pca.svg")
        save_heatmap(np.array(heatmap_weights), args, f"{name}_{ident}_weights.svg")


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--trained_path', type=str, required=True)
    parser.add_argument('--task_name', choices=['rte', 'mrpc', 'cola', 'stsb'], required=True)
    parser.add_argument('--model_name', type=str, default='roberta-base')
    parser.add_argument('--cache_dir', type=str, default="/system/user/publicdata/llm")
    main(parser.parse_args())