import os
import sys

from src.task_generator import *
from src.prompt import *
from src.conventional import *
from src.datagenerator import *
from src.args import parse_args
from src.utils import set_seed
from src.kernel_hsic import kernel_HSIC, linear_HSIC

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

from tqdm import tqdm


point_colors = [(0, "#2878B5"), (1, "#C82423")]
point_cmap = LinearSegmentedColormap.from_list("pt_color", point_colors)

bg_colors = [(0, "#8E8BFE"), (1, "#FEA3A2")]
bg_cmap = LinearSegmentedColormap.from_list("bg_color", bg_colors)


DATAROOT = "./"

llm_pred_paths = {
    "linear": os.path.join(DATAROOT, "data_records/pred_results", "llama-3_binary_linear_classification_preds.npy"),
    "circle": os.path.join(DATAROOT, "data_records/pred_results", "llama-3_binary_circle_classification_preds.npy"),
    "moon": os.path.join(DATAROOT, "data_records/pred_results", "llama-3_binary_moon_classification_preds.npy"),
}


def plot_samples(train_data, train_labels, queries, query_labels, fig_title, fig_name):
    plt.figure(figsize=(6, 4))
    
    ### plot predictions
    plt.contourf(np.array(queries[:,0]).reshape(50, 50), 
                 np.array(queries[:,1]).reshape(50, 50), 
                 np.array(query_labels).reshape(50,50), 
                 alpha=0.8, cmap=ListedColormap(["#8E8BFE", "#FEA3A2"]))
    
    ### plot in-context samples
    POINTSIZE = 100
    train_pos = select_data(train_data, train_labels, target_label=1)
    plt.scatter(train_pos[:, 0], train_pos[:, 1], s=POINTSIZE, color="#C82423", label="Class=1 (In-Context Samples)", edgecolors="k")
    train_neg = select_data(train_data, train_labels, target_label=0)
    plt.scatter(train_neg[:, 0], train_neg[:, 1], s=POINTSIZE, color="#2878B5", label="Class=0 (In-Context Samples)", edgecolors="k")
    
    ### range of ticks
    x_min, y_min = np.min(queries, axis=0)
    x_max, y_max = np.max(queries, axis=0)
    plt.xlim((x_min, x_max))
    plt.ylim((y_min, y_max))
    
    ### remove ticks
    plt.xticks([])
    plt.yticks([])
    
    ### legend
    legend_font = {"weight": "bold", "size": 8}
    legend_elements = [Line2D([0], [0], marker='o', color="k", markerfacecolor="#2878B5", lw=0, label="Class 0 (In-Context Samples)"), 
                       Line2D([0], [0], marker='o', color="k", markerfacecolor="#C82423", lw=0, label="Class 1 (In-Context Samples)"),
                       Patch(facecolor="#8E8BFE", edgecolor="#8E8BFE", label="Class 0 (Model Prediction)"),
                       Patch(facecolor="#FEA3A2", edgecolor="#FEA3A2", label="Class 1 (Model Prediction)")]
    plt.legend(handles=legend_elements, prop=legend_font)
    
    ### title
    #plt.title(fig_title, fontweight="bold")
    
    ### save fig
    savepath = os.path.join(os.getcwd(), "plot/figures/paper")
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    plt.savefig(os.path.join(savepath, fig_name+".pdf"), dpi=300, bbox_inches="tight")
    

def get_preds(task_mode:str) -> np.ndarray:
    """
    Get the predictions of LLMs from the saved file.
    Make sure all data files exist.

    Args:
        task_mode (str): The task mode. ["linear", "circle", "moon"].

    Returns:
        np.ndarray: The predictions of the LLMs.
    """
    assert os.path.exists(llm_pred_paths[task_mode]), "The data file does not exist. Please run `run_llm.py` first."
    data = np.load(file=llm_pred_paths[task_mode], allow_pickle=True).item()["predictions"]
    
    return np.array(data)
    


def select_data(data, labels, target_label):
    
    targets = []
    for i in range(data.shape[0]):
        if labels[i] == target_label:
            targets.append(data[i])
    return np.array(targets)


def get_probability(num_methods:int, mode:str="standard") -> List[float]:
    """
    Get the probability for randomly sampling methods.

    Args:
        num_methods (int): The number of candidates.
        mode (str, optional): The type of probability.

    Returns:
        List[float]: A list of probability.
    """
    if mode == "standard":
        probs = [0.673, 0.000, 0.044, 0.003, 0.280]
    elif mode == "ml-only":
        probs = [0.8, 0.04, 0.02, 0.01, 0.13]
    elif mode == "uniform":
        probs = [1./num_methods] * num_methods
    elif mode == "random":
        rand_numbers = np.random.rand(num_methods)
        probs = rand_numbers / rand_numbers.sum()
    else:
        raise ValueError("Unrecognized mode.")
    
    return probs


