import lib

import numpy as np
import time

from tqdm import tqdm

import generate_dWPE

import itertools
import os

import sys
from pathlib import Path

#parameters
Nl = [60,100,120,140,160]
Tl = [50,100,300,500,1000,2000,3000,4000]

R = 20000

rep = 3

#solvertypel = ['CACm','MHCACm','AIM','CAC','SA','PT']
solvertypel = ['CACm','MHCACm','AIM','CAC','PT']

biasl = [0.0,6.0,8.0,10.0,12.0]

biasshow = 10 #10
Nshow = 100 #for (c) #100

overwrite = 1

fMHv = {}
fMHv['CACm'] = 0.0
fMHv['AIM'] = 0.0
fMHv['CAC'] = 0.0
fMHv['MHCACm'] = 0.1
fMHv['MHCAC'] = 0.1
fMHv['SA'] = 0.0
fMHv['PT'] = 0.0

#generate problem instance
alphatxt = '0.80'

# Construct the folder name
figfolder = "fig_run_allGS"

# Check if the folder already exists
if not os.path.exists(figfolder):
    # Create the folder if it doesn't exist
    os.makedirs(figfolder)
    print(f"Folder created: {figfolder}")
else:
    print(f"Folder already exists: {figfolder}")

#############################################################

aares2 = []

for bias in biasl:
    
    ares = []
    
    for solvertype in solvertypel:
        
        res = []
        fMH = fMHv[solvertype]
        
        # Use tqdm with itertools.product to iterate over all combinations
        for combination in tqdm(itertools.product(Nl, Tl), desc="Iterating combinations"):
                
            N = combination[0]
            T = combination[1]
                
            savefile = figfolder + "/%s_N_%d_T_%d_alpha_%0.2f_bias_%d_fmH_%0.2f_info.txt" % (solvertype,N,T, float(alphatxt),bias,fMH)
            
            if Path(savefile).exists():
          
                p0l,TTSl,p0maxl,TTSminl,p0minl,TTSmaxl = np.loadtxt(savefile)
    
            else:
                
                print(f"{savefile} is missing!")
                void=(np.ones(rep)*np.nan).tolist()
                res.append([void,void,void,void,void,void])
                continue
 
            res.append([p0l,TTSl,p0maxl,TTSminl,p0minl,TTSmaxl])
            
        ares.append(res)
    
    aares2.append(ares)
    
aares = aares2
        
#############################################################


Nlshow = [Nshow] #for (a,b)

ibiasshow = np.where(np.array(biasl)==biasshow)[0][0]
iN = np.where(np.array(Nl)==Nshow)[0][0]
    
import matplotlib.pyplot as plt
plt.style.use('plot_style.txt')
    

plt.figure(figsize=(12,4))

#colors = ['r','g','b','m','k','y']
#markers = ['d','s','o','x','.','<']
#lines = ['-','--',':','-.','-','--']
colors = ['r','g','b','m','y']
markers = ['d','s','o','x','<']
lines = ['-','--',':','-.','--']


ymin = 10**2*2
ymax = 10**7

aaaTTSm = []
aaaTTSv = []

for ibias, bias in enumerate(biasl):
    
    ares = aares[ibias]
    
    aaTTSm = []
    aaTTSv = []
   
    for res,solvertype,color,line,marker in zip(ares,solvertypel,colors,lines,markers):
        
        res = np.reshape(res,[len(Nl),len(Tl),6,rep])
        
        aTTSm = []
        aTTSv = []

        for i, N in enumerate(Nl):
            
            TTS = np.nanmean(res[i,:,1,:],1)
            TTSmin = np.nanmean(res[i,:,3,:],1)
            TTSmax = np.nanmean(res[i,:,5,:],1)
            
            TTSv = np.nanstd(res[i,:,1,:],1)*1.96/np.sqrt(rep)
            TTSminv = np.nanstd(res[i,:,3,:],1)*1.96/np.sqrt(rep)
            TTSmaxv = np.nanstd(res[i,:,5,:],1)*1.96/np.sqrt(rep)
            
            aTTSm.append([TTS,TTSmin,TTSmax])
            aTTSv.append([TTSv,TTSminv,TTSmaxv])
            
            if ibias==ibiasshow and N in Nlshow: #bias to show
                    
      
                plt.subplot(1,3,3)
                
                plt.plot(Tl,TTSmin,label='%s, N=%d' %(solvertype,N),linestyle=line,color=color,marker=marker)
                plt.fill_between(Tl,TTSmin-TTSminv,TTSmin+TTSminv,color=color, interpolate=True, alpha=0.3)

            
        aaTTSm.append(aTTSm)
        aaTTSv.append(aTTSv)
        
    aaaTTSm.append(aaTTSm)
    aaaTTSv.append(aaTTSv)
    
