import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import inst_utils
import model
import instance_setup
import os
import sys
from datetime import datetime
import shutil


device = "cpu"

T_in = 4
D_int = 3
T_D = 3

SEED = 22

use_DAS = True

model_id = 0

group_name = "SN_1T_N_100"

out_name = "landscape1"

if(len(sys.argv) > 1):
    T_in = int(sys.argv[1])

if(len(sys.argv) > 2):
    D_int = int(sys.argv[2])

if(len(sys.argv) > 3):
    T_D = int(sys.argv[3])

if(len(sys.argv) > 4):
    SEED = int(sys.argv[4])

if(len(sys.argv) > 5):
    group_name_in = sys.argv[5]

if(len(sys.argv) > 6):
    model_id = int(sys.argv[6])

if(len(sys.argv) > 7):
    group_name_test = sys.argv[7]

if(len(sys.argv) > 8):
    out_name = sys.argv[8]

if(model_id == -1):
    T_in = 4
    D_int = 1

if(model_id == -2):
    T_in = 5
    D_int = 1

if(model_id == 1):
    D_int = 1

instance_setup.set_instance_group(group_name_test)

model.model_id = model_id

hyper_params = {
    "T_in": T_in,
    "D_int": D_int,
    "T_D": T_D,
    "is_MIS": False,
    "fix_aux": False
}


now = datetime.now()
run_time_str = now.strftime("%Y-%m-%d %H:%M:%S")

inpath = f"out/m_{model_id}/{group_name_in}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"

outpath = f"eval/landscape/m_{model_id}_{group_name_test}/{out_name}"

if(not os.path.exists(outpath)):
    os.makedirs(outpath)

p_checkpoint_path = f"boot/m_{model_id}_T_{T_in}_{T_D}_Di_{D_int}"

P = model.get_P(hyper_params)


P_total = P*T_D

def calc_reward(E_opt, GE, rho):
    #return (E_opt/GE)**5
    return (E_opt <= GE + 0)*1.0 + (E_opt >= GE*0.5)*(-0.0)


R_base = 200*1000
R = 0

def get_rew(p_x, p_L, p_pert_dir, inst_config_idx, nonoise = True, dim = 1, p_pert_dir2 = None):
    global R
    mom2 = 0
    


    

    #print(inst_config_idx)
    J, norm, inst_seed, GE_base, N, T, rho, MIS, meta, clss = instance_setup.get_config(inst_config_idx)
    
    instance_path = f"./instance/{clss}/N_{N}/"
    GE = min(inst_utils.get_GE(instance_path, N, inst_seed), GE_base)


    #number of replicas normalized to keep computation time constant
    R = R_base//T

    pert_range = np.arange(-1.0, 1.0, 0.05)
    pert2_range = pert_range*0

    rew_list = np.zeros_like(pert_range)
    

    if(p_pert_dir2 is None):
        p_pert_dir2 = torch.zeros_like(p_pert_dir)
    
    
    if(dim == 2):
        pr1 = np.arange(-1.0, 1.0, 0.05)
        pr2 = np.arange(-1.0, 1.0, 0.05)
        X, Y = np.meshgrid(pr1, pr2)
        pert_range = X.reshape(-1)
        pert2_range = Y.reshape(-1)

        rew_list = np.zeros_like(pert_range)

    for idx, (pert, pert2) in enumerate(zip(pert_range, pert2_range)):

        #print(p_L.shape)
        noise = torch.randn(P_total,R)/(P_total)**0.5
        if(nonoise):
            noise = noise*0
        
        p_pert = p_x.reshape(-1,1) + torch.matmul(p_L, noise) + p_pert_dir.reshape(-1,1)*pert + p_pert_dir2.reshape(-1,1)*pert2


        #print(p_pert[0,:5])
        
        print(p_pert[0,0])
        
        t_modes = model.get_t_modes(T, T_D)
        hyper_params["is_MIS"] = MIS
        print(hyper_params["is_MIS"])
        if(MIS):
            hyper_params["aux_spin"] = torch.tensor(instance_setup.get_current_MIS_size())
            hyper_params["fix_aux"] = True
        
        hyper_params["is_MC"] = False
        if(clss.startswith("BA") or clss.startswith("tBA")):
            hyper_params["is_MC"] = True
        
        #if(hyper_params["is_MIS"]):
        norm = 0.02
        
        hyper_params["norm"] = norm
        #print(N, norm)


        J = J.reshape(1, N, N)
        E_opt, E_hist, traj_rec = model.run_alg(p_pert.detach(), t_modes, J, N, R, T, hyper_params, device = device, rec_traj = True)
        
        

        GE = min(torch.min(E_opt).cpu().numpy(), GE)

        E_hist = E_hist - GE

        
        
        #print(E_opt, GE)
        
        #save best found GE
        instance_path = f"./instance/{clss}/N_{N}/"
        
        inst_utils.save_GE(instance_path, N, inst_seed, GE, "min")
        
        
        #convert objective value to reward
        reward = calc_reward(E_opt.cpu().numpy(), GE, rho)

        print(pert, np.mean(reward))

        

        #print(GE, E_opt, reward)

        #print(p_pert.shape)

        #instance_setup.rew_inst[inst_config_idx] = np.mean(reward)
        rew_list[idx] = np.mean(reward)
        
    if(dim == 2):
        return pert_range.reshape(-1, pr1.shape[0]), pert2_range.reshape(-1, pr1.shape[0]), rew_list.reshape(-1, pr1.shape[0])
    return pert_range, rew_list
    

instance_setup.preprocess_inst()

if(group_name_in == "FROM_BOOT"):
    v = -SEED
    boot_path = p_checkpoint_path + f"/v{v}"
    p_x = torch.tensor(np.loadtxt(boot_path + "/p.txt"), dtype = torch.float32)
    p_L = torch.tensor(np.loadtxt(boot_path + "/p_L.txt"), dtype = torch.float32)
else:
    p_x = torch.tensor(np.loadtxt(inpath + "/p.txt"), dtype = torch.float32)
    p_L = torch.tensor(np.loadtxt(inpath + "/p_L.txt"), dtype = torch.float32)


torch.manual_seed(1)

p_pert = torch.randn(p_x.shape)*0.04

#inst_config_idx = 2

dim = 2

if(dim == 1):
    print("eval without noise...")

    rew_total = 0

    n_inst = 5


    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

    for i in range(n_inst):
        inst_config_idx = i+0
        pert_range, rew_list = get_rew(p_x, p_L, p_pert, inst_config_idx, nonoise=True)

        rew_var = rew_list*(1 - rew_list)/R

        plt.fill_between(pert_range, rew_list -  2*rew_var**0.5, rew_list + 2*rew_var**0.5, alpha = 0.2, color = colors[i])
        plt.plot(pert_range, rew_list, alpha = 0.5, color = colors[i])
        
        rew_total += rew_list


    plt.plot(pert_range, rew_total/n_inst, dashes = [5,5])

    plt.savefig(outpath + "/rew_slice.png")

    plt.show()
    plt.close()
    #np.savetxt(outpath + "/rew_centered.txt", instance_setup.rew_inst)


elif(dim == 2):


    p_pert2 = torch.randn(p_x.shape)*0.04

    inst_config_idx = 1


    pert_range, pert2_range, rew_list = get_rew(p_x, p_L, p_pert, inst_config_idx, nonoise=True, dim=2, p_pert_dir2 = p_pert2)

    plt.pcolormesh(pert_range, pert2_range, rew_list)

    plt.show()
    plr.close()