import os

from src.task_generator import *
from src.prompt import *
from src.conventional import *
from src.datagenerator import *
from src.args import parse_args
from src.utils import set_seed

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from tqdm import tqdm


point_colors = [(0, "#2878B5"), (1, "#C82423")]
point_cmap = LinearSegmentedColormap.from_list("pt_color", point_colors)

bg_colors = [(0, "#8E8BFE"), (1, "#FEA3A2")]
bg_cmap = LinearSegmentedColormap.from_list("bg_color", bg_colors)


def select_data(data, labels, target_label):
    
    targets = []
    for i in range(data.shape[0]):
        if labels[i] == target_label:
            targets.append(data[i])
    return np.array(targets)


def plot_samples(train_data, train_labels, queries, query_labels, fig_title, fig_name):
    plt.figure(figsize=(6, 4))
    
    ### plot predictions
    plt.contourf(np.array(queries[:,0]).reshape(50, 50), 
                 np.array(queries[:,1]).reshape(50, 50), 
                 np.array(query_labels).reshape(50,50), 
                 alpha=0.8, cmap=ListedColormap(["#8E8BFE", "#FEA3A2"]))
    
    ### plot in-context samples
    POINTSIZE = 100
    train_pos = select_data(train_data, train_labels, target_label=1)
    plt.scatter(train_pos[:, 0], train_pos[:, 1], s=POINTSIZE, color="#C82423", label="Class=1 (In-Context Samples)", edgecolors="k")
    train_neg = select_data(train_data, train_labels, target_label=0)
    plt.scatter(train_neg[:, 0], train_neg[:, 1], s=POINTSIZE, color="#2878B5", label="Class=0 (In-Context Samples)", edgecolors="k")
    
    ### range of ticks
    x_min, y_min = np.min(queries, axis=0)
    x_max, y_max = np.max(queries, axis=0)
    plt.xlim((x_min, x_max))
    plt.ylim((y_min, y_max))
    
    ### remove ticks
    plt.xticks([])
    plt.yticks([])
    
    ### legend
    legend_font = {"weight": "bold", "size": 8}
    legend_elements = [Line2D([0], [0], marker='o', color="k", markerfacecolor="#2878B5", lw=0, label="Class 0 (In-Context Samples)"), 
                       Line2D([0], [0], marker='o', color="k", markerfacecolor="#C82423", lw=0, label="Class 1 (In-Context Samples)"),
                       Patch(facecolor="#8E8BFE", edgecolor="#8E8BFE", label="Class 0 (Model Prediction)"),
                       Patch(facecolor="#FEA3A2", edgecolor="#FEA3A2", label="Class 1 (Model Prediction)")]
    plt.legend(handles=legend_elements, prop=legend_font)
    
    ### title
    #plt.title(fig_title, fontweight="bold", fontsize=28)
    
    ### save fig
    savepath = os.path.join(os.getcwd(), "plot/figures/paper")
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    plt.savefig(os.path.join(savepath, fig_name+".pdf"), dpi=300, bbox_inches="tight")


if __name__ == "__main__":
    
    args = parse_args()
    set_seed(args.seed)
    
    if "linear" in args.task_mode:
        data, labels = generate_linear_task(num_classes=2, mode="train", num_samples=128, randseed=args.seed)
    elif "circle" in args.task_mode:
        data, labels = generate_circle_task(mode="train", noise=args.circle_task_noise, num_samples=128, randseed=args.seed)
    elif "moon" in args.task_mode:
        data, labels = generate_moon_task(mode="train", num_samples=128, randseed=args.seed)
    else:
        raise ValueError("Unrecognized task mode.")
    
    ### generate grid data
    queries = generate_grid_data(data)
    
    models = ["decision_tree", "knn", "svm", "linear_regression", "mlp"]
    
    for model in models:
        if model == "decision_tree":
            title = "Decision Tree"
            print(f"Evaluating with {title}")
            if args.pred_mode == "all":
                pred_labels = decisiontree(data=data, labels=labels, queries=queries, seed=args.seed)
            elif args.pred_mode == "single":
                pred_labels = []
                for subquery in tqdm(queries):
                    pred_labels.append(
                        decisiontree(data=data, labels=labels, queries=np.reshape(subquery, (1, -1)), seed=args.seed)
                    )
        elif model == "mlp":
            title = "MLP"
            print(f"Evaluating with {title}")
            if args.pred_mode == "all":
                pred_labels = mlp(data=data, labels=labels, queries=queries, randseed=args.seed)
            elif args.pred_mode == "single":
                pred_labels = []
                for subquery in tqdm(queries):
                    pred_labels.append(
                        mlp(data=data, labels=labels, queries=np.reshape(subquery, (1, -1)), randseed=args.seed)
                    )
        elif model == "knn":
            title = "K-NN"
            print(f"Evaluating with {title}")
            if args.pred_mode == "all":
                pred_labels = knn(data=data, labels=labels, queries=queries)
            elif args.pred_mode == "single":
                pred_labels = []
                for subquery in tqdm(queries):
                    pred_labels.append(
                        knn(data=data, labels=labels, queries=np.reshape(subquery, (1, -1)))
                    )
        elif model == "svm":
            title = "SVM (RBF Kernel)"
            print(f"Evaluating with {title}")
            if args.pred_mode == "all":
                pred_labels = svm(data=data, labels=labels, queries=queries, seed=args.seed)
            elif args.pred_mode == "single":
                pred_labels = []
                for subquery in tqdm(queries):
                    pred_labels.append(
                        svm(data=data, labels=labels, queries=np.reshape(subquery, (1, -1)), seed=args.seed)
                    )
        elif model == "linear_regression":
            title = "Linear Regression"
            print(f"Evaluating with {title}")
            if args.pred_mode == "all":
                pred_labels = linear_regression(data=data, labels=labels, queries=queries)
            elif args.pred_mode == "single":
                pred_labels = []
                for subquery in tqdm(queries):
                    pred_labels.append(
                        linear_regression(data=data, labels=labels, queries=np.reshape(subquery, (1, -1)))
                    )
        else:
            raise ValueError(f"Unrecognized model. Please select from {models}.")

        if args.pred_mode == "all":
            fig_name = f"{title}_binary_{args.task_mode}"
        elif args.pred_mode == "single":
            fig_name = f"{title}_binary_{args.task_mode}_pw"
        plot_samples(data, labels, queries, pred_labels, fig_title=title, fig_name=fig_name)
    

    
    