import CAC_Ising

import lib

import numpy as np
import time

from tqdm import tqdm

import generate_dWPE

import itertools
import os

import sys
from pathlib import Path

#parameters
Nl = [60,100,120,140,160]
Tl = [50,100,300,500,1000,2000,3000,4000]

R = 20000

rep = 3

solvertypel = ['CACm','MHCACm','AIM','CAC']

biasl = [0.0,6.0,8.0,10.0,12.0]

biasshow = 10 #10
Nshow = 100 #for (c) #100

overwrite = 1

fMHv = {}
fMHv['CACm'] = 0.0
fMHv['AIM'] = 0.0
fMHv['CAC'] = 0.0
fMHv['MHCACm'] = 0.1
fMHv['MHCAC'] = 0.1

#generate problem instance
alphatxt = '0.80'

do_RUN = 1

# Construct the folder name
figfolder = "fig_run_allGS"

# Check if the folder already exists
if not os.path.exists(figfolder):
    # Create the folder if it doesn't exist
    os.makedirs(figfolder)
    print(f"Folder created: {figfolder}")
else:
    print(f"Folder already exists: {figfolder}")


#############################################################

aares = []

for bias in biasl:

    if bias == 0:
        D_WPE = 1 # 3
        R_WPE = -1 # 6
    else:        
        #first GS is close to ferromagnetic
        D1_WPE = 1
        R1_WPE = 3
        
        #second GS is random
        D2_WPE = 1
        R2_WPE = -1
        
        bias = bias
    
        D_WPE = D1_WPE + D2_WPE
    

    
    ares = []
    
    for solvertype in solvertypel:
    
        #folder_name = lib.create_timestamped_folder(solvertype)
        
        res = []
        
        # Use tqdm with itertools.product to iterate over all combinations
        for combination in tqdm(itertools.product(Nl, Tl), desc="Iterating combinations"):
                
            N = combination[0]
            T = combination[1]
            
            #configure
            fMH = fMHv[solvertype]
            if solvertype=='MHCACm':
                PARAM_NAMES = ["beta","kappa","lamb","xi","gamma","a"]
                hyperparams = {'T': T,'doa':1,'dosampling':0,'fMH':fMH}
                fMH = hyperparams['fMH']
               
            if solvertype=='MHCAC':
                PARAM_NAMES = ["beta","kappa","lamb","xi","a"]
                hyperparams = {'T': T,'doa':0,'dosampling':0,'fMH':fMH}
                fMH = hyperparams['fMH']
               
            if solvertype=='CACm':
                PARAM_NAMES = ["beta","lamb","xi","gamma","a"]
                hyperparams = {'T': T}
            
            if solvertype=='CAC':
                PARAM_NAMES = ["beta","lamb","xi","a"]
                hyperparams = {'T': T}
             
            if solvertype=='AIM':
                PARAM_NAMES = ["beta","lamb","gamma"]
                hyperparams = {'T': T}
           
      
            #load optimal parameters
            folder_name = f"{solvertype}_{bias}_{fMH}"
            
            file_name = f"wishart_{N}_{alphatxt}_{bias}_{T}_{fMH}.txt"
            
            if Path(folder_name+'/'+file_name).exists():
                p0, opt_params, pvec = lib.read_file(folder_name, file_name) 
                
            else:
                print(f"{file_name} is missing!")
                void=(np.ones(rep)*np.nan).tolist()
                res.append([void,void,void,void,void,void])
                continue
 
            
            x = np.log(opt_params)
            x = np.expand_dims(x,1)
            x = np.tile(x,[1,R])
            
            #pt_device = "cpu"
            pt_device = "cuda"
            
            p0l = []
            TTSl = []
            p0minl = []
            TTSmaxl = []
            p0maxl = []
            TTSminl = []
            
            savefile = figfolder + "/%s_N_%d_T_%d_alpha_%0.2f_bias_%d_fmH_%0.2f_info.txt" % (solvertype,N,T, float(alphatxt),bias,fMH)
            if ~overwrite: #continue if exists already
                if Path(savefile).exists(): 
                    print(f"{savefile} already exists!")
                    void=(np.ones(rep)*np.nan).tolist()
                    res.append([void,void,void,void,void,void])
                    continue
            
            for k in range(rep):
                
                if 0:
                    i = 0
                    J, eps0, H0 = lib.load_wishart(N,alphatxt,i)
                    prec = 1
                    
                else:
                    i = int(1000* np.random.rand())*rep
                    
                    alpha = 0.8
                    M = int(N*alpha)
                    
                    if bias==0:
                        J, H0, gs = generate_dWPE.gen_dWPE(i, N, M, D_WPE, R_WPE)
                    else:
                        J, H0, gs = generate_dWPE.gen_dWPE_cluster(i, N, M, D1_WPE, R1_WPE, D2_WPE, R2_WPE, bias =  bias)
                        J_bias = np.ones((N,N))
                        J_bias = J_bias - np.diag(np.diag(J_bias))
                        J = J + (0)*J_bias*10**(-4)
                        
                    #obfuscate the GS with Gauge transform
                    if 1:
                        x0 = np.sign(np.random.uniform(-1,1,N))
                        J = J * (np.expand_dims(x0,0) * np.expand_dims(x0,1))
                        gs = np.array(gs) * np.expand_dims(x0,1)
                        
                    eps0 = np.mean(np.abs(J))
                    
                    prec = 10**(-6) #precision for GS energy
                    H0 = np.floor(H0/prec)
                    
                    gs = (np.array(gs).T)
                    gs = gs/np.expand_dims(gs[:,0],1)
                    gs = gs.tolist()
                
                #run
                solver = CAC_Ising.CAC(pt_device, N, J=J, H0=H0, solvertype = solvertype, precGS = prec)
                solver.eps = eps0
                
                for idx, param_name in enumerate(PARAM_NAMES):
                    setattr(solver, param_name, np.exp(x[idx, :]))
                    
                solver.init(R,PARAM_NAMES,hyperparams)
                
                # #solve
                #print("solving")
                
                tstart = time.time()
                Ps, E_opt = solver.traj(H0)
                p0 = np.mean(E_opt==H0)
                TTS = np.log(1-0.99)/np.log(1-p0)*T
                
                # count recurrence of GS
                count_rec = lib.count_vector_occurrences(solver.allGS,gs)
                print('---------------',count_rec)
                
                if np.abs(np.sum(count_rec)/R-p0)>0:
                    sys.exit("Error: Something went wrong")
  
                if len(count_rec)==2: #due to the way we are planting
                    p0min = count_rec[1]/R
                    #p0min = np.min(count_rec)/R
                    p0max = count_rec[0]/R
                    #p0max = np.max(count_rec)/R 
                elif len(count_rec)==1:
                    p0min = count_rec[0]/R
                    p0max = count_rec[0]/R
                elif len(count_rec)>2:
                    p0min = np.min(count_rec)/R
                    p0max = np.max(count_rec)/R 
                    
                TTSmax = T*np.log(1-0.99)/np.log(1-p0min)
                TTSmin = T*np.log(1-0.99)/np.log(1-p0max)
                
                print(solvertype,T,N,":",p0,TTS,'|',p0max,TTSmin,'|',p0min,TTSmax)
                #np.min(E_opt),H0
                
                p0l.append(p0)
                TTSl.append(TTS)
                p0maxl.append(p0max)
                TTSminl.append(TTSmin)
                p0minl.append(p0min)
                TTSmaxl.append(TTSmax)
                
            np.savetxt(savefile, [p0l,TTSl,p0maxl,TTSminl,p0minl,TTSmaxl])
    
            res.append([p0l,TTSl,p0maxl,TTSminl,p0minl,TTSmaxl])
                
        ares.append(res)
        
    aares.append(ares)
    