import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import inst_utils
import model
import instance_setup
import os
import sys
from datetime import datetime
import shutil


device = "cpu"

T_in = 4
D_int = 3
T_D = 3

SEED = 22

use_DAS = True

model_id = 0

group_name = "SN_1T_N_100"

out_name = "eval1"

snap_name = "snap_used"

if(len(sys.argv) > 1):
    T_in = int(sys.argv[1])

if(len(sys.argv) > 2):
    D_int = int(sys.argv[2])

if(len(sys.argv) > 3):
    T_D = int(sys.argv[3])

if(len(sys.argv) > 4):
    SEED = int(sys.argv[4])

if(len(sys.argv) > 5):
    group_name_in = sys.argv[5]

if(len(sys.argv) > 6):
    model_id = int(sys.argv[6])

if(len(sys.argv) > 7):
    group_name_test = sys.argv[7]

if(len(sys.argv) > 8):
    out_name = sys.argv[8]

if(model_id == -1):
    T_in = 4
    D_int = 1

if(model_id == -2):
    T_in = 5
    D_int = 1

if(model_id == 1):
    D_int = 1

instance_setup.set_instance_group(group_name_test)

model.model_id = model_id

hyper_params = {
    "T_in": T_in,
    "D_int": D_int,
    "T_D": T_D,
    "is_MIS": False,
    "fix_aux": False
}


now = datetime.now()
run_time_str = now.strftime("%Y-%m-%d %H:%M:%S")

inpath = f"out/m_{model_id}/{group_name_in}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"

outpath = f"eval/m_{model_id}_{group_name_test}/{out_name}"

if(not os.path.exists(outpath)):
    os.makedirs(outpath)

p_checkpoint_path = f"boot/m_{model_id}_T_{T_in}_{T_D}_Di_{D_int}"

P = model.get_P(hyper_params)


P_total = P*T_D

def calc_reward(E_opt, GE, rho):
    #return (E_opt/GE)**5
    return (E_opt <= GE + 0)*1.0 + (E_opt >= GE*0.5)*(-0.0)


R_base = 10*1000

E_hist = None

def get_rew(p_x, p_L, nonoise = True):
    global E_hist
    mom2 = 0
    


        
        
    inst_config_idx = 0
    #print(inst_config_idx)
    J, norm, inst_seed, GE_base, N, T, rho, MIS, meta, clss = instance_setup.get_config(inst_config_idx)
    
    instance_path = f"./instance/{clss}/N_{N}/"
    GE = min(inst_utils.get_GE(instance_path, N, inst_seed), GE_base)


    #number of replicas normalized to keep computation time constant
    R = R_base//T

    #print(p_L.shape)
    noise = torch.randn(P_total,R)/(P_total)**0.5
    if(nonoise):
        noise = noise*0
    
    p_pert = p_x.reshape(-1,1) + torch.matmul(p_L, noise)

    #print(p_pert[0,:5])
    
    
    t_modes = model.get_t_modes(T, T_D)
    hyper_params["is_MIS"] = MIS
    print(hyper_params["is_MIS"])
    if(MIS):
        hyper_params["aux_spin"] = torch.tensor(instance_setup.get_current_MIS_size())
        hyper_params["fix_aux"] = True
    
    hyper_params["is_MC"] = False
    if(clss.startswith("BA") or clss.startswith("tBA")):
        hyper_params["is_MC"] = True
    
    #norm = 0.02
    hyper_params["norm"] = norm
    #print(N, norm)

    J = J.reshape(1, N, N)

    E_opt, E_hist, traj_rec = model.run_alg(p_pert.detach(), t_modes, J, N, R, T, hyper_params, device = device, rec_traj = True, B = 1)
    
    

    GE = min(torch.min(E_opt).cpu().numpy(), GE)

    E_hist = E_hist - GE

    
    
    #print(E_opt, GE)
    
    #save best found GE
    instance_path = f"./instance/{clss}/N_{N}/"
    
    inst_utils.save_GE(instance_path, N, inst_seed, GE, "min")
    
    
    #convert objective value to reward
    reward = calc_reward(E_opt.cpu().numpy(), GE, rho)

    print(np.mean(reward))

    

    #print(GE, E_opt, reward)

    #print(p_pert.shape)

    instance_setup.rew_inst[inst_config_idx] = np.mean(reward)
    instance_setup.obj_inst[inst_config_idx] = model.current_obj
    instance_setup.obj_max_inst[inst_config_idx] = model.current_obj_max

    
    
    print("mom2", mom2)
    

instance_setup.preprocess_inst()


def get_p(epoch):

    
    p_x = torch.tensor(np.loadtxt(inpath + "/snap/" + snap_name + f"/p_txt/snap{epoch}.txt"), dtype = torch.float32)
    p_L = torch.tensor(np.loadtxt(inpath + "/p_L.txt"), dtype = torch.float32)

    return p_x, p_L



plt.figure(figsize = (5,3.5))

epoch_list = [19, 99]

color_list = ["blue", "red"]

for epoch, color in zip(epoch_list, color_list):
    p_x, p_L = get_p(epoch)

    print("eval without noise...")
    get_rew(p_x, p_L, nonoise=True)
    print("mean", np.mean(instance_setup.rew_inst))
    print("median", np.median(instance_setup.rew_inst))
    print("mean obj", np.mean(instance_setup.obj_inst))
    print("mean obj max", np.mean(instance_setup.obj_max_inst))

    print("std obj", np.std(instance_setup.obj_inst))
    print("ste obj", np.std(instance_setup.obj_inst)/instance_setup.numb_inst_config**0.5)

    plt.xlabel("t")
    plt.ylabel("E")
    plt.yscale("symlog")
    for i in range(min(E_hist.shape[1], 40)):
        alpha = 0.05
        if(np.min(E_hist[:,i].numpy()) <= 0):
            alpha = 0.5
        
        plt.plot(range(E_hist.shape[0]), E_hist[:,i], color = color, alpha = alpha)



info = np.loadtxt(inpath + "/info.txt")

plt.tight_layout()
plt.show()
plt.close()

plt.figure(figsize = (5,3.5))


plt.xlabel("training epoch")
plt.ylabel("reward")



plt.plot(info[:,0], info[:,1])

plt.tight_layout()
plt.show()
plt.close()




