import CAC_Ising

import DASTuneADAM as DASTuner

import lib

import numpy as np
import time
    
import os

import generate_dWPE

import matplotlib.pyplot as plt

datapath = './../../Data'
datapath = './../Data_GSET/GSET'

def tune_gset(folder_name,instance,hyperparams,PARAM_NAMES,x,flags,tunerparams):

    global count, fit_est
    
    N = instance['N']
    i0 = instance['id']
    ni = {800:21, 1000:4, 2000:21}
    n = ni[N]
    
    T = hyperparams['T']
    fMH = hyperparams['fMH']
    
    savetraj = flags['savetraj']
    pt_device = flags['pt_device']
    solvertype = flags['solvertype']
    
    nsamp_max = tunerparams['nsamp_max']
    R = tunerparams['R']
    
    x_init = np.array(x)
    L_init = np.diag(np.ones(len(x))*0.5)

    #generate problem and initialize solver

    def LoadInstance(i):
        i+=1
        path = datapath + f'/GSET_{N}/GSET_{N}_{n}_{i}'

        file = open ( path , 'r')
        wmat = np.array([[float(num) for num in line.split(' ')] for line in file ])
        
        w = np.zeros((N,N))
        for l in range(len(wmat[:,0])):
            w[int(wmat[l,0]-1),int(wmat[l,1]-1)] = wmat[l,2]

        w = w + w.T

        return w
        
    def LoadOptimalC(N,i,w):
        n = ni[N]
        path = datapath + f'/GSET_{N}/GSET_{N}_{n}_SOL'
        file = open ( path , 'r')
        Cv = np.array([[float(num) for num in line.split(' ')] for line in file ])
        C0 = float(Cv[i])
        H0 = -4*C0 - np.sum(w)
        H0 = H0/2
        
        return H0


    def gen_problem():
        
        J = LoadInstance(i0)
        H0 = LoadOptimalC(N,i0,J)
            
        eps0 = np.mean(np.abs(J))
        
        prec = 1
        H0 = np.floor(H0/prec)
            
        #setup solver

        solver = CAC_Ising.CAC(pt_device, N, J=J, H0=H0, solvertype=solvertype, precGS = prec)
        solver.eps = eps0

        return solver, H0

    #tuner parameters

    #fitness_beta: -1 = use success rate, 0 = use energy,  >0  =  use exp(-beta*E)
    fitness_beta = -1
    #fitness_beta = 0


    def log(*args):
        if(debug):
            print(*args)



    #SAM: note: a lot of code is not necessary here

    T_index = -1

    for idx, param in enumerate(PARAM_NAMES):
        if(param == "T"):
            T_index = idx


    count = 0
    debug = True

    def sample(x, seed, fitness_beta=-1):
        global count, fit_est
        R = x.shape[1]
        D = x.shape[0]

        solver, E0 = gen_problem()

        #setup solver


        T_vec = None
        if(T_index >= 0):
            T_vec = x[T_index, :]


        for idx, param_name in enumerate(PARAM_NAMES):
            if(T_index == idx):
                #print("T list", np.exp(x[idx, :]))
                setattr(solver, param_name, T_vec)
            else:
                #setattr(solver, param_name, x[idx, :])
                setattr(solver, param_name, np.exp(x[idx, :]))


        #print("here", coupler.lamb_fb)
        # solver.J = J
    # 	solver.H0 = E0
        solver.init(R,PARAM_NAMES,hyperparams)
        #solver.seed(seed)



        #solve
        if(count % 20 == 0):
            log("solving...")

        tstart = time.time()
        Ps, E_opt = solver.traj(E0)

        if(T_vec is None):
            T_vec = np.ones(R)



        #print("Es", E_opt)
        #print(E0)
        print(Ps, np.average((E_opt - E0)/(-E0)), time.time() - tstart)


        out = None
        if(fitness_beta == -1):
            out = (E_opt <= E0)/T_vec
        elif(fitness_beta == 0):
            #out = (E_opt - E0)/(E0*T_vec)
            out = E_opt==E0
            #out = Ps
        else:
            #out = np.exp( -fitness_beta*(E_opt - E0))/T_vec
            out = E_opt<=fitness_beta*E0



        #estimate good time step

        #log(Ps, time.time() - tstart)
        if(count % 1 == 0):
            #log(Ps, tuner.fit_est, beta)
            #log(E0)
            #log(E_opt)
            #log("fit", out)
            print(np.min(E_opt),E0)
            
            #log("f ", tuner.fit_est, "c ", tuner.curv_est, "g", np.linalg.norm(tuner.grad_est), "r1 ", tuner.curv_est/tuner.fit_est, "r2 ", (tuner.curv_est - np.linalg.norm(tuner.grad_est))/tuner.fit_est)

        count += 1
        return out




    D = len(PARAM_NAMES)


    #####################################################





    #use DAS tuner

    fit_est_beta = 0.01

    tuner = DASTuner.Sampler(sample, D, R)
    tuner.fit_est_beta = fit_est_beta
    tuner.curv_est_beta = fit_est_beta
    tuner.grad_est_beta = fit_est_beta/D

    if(x_init is None):
        x_init = np.zeros(D)

    if(L_init is None):
        L_init = np.diag(np.ones(D))

    tuner.init_window(x_init, L_init)

    #tuner.fit_est = f_est
    # dt0 = 
    # tuner.dt_log = np.log(dt0/f_est)
    tuner.dt0 = 2.0

    tot_samp_rec, x_rec, L_rec = tuner.optimize(tot_samp_max = nsamp_max, R_end = 10.0)
    
    param_out = x_rec[len(x_rec)-1]




    #####################################################




    log("opt found ", param_out)
    log("evaluating...")
    R_eval = 2000
    count = 0
    f_eval = 0
    N_inst = 1
    evalist = []
    for i in range(N_inst):

        eva = np.average(sample(np.outer(param_out, np.ones(R_eval)),range(R_eval),fitness_beta=0))
        f_eval += eva
        evalist.append(eva)
    f_eval = f_eval/N_inst	

    print("f_eval", f_eval)
    log("f_eval", f_eval)
    log("L", tuner.L)

    info = {}
    info["L"] = tuner.L
    info["curv_est"] = tuner.curv_est


    #return param_out, f_eval, tuner.fit_est, info


    #####################################################

    file_name = folder_name + f"/gset_{N}_{i0}_{T}_{fMH}.txt"
    with open(file_name, 'w') as f:
        f.write(' '.join(map(str, [f_eval])) + '\n')
        f.write(' '.join(map(str, evalist)) + '\n')
        f.write(' '.join(map(str, np.exp(param_out))) + '\n')

    #####################################################

    
    def save_to_file_and_plot(folder_name, plot_file_name, PARAM_NAMES, x_rec):
        # Construct the file name for data

        # Construct the file name for the plot
        plot_file_path = os.path.join(folder_name, plot_file_name)

        # Plot the figure
        plt.figure()
        for idx, PARAM in enumerate(PARAM_NAMES):
            plt.plot(np.exp(x_rec)[:,idx],label=PARAM)
        plt.xlabel('steps')
        plt.ylabel('parameters')
        
        plt.legend()
        
        plt.yscale('log')
        
        ax = plt.gca()
        ax.spines["top"].set_visible(True)
        ax.spines["right"].set_visible(True)
        ax.grid(True)

        # Save the figure to a file
        plt.savefig(plot_file_path)

        # Close the figure
        plt.close()

        print(f"Figure saved to file: {plot_file_path}")
        
    
    file_name = f"gset_{N}_{i0}_{T}_{fMH}.png"
    save_to_file_and_plot(folder_name, file_name, PARAM_NAMES, x_rec)
    