import json
import numpy as np
import matplotlib.pyplot as plt

import argparse
import os

argparser = argparse.ArgumentParser()
argparser.add_argument('--data_dir', type=str, default='eval_results/toy/eval_results/', help='Directory containing evaluation results.')
argparser.add_argument('--data_name', type=str, default='gp', help='Name of the dataset. gp or sawtooth.')
argparser.add_argument('--n_mc', type=int, default=128, help='Number of Monte Carlo samples.')
argparser.add_argument('--n_ctx', type=int, default=128, help='Number of context points.')
argparser.add_argument('--n_tar', type=int, default=128, help='Number of target points.')
argparser.add_argument('--viz_mode', type=str, default='prediction', help='Visualization mode: prediction or samples.')
argparser.add_argument('--num_v_samples', type=int, default=4, help='Number of samples to visualize if viz_mode is samples.')
args = argparser.parse_args()

if args.data_name == 'gp':
    data_name = 'gp_data_hp'
elif args.data_name == 'sawtooth':
    data_name = 'sawtooth_data'
elif args.data_name == 'eeg':
    data_name = 'eeg_data' 
elif args.data_name == 'eeg_forecast':
    data_name = 'eeg_forecasting_data'
elif args.data_name == 'bav':
    data_name = 'bav_data'
else:
    raise ValueError('Invalid data_name. Choose from gp, sawtooth, eeg, eeg_forecast, bav.')

data_dir = os.path.join(args.data_dir, f"{data_name}_{args.n_mc}per_{args.n_ctx}con_{args.n_tar}tar")

print(f"Loading data from {data_dir}")



if data_name == 'gp_data_hp' or data_name == 'sawtooth_data':
    K16dir = os.path.join(data_dir, 'fast_buf_np_K16/evaluation_metrics.json')
    K4dir = os.path.join(data_dir, 'fast_buf_np_K4_M16/evaluation_metrics.json')

    # Replace 'path/to/file.json' with your actual JSON file path
    json_file_path = K4dir

    # print the folder name with gp_data
    print(json_file_path.split('/')[3])

    with open(json_file_path, 'r') as file:
        dataK4 = json.load(file)

    json_file_path = K16dir
    with open(json_file_path, 'r') as file:
        dataK16 = json.load(file)

    datas = [dataK4, dataK16]
    batch_idx = 0

    fig, axs = plt.subplots(2, 5, figsize=(30, 6))

    for i, data in enumerate(datas):
        for batch_idx in range(5):
            xc = np.array(data['predictions'][batch_idx]['xc'])[-1].squeeze() # [Tx]
            yc = np.array(data['predictions'][batch_idx]['yc'])[-1].squeeze() # [Tx]
            xt = np.array(data['predictions'][batch_idx]['xt'])[-1].squeeze() # [Tt]
            yt = np.array(data['predictions'][batch_idx]['yt'])[-1].squeeze() # [Tt]
            yhat = np.array(data['predictions'][batch_idx]['predictions']) # [n_mc, 1, T, Dy]

            if args.viz_mode == 'prediction':
                # get yhat median and iqr (yhat is get from mixture of gaussians)
                yhat_median = np.median(yhat, axis=0).squeeze()
                yhat_q1 = np.percentile(yhat, 25, axis=0).squeeze()
                yhat_q3 = np.percentile(yhat, 75, axis=0).squeeze()

                # sort xt and yhat_median and iqr based on xt
                sort_indices = np.argsort(xt)
                xt = xt[sort_indices]
                yt = yt[sort_indices]
                yhat_median = yhat_median[sort_indices]
                yhat_q1 = yhat_q1[sort_indices]
                yhat_q3 = yhat_q3[sort_indices]
                
                axs[i][batch_idx].plot(xt, yhat_median, color='red', label='Predicted Median', zorder=2)
                axs[i][batch_idx].fill_between(xt, yhat_q1, yhat_q3, color='red', alpha=0.3, label='Predicted IQR', zorder=2)

            elif args.viz_mode == 'samples':
                yhat = yhat[:args.num_v_samples].squeeze() # [T]
                sort_indices = np.argsort(xt)
                xt = xt[sort_indices]
                yt = yt[sort_indices]
                for s in range(yhat.shape[0]):
                    yhat_s = yhat[s].squeeze()
                    # sort xt and yhat_s based on xt
                    yhat_s = yhat_s[sort_indices]
                    if s == 0:
                        axs[i][batch_idx].plot(xt, yhat_s, color='red', alpha=0.5, zorder=1, label='joint samples')
                    else:
                        axs[i][batch_idx].plot(xt, yhat_s, color='red', alpha=0.5, zorder=1)

            axs[i][batch_idx].scatter(xc, yc, color='black', label='Context Points', s=25, alpha=1, marker='X', zorder=3)
            axs[i][batch_idx].scatter(xt, yt, color='green', label='True Target Points', s=20, alpha=0.3, zorder=0)

            

            handles, labels = axs[0][0].get_legend_handles_labels()
            
        # print label in x axis K16 or K4
        axs[i][0].set_ylabel('K16' if i == 0 else 'K4', fontsize=16)
        fig.legend(handles, labels, loc='lower center', ncol=4)
    fig.suptitle(f"Model Predictions with Fast-buffer NP - n_ctx={args.n_ctx}, n_tar={args.n_tar}, n_mc={args.n_mc}")
    plt.xlabel('x')
    plt.ylabel('y')
    plt.show()