ax =plt.subplot(1,3,3)

plt.legend(ncols=2)

plt.xlabel('T')
plt.ylabel(r'$\mathrm{TTS}_{\mathrm{min}}$ (easy ground-state)')

plt.yscale('symlog')

plt.xlim(np.min(Tl),np.max(Tl))
plt.ylim(ymin,ymax)
ax = plt.gca()

ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)

plt.grid(True)

ax.set_title('c')


def convert_res(aaaTTSm,aaaTTSv,iN):
        
    aaaTTSm = np.array(aaaTTSm)
    aaaTTSm[aaaTTSm<0]=np.nan
    
    matm_avg = np.zeros((len(biasl),len(solvertypel)))
    matm_min = np.zeros((len(biasl),len(solvertypel)))
    matm_max = np.zeros((len(biasl),len(solvertypel)))
    matv_avg = np.zeros((len(biasl),len(solvertypel)))
    matv_min = np.zeros((len(biasl),len(solvertypel)))
    matv_max = np.zeros((len(biasl),len(solvertypel)))
    
    for i in range(len(biasl)):
        for j in range(len(solvertypel)):
            
            if np.sum(np.isnan(np.array(aaaTTSm)[i,j,iN,0,:]))<len(Tl):
                imat_avg = np.nanargmin(np.array(aaaTTSm)[i,j,iN,0,:])
                matm_avg[i,j] = np.array(aaaTTSm)[i,j,iN,0,imat_avg]
                matv_avg[i,j] = np.array(aaaTTSv)[i,j,iN,0,imat_avg]
            else:
                matm_avg[i,j] = np.nan
                matv_avg[i,j] = np.nan
            
            if np.sum(np.isnan(np.array(aaaTTSm)[i,j,iN,1,:]))<len(Tl):
                imat_min = np.nanargmin(np.array(aaaTTSm)[i,j,iN,1,:])
                matm_min[i,j] = np.array(aaaTTSm)[i,j,iN,1,imat_min]
                matv_min[i,j] = np.array(aaaTTSv)[i,j,iN,1,imat_min]
            else:
                matm_min[i,j] = np.nan
                matv_min[i,j] = np.nan
            
            if np.sum(np.isnan(np.array(aaaTTSm)[i,j,iN,2,:]))<len(Tl):
                imat_max = np.nanargmin(np.array(aaaTTSm)[i,j,iN,2,:])
                matm_max[i,j] = np.array(aaaTTSm)[i,j,iN,2,imat_max]
                matv_max[i,j] = np.array(aaaTTSv)[i,j,iN,2,imat_max]
            else:
                matm_max[i,j] = np.nan
                matv_max[i,j] = np.nan

    return matm_avg, matv_avg, matm_min, matv_min, matm_max, matv_max

matm_avg, matv_avg, matm_min, matv_min, matm_max, matv_max = convert_res(aaaTTSm,aaaTTSv,iN)

plt.subplot(1,3,1)


count = 0
for solvertype,color,line in zip(solvertypel,colors,lines):
    
    
    if solvertype=='AIM': #because error bar is confusing (skip for this point)
        
        plt.plot(biasl,matm_min[:,count],marker=markers[count],color=color,linestyle=line,label='%s' % (solvertype))
        plt.fill_between(biasl[0:-1],matm_min[0:-1,count]-matv_min[0:-1,count],matm_min[0:-1,count]+matv_min[0:-1,count],color=color, interpolate=True, alpha=0.3)
    
    else:
    
        plt.plot(biasl,matm_min[:,count],marker=markers[count],color=color,linestyle=line,label='%s' % (solvertype))
        plt.fill_between(biasl,matm_min[:,count]-matv_min[:,count],matm_min[:,count]+matv_min[:,count],color=color, interpolate=True, alpha=0.3)

    
    count+=1


