import os
import logging
import sys
import random
import numpy as np
import torch
import pandas as pd
import pathlib

from joint_opt.strategies.parallel_strategy import ParallelisationStrategy
from joint_opt.optimise_parallelism.bo import OOM_ADD_FACTOR


_, case, quantity, seed = sys.argv

if quantity == 'time':
    algs = ['pm_et-m52_ucb', 'bo-m52_ucb', 'cost', 'xgb']
elif quantity == 'mem':
    algs = ['pm_et-m52_ucb', 'bo-m52_ucb', 'cost']

for alg in algs:

    print('='*20, '  ', alg, '  ', '='*20)

    if case == 'bert8':
        results_root = 'bert_bz256_8gpus'
        if quantity == 'time':
            lim = (0., 2.9)
        else:
            lim = (-1., 0.)
        plot_k = {None, 5, 10, 20, 30, 40}
    elif case == 'qwen8':
        results_root = 'qwen2_bz64_8gpus'
        if quantity == 'time':
            lim = (0.1, 0.5)
        else:
            lim = (-1., 0.6)
        plot_k = {None, 5, 10, 20, 30, 40}

    prefix = f'{results_root}-{quantity}/{alg}'
    data_folder = f'./{results_root}/results/{alg}/seed_{seed}'
    sample_data = f'{data_folder}/ran_queries.csv'

    with open(f'{data_folder}/learned_surrogate.pkl', 'rb') as f:
        import pickle as pkl
        models = pkl.load(f)
        
    all_pred_fns = dict()
        
    if alg == 'cost':
        x_upper_bound = models['x_upper_bound']
        MAX_MEM = models['gpu_max_GB']
        f = models['throughput'] if (quantity == 'time') else models['mem']
        pred_fn = lambda x: (f(x).reshape(-1).detach().cpu().numpy(), None)
        all_pred_fns['main'] = (None, pred_fn)
        
    elif alg.startswith('bo') or alg.startswith('pm'):
        
        def process_gp(gp_):
            def pred_fn(x):
                gp_.eval()
                posterior = gp_(x)
                m = posterior.mean.reshape(-1).detach().cpu().numpy()
                s = posterior.variance.reshape(-1).detach().cpu().numpy() ** 0.5
                return m, s
            return pred_fn
        
        for (k, d) in models.items():
            x_upper_bound = d['x_upper_bound']
            MAX_MEM = d['gpu_max_GB']
            gp = d['throughput'] if (quantity == 'time') else d['mem']
            all_pred_fns[f's{k}'] = (k, process_gp(gp))

    elif alg == 'xgb':
        assert quantity == 'time'
        import xgboost as xgb
        x_upper_bound = models['x_upper_bound']
        x_train = models['x_train']
        y_train = models['y_train']
        y_max = models['y_max']
        xgb_params = models['xgb_params']
        dtrain = xgb.DMatrix(x_train, y_train)
        bst = xgb.train(xgb_params, dtrain)
        pred_fn = lambda x: ((y_max * bst.predict(xgb.DMatrix(x))).reshape(-1), None)
        all_pred_fns['main'] = (None, pred_fn)

    else:
        raise ValueError


    seen_strats = set()


    def read_throughput_data(df, n=None):
        N = n if n is not None else len(df)
        idxs_use = np.array([i for i in range(N) if df['strat_str'][i] not in seen_strats], dtype=int)
        xs = [eval(x) for x in df['strat_emb']]
        queried_X = np.array(xs)[idxs_use]
        queried_y = np.array(df['score_mean'])[idxs_use]
        queried_y_std = np.array(df['score_std'])[idxs_use]
        seen_strats.update({df['strat_str'][i] for i in idxs_use})

        train_x = ParallelisationStrategy.log_transform(torch.tensor(np.array(queried_X)).double()) / x_upper_bound[...,:]
        train_y = torch.tensor(np.array(queried_y)).reshape(-1, 1).double()
        train_y_std = torch.tensor(np.array(queried_y_std)).reshape_as(train_y).double()
        train_y_var = train_y_std ** 2
        train_y_isnan = train_y.isnan().reshape(-1)

        train_x = train_x[~train_y_isnan]
        train_y = train_y[~train_y_isnan]
        train_y_var = train_y_var[~train_y_isnan]
        return train_x, train_y, train_y_var


    def read_mem_data(df, n=None):
        FACTOR = 1. + OOM_ADD_FACTOR
        N = n if n is not None else len(df)
        idxs_use = np.array([i for i in range(N) if df['strat_str'][i] not in seen_strats], dtype=int)
        xs = [eval(x) for x in df['strat_emb']]
        queried_X = np.array(xs)[idxs_use]
        mem = df['raw_mem'].apply(lambda x: np.array([float('nan')]) if 'inf' in x else np.array(eval(x)))
        queried_y = (mem.apply(lambda x: MAX_MEM * FACTOR if np.isnan(x).any() else np.max(x)))[idxs_use]
        seen_strats.update({df['strat_str'][i] for i in idxs_use})

        train_x = ParallelisationStrategy.log_transform(torch.tensor(np.array(queried_X)).double()) / x_upper_bound[...,:]
        train_y = torch.tensor(np.array(queried_y)).reshape(-1, 1).double()
        train_y = (train_y - MAX_MEM) / MAX_MEM
        train_y_var = 1e-6 * torch.tensor(np.ones_like(train_y)).reshape_as(train_y).double()
        return train_x, train_y, train_y_var


    if quantity == 'time':
        train_x, train_y, train_y_var = read_throughput_data(pd.read_csv(sample_data))
    else:
        train_x, train_y, train_y_var = read_mem_data(pd.read_csv(sample_data))
        
    test_x, test_y, test_y_var = train_x, train_y, train_y_var

    N = 400
    p = [str(x) for x in pathlib.Path(f'./{results_root}/results').rglob('**/ran_queries.csv')]
    p = [x for x in p if os.path.exists(x) and '/_' not in x]
    p1 = [x for x in p if 'pm_et' in x]
    p2 = [x for x in p if x not in p1]
    p1 = p1[:4]
    for x in p1 + p2:
        print(x)
        if quantity == 'time':
            xs, ys, yvar = read_throughput_data(pd.read_csv(x))
        else:
            xs, ys, yvar = read_mem_data(pd.read_csv(x))
        test_x = torch.concat([test_x, xs], dim=0)
        test_y = torch.concat([test_y, ys], dim=0)
        test_y_var = torch.concat([test_y_var, yvar], dim=0)
        if test_x.shape[0] > N:
            break

    test_x = test_x[:N]
    test_y = test_y[:N]
    test_y_var = test_y_var[:N]

    print(f'{train_x.shape=} {train_y.shape=} {test_x.shape=} {test_y.shape=}')

    for (case_suffix, (k, pred_fn)) in all_pred_fns.items():
        
        if k not in plot_k:
            continue

        import matplotlib.pyplot as plt
        plt.figure(figsize=(2.3, 2.3), dpi=200)
        
        y, yerr = pred_fn(test_x)
        plt.errorbar(
            test_y.reshape(-1).detach().cpu().numpy(), 
            y, 
            xerr=test_y_var.reshape(-1).detach().cpu().numpy() ** 0.5,
            # yerr=poterior.variance.reshape(-1).detach().cpu().numpy() ** 0.5, 
            yerr=yerr,
            fmt='.',
            color='blue',
            alpha=0.3,
            markersize=3,
            capsize=0.5,
            elinewidth=0.5,
            markeredgewidth=0.5,
        )
        
        y, yerr = pred_fn(train_x[:k])
        plt.errorbar(
            train_y[:k].reshape(-1).detach().cpu().numpy(), 
            y, 
            fmt='o',
            color='orange',
            alpha=1,
            markersize=3,
            zorder=6,
        )
        
        if quantity == 'time':
            y, _ = pred_fn(test_x)
            best_pred = np.argmax(y)
            plt.plot(
                [test_y[best_pred]], 
                [y[best_pred]], 
                '^',
                color='red',
                alpha=1,
                markersize=6,
                zorder=11,
            )
            plt.plot(
                2 * [float(test_y[best_pred])], 
                [-10, float(y[best_pred])], 
                ':',
                color='red',
                alpha=0.3,
            )
            best_actual = np.argmax(test_y)
            plt.plot(
                [test_y[best_actual]], 
                [y[best_actual]], 
                '*',
                color='magenta',
                alpha=1,
                markersize=6,
                zorder=10,
            )
            plt.plot(
                2 * [float(test_y[best_actual])], 
                [-10, float(y[best_actual])], 
                ':',
                color='magenta',
                alpha=0.3,
            )
        
        os.makedirs(f'./graphs_surrogates/{prefix}', exist_ok=True)
        fname = f'./graphs_surrogates/{prefix}/{case_suffix}.pdf'
        print(fname)
        # plt.title(f'After {N} trials')
        if quantity == 'time':
            plt.xlabel('Actual $\mathcal{R}$ ($s^{-1}$)')
            if alg != 'xgb':
                plt.ylabel('Pred. $\mathcal{R}$ ($s^{-1}$)')
                plt.plot([0., 5.], [0., 5.], color='gray', linestyle='--', alpha=0.3)
            else:
                plt.ylabel('Relative ranking')
        else:
            plt.xlabel('Fraction mem.')
            plt.ylabel('Pred. fraction mem.')
            plt.axhline(0., color='gray', linestyle='--', alpha=0.3)
            plt.axvline(0., color='gray', linestyle='--', alpha=0.3)
        plt.xlim(lim)
        if alg != 'xgb':
            plt.ylim(lim)
        else:
            plt.ylim((min(y) - 0.1, max(y) + 0.1))
        plt.tight_layout()
        plt.savefig(fname)
        plt.close()
        
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot([1., 2.], [1., 2.], '.', color='blue', label='Sample PCs')
ax.plot([1., 2.], [1., 2.], 'o', color='orange', label='Trialed PCs')
ax.plot([1., 2.], [1., 2.], '^', color='red', label='Predicted optimal PC')
ax.plot([1., 2.], [1., 2.], '*', color='magenta', label='Actual optimal PC')
plt.close(fig)
fig_leg = plt.figure(figsize=(8, 0.7), dpi=150)
ax_leg = fig_leg.add_subplot(111)
legend = ax_leg.legend(*ax.get_legend_handles_labels(), loc='center', ncol=10)
ax_leg.axis('off')
fig_leg.canvas.draw()  # Needed to get proper bounding box
bbox = legend.get_window_extent()
bbox = bbox.transformed(fig_leg.dpi_scale_trans.inverted())
fig_leg.savefig(f'./graphs_surrogates/legend.pdf', bbox_inches=bbox, pad_inches=0)
plt.close()
