import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator

def hpo_plot(insens, all, basic=[], name=None):
    insens = np.array(insens).reshape(5, -1)
    all = np.array(all).reshape(5,-1)
    if len(basic) > 0:
        basic = np.array(basic).reshape(5, -1)
        data = np.concatenate((basic, insens, all), axis=0)
    else:
        data = np.concatenate((insens, all), axis=0)

    print(data)
    avg_data = np.nanmean(data, axis=0)
    std_data = np.nanstd(data, axis=0)
    print(f'avg_data {avg_data}')
    print(f'std_data {std_data}')

    d = {}
    d['1024'] = avg_data[:4]
    d['2048'] = avg_data[4:]
    print(d)

    lrs = [0.01, 0.001, 0.0001, 0.00001]

    plt.figure(figsize=(10, 6))

    # Loop over batch sizes to plot each line with a different color
    for key in d:
        plt.plot(lrs, d[key], marker='o', label=f'Batch Size {key}')
    plt.xscale('log')
    # Adding labels and title
    plt.xlabel('Learning Rate')
    plt.ylabel('F1 Score')
    plt.legend()

    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Show the plot
    plt.savefig(f'./hpo_plot_{name}.pdf')

def prepare_hpo():
    name = 'MLP'
    
    insens = [
        [36.919, 39.691, 45.543, 35.032],  # Horizon 1, Batch Size 1024
    [38.18, 41.783, 41.175, 40.181],   # Horizon 1, Batch Size 2048
    [40.791, 47.165, 48.549, 47.191],  # Horizon 2, Batch Size 1024
    [42.809, 47.19, 47.977, 46.939],   # Horizon 2, Batch Size 2048
    [45.01, 45.886, 46.576, 46.115],   # Horizon 3, Batch Size 1024
    [45.263, 48.33, 45.485, 45.83],    # Horizon 3, Batch Size 2048
    [42.059, 41.745, 44.124, 41.986],  # Horizon 5, Batch Size 1024
    [42.823, 44.714, 40.878, 41.208],  # Horizon 5, Batch Size 2048
    [33.126, 34.858, 34.536, 34.815],  # Horizon 10, Batch Size 1024
    [36.241, 34.757, 34.847, 34.631],
    ]
    
    all = [
        [41.455, 46.684, 43.598, 42.746],  # Horizon 1, Batch Size 1024
    [40.245, 45.977, 44.668, 41.568],  # Horizon 1, Batch Size 2048
    [45.803, 49.750, 49.713, 49.724],  # Horizon 2, Batch Size 1024
    [46.884, 49.102, 50.363, 49.836],  # Horizon 2, Batch Size 2048
    [44.939, 49.230, 50.605, 48.043],  # Horizon 3, Batch Size 1024
    [44.880, 47.205, 48.092, 47.781],  # Horizon 3, Batch Size 2048
    [40.874, 46.063, 44.137, 41.347],  # Horizon 5, Batch Size 1024
    [42.021, 46.266, 45.342, 44.647],  # Horizon 5, Batch Size 2048
    [32.365, 37.545, 37.414, 34.336],  # Horizon 10, Batch Size 1024
    [34.333, 36.852, 35.867, 35.886]
    ]

    basic = [
        
    ]

    hpo_plot(insens, all, basic, name)

def alpha_plot(a1, a2, a4, feature):

    avg_a1 = np.mean(a1, axis=1, keepdims=True)
    avg_a2 = np.mean(a2, axis=1,  keepdims=True)
    avg_a4 = np.mean(a4, axis=1,  keepdims=True)

    models = [
    "MLP",
    "LSTM",
    "CNN1",
    "CTABL",
    "DEEPLOB",
    "DAIN",
    "CNNLSTM",
    "CNN2",
    "TRANSLOB",
    "TLONBoF",
    "BINCTABL",
    "DEEPLOBATT",
    "DLA"
]

    final_data = np.concatenate((avg_a1,avg_a2,avg_a4), axis=1)
    print(final_data)
    alphas = [0.00001, 0.00002, 0.00004]
    assert len(models) == len(final_data)

    fig = plt.figure(figsize=(12, 12))
    cmap = plt.get_cmap('tab20')
    for i in range(len(models)):
        plt.plot(alphas, final_data[i], color=cmap(i/13), marker='o', label=f'{models[i]}')

    # Adding labels and title
    plt.xlabel('Alpha')
    plt.ylabel('Average F1 Score')
    plt.legend(fontsize='small', loc='upper right')
    
    ax = plt.gca()  # Get current axis
    ax.yaxis.set_major_locator(MultipleLocator(1))
    ax.yaxis.set_minor_locator(MultipleLocator(0.5))

    plt.tight_layout()

    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Show the plot
    plt.savefig(f'./alpha_plot_{feature}.pdf')