plt.xlabel(r'bias $b$')
plt.ylabel(r'$\mathrm{TTS}_{\mathrm{min}}*$ (easy ground-state)')

plt.yscale('symlog')

plt.legend(ncols=2)

plt.ylim(ymin,ymax)
plt.xlim(np.min(biasl),np.max(biasl))

ax = plt.gca()
ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)

ax.set_title('a')

plt.grid(True)
         
plt.subplot(1,3,2)


count = 0
for solvertype,color,line in zip(solvertypel,colors,lines):
    
    plt.plot(biasl,matm_max[:,count],marker=markers[count],color=color,linestyle=line,label='%s' % (solvertype))
    plt.fill_between(biasl,matm_max[:,count]-matv_max[:,count],matm_max[:,count]+matv_max[:,count],color=color, interpolate=True, alpha=0.3)
    
    count+=1


plt.xlabel(r'bias $b$')
plt.ylabel(r'$\mathrm{TTS}_{\mathrm{max}}*$ (hard ground-state)')

plt.yscale('symlog')

plt.legend(ncols=2,loc='lower left')

plt.ylim(ymin,ymax)
plt.xlim(np.min(biasl),np.max(biasl))

ax = plt.gca()
ax.spines["top"].set_visible(True)
ax.spines["right"].set_visible(True)

ax.set_title('b')

plt.grid(True)

plt.tight_layout()

plt.savefig(figfolder + f'/allGS_wishart.png')
plt.savefig(figfolder + f'/allGS_wishart.eps')
plt.savefig(figfolder + f'/allGS_wishart.pdf')

# export results of table 2
if 1:
    # export TTS, TTSmin, TTSmax
    # for N=100, 140
    # for all solvers
    
    ib = 4 #
    
    Nshow = 100
    
    Nshow = 140 
    iN = np.where(np.array(Nl)==Nshow)[0][0]
    matm_avg, matv_avg, matm_min, matv_min, matm_max, matv_max = convert_res(aaaTTSm,aaaTTSv,iN)
    
    ptypel = ['avg','min','max']
    
    txt = ''
    for i, solvertype in enumerate(solvertypel):
    
        if solvertype=='MHCACm':
            txt = txt + '\midrule \n'
        
        txt = txt + solvertype + ' & '
        
        for ptype in ptypel:    
    
            for j, N in enumerate([100,140]):
                
                iN = np.where(np.array(Nl)==N )[0][0]
                matm_avg, matv_avg, matm_min, matv_min, matm_max, matv_max = convert_res(aaaTTSm,aaaTTSv,iN)
            
                if ptype == 'avg':
                    
                    if ~np.isnan(matm_avg[ib,i]):
                        #txt = txt + ' %0.2f $\pm$ %0.2f & ' % (matm_avg[ib,i]/1000,matv_avg[ib,i]/1000)
                        txt = txt + ' %0.2f & ' % (matm_avg[ib,i]/1000)
                    else:
                        txt = txt + 'NA &'
                    
                if ptype == 'min':
                    
                    if ~np.isnan(matm_min[ib,i]):
                        #txt = txt + '%0.2f $\pm$ %0.2f &' % (matm_min[ib,i]/1000,matv_min[ib,i]/1000)
                        txt = txt + '%0.2f&' % (matm_min[ib,i]/1000)
                    else:
                        txt = txt + 'NA &'
                    
                if ptype == 'max':
                    
                    if ~np.isnan(matm_max[ib,i]):
                        #txt = txt + '%0.2f $\pm$ %0.2f &' % (matm_max[ib,i]/1000,matv_max[ib,i]/1000)
                        txt = txt + '%0.2f &' % (matm_max[ib,i]/1000)
                    else:
                        txt = txt + 'NA &'
                    
        txt = txt + '\n'
        
    with open(figfolder + '/table2.txt', 'w') as file:
        file.write(txt)