
import PT_Ising

import DASTuneADAM as DASTuner

import lib

import generate_dWPE

import matplotlib.pyplot as plt

from multiprocessing import Pool

import numpy as np

from tqdm import tqdm
import time

import itertools
import lib

import os



def solve_for_k(k, x, solver, E0, T):
    
    energies, solutions = solver.run_pt(np.exp(x[0,k]),np.exp(x[1,k]),np.exp(x[2,k]), T, 1)
    
    # Check if the solution is optimal
    isSolved = energies == E0
    
    # Return the result for this value of k
    return isSolved, energies[0]


def tune_wishart(folder_name,instance,hyperparams,PARAM_NAMES,flags,tunerparams,data):

    global count, fit_est
    
    N = instance['N']
    alphatxt = instance['alphatxt']
    
    savetraj = flags['savetraj']
    pt_device = flags['pt_device']
    solvertype = flags['solvertype']
    
    nsamp_max = tunerparams['nsamp_max']
    R = tunerparams['R']
    
    T = hyperparams['T']
    
    #x_init = np.array(x)
    #L_init = np.diag(np.ones(len(x))*0.5)

    #generate problem and initialize solver

    def gen_problem():
        
        if data['datatype'] == 'load':
            #load instance
            i = int(10* np.random.rand())
            J, eps0, H0 = lib.load_wishart(N,alphatxt,i)
            prec = 1
            
        else:
            
            i = int(100000* np.random.rand())
            alpha = float(alphatxt)
            M = int(N*alpha)
            
            if data['datatype']=='unbias':
                J, H0, gs = generate_dWPE.gen_dWPE(i, N, M, data['D_WPE'], data['R_WPE'])
            elif data['datatype']=='bias':
                J, H0, gs = generate_dWPE.gen_dWPE_cluster(i, N, M, data['D1_WPE'], data['R1_WPE'], data['D2_WPE'], data['R2_WPE'], bias =  data['bias'])
                J_bias = np.ones((N,N))
                J_bias = J_bias - np.diag(np.diag(J_bias))
                J = J + (0)*J_bias*10**(-4)

            eps0 = np.mean(np.abs(J))
            
            #prec = 10**(-6) #precision for GS energy
            prec = 1 #precision for GS energy
            H0 = np.floor(H0/prec)
            
            #not used here
            gs = (np.array(gs).T)
            gs = gs/np.expand_dims(gs[:,0],1)
            gs = gs.tolist()
            
        #setup solver

        solver = PT_Ising.PT(N, J, H0, prec)

        return solver, H0


    #init to SA default
    solver, E0 = gen_problem()
    max_temp, min_temp = solver.get_default_temp()
    splt = 0.25
    
    
    x_init = np.log([min_temp,max_temp,splt])
    #L_init = np.diag(np.abs(x_init))
    L_init = np.diag(np.ones(len(x_init))*0.5)

    #tuner parameters

    #fitness_beta: -1 = use success rate, 0 = use energy,  >0  =  use exp(-beta*E)
    fitness_beta = -1
    #fitness_beta = 0


    def log(*args):
        if(debug):
            print(*args)



    #SAM: note: a lot of code is not necessary here

    T_index = -1

    for idx, param in enumerate(PARAM_NAMES):
        if(param == "T"):
            T_index = idx


    count = 0
    debug = True

    def sample(x, seed, fitness_beta=-1):
        global count, fit_est
        R = x.shape[1]
        D = x.shape[0]

        solver, E0 = gen_problem()

        #setup solver


        T_vec = None
        if(T_index >= 0):
            T_vec = x[T_index, :]


        for idx, param_name in enumerate(PARAM_NAMES):
            if(T_index == idx):
                #print("T list", np.exp(x[idx, :]))
                setattr(solver, param_name, T_vec)
            else:
                #setattr(solver, param_name, x[idx, :])
                setattr(solver, param_name, np.exp(x[idx, :]))


        #print("here", coupler.lamb_fb)
        # solver.J = J
    # 	solver.H0 = E0
        #solver.init(R,PARAM_NAMES,hyperparams)
        #solver.seed(seed)



        #solve
        if(count % 20 == 0):
            log("solving...")

        tstart = time.time()
        
        #Ps, E_opt = solver.traj(E0)
        if True:
            Ps = []
            E_opt = []
            for k in range(len(x[0,:])):
                energies, solutions = solver.run_pt(np.exp(x[0,k]),np.exp(x[1,k]),np.exp(x[2,k]), T, 1)
                isSolved = energies==E0
                Ps = np.mean(isSolved[0])
                E_opt.append(energies[0])
        else:
            
            num_iterations = len(x[0, :])
    
            # Create a Pool of workers (optional: specify num_workers for a fixed number of processes)
            with Pool(processes=8) as pool:
                # Run solve_for_k in parallel for each value of k
                results = pool.starmap(solve_for_k, [(k, x, solver, E0, T) for k in range(num_iterations)])
            
            # Extract results
            Ps = [result[0] for result in results]
            E_opt = [result[1] for result in results]
        
        if(T_vec is None):
            T_vec = np.ones(R)



        #print("Es", E_opt)
        #print(E0)
        print(Ps, np.average((E_opt - E0)/(-E0)), time.time() - tstart)


        out = None
        if(fitness_beta == -1):
            out = (E_opt <= E0)/T_vec
        elif(fitness_beta == 0):
            #out = (E_opt - E0)/(E0*T_vec)
            out = E_opt==E0
            #out = Ps
        else:
            #out = np.exp( -fitness_beta*(E_opt - E0))/T_vec
            out = E_opt<=fitness_beta*E0



        #estimate good time step

        #log(Ps, time.time() - tstart)
        if(count % 20 == 0):
            #log(Ps, tuner.fit_est, beta)
            log(E0)
            #log(E_opt)
            #log("fit", out)
            log("f ", tuner.fit_est, "c ", tuner.curv_est, "g", np.linalg.norm(tuner.grad_est), "r1 ", tuner.curv_est/tuner.fit_est, "r2 ", (tuner.curv_est - np.linalg.norm(tuner.grad_est))/tuner.fit_est)

        count += 1
        return out




    D = len(PARAM_NAMES)


    #####################################################





    #use DAS tuner

    fit_est_beta = 0.01

    tuner = DASTuner.Sampler(sample, D, R)
    tuner.fit_est_beta = fit_est_beta
    tuner.curv_est_beta = fit_est_beta
    tuner.grad_est_beta = fit_est_beta/D

    if(x_init is None):
        x_init = np.zeros(D)

    if(L_init is None):
        L_init = np.diag(np.ones(D))

    tuner.init_window(x_init, L_init)

    #tuner.fit_est = f_est
    # dt0 = 
    # tuner.dt_log = np.log(dt0/f_est)
    tuner.dt0 = 0.5

    tot_samp_rec, x_rec, L_rec = tuner.optimize(tot_samp_max = nsamp_max, R_end = 10.0)
    
    param_out = x_rec[len(x_rec)-1]




    #####################################################




    log("opt found ", param_out)
    log("evaluating...")
    R_eval = 400
    count = 0
    f_eval = 0
    N_inst = 15
    evalist = []
    for i in range(N_inst):

        eva = np.average(sample(np.outer(param_out, np.ones(R_eval)),range(R_eval),fitness_beta=0))
        f_eval += eva
        evalist.append(eva)
    f_eval = f_eval/N_inst	

    print("f_eval", f_eval)
    log("f_eval", f_eval)
    log("L", tuner.L)

    info = {}
    info["L"] = tuner.L
    info["curv_est"] = tuner.curv_est


    #return param_out, f_eval, tuner.fit_est, info


    #####################################################

    bias = data['bias']
    if 'fMH' in hyperparams:
        fMH = hyperparams['fMH']
    else:
        fMH = 0.0
    T = hyperparams['T']
    file_name = f"wishart_{N}_{alphatxt}_{bias}_{T}_{fMH}.txt"
    lib.save_to_file(folder_name, file_name, f_eval, evalist, param_out)
    
    
    #####################################################

    
    def save_to_file_and_plot(folder_name, plot_file_name, PARAM_NAMES, x_rec):
        # Construct the file name for data

        # Construct the file name for the plot
        plot_file_path = os.path.join(folder_name, plot_file_name)

        # Plot the figure
        plt.figure()
        for idx, PARAM in enumerate(PARAM_NAMES):
            plt.plot(np.exp(x_rec)[:,idx],label=PARAM)
        plt.xlabel('steps')
        plt.ylabel('parameters')
        
        plt.legend()
        
        plt.yscale('log')
        
        ax = plt.gca()
        ax.spines["top"].set_visible(True)
        ax.spines["right"].set_visible(True)
        ax.grid(True)

        # Save the figure to a file
        plt.savefig(plot_file_path)

        # Close the figure
        plt.close()

        print(f"Figure saved to file: {plot_file_path}")
        
    
    plot_file_name = f"wishart_{N}_{alphatxt}_{bias}_{T}_{fMH}.png"
    save_to_file_and_plot(folder_name, plot_file_name, PARAM_NAMES, x_rec)
    


