import os
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from typing import Dict

ROOTPATH = "./"

def load_data(filename:str):
    data = np.load(file=filename, allow_pickle=True).item()
    return data


def select_data(data, labels, target_label):
    
    targets = []
    for i in range(data.shape[0]):
        if labels[i] == target_label:
            targets.append(data[i])
    return np.array(targets)


def plot_decision_boundary(filepath:str, fig_title:str, fig_name:str):
    # read data
    data_dict = load_data(filepath)
    
    # split data
    train_data = data_dict["train_data"]
    train_labels = data_dict["train_labels"]
    query_data = data_dict["query_data"]
    query_pred = data_dict["predictions"]
    
    plt.figure(figsize=(6, 4))
    
    ### plot predictions
    plt.contourf(np.array(query_data[:,0]).reshape(50, 50), 
                 np.array(query_data[:,1]).reshape(50, 50), 
                 np.array(query_pred).reshape(50,50), 
                 alpha=0.8, cmap=ListedColormap(["#8E8BFE", "#FEA3A2"]))
    
    ### plot in-context samples
    POINTSIZE = 100
    train_pos = select_data(train_data, train_labels, target_label=1)
    plt.scatter(train_pos[:, 0], train_pos[:, 1], s=POINTSIZE, color="#C82423", label="Class=1 (In-Context Samples)", edgecolors="k")
    train_neg = select_data(train_data, train_labels, target_label=0)
    plt.scatter(train_neg[:, 0], train_neg[:, 1], s=POINTSIZE, color="#2878B5", label="Class=0 (In-Context Samples)", edgecolors="k")
    
    ### range of ticks
    x_min, y_min = np.min(query_data, axis=0)
    x_max, y_max = np.max(query_data, axis=0)

    plt.xlim((x_min, x_max))
    plt.ylim((y_min, y_max))
    
    ### remove ticks
    plt.xticks([])
    plt.yticks([])
    
    ### legend
    legend_font = {"weight": "bold", "size": 8}
    legend_elements = [Line2D([0], [0], marker='o', color="k", markerfacecolor="#2878B5", lw=0, label="Class 0 (In-Context Samples)"), 
                       Line2D([0], [0], marker='o', color="k", markerfacecolor="#C82423", lw=0, label="Class 1 (In-Context Samples)"),
                       Patch(facecolor="#8E8BFE", edgecolor="#8E8BFE", label="Class 0 (Model Prediction)"),
                       Patch(facecolor="#FEA3A2", edgecolor="#FEA3A2", label="Class 1 (Model Prediction)")]
    plt.legend(handles=legend_elements, prop=legend_font)
    
    ### title
    #plt.title(fig_title, fontweight="bold", fontsize=28)
    
    ### save fig
    savepath = os.path.join(ROOTPATH, "figures/paper")
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    plt.savefig(os.path.join(savepath, fig_name+".pdf"), dpi=300, bbox_inches="tight")


def confidence_plotter(filepath:str, figname:str):
    
    data_dict = load_data(filename=filepath)
    
    train_data = data_dict["train_data"]
    train_labels = data_dict["train_labels"]
    conf_score = np.clip(data_dict["confidence"], a_min=0.0, a_max=1.0)
    query_data = data_dict["query_data"]
    
    # plot
    plt.figure(figsize=(6, 4))
    
    ### plot confidence
    cs = plt.contourf(np.array(query_data[:,0]).reshape(50, 50),
                 np.array(query_data[:,1]).reshape(50, 50),
                 np.array(conf_score).reshape(50, 50),
                 levels = np.round(np.linspace(0.5, 1.0, 10), 2),
                 alpha=0.8, cmap=plt.get_cmap("Spectral_r"))
    cbar = plt.colorbar(cs)
    
    ### plot in-context samples
    POINTSIZE = 100
    train_pos = select_data(train_data, train_labels, target_label=1)
    plt.scatter(train_pos[:, 0], train_pos[:, 1], s=POINTSIZE, color="#C82423", label="Class=1 (In-Context Samples)", edgecolors="k")
    train_neg = select_data(train_data, train_labels, target_label=0)
    plt.scatter(train_neg[:, 0], train_neg[:, 1], s=POINTSIZE, color="#2878B5", label="Class=0 (In-Context Samples)", edgecolors="k")
    
    plt.xticks([])
    plt.yticks([])
    
    ### save fig
    savepath = os.path.join(ROOTPATH, "figures/paper")
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    plt.savefig(os.path.join(savepath, figname+".pdf"), dpi=300, bbox_inches="tight")


if __name__ == "__main__":
    prompt_mode = "pred"
    filepath = f"PATH_TO_THE_FILE"
    model_name = "llama-3-8b"
    task_mode = "binary_linear_classification"
    exp_name = None

    fig_name = f"{model_name}_{prompt_mode}_{task_mode}"
    if exp_name is not None:
        fig_name += f"_{exp_name}"
    fig_title = model_name+"+"+prompt_mode
    
    plot_decision_boundary(filepath=filepath, fig_title=fig_title, fig_name=fig_name)