import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import inst_utils
import model
import instance_setup
import os
import sys
from datetime import datetime
import shutil
import device_config


device = device_config.device

T_in = 4
D_int = 3
T_D = 3

SEED = 22

use_DAS = True

model_id = 0

group_name = "SN_1T_N_100"

out_name = "eval1"

if(len(sys.argv) > 1):
    T_in = int(sys.argv[1])

if(len(sys.argv) > 2):
    D_int = int(sys.argv[2])

if(len(sys.argv) > 3):
    T_D = int(sys.argv[3])

if(len(sys.argv) > 4):
    SEED = int(sys.argv[4])

if(len(sys.argv) > 5):
    group_name_in = sys.argv[5]

if(len(sys.argv) > 6):
    model_id = int(sys.argv[6])

if(len(sys.argv) > 7):
    group_name_test = sys.argv[7]

if(len(sys.argv) > 8):
    out_name = sys.argv[8]

if(model_id == -1):
    T_in = 4
    D_int = 1

if(model_id == -2):
    T_in = 5
    D_int = 1

if(model_id == 1):
    D_int = 1

instance_setup.set_instance_group(group_name_test)

model.model_id = model_id

hyper_params = {
    "T_in": T_in,
    "D_int": D_int,
    "T_D": T_D,
    "is_MIS": False,
    "fix_aux": False
}


now = datetime.now()
run_time_str = now.strftime("%Y-%m-%d %H:%M:%S")

inpath = f"out/m_{model_id}/{group_name_in}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"

outpath = f"eval/m_{model_id}_{group_name_test}/{out_name}"

if(not os.path.exists(outpath)):
    os.makedirs(outpath)

p_checkpoint_path = f"boot/m_{model_id}_T_{T_in}_{T_D}_Di_{D_int}"

P = model.get_P(hyper_params)


P_total = P*T_D

def calc_reward(E_opt, GE, rho):
    #return (E_opt/GE)**5
    return (E_opt <= GE + 0)*1.0 + (E_opt >= GE*0.5)*(-0.0)


R = 30

B = instance_setup.numb_inst_config

def get_rew(p_x, p_L, nonoise = True):
    
    mom2 = 0


    J_list = []
    norm_all = torch.zeros(B)
    GE_all = torch.zeros(B)
    inst_seed_all = []
    inst_config_idx_all = []

    aux_spin_all = torch.zeros(B, dtype = torch.int32)

    for b in range(B):
        
        inst_config_idx = b
        #print(inst_config_idx)
        J, norm, inst_seed, GE_base, N, T, rho, MIS, meta, clss = instance_setup.get_config(inst_config_idx)
        
        instance_path = f"./instance/{clss}/N_{N}/"
        GE = min(inst_utils.get_GE(instance_path, N, inst_seed), GE_base)
        
        J_list.append(J)
        norm_all[b] = norm
        GE_all[b] = GE
        inst_seed_all.append(inst_seed)
        inst_config_idx_all.append(inst_config_idx)
        aux_spin_all[b] = instance_setup.get_current_MIS_size()
        
        
    J_all = torch.zeros(B,N,N)
    for b in range(B):
        J_all[b,:,:] = J_list[b]

    #print(J_all.shape)

    #number of replicas normalized to keep computation time constant
    #R = R_base//T

    #print(p_L.shape)
    noise = torch.randn(P_total, R*B)/(P_total)**0.5
    if(nonoise):
        noise = noise*0
    
    p_pert = p_x.reshape(-1,1) + torch.matmul(p_L, noise)

    #print(p_pert.shape)
    #print(p_pert[0,:5])
    
    
    t_modes = model.get_t_modes(T, T_D)
    hyper_params["is_MIS"] = MIS
    #print(hyper_params["is_MIS"])
    if(MIS):
        hyper_params["aux_spin"] = aux_spin_all
        hyper_params["fix_aux"] = True
    
    hyper_params["is_MC"] = False
    if(clss.startswith("BA") or clss.startswith("tBA")):
        hyper_params["is_MC"] = True
    
    #norm = 0.02
    hyper_params["norm"] = norm

    
    #print(N, norm)
    E_opt, E_hist, traj_rec = model.run_alg(p_pert.reshape(P_total, B, R).detach(), t_modes, J_all, N, R, T, hyper_params, device = device, rec_traj = True, B = B)
    
    
    #print(torch.min(E_opt, axis = 1))

    GE_all = torch.minimum(torch.min(E_opt, axis = 1).values.cpu(), GE_all)

    E_hist = E_hist - GE_all[0]
    
    
   
    for GE, inst_seed in zip(GE_all, inst_seed_all):
        #save best found GE
        instance_path = f"./instance/{clss}/N_{N}/"
        
        inst_utils.save_GE(instance_path, N, inst_seed, GE, "min")
    
    
    
    
    #convert objective value to reward

    adv_all = torch.zeros(B, R)
    for b in range(B):
        inst_config_idx = inst_config_idx_all[b]

        reward = calc_reward(E_opt[b, :].cpu().numpy(), GE_all[b].cpu().numpy(), rho)

        adv = reward - np.mean(reward)
        adv_all[b, :] = torch.tensor(adv)
        #print(GE, E_opt, reward)

        #print(p_pert.shape)

        instance_setup.rew_inst[inst_config_idx] = np.mean(reward)
        #cumulative_rew_inst[inst_config_idx] += np.mean(reward)
        #print(model.current_obj_max)
        instance_setup.obj_inst[inst_config_idx] = model.current_obj[b].cpu().numpy()
        #instance_setup.obj_t10_inst[inst_config_idx] = model.current_obj_top_10
        instance_setup.obj_max_inst[inst_config_idx] = model.current_obj_max[b].cpu().numpy()
        print(np.mean(reward))
    

    print("mom2", mom2)
    