elif data_name == 'bav_data':
    K16dir = os.path.join(data_dir, 'fast_buf_np_K16/evaluation_metrics.json')

    json_file_path = K16dir
    with open(json_file_path, 'r') as file:
        dataK16 = json.load(file)

    datas = [dataK16]
    batch_idx = 0

    fig, axs = plt.subplots(4, 1, figsize=(25, 30))

    for i, data in enumerate(datas):
        for batch_idx in range(4):
            xc = np.array(data['predictions'][batch_idx+5]['xc']).squeeze() # [Tx]
            yc = np.array(data['predictions'][batch_idx+5]['yc']).squeeze() # [Tx]
            xt = np.array(data['predictions'][batch_idx+5]['xt']).squeeze() # [Tt]
            yt = np.array(data['predictions'][batch_idx+5]['yt']).squeeze() # [Tt]
            yhat = np.array(data['predictions'][batch_idx+5]['predictions']) # [n_mc, 1, T, Dy]

            if args.viz_mode == 'prediction':
                # get yhat median and iqr (yhat is get from mixture of gaussians)
                yhat_median = np.median(yhat, axis=0).squeeze()
                yhat_q1 = np.percentile(yhat, 25, axis=0).squeeze()
                yhat_q3 = np.percentile(yhat, 75, axis=0).squeeze()

                xc_index = np.arange(len(xc))
                xt_index = np.arange(len(xt)) + len(xc)
                axs[batch_idx].plot(xt_index, yhat_median, color='red', label='Predicted Median', zorder=2)
                axs[batch_idx].fill_between(xt_index, yhat_q1, yhat_q3, color='red', alpha=0.3, label='Predicted IQR', zorder=2)

            elif args.viz_mode == 'samples':
                yhat = yhat[:args.num_v_samples].squeeze() # [T]
                sort_indices = np.argsort(xt)
                xt = xt[sort_indices]
                yt = yt[sort_indices]
                for s in range(yhat.shape[0]):
                    yhat_s = yhat[s].squeeze()
                    # sort xt and yhat_s based on xt
                    yhat_s = yhat_s[sort_indices]
                    if s == 0:
                        axs[i][batch_idx].plot(xt, yhat_s, color='red', alpha=0.5, zorder=1, label='joint samples')
                    else:
                        axs[i][batch_idx].plot(xt, yhat_s, color='red', alpha=0.5, zorder=1)

            xc_index = np.arange(len(xc))
            xt_index = np.arange(len(xt)) + len(xc)

            if xc.ndim > 1:
                axs[batch_idx].scatter(xc_index, yc, color='black', label='Context Points', s=25, alpha=1, marker='X', zorder=3)
            axs[batch_idx].scatter(xt_index, yt, color='green', label='True Target Points', s=20, alpha=0.3, zorder=0)

            

            handles, labels = axs[0].get_legend_handles_labels()
            
        # print label in x axis K16 or K4
        axs[0].set_ylabel('K16' if i == 0 else 'K4', fontsize=16)
        fig.legend(handles, labels, loc='lower center', ncol=4)
    fig.suptitle(f"Model Predictions with Fast-buffer NP - n_ctx={args.n_ctx}, n_tar={args.n_tar}, n_mc={args.n_mc}")
    fig.tight_layout()
    plt.xlabel('x')
    plt.ylabel('y')
    plt.savefig(f'bav_example_K16_{args.n_mc}_{args.n_tar}_{args.n_ctx}_{args.viz_mode}.png', dpi=300)
    plt.show()