def prepare_alpha():
    a1 = np.array([
    [37.160, 43.472, 46.296, 45.630, 35.976],
    [47.929, 48.284, 50.695, 50.484, 43.474],
    [37.980, 41.931, 47.301, 49.060, 41.873],
    [45.024, 45.697, 47.326, 44.944, 44.785],
    [48.541, 47.775, 50.258, 47.477, 38.535],
    [46.037, 49.413, 49.427, 46.764, 40.413],
    [46.630, 44.118, 50.118, 46.004, 46.529],
    [45.830, 49.907, 50.326, 50.272, 41.971],
    [50.066, 50.895, 50.509, 49.692, 46.628],
    [44.520, 47.214, 43.228, 49.409, 44.268],
    [44.151, 44.391, 45.379, 45.584, 44.094],
    [46.917, 47.997, 51.728, 49.676, 44.128],
    [47.439, 49.996, 50.136, 47.377, 36.444]
])
    
    a2 = np.array([
    [39.506, 45.029, 45.268, 45.681, 40.506],
    [47.494, 49.760, 51.036, 48.519, 49.913],
    [44.695, 49.182, 49.303, 48.000, 45.299],
    [39.514, 45.893, 46.913, 45.670, 43.667],
    [47.815, 45.593, 45.579, 45.070, 38.657],
    [44.611, 45.840, 48.874, 46.530, 39.390],
    [37.710, 45.054, 43.430, 45.737, 46.422],
    [47.251, 48.248, 50.136, 47.565, 43.717],
    [46.230, 47.916, 50.823, 47.853, 45.845],
    [44.739, 49.428, 46.901, 45.756, 43.810],
    [45.184, 47.356, 45.865, 46.731, 44.234],
    [46.818, 49.818, 50.525, 48.279, 46.580],
    [34.711, 41.011, 41.816, 46.870, 36.906]
])
    
    a4 = np.array([
    [42.179, 42.623, 44.683, 43.670, 39.676],
    [50.533, 49.607, 45.793, 47.074, 44.035],
    [44.180, 47.410, 44.294, 41.485, 43.375],
    [40.567, 40.847, 40.584, 41.336, 39.884],
    [49.497, 47.882, 46.313, 46.244, 43.379],
    [46.920, 44.002, 44.659, 46.027, 41.730],
    [48.993, 47.387, 46.306, 46.046, 43.793],
    [47.072, 47.272, 45.851, 45.195, 44.296],
    [47.431, 44.890, 45.193, 44.260, 45.675],
    [41.419, 42.129, 40.174, 41.930, 40.319],
    [44.335, 43.396, 42.707, 42.020, 43.010],
    [49.245, 48.119, 46.887, 47.283, 45.175],
    [47.439, 47.194, 42.998, 44.598, 44.183]
])
    feature='basic'
    
    alpha_plot(a1,a2,a4, feature)

def level_plot(l2, l5, l10=None, dataset='FI'):

    avg_l2 = np.mean(l2, axis=1, keepdims=True)
    avg_l5 = np.mean(l5, axis=1,  keepdims=True)
    if l10:
        avg_l10 = np.mean(l10, axis=1,  keepdims=True)
        final_data = np.concatenate((avg_l2,avg_l5,avg_l10), axis=1)
    else:
        final_data = np.concatenate((avg_l2,avg_l5), axis=1)

    models = [
    "MLP",
    "LSTM",
    "CNN1",
    "CTABL",
    "DEEPLOB",
    "DAIN",
    "CNNLSTM",
    "CNN2",
    "TRANSLOB",
    "TLONBoF",
    "BINCTABL",
    "DEEPLOBATT",
    "DLA"
]

    print(final_data)
    if dataset == 'FI':
        levels = [2, 5, 10]
    elif dataset == 'CHF':
        levels = [2, 5]
    assert len(models) == len(final_data)

    fig = plt.figure(figsize=(12, 12))
    cmap = plt.get_cmap('tab20')
    for i in range(len(models)):
        plt.plot(levels, final_data[i], color=cmap(i/13), marker='o', label=f'{models[i]}')

    # Adding labels and title
    plt.xlabel('Level')
    plt.ylabel('Average F1 Score')
    plt.legend(fontsize='small', loc='upper right')
    
    ax = plt.gca()  # Get current axis
    ax.yaxis.set_major_locator(MultipleLocator(1))
    ax.yaxis.set_minor_locator(MultipleLocator(0.5))

    plt.tight_layout()

    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    # Show the plot
    plt.savefig(f'./level_{dataset}.pdf')

def prepare_level():

    l2 = np.array([
    [39.284, 43.894, 47.373, 45.938, 39.194],
    [45.546, 49.411, 50.455, 48.496, 44.446],
    [41.827, 46.347, 49.585, 46.486, 44.066],
    [42.843, 45.112, 46.767, 46.201, 43.261],
    [42.716, 47.772, 48.706, 47.881, 41.84],
    [42.421, 48.569, 49.276, 46.43, 40.659],
    [43.736, 46.467, 49.515, 47.546, 45.443],
    [40.458, 47.429, 47.416, 48.031, 46.229],
    [49.064, 50.553, 50.659, 48.375, 46.35],
    [44.182, 45.799, 46.596, 46.306, 43.447],
    [44.27, 44.04, 45.977, 46.177, 43.517],
    [46.533, 48.812, 49.501, 48.163, 41.297],
    [45.275, 49.681, 49.407, 47.184, 38.708]
])
    
    l5 = np.array([
    [39.506, 45.029, 45.268, 45.681, 40.506],
    [47.494, 49.760, 51.036, 48.519, 44.913],
    [44.615, 49.182, 49.303, 48.000, 45.299],
    [39.757, 45.039, 46.313, 45.650, 43.467],
    [47.815, 45.593, 45.579, 45.070, 38.657],
    [44.611, 48.540, 48.874, 46.530, 39.390],
    [37.710, 45.054, 43.400, 45.737, 46.222],
    [47.251, 48.248, 50.136, 47.565, 43.717],
    [46.320, 49.216, 50.823, 47.855, 45.845],
    [44.738, 47.355, 46.801, 46.836, 44.310],
    [45.134, 46.116, 45.965, 45.781, 44.234],
    [46.818, 48.918, 50.525, 48.729, 46.580],
    [34.711, 41.011, 41.816, 46.870, 36.906]
])
    
    l10 = None

    dataset = 'CHF'

    level_plot(l2, l5, l10, dataset)

if __name__ == '__main__':
    prepare_level()