if __name__ == "__main__":
    
    
        
    
    biasl = [0.0,6.0,8.0,10.0,12.0]
    Nl = [60,100,120,140,160]
    T = [50,100,300,500,1000,2000,3000,4000]
    
    
    pt_device = 'cuda'
    #pt_device = 'cpu'
    
    alphatxt = '0.80'
    
    fMHv = {}
    fMHv['PT'] = 0.0
    
    
    ####################################################
    #SOLVER
    
    solvertypel = ['PT']
    
    debug = 0
    
    
    
    for bias in biasl:
    
    
        data = {}
        if bias==0:
            data['datatype'] = 'unbias'
        else:
            data['datatype'] = 'bias'
    
        if data['datatype'] == 'unbias':
            data['D_WPE'] = 1 # 3
            data['R_WPE'] = -1 # 6
            
            data['bias'] = 0.0
            
        elif data['datatype'] == 'bias':
            #first GS is close to ferromagnetic
            data['D1_WPE'] = 1
            data['R1_WPE'] = 3
            
            #second GS is random
            data['D2_WPE'] = 1
            data['R2_WPE'] = -1
            
            data['bias'] = bias
    
    
            data['D_WPE'] = data['D1_WPE'] + data['D2_WPE']
    
            bias = data['bias']
    
        ####################################################
    
        total = len(Nl) * len(T)
    
        afolder_name = []
    
        for solvertype in solvertypel:
    
            fMH = fMHv[solvertype]
    
            #folder_name = lib.create_timestamped_folder(solvertype)
            folder_name = f"{solvertype}_{bias}_{fMH}"
            
            # Check if the folder already exists
            if not os.path.exists(folder_name):
                # Create the folder if it doesn't exist
                os.makedirs(folder_name)
                print(f"Folder created: {folder_name}")
            else:
                print(f"Folder already exists: {folder_name}")
                
            
            # Use tqdm with itertools.product to iterate over all combinations
            for combination in tqdm(itertools.product(Nl, np.flip(T)), desc="Iterating combinations"):
                
                if solvertype=='PT':
                    PARAM_NAMES = ["max_temp","min_temp","splt"]
                    hyperparams = {'T': combination[1]}
                        
                #generate problem instance
                instance = {'alphatxt':alphatxt, 'N': combination[0]}
                flags = {'savetraj':1, 'pt_device': pt_device,'solvertype':solvertype}
                #tunerparams = {'nsamp_max': 5000, 'R': 200}
                tunerparams = {'nsamp_max': 100000, 'R': 200}
                #x = [0.1,1.0,0.1,1.0]
                
                if solvertype=='PT':
                    tune_wishart(folder_name,instance,hyperparams,PARAM_NAMES,flags,tunerparams,data)
                
                if 0:
                    if combination[1]==np.min(hyperparams['T']):
                        file_name = f"wishart_{combination[0]}_{alphatxt}_{np.max(hyperparams['T'])}.txt"
                    else:
                        file_name = f"wishart_{combination[0]}_{alphatxt}_{combination[1]}.txt"
                        
                    p0, opt_params, pvec = lib.read_file(folder_name, file_name)
                    x = np.log(opt_params)
                    
            afolder_name.append(folder_name)
    
    
        ####################################################
