import PT_Ising
import lib

import numpy as np
import time

from tqdm import tqdm

import generate_dWPE

import itertools
import os

import sys
from pathlib import Path

#parameters
Nl = [60,100,120,140,160]
Tl = [50,100,300,500,1000,2000,3000,4000]

R = 20000

rep = 3

solvertypel = ['PT']

biasl = [0.0,6.0,8.0,10.0,12.0]

biasshow = 10 #10
Nshow = 100 #for (c) #100

overwrite = 1

fMHv = {}
fMHv['PT'] = 0.0

#generate problem instance
alphatxt = '0.80'

do_RUN = 1

# Construct the folder name
figfolder = "fig_run_allGS"

# Check if the folder already exists
if not os.path.exists(figfolder):
    # Create the folder if it doesn't exist
    os.makedirs(figfolder)
    print(f"Folder created: {figfolder}")
else:
    print(f"Folder already exists: {figfolder}")


#############################################################

aares = []

for bias in biasl:

    if bias == 0:
        D_WPE = 1 # 3
        R_WPE = -1 # 6
    else:        
        #first GS is close to ferromagnetic
        D1_WPE = 1
        R1_WPE = 3
        
        #second GS is random
        D2_WPE = 1
        R2_WPE = -1
        
        bias = bias
    
        D_WPE = D1_WPE + D2_WPE
    
    data = {}
    if bias == 0:
        data['datatype'] = 'unbias'
    else:
        data['datatype'] = 'bias'
        
    if data['datatype'] == 'unbias':
        data['D_WPE'] = D_WPE
        data['R_WPE'] = R_WPE
        data['bias'] = bias
        
    elif data['datatype'] == 'bias':
        #first GS is close to ferromagnetic
        data['D1_WPE'] = D1_WPE
        data['R1_WPE'] = R1_WPE
        
        #second GS is random
        data['D2_WPE'] = D2_WPE
        data['R2_WPE'] = R2_WPE
        
        data['bias'] = bias


        data['D_WPE'] = data['D1_WPE'] + data['D2_WPE']

        bias = data['bias']

    
    ares = []
    
    for solvertype in solvertypel:
    
        #folder_name = lib.create_timestamped_folder(solvertype)
        
        res = []
        
        # Use tqdm with itertools.product to iterate over all combinations
        for combination in tqdm(itertools.product(Nl, Tl), desc="Iterating combinations"):
                
            N = combination[0]
            T = combination[1]
            
            #configure
            fMH = fMHv[solvertype]
  
            if solvertype=='PT':
               PARAM_NAMES = ["max_temp","min_temp","splt"]
               hyperparams = {'T': T, 'K': 400, 'use_default_beta': True}
      
            #load optimal parameters
            folder_name = f"{solvertype}_{bias}_{fMH}"
            
            file_name = f"wishart_{N}_{alphatxt}_{bias}_{T}_{fMH}.txt"
            
            if Path(folder_name+'/'+file_name).exists():
                p0, opt_params, pvec = lib.read_file(folder_name, file_name) 
                
            else:
                print(f"{file_name} is missing!")
                void=(np.ones(rep)*np.nan).tolist()
                res.append([void,void,void,void,void,void])
                continue
 
            
            x = np.log(opt_params)
            x = np.expand_dims(x,1)
            x = np.tile(x,[1,R])
            
            #pt_device = "cpu"
            pt_device = "cuda"
            
            p0l = []
            TTSl = []
            p0minl = []
            TTSmaxl = []
            p0maxl = []
            TTSminl = []
            
            savefile = figfolder + "/%s_N_%d_T_%d_alpha_%0.2f_bias_%d_fmH_%0.2f_info.txt" % (solvertype,N,T, float(alphatxt),bias,fMH)
            if ~overwrite: #continue if exists already
                if Path(savefile).exists(): 
                    print(f"{savefile} already exists!")
                    void=(np.ones(rep)*np.nan).tolist()
                    res.append([void,void,void,void,void,void])
                    continue
            
            for k in range(rep):
                
                if 0:
                    i = 0
                    J, eps0, H0 = lib.load_wishart(N,alphatxt,i)
                    prec = 1
                    
                else:
                    i = int(1000* np.random.rand())*rep
                    
                    alpha = 0.8
                    M = int(N*alpha)
                    
                    if bias==0:
                        J, H0, gs = generate_dWPE.gen_dWPE(i, N, M, D_WPE, R_WPE)
                    else:
                        J, H0, gs = generate_dWPE.gen_dWPE_cluster(i, N, M, D1_WPE, R1_WPE, D2_WPE, R2_WPE, bias =  bias)
                        J_bias = np.ones((N,N))
                        J_bias = J_bias - np.diag(np.diag(J_bias))
                        J = J + (0)*J_bias*10**(-4)
                        
                    #obfuscate the GS with Gauge transform
                    if 1:
                        x0 = np.sign(np.random.uniform(-1,1,N))
                        J = J * (np.expand_dims(x0,0) * np.expand_dims(x0,1))
                        gs = np.array(gs) * np.expand_dims(x0,1)
                        
                    eps0 = np.mean(np.abs(J))
                    
                    #prec = 10**(-6) #precision for GS energy
                    prec = 1 #precision for GS energy
                    H0 = np.floor(H0/prec)
                    
                    gs = (np.array(gs).T)
                    gs = gs/np.expand_dims(gs[:,0],1)
                    gs = gs.tolist()
                
                #run
                if solvertype=='PT':
                    PARAM_NAMES = ["max_temp","min_temp","splt"]
                    #x = np.log([0.1,0.1])
                    hyperparams = {'T': combination[1], 'K': 400, 'use_default_beta': True}
                        
                #generate problem instance
                instance = {'alphatxt':alphatxt, 'N': combination[0]}
                flags = {'savetraj':1, 'pt_device': pt_device,'solvertype':solvertype}
                #tunerparams = {'nsamp_max': 5000, 'R': 200}
                tunerparams = {'nsamp_max': 100000, 'R': 200}
                #x = [0.1,1.0,0.1,1.0]
                
                #setup solver
                solver = PT_Ising.PT(N, J, H0, prec)

                if False:#get default beta
                
                    beta_start,beta_end = solver. get_default_beta()
                    beta_start = beta_start
                    beta_end = beta_end
                    
                else:    #optimal
                
                    min_temp,max_temp,splt = opt_params
             
                #run
                K = R
                energies, solutions = solver.run_pt(min_temp,max_temp,splt, T, K)
                
                idx = np.where(energies==H0)[0]
                allGS = np.array(solutions.tolist())[idx,:]
                allGS = allGS/np.expand_dims(allGS[:,0],1)
            
                #evaluate
                p0 = np.mean(energies==H0)

                Ps = p0
                
                if p0>0:
                    TTS = np.log(1-0.99)/np.log(1-p0)*T
                else:
                    TTS = np.nan
                
                # count recurrence of GS
                count_rec = lib.count_vector_occurrences(allGS,gs)
                print('---------------',count_rec)
                
                if np.abs(np.sum(count_rec)/R-p0)>0:
                    sys.exit("Error: Something went wrong")
  
                if len(count_rec)==2: #due to the way we are planting
                    p0min = count_rec[1]/R
                    #p0min = np.min(count_rec)/R
                    p0max = count_rec[0]/R
                    #p0max = np.max(count_rec)/R 
                elif len(count_rec)==1:
                    p0min = count_rec[0]/R
                    p0max = count_rec[0]/R
                elif len(count_rec)>2:
                    p0min = np.min(count_rec)/R
                    p0max = np.max(count_rec)/R 
                    
                if p0min>0:
                    TTSmax = T*np.log(1-0.99)/np.log(1-p0min)
                else:
                    TTSmax = np.nan
                    
                if p0max>0:
                    TTSmin = T*np.log(1-0.99)/np.log(1-p0max)
                else:
                    TTSmin = np.nan
    
                print(solvertype,T,N,":",p0,TTS,'|',p0max,TTSmin,'|',p0min,TTSmax)
                #np.min(E_opt),H0
                
                p0l.append(p0)
                TTSl.append(TTS)
                p0maxl.append(p0max)
                TTSminl.append(TTSmin)
                p0minl.append(p0min)
                TTSmaxl.append(TTSmax)
                
            np.savetxt(savefile, [p0l,TTSl,p0maxl,TTSminl,p0minl,TTSmaxl])
    
            res.append([p0l,TTSl,p0maxl,TTSminl,p0minl,TTSmaxl])
                
        ares.append(res)
        
    aares.append(ares)
    