print("preprocess")
instance_setup.preprocess_inst()

print("done")

if(group_name_in == "FROM_BOOT"):
    v = -SEED
    boot_path = p_checkpoint_path + f"/v{v}"
    p_x = torch.tensor(np.loadtxt(boot_path + "/p.txt"), dtype = torch.float32)
    p_L = torch.tensor(np.loadtxt(boot_path + "/p_L.txt"), dtype = torch.float32)
else:
    p_x = torch.tensor(np.loadtxt(inpath + "/p.txt"), dtype = torch.float32)
    p_L = torch.tensor(np.loadtxt(inpath + "/p_L.txt"), dtype = torch.float32)


print("eval without noise...")
get_rew(p_x, p_L, nonoise=True)
print("mean", np.mean(instance_setup.rew_inst))
print("median", np.median(instance_setup.rew_inst))
print("mean obj", np.mean(instance_setup.obj_inst))
print("mean obj max", np.mean(instance_setup.obj_max_inst))

print("std obj", np.std(instance_setup.obj_inst))
print("ste obj", np.std(instance_setup.obj_inst)/instance_setup.numb_inst_config**0.5)



np.savetxt(outpath + "/rew_centered.txt", instance_setup.rew_inst)
np.savetxt(outpath + "/obj_centered.txt", instance_setup.obj_inst)
np.savetxt(outpath + "/obj_max_centered.txt", instance_setup.obj_max_inst)



print("eval with noise...")
get_rew(p_x, p_L, nonoise=True)
print("mean", np.mean(instance_setup.rew_inst))
print("median", np.median(instance_setup.rew_inst))
print("mean obj", np.mean(instance_setup.obj_inst))
print("mean obj max", np.mean(instance_setup.obj_max_inst))

print("std obj", np.std(instance_setup.obj_inst))
print("ste obj", np.std(instance_setup.obj_inst)/instance_setup.numb_inst_config**0.5)


np.savetxt(outpath + "/rew_noise.txt", instance_setup.rew_inst)
np.savetxt(outpath + "/obj_noise.txt", instance_setup.obj_inst)
np.savetxt(outpath + "/obj_max_noise.txt", instance_setup.obj_max_inst)
