import matplotlib as mpl
mpl.use('Agg')
mpl.rcParams.update({'font.size': 22})
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import scipy
from jax import random
import importlib
from model import Y2RegressionInterpolated
ns = [1000]
seed = [0,1,2,3,4,]
alphas = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
models = ['Y2Regression']
g = 20
for d in [1, 10]:
    for model in models:
        for n in ns:
            data = []
            for alpha in alphas:
                pvi = []
                npe = []
                nre = []
                nle = []
                pvi_sample = []
                pvi_best = 1e9
                sbi_sample = []
                sbi_best = 1e9
                for s in seed:
                    try:
                        with open(f'result/{model}_{alpha}_{n}_{g}_{d}/PVI_{s}_100/eval', 'r') as f:
                            line = f.readline().split()
                            line = np.array([float(l) for l in line])
                            val = line[0]
                            pvi.append(val)
                    except:
                        pass

                    try:
                        with open(f'result/{model}_{alpha}_{n}_{g}_{d}/Torch_{s}_1_1/eval', 'r') as f:
                            line = f.readline().split()
                            line = np.array([float(l) for l in line])
                            val = line[0]
                            npe.append(val)
                    except:
                        pass

                    try:
                        with open(f'result/{model}_{alpha}_{n}_{g}_{d}/Torch3_{s}_1_1/eval', 'r') as f:
                            line = f.readline().split()
                            line = np.array([float(l) for l in line])
                            val = line[0]
                            nle.append(val)
                    except:
                        pass

                    try:
                        with open(f'result/{model}_{alpha}_{n}_{g}_{d}/Torch4_{s}_1_1/eval', 'r') as f:
                            line = f.readline().split()
                            line = np.array([float(l) for l in line])
                            val = line[0]
                            nre.append(val)
                    except:
                        pass
                print(pvi, npe, nle, nre)
                ids = np.argsort(pvi)[:np.minimum(3, len(pvi))]
                pvi = np.array(pvi)
                target = np.mean(pvi[ids])
                for i in ids:
                    data.append({'N': n, 'alpha': alpha, 'value': pvi[i] / target, 'type': 'PVI-CRPS', 'model': model})
                ids = np.argsort(npe)[:np.minimum(3, len(npe))]
                for i in ids:
                    data.append({'N': n, 'alpha': alpha, 'value': npe[i] / target, 'type': 'NPE', 'model': model})
                # ids = np.argsort(nle)[:np.minimum(3, len(nle))]
                # for i in ids:
                #    data.append({'N': n, 'alpha': alpha, 'value': nle[i], 'type': 'NLE', 'model': model})
                ids = np.argsort(nre)[:np.minimum(3, len(nre))]
                for i in ids:
                    data.append({'N': n, 'alpha': alpha, 'value': nre[i] / target, 'type': 'NRE', 'model': model})

            print(data)
            data = pd.DataFrame(data)
            if d != 1:
                plt.figure(figsize=(5, 4))
            else:
                plt.figure(figsize=(6, 4))
            lg = sns.lineplot(data=data, x='alpha', y='value', hue='type', err_style='bars', legend=(d==1))
            plt.title(f'Testing CRPS')
            plt.title(f'g=20 m={d}')
            plt.ylim([0.9, 1.5])
            if d == 1:
                plt.ylabel('Ratio')
                lg.legend_.set_title(None)
            else:
                plt.ylabel('')
                plt.yticks([])
            plt.xlabel('$\\alpha$')
            # if model == 'Y2RegressionWrong':
            #    plt.ylim([30000, 100000])
            plt.tight_layout()
            plt.savefig(f'figure2/{model}_pll_alpha_{d}_{g}.pdf')
            plt.clf()