def apply_method(model_kw:str, data:np.ndarray, labels:np.ndarray, query:np.ndarray, args):
    """
    Apply machine methods according to the model key words.

    Args:
        model_kw (str): Model key word.
        data (np.ndarray): In-context data.
        labels (np.ndarray): Labels of in-context data.
        query (np.ndarray): Query data.
        args (_type_): Hyperparameters

    Raises:
        ValueError: _description_

    Returns:
        _type_: predictions.
    """
    
    if model_kw == "decision_tree":
        pred_labels = decisiontree(data=data, labels=labels, queries=query)
    elif model_kw == "mlp":
        pred_labels = mlp(data=data, labels=labels, queries=query, randseed=args.seed)
    elif model_kw == "knn":
        pred_labels = knn(data=data, labels=labels, queries=query)
    elif model_kw == "svm":
        pred_labels = svm(data=data, labels=labels, queries=query)
    elif model_kw == "linear_regression":
        pred_labels = linear_regression(data=data, labels=labels, queries=query)
    else:
        raise ValueError(f"Unrecognized model. Please select from {models}.")
    return pred_labels


def prob_hybrid(args, models:List[str], data:np.ndarray, labels:np.ndarray, queries:np.ndarray, probList:List[float], prob_type:str="standard"):
    """
    Simulate the behavior of LLMs with various kinds of probabilities.

    Args:
        args (_type_): Hyperparameters.
        models (List[str]): A list of models.
        data (np.ndarray): In-context data.
        labels (np.ndarray): Labels of in-context data.
        queries (np.ndarray): Query data.
        probs (List[float]): A set of probabilities. If probs is not None, use probs.
        prob_type (str): The type of probability.

    Returns:
        _type_: predictions
    """
    if probList is not None:
        probs = probList
    else:
        probs = get_probability(num_methods=len(models), mode=prob_type)

    preds = []
    
    for subquery in tqdm(queries):
        subquery = np.reshape(subquery, (1, -1))
        model = np.random.choice(models, p=probs)
        
        pred_labels = apply_method(model_kw=model, data=data, labels=labels, query=subquery, args=args)
        
        preds.append(pred_labels[0])
        
    return preds


def hsic_hybrid(data:np.ndarray, labels:np.ndarray, queries:np.ndarray, models:List[str], args):
    
    ### 1. Get the pred matrix of all conventional methods.
    all_preds = []
    for model in models:
        all_preds.append(apply_method(
            model_kw=model, data=data, labels=labels, query=queries, args=args
        ))
    
    ### 2. Load predictions of llm
    if args.task_mode == "linear_classification":
        filename = "llama-3_binary_linear_classification_preds.npy"
    elif args.task_mode == "circle_classification":
        filename = "llama-3_binary_circle_classification_preds.npy"
    elif args.task_mode == "moon_classification":
        filename = "llama-3_binary_moon_classification_preds.npy"
    else:
        raise FileNotFoundError("Cannot find such a file.")
    file_path = os.path.join(DATAROOT, "data_records/pred_results", filename)
    
    llm_pred = np.load(file=file_path, allow_pickle=True).item()["predictions"]
    
    hsic_values = []
    for conv_pred in all_preds:
        hsic_values.append(kernel_HSIC(np.reshape(conv_pred, (-1, 1)), np.reshape(llm_pred, (-1, 1))))
    print(hsic_values)
    
    hsic_values = np.array(hsic_values)
    probs = hsic_values / hsic_values.sum()
    
    preds = prob_hybrid(args=args, models=models, data=data, labels=labels, queries=queries, probList=probs)
    
    return preds
        


if __name__ == "__main__":
    
    args = parse_args()
    set_seed(args.seed)
    
    if "linear" in args.task_mode:
        data, labels = generate_linear_task(num_classes=2, mode="train", num_samples=128, randseed=args.seed)
    elif "circle" in args.task_mode:
        data, labels = generate_circle_task(mode="train", noise=args.circle_task_noise, num_samples=128, randseed=args.seed)
    elif "moon" in args.task_mode:
        data, labels = generate_moon_task(mode="train", num_samples=128, randseed=args.seed)
    else:
        raise ValueError("Unrecognized task mode.")
    
    ### generate grid data
    queries = generate_grid_data(data)
    
    #models = ["decision_tree", "knn", "svm", "mlp", "linear_regression"]
    models = ["knn", "svm", "mlp", "linear_regression"]
    
    if args.hybrid_type == "prob":
        preds = prob_hybrid(args=args, models=models, data=data, labels=labels, queries=queries, probList=None, prob_type=args.prob_type)
    elif args.hybrid_type == "hsic":
        preds = hsic_hybrid(data=data, labels=labels, queries=queries, models=models, args=args)
    
    if args.prob_type == "standard":
        fig_name = "standard_hybrid_ml_methods"
    elif args.prob_type == "ml-only":
        fig_name = "ml_only_hybrid_ml_methods"
    elif args.prob_type == "uniform":
        fig_name = "uniform_hybrid_ml_methods"
    elif args.prob_type == "random":
        fig_name = "random_hybrid_ml_methods"
    else:
        raise ValueError("Unrecognized probability type.")
    
    if args.hybrid_type == "prob":
        fig_name = f"prob_{fig_name}"
    elif args.hybrid_type == "hsic":
        fig_name = f"hsic_{fig_name}"
    else:
        raise ValueError("Unrecognized hybrid type.")

    plot_samples(data, labels, queries, preds, fig_title="", fig_name=fig_name+"_wodecisiontree")