import numpy as np

from tqdm import tqdm
import time

import itertools
import lib

import os

####################################################
#DATA

Nl = [60,100,120,140]
T = [50,100,300,500,1000,2000,3000,4000]

alphatxt = '0.80'

to_plot = ['unbias','bias']

####################################################
#SOLVER

#solvertypel = ['MHCACm','CACm','AIM','CAC','SA']
solvertypel = ['MHCACm','CACm','AIM','CAC','PT']

####################################################


fMHv = {}
fMHv['CACm'] = 0.0
fMHv['AIM'] = 0.0
fMHv['CAC'] = 0.0
fMHv['MHCACm'] = 0.1
fMHv['MHCAC'] = 0.1
fMHv['SA'] = 0.0
fMHv['PT'] = 0.0
    
import matplotlib.pyplot as plt
plt.style.use('plot_style.txt')

# Construct the folder name
figfolder = "fig_tune_wishart"

# Check if the folder already exists
if not os.path.exists(figfolder):
    # Create the folder if it doesn't exist
    os.makedirs(figfolder)
    print(f"Folder created: {figfolder}")
else:
    print(f"Folder already exists: {figfolder}")

    
    
plt.figure(figsize=(12, 4))

colors = ['r','g','b','m','y']
markers = ['d','s','o','x','<']
lines = ['-','--',':','-.','--']

####################################################

for ib, biastxt in enumerate(to_plot):  

    data = {}
    data['datatype'] = biastxt
    
    bias = 0.0
    
    if data['datatype'] == 'unbias':
        data['D_WPE'] = 1
        data['R_WPE'] = -1
        
        data['bias'] = 0.0
        
    elif data['datatype'] == 'bias':
        #first GS is close to ferromagnetic
        data['D1_WPE'] = 1
        data['R1_WPE'] = 3
        
        #second GS is random
        data['D2_WPE'] = 1
        data['R2_WPE'] = -1
        
        data['bias'] = 12.0
    
        data['D_WPE'] = data['D1_WPE'] + data['D2_WPE']
    
        bias = data['bias']
    
    ####################################################
    
    res_sol = []
    
    for solvertype in solvertypel:
    
        res = []
        
        fMH = fMHv[solvertype]
    
        for N in Nl:
    
            p0mat = np.zeros(len(T))
            TTSmat = np.zeros(len(T))
            TTSmatv = np.zeros(len(T))

            for iT, cT in enumerate(T):
    
                folder_name_ = f"{solvertype}_{bias}_{fMH}"
                file_name = f"wishart_{N}_{alphatxt}_{bias}_{cT}_{fMH}.txt"
                
                #print(file_name)
    
                p0, opt_params, pvec = lib.read_file(folder_name_, file_name)
    
                if solvertype=='AIM' and N==100:
                    print('AIM',bias,p0)
    
                #bootstrap
                if len(opt_params)>0:
                    p0mat[iT] = p0
                    #TTS = np.log(1-0.99)/np.log(1-p0)*cT
                    TTS = np.mean(np.log(1-0.99)/np.log(1-np.array(pvec))*cT)
                    TTSmat[iT] = TTS
    
                    TTSv = np.std(np.log(1-0.99)/np.log(1-np.array(pvec))*cT)
                    TTSmatv[iT] = TTSv*1.96/np.sqrt(len(pvec))
    
                    print("T:", T, "p0:", p0, "TTS:", TTS)
    
            res.append((p0mat,TTSmat,TTSmatv))
            
        res_sol.append(res)
        
    ####################################################

    
    solvertypel2 = ['MHCACm','CACm']

    for res, solvertype, color, line in zip(res_sol,solvertypel2,colors,lines):
    
        for cres, N, marker in zip(res,Nl,markers):
    
            p0mat,TTSmat,TTSmatv = cres
            
            TTSmat[TTSmat==0] = np.nan
            TTSmatv[TTSmat==0] = np.nan

    ax = plt.subplot(1,2,ib+1)
    
    #SA results
    if bias==0.0:
        Ps_list = [0.5735,0.07949999999999999,0.00475]
        N_SA = [50,100,150]
        T_list = [1000,1000,1000]
        initial_temp_list = [30,50,50]
        final_temp_list = [10,30,30]

        TTS_SA = np.log(1-0.99)/np.log(1-np.array(Ps_list))*np.array(T_list)
    else:
        
        SAres = np.loadtxt(f'SA_bias_TTS/bias_{bias}.txt')
        N_SA = SAres[:,0]
        TTS_SA = SAres[:,1]
    
    plt.rcParams['figure.figsize'] = [6, 4]  # width: 6 inches, height: 4 inches
    
    for res, solvertype, color, line, marker in zip(res_sol,solvertypel,colors,lines, markers):
    
        TTS_CAC = []
        TTS_CACv = []
        
        for cres, N in zip(res,Nl):
        
            p0mat,TTSmat,TTSmatv = cres
            
            TTSmat[TTSmat<=0] = np.nan
            TTSmatv[TTSmat<=0] = np.nan
            
            if np.sum(np.isnan(TTSmat))<len(T):
                imin = np.nanargmin(TTSmat)
                TTS_CAC.append(TTSmat[imin])
                TTS_CACv.append(TTSmatv[imin])
            else:
                TTS_CAC.append(np.nan)
                TTS_CACv.append(np.nan)
        
        TTS_CAC = np.array(TTS_CAC)
        TTS_CACv = np.array(TTS_CACv)
        
        plt.plot(Nl,TTS_CAC,linestyle=line,marker=marker,color=color,label='%s' % solvertype)
        
        plt.fill_between(Nl,TTS_CAC-TTS_CACv,TTS_CAC+TTS_CACv,color=color, interpolate=True, alpha=0.3)
    
    if 1:   
        plt.plot(N_SA,TTS_SA,'d-k',label='SA (in sweeps)')
    
    plt.yscale('symlog')
    plt.xlabel(r'N')
    plt.ylabel('TTS (any ground-state)')
    plt.legend(ncols=2)
    
    plt.xlim(np.min(Nl),np.max(Nl))
    
    plt.ylim(10**2,10**7)
    
    plt.gca().spines["top"].set_visible(True)
    plt.gca().spines["right"].set_visible(True)
    
    plt.grid(True)
    
    if ib==0:
        plt.title('a (unbiased Wishart)')
    if ib==1:
        plt.title('b (biased Wishart, b=%d)' % bias)
            
    
plt.savefig(figfolder + f'/optimization_T_wishart_fMH_{fMH}.pdf')



