import os.path

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


def plot_rank_heatmap(lora_state_dict, new_state_dict, model_args, data_args, config, kind):
    cts = {}
    for key in lora_state_dict.keys():
        if key not in new_state_dict:
            cts[key] = 0
        else:
            cts[key] = new_state_dict[key].shape[0]
    ticks = [k.split('.')[-2] for k in lora_state_dict.keys() if 'lora_A' in k][:6]
    heatmap = np.array([cts[k] for k in lora_state_dict.keys() if 'lora_A' in k]).reshape(config.num_hidden_layers, 6)
    plt.figure(figsize=(10, 5))
    sns.heatmap(heatmap.T, annot=True, linewidths=0.5)
    plt.xlabel("Layer")
    plt.ylabel("Weights")
    plt.xticks(np.arange(.5, config.num_hidden_layers+.5), np.arange(1, config.num_hidden_layers+1))
    plt.yticks(np.arange(.5, 6.5), ticks, rotation=90)
    if not os.path.exists("plots/adaptive_ranks/"):
        os.makedirs("plots/adaptive_ranks/", exist_ok=True)
    model_path = model_args.model_name_or_path.split('/')[0]
    model_name = model_args.model_name_or_path.replace(f"{model_path}/", "")
    plt.gcf().savefig(f"plots/adaptive_ranks/adaptive_rank_{model_name}_{model_args.lora_r}_{data_args.task_name}_pca_{kind}.svg")