elif data_name == 'eeg_data' or data_name == 'eeg_forecasting_data':
    K8dir = os.path.join(data_dir, 'fast_buf_np_KT/evaluation_metrics.json')
    print(K8dir.split('/')[3])

    with open(K8dir, 'r') as file:
        dataK8 = json.load(file)

    fig, axs = plt.subplots(7, 5, figsize=(30, 20))

    for batch_idx in range(5):
        xc = np.array(dataK8['predictions'][batch_idx]['xc']).squeeze() # [Tx]
        yc = np.array(dataK8['predictions'][batch_idx]['yc']).squeeze() # [Tx]
        xt = np.array(dataK8['predictions'][batch_idx]['xt']).squeeze() # [Tt]
        yt = np.array(dataK8['predictions'][batch_idx]['yt']).squeeze() # [Tt]
        yhat = np.array(dataK8['predictions'][batch_idx]['predictions']) # [n_mc, 1, T, Dy]

        sort_indices = np.argsort(xt)
        xt = xt[sort_indices]
        yt = yt[sort_indices]
        yhat = yhat[:, :, sort_indices, :]

        for i in range(7):
            # write channel number vertically at y axis
            axs[i, batch_idx].set_ylabel(f'Ch {i+1}', fontsize=12)

            if args.viz_mode == 'prediction':
                # get yhat median and iqr (yhat is get from mixture of gaussians)
                yhat_median = np.median(yhat[:, :, :, i], axis=0).squeeze()
                yhat_q1 = np.percentile(yhat[:, :, :, i], 25, axis=0).squeeze()
                yhat_q3 = np.percentile(yhat[:, :, :, i], 75, axis=0).squeeze()
                axs[i, batch_idx].plot(xt, yhat_median, color='red', label='Predicted Median', zorder=2)
                axs[i, batch_idx].fill_between(xt, yhat_q1, yhat_q3, color='red', alpha=0.3, label='Predicted IQR', zorder=2)

            elif args.viz_mode == 'samples':
                yhat_samples = yhat[:args.num_v_samples, :, :, i].squeeze() # [T]
                for s in range(yhat_samples.shape[0]):
                    yhat_s = yhat_samples[s].squeeze()
                    if s == 0:
                        axs[i, batch_idx].plot(xt, yhat_s, color='red', alpha=0.5, zorder=1, label='joint samples')
                    else:
                        axs[i, batch_idx].plot(xt, yhat_s, color='red', alpha=0.5, zorder=1)

            axs[i, batch_idx].plot(xt, yt[:,i], color='blue', label='True Target Points', zorder=0)
            # plot context points as scatter
            axs[i, batch_idx].scatter(xc, yc[:,i], color='black', label='Context Points', s=25, alpha=1, marker='X', zorder=3)

            handles, labels = axs[0, 0].get_legend_handles_labels()
            fig.legend(handles, labels, loc='lower center', ncol=4)
    fig.suptitle(f"Model Predictions with Fast-buffer NP - n_ctx={args.n_ctx}, n_tar={args.n_tar}, n_mc={args.n_mc}")

    plt.savefig('eeg_example.png', dpi=300)
    plt.show()