# Plot script for LoCoOp evaluation results
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
import os
import pandas as pd
import torch
import torch.nn.functional as F
import argparse


def plot_distribution(in_dataset, out_dataset, id_scores, ood_scores, score_name, output_dir):
    sns.set(style="white", palette="muted")
    palette = ['#A8BAE3', '#55AB83']

    data = {
        f"ID ({in_dataset})": [-1 * id_score for id_score in id_scores],
        f"OOD ({out_dataset})": [-1 * ood_score for ood_score in ood_scores]
    }

    sns.displot(data, label="id", kind="kde", palette=palette, fill=True, alpha=0.8)

    plt.suptitle(f"Score Distribution\nID: ImageNet, OOD: {out_dataset}", fontsize=14, y=1.1)

    if score_name is not None:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(os.path.join(output_dir, f"{out_dataset}_{score_name}.png"), bbox_inches='tight')
    else:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        plt.savefig(os.path.join(output_dir, f"{out_dataset}.png"), bbox_inches='tight')


def main():
    parser = argparse.ArgumentParser(description="Plot LoCoOp evaluation results")
    parser.add_argument("--input-file", type=str, required=True, help="Path to scores.npz file")
    parser.add_argument("--output-dir", type=str, required=True, help="Output directory for plots")
    
    args = parser.parse_args()
    
    # Load scores
    scores = np.load(args.input_file, allow_pickle=True)
    print(scores)
    MCM_scores = scores["MCM"].item()
    # GL_MCM_scores = scores["GL-MCM"].item()

    in_score_mcm = MCM_scores["ImageNet"]
    # in_score_gl = GL_MCM_scores["ImageNet"]

    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)

    for dataset in MCM_scores.keys():
        if dataset == "ImageNet":
            continue
        out_score_mcm = MCM_scores[dataset]
        # out_score_gl = GL_MCM_scores[dataset]

        # plot
        plot_distribution("ImageNet", dataset, in_score_mcm, out_score_mcm, "MCM", args.output_dir)
        # plot_distribution("ImageNet", dataset, in_score_gl, out_score_gl, "GL-MCM", args.output_dir)

if __name__ == "__main__":
    main()