import numpy as np
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import normalize
from sklearn.utils import shuffle
import matplotlib as mpl
import matplotlib.pyplot as plt
import os
from collections import OrderedDict
import torch
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),  # [0,1] range
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # optional but common
])


def process_features(X, y, is_shuffle=True, seed=100):
    if is_shuffle:
        X, y = shuffle(X, y, random_state=seed)

    cat_imputer = SimpleImputer(strategy="most_frequent")
    num_imputer = SimpleImputer(strategy="mean")

    categorical_features = X.select_dtypes(include=['category', 'object']).columns.tolist()
    numerical_features = X.select_dtypes(include=['int', 'float']).columns.tolist()

    if categorical_features:
        encoder = OneHotEncoder(handle_unknown="ignore", sparse_output=False)
        X_cat = cat_imputer.fit_transform(X[categorical_features])
        X_cat = encoder.fit_transform(X_cat)
    else:
        X_cat = np.array([]).reshape(len(X), 0)
    
    if numerical_features:
        X_num = num_imputer.fit_transform(X[numerical_features])
        scaler = StandardScaler()
        X_num = scaler.fit_transform(X_num)
    else:
        X_num = np.array([]).reshape(len(X), 0)
    
    X_processed = np.hstack([X_cat, X_num])
    X_processed = normalize(X_processed)

    y_processed = OrdinalEncoder(dtype=np.int32).fit_transform(y.to_numpy().reshape((-1, 1)))

    return X_processed, y_processed


def import_data(name):
    if name == 'mnist':
        X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
        X = X.astype(np.float32) / 255.0
    elif name == 'fashion':
        X, y = fetch_openml('fashion-MNIST', version=1, return_X_y=True)
        X = X.astype(np.float32) / 255.0
    elif name == 'mushroom':
        X, y = fetch_openml('mushroom', version=1, return_X_y=True)
    elif name == 'adult':
        X, y = fetch_openml('adult', version=2, return_X_y=True)
    elif name == 'covertype':
        X, y = fetch_openml('covertype', version=3, return_X_y=True)
    elif name == 'isolet':
        X, y = fetch_openml('isolet', version=1, return_X_y=True)
    elif name == 'letter':
        X, y = fetch_openml('letter', version=1, return_X_y=True)
    elif name == 'Magic':
        X, y = fetch_openml('MagicTelescope', version=1, return_X_y=True)
    elif name == 'shuttle':
        X, y = fetch_openml('shuttle', version=1, return_X_y=True)
    else:
        raise RuntimeError('Dataset does not exist')
    
    return X, y


def linestyle2dashes(style):
    if   style == 'solid':                  return (0, ())
    elif style == 'dotted':                 return (0, (1, 1))
    elif style == 'loosely dotted':         return (0, (1, 10))
    elif style == 'densely dotted':         return (0, (1, 1))
    elif style == 'dashed':                 return (0, (5, 5))
    elif style == 'loosely dashed':         return (0, (5, 10))
    elif style == 'densely dashed':         return (0, (5, 1))
    elif style == 'dashdotted':             return (0, (3, 5, 1, 5))
    elif style == 'loosely dashdotted':     return (0, (3, 10, 1, 10))
    elif style == 'densely dashdotted':     return (0, (3, 1, 1, 1))
    elif style == 'dashdotdotted':          return (0, (3, 5, 1, 5, 1, 5))
    elif style == 'loosely dashdotdotted':  return (0, (3, 10, 1, 10, 1, 10))
    elif style == 'densely dashdotdotted':  return (0, (3, 1, 1, 1, 1, 1))


def plot_results(horizon, filename, environments, algorithms):
    mpl.rcParams["axes.linewidth"] = 0.75
    mpl.rcParams["grid.linewidth"] = 0.75
    mpl.rcParams["lines.linewidth"] = 1
    mpl.rcParams["patch.linewidth"] = 1.5
    mpl.rcParams["xtick.major.size"] = 3
    mpl.rcParams["ytick.major.size"] = 3

    mpl.rcParams["pdf.fonttype"] = 42
    mpl.rcParams["ps.fonttype"] = 42
    mpl.rcParams["font.size"] = 14
    mpl.rcParams["axes.titlesize"] = "large"
    mpl.rcParams["axes.labelsize"] = "medium"
    mpl.rcParams["xtick.labelsize"] = "medium"
    mpl.rcParams["ytick.labelsize"] = "medium"
    mpl.rcParams["legend.fontsize"] = "large"

    mpl.rcParams["text.usetex"] = False
    # mpl.rcParams["text.usetex"] = True
    # mpl.rcParams['text.latex.preamble'] = r'\usepackage{amsmath} \usepackage{amssymb}'


    step = np.arange(1, horizon + 1)

    alg_map = {
        "epsilon": ["orange", "solid", r"$\varepsilon$-Greedy"],
        "DeepFPL": ["purple", "solid", "DeepFPL"],
        "NeuralUCB": ["green", "solid", "NeuralUCB"],
        "NeuralTS": ["black", "solid", "NeuralTS"],
        "DeepFP": ["blue", "solid", "DeepFP (ours)"],
    }
    algorithms = [(alg, *alg_map[alg]) for alg in algorithms]

    fig, axs = plt.subplots(1, len(environments), figsize=(len(environments)*5+2, 3))
    handles, labels = [], []

    for fig_idx, env_def in enumerate(environments):
        env_name, K, d = env_def[0], env_def[1], env_def[2]
        res_dir = os.path.join(".", "Results", f"{filename}", env_name)
        if isinstance(axs, np.ndarray):
            plt.sca(axs[fig_idx])
        else:
            plt.sca(axs)  

        for alg_idx, alg_def in enumerate(algorithms):
            alg_name, alg_color, alg_line, alg_label = alg_def
            fname = os.path.join(res_dir, alg_name)
            cum_regret = np.loadtxt(fname, delimiter=",")
            std_regret = cum_regret.std(axis=1) / np.sqrt(cum_regret.shape[1])

            plt.plot(step, cum_regret.mean(axis=1), color=alg_color, linestyle=linestyle2dashes(alg_line), label=alg_label)

            sparse_indices = np.arange(0, len(step), len(step)//10)
            plt.errorbar(step[sparse_indices], cum_regret.mean(axis=1)[sparse_indices], yerr=std_regret[sparse_indices], fmt='o', color=alg_color, capsize=3, capthick=1, linewidth=1,markersize=1)

            if isinstance(axs, np.ndarray):
                h, l = axs[fig_idx].get_legend_handles_labels()
            else:
                h, l = axs.get_legend_handles_labels() 

            handles.extend(h)
            labels.extend(l)

        plt.title(rf"{env_name} ($K={K}, d={d}$)")
        plt.xlabel(r"Round $t$")
        plt.xlim(0, horizon)
        plt.grid(True)

        if fig_idx == 0:
            plt.ylabel("Cumulative Penalty")

    unique_handles_labels = list(OrderedDict(zip(labels, handles)).items())
    labels, handles = zip(*unique_handles_labels)
    fig.legend(handles=handles, labels=labels, loc='center', frameon=True, bbox_to_anchor=(0.5, -0.2), fancybox=True, ncol=len(labels), fontsize=15)

    plot_dir = os.path.join(".", "Plots")
    os.makedirs(plot_dir, exist_ok=True)

    fig_name = f"{filename}.pdf"
    fname = os.path.join(plot_dir, fig_name)
    plt.savefig(fname, format="pdf", dpi=1200, bbox_inches="tight")
    plt.show()