import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import inst_utils
import model
import instance_setup
import os
import sys
from datetime import datetime
import shutil
import json
import device_config

device = device_config.device

T_in = 4
D_int = 3
T_D = 3

SEED = 22

use_DAS = True

model_id = 0

inst_rho = 0

group_name = "SN_1T_N_100"

if(len(sys.argv) > 1):
    T_in = int(sys.argv[1])

if(len(sys.argv) > 2):
    D_int = int(sys.argv[2])

if(len(sys.argv) > 3):
    T_D = int(sys.argv[3])

if(len(sys.argv) > 4):
    SEED = int(sys.argv[4])

if(len(sys.argv) > 5):
    group_name = sys.argv[5]

if(len(sys.argv) > 6):
    model_id = int(sys.argv[6])

if(len(sys.argv) > 7):
    inst_rho = float(sys.argv[7])


if(model_id == -1):
    T_in = 4
    D_int = 1

if(model_id == -2):
    T_in = 5
    D_int = 1

if(model_id == 1):
    D_int = 1


instance_setup.set_instance_group(group_name)

model.model_id = model_id

hyper_params = {
    "T_in": T_in,
    "D_int": D_int,
    "T_D": T_D,
    "is_MIS": False,
    "fix_aux": False
}

rew_type = "succ"

with open('config.json', 'r') as file:
    config =  json.load(file)

    R = config["R"]

    B = config["B"]

    numb_epochs = config["numb_epochs"]

    rew_type = config["rew_type"]



now = datetime.now()
run_time_str = now.strftime("%Y-%m-%d %H:%M:%S")


outpath = f"out/m_{model_id}/{group_name}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"

if(inst_rho > 0):
    outpath = f"out/m_{model_id}/{group_name}/rho_{inst_rho}/T_{T_in}_{T_D}_Di_{D_int}_S_{SEED}"

p_checkpoint_path = f"boot/m_{model_id}_T_{T_in}_{T_D}_Di_{D_int}"


if(not use_DAS):
    outpath = outpath + "_fw"

if(not os.path.exists(outpath)):
    os.makedirs(outpath)

if(not os.path.exists(p_checkpoint_path + "/current")):
    os.makedirs(p_checkpoint_path + "/current")


P = model.get_P(hyper_params)


P_total = P*T_D


rew_scale = 0.005

def calc_reward(E_opt, GE, rho):
    
    #return -(E_opt - GE)*0.01
    if(rew_type == "obj"):
        return np.maximum(0.0, 1.0 - (E_opt - GE)*rew_scale)*1.0
    
    #print(GE, np.mean(E_opt))
    if(rew_type == "succ"):
        return (E_opt <= GE )*1.0 + (E_opt >= GE*0.5)*(-0.5)


batch_N = 1



global_rng = np.random.default_rng(seed = 0)

E_hist = None
traj_rec = None

def get_batch_grad(p_x, p_L, nonoise = False):
    global E_hist, traj_rec, cumulative_rew_inst
    mom2 = 0
    for i in range(batch_N):

        
        p = 1.0/np.maximum(0.1, cumulative_rew_inst)**inst_rho
        p = p/np.sum(p)
        # if(i == batch_N - 1):
        #     print(p)
        
        J_list = []
        norm_all = torch.zeros(B)
        GE_all = torch.zeros(B)
        inst_seed_all = []
        inst_config_idx_all = []

        aux_spin_all = torch.zeros(B, dtype = torch.int32)

        for b in range(B):
            np.random.seed(int(global_rng.random()*10000))
            inst_config_idx = np.random.choice(range(instance_setup.numb_inst_config), 1, p = p)[0]
            #print(inst_config_idx)
            J, norm, inst_seed, GE_base, N, T, rho, MIS, meta, clss = instance_setup.get_config(inst_config_idx)
            
            instance_path = f"./instance/{clss}/N_{N}/"
            GE = min(inst_utils.get_GE(instance_path, N, inst_seed), GE_base)

            J_list.append(J)
            norm_all[b] = norm
            GE_all[b] = GE
            inst_seed_all.append(inst_seed)
            inst_config_idx_all.append(inst_config_idx)
            aux_spin_all[b] = instance_setup.get_current_MIS_size()
        
        J_all = torch.zeros(B,N,N)
        for b in range(B):
            J_all[b,:,:] = J_list[b]

        #print(J_all.shape)

        #number of replicas normalized to keep computation time constant
        #R = R_base//T

        #print(p_L.shape)
        noise = torch.randn(P_total, R*B)/(P_total)**0.5
        if(nonoise):
            noise = noise*0
        
        p_pert = p_x.reshape(-1,1) + torch.matmul(p_L, noise)

        #print(p_pert.shape)
        #print(p_pert[0,:5])
        
        
        t_modes = model.get_t_modes(T, T_D)
        hyper_params["is_MIS"] = MIS
        #print(hyper_params["is_MIS"])
        if(MIS):
            hyper_params["aux_spin"] = aux_spin_all
            hyper_params["fix_aux"] = True
        
        hyper_params["is_MC"] = False
        if(clss.startswith("BA")):
            hyper_params["is_MC"] = True
        
        #norm = 0.02
        hyper_params["norm"] = norm
        #print(N, norm)
        E_opt, E_hist, traj_rec = model.run_alg(p_pert.reshape(P_total, B, R).detach(), t_modes, J_all, N, R, T, hyper_params, device = device, rec_traj = True, B = B)
        
        
        #print(torch.min(E_opt, axis = 1))

        GE_all = torch.minimum(torch.min(E_opt, axis = 1).values.cpu(), GE_all)

        E_hist = E_hist - GE_all[0]

        
        
        #print(E_opt, GE)
        
        for GE, inst_seed in zip(GE_all, inst_seed_all):
            #save best found GE
            instance_path = f"./instance/{clss}/N_{N}/"
            
            inst_utils.save_GE(instance_path, N, inst_seed, GE, "min")
        
        
        #convert objective value to reward

        adv_all = torch.zeros(B, R)
        for b in range(B):
            inst_config_idx = inst_config_idx_all[b]

            reward = calc_reward(E_opt[b, :].cpu().numpy(), GE_all[b].cpu().numpy(), rho)

            adv = reward - np.mean(reward)
            adv_all[b, :] = torch.tensor(adv)
            #print(GE, E_opt, reward)

            #print(p_pert.shape)

            instance_setup.rew_inst[inst_config_idx] = np.mean(reward)
            cumulative_rew_inst[inst_config_idx] += np.mean(reward)
            #print(model.current_obj_max)
            instance_setup.obj_inst[inst_config_idx] = model.current_obj[b].cpu().numpy()
            #instance_setup.obj_t10_inst[inst_config_idx] = model.current_obj_top_10
            instance_setup.obj_max_inst[inst_config_idx] = model.current_obj_max[b].cpu().numpy()

        loss = torch.mean(torch.sum((p_x.detach().reshape(-1,1) - p_pert)**2, axis = 0)*adv_all.reshape(-1))

        loss = loss/batch_N

            #mom2 += torch.mean(torch.sum(noise**2, axis = 0)*torch.tensor(adv))/batch_N

        loss.backward()
    
    

torch.manual_seed(SEED)



p_x = torch.randn(P_total)*0.0 + (0.0)
# p_x[0] = -1.0
# p_x[T_D] = -1.0
#p_x[::T_D] = 0.0

p_L = torch.randn(P_total, P_total)*0.0 + 0.1*torch.diag(torch.ones(P_total))*P_total**0.5



p_L = p_L/(1 + torch.tensor(range(P_total))).reshape(1,-1)**0.0

if(SEED < 0):
	v = -SEED
	boot_path = p_checkpoint_path + f"/v{v}"
	p_x = torch.tensor(np.loadtxt(boot_path + "/p.txt"), dtype = torch.float32)
	p_L = torch.tensor(np.loadtxt(boot_path + "/p_L.txt"), dtype = torch.float32)
	p_L = p_L*1.0

p_x = Variable(p_x, requires_grad = True)
p_L = Variable(p_L, requires_grad = True)

Lr = 0.2

print("preprocess...")
instance_setup.preprocess_inst()
print("done")

epoch_rec = []
rew_rec = []
wsize_rec = []

p_rec = []

lamb_rec = []

rew_all_rec = []

rew_base = 0

cumulative_rew_inst = np.zeros_like(instance_setup.rew_inst)

def create_snapshot(epoch):
    files = os.listdir(outpath)

    snap_dir = outpath + "/snap/" + run_time_str.replace("/", "_") + "/"

    print("creating snapshot...", epoch)
    for filename in files:
        if(filename != "snap" and not filename.startswith(".")):
            print("copying", filename)
            dirname = filename.replace(".","_")
            ext = filename.split(".")[1]
            if(not os.path.exists(snap_dir + dirname)):
                os.makedirs(snap_dir + dirname)
            shutil.copyfile(outpath + "/" + filename, snap_dir + dirname + f"/snap{epoch}.{ext}")
    print("done")


for epoch in range(numb_epochs):

    get_batch_grad(p_x, p_L)

    with torch.no_grad():
        if(epoch >= 1):
            p_x.data += 100*p_x.grad*Lr
            if(use_DAS):
                p_L.data += (P_total**0.5)*(100)*p_L.grad*Lr

        p_L.data = p_L.data*(10/torch.clamp(torch.mean(p_L.data**2, axis = 1)**0.5, 10, 1000).reshape(-1,1))
        # print(torch.sum(p_L.data**2, axis = 1)**0.5)
        # print(torch.clamp(torch.sum(p_L.data**2, axis = 1)**0.5, 1, 1000))
        
        p_x.grad.zero_()
        p_L.grad.zero_()
    
    #print(p_x[0])
    #print(p_L[0,0])
    #print(instance_setup.rew_inst)
    print("epoch", f"{epoch}/{numb_epochs}", "rew", np.mean(instance_setup.rew_inst), np.mean(instance_setup.obj_inst), np.mean(instance_setup.obj_t10_inst), np.mean(instance_setup.obj_max_inst))
    #print(instance_setup.obj_inst)
    #print("obj std", np.std(instance_setup.obj_inst), np.std(instance_setup.obj_inst)/instance_setup.numb_inst_config**0.5)
    np.set_printoptions(suppress=True)
    #print("window size", torch.mean(p_L**2).detach().cpu().numpy()**0.5)
    LLt = torch.matmul(p_L, p_L.T).detach().numpy()
    e,v  = np.linalg.eig(LLt)
    #print("diag", sorted(np.diag(LLt)))
    #print("ev", sorted(np.real(e)))
    p_rec.append(p_x.detach().numpy() + 0.0)

    epoch_rec.append(epoch)
    rew_rec.append(np.mean(instance_setup.rew_inst))
    wsize_rec.append(torch.mean(p_L**2).detach().cpu().numpy()**0.5)
    lamb_rec.append(sorted(np.real(e)))
    rew_all_rec.append(instance_setup.rew_inst + 0.0)
    kappa = np.mean(np.log(np.real(e)))
    #print(np.prod(np.real(e)), kappa)

    
    #print("Lr", Lr, "R base", R_base, "rew scale", rew_scale)


    if(epoch % 10 == 9):
        if(np.mean(instance_setup.rew_inst) > 0.5):
            rew_scale = rew_scale*1.5
    
    stride = 5

    if(device == "cuda"):
        stride = 50

    if(epoch % stride == stride - 1):
        
        
        plt.xlabel("epoch")
        plt.ylabel("reward")
        plt.plot(epoch_rec, rew_rec)
        plt.savefig(outpath + "/rew_progress.png")
        plt.close()
        
        fig, ax = plt.subplots(2,2, figsize = (12,10))
        
        ax[0,0].set_xlabel("epoch")
        ax[0,0].set_ylabel("reward")
        for i in range(instance_setup.numb_inst_config):
            ax[0,0].plot(epoch_rec, [rew[i] for rew in rew_all_rec])

        ax[1,0].set_yscale("log")
        ax[1,0].set_xlabel("epoch")
        ax[1,0].set_ylabel("reward")
        for i in range(instance_setup.numb_inst_config):
            ax[1,0].plot(epoch_rec, [rew[i] for rew in rew_all_rec])
        

        ax[1,1].set_xlabel("epoch")
        ax[1,1].set_ylabel("reward")
        ax[1,1].plot(epoch_rec, rew_rec)


        percentile_list = [0,25,50,75,100]
        ax[0,1].set_yscale("log")
        ax[0,1].set_xlabel("epoch")
        ax[0,1].set_ylabel("reward")
        for pct in percentile_list:
            ax[0,1].plot(epoch_rec, [np.percentile(rew, pct) for rew in rew_all_rec])

        plt.savefig(outpath + "/rew_progress.png")
        plt.close()

        plt.xlabel("epoch")
        plt.ylabel("window size")
        plt.plot(epoch_rec, wsize_rec)
        plt.savefig(outpath + "/wsize_progress.png")
        plt.close()

        np.savetxt(outpath + "/info.txt", np.array([epoch_rec, rew_rec, wsize_rec]).T)

        np.savetxt(outpath + "/rew.txt", instance_setup.rew_inst)
        
        np.savetxt(outpath + "/p.txt", p_x.detach().numpy())
        np.savetxt(outpath + "/p_L.txt", p_L.detach().numpy())

        if(not os.path.exists(p_checkpoint_path + "/current")):
            os.makedirs(p_checkpoint_path + "/current")
        np.savetxt(p_checkpoint_path + "/current/p.txt", p_x.detach().numpy())
        np.savetxt(p_checkpoint_path + "/current/p_L.txt", p_L.detach().numpy())

        with open(p_checkpoint_path + "/current/info.txt", "w") as file:
            file.write(group_name + "\n")
            file.write(f"SEED {SEED}\n")

        for i in range(P_total):
            plt.plot(epoch_rec, [p[i] for p in p_rec])
        
        plt.savefig(outpath + "/p_progress.png")
        plt.close()

        for i in range(P_total):
            plt.plot(epoch_rec, [lamb[i] for lamb in lamb_rec])
        
        plt.savefig(outpath + "/lamb_progress.png")
        plt.close()

        T_ = 100
        t_modes = model.get_t_modes(T_, T_D)
        t_range = np.array(range(T_))/T_
        p_mode = p_x.detach().numpy().reshape(-1, T_D)

        p = np.dot(p_mode, t_modes)

        for i in range(p.shape[0]):
            plt.plot(t_range, p[i, :])
        
        plt.savefig(outpath + "/anneal_current.png")
        plt.close()
        

        # if(model_id == 0 or model_id == 2 or model_id == 4 or model_id == 6):
        #     layer1 = p[:,:T_in*D_int, :]
        #     layer2 = p[:,T_in*D_int:T_in*D_int+ D_int, :]

            
        #     z1 = torch.sum(layer1.reshape(B, D_int, T_in, 1, R)*h_hist.reshape(B, 1, T_in, N, R), axis = 2)
        #     z1 = torch.tanh(z1) + z1
        #     z2 = torch.sum(layer2.reshape(B, D_int, 1, R)*z1.reshape(B, D_int, N, R), axis = 1)
            

        # if(model_id == 1 or model_id == 3 or model_id == 5 or model_id == 7):
        #     layer = p[:, :T_in, :]
        #     z2 = torch.sum(layer.reshape(B, T_in, 1, R)*h_hist.reshape(B, T_in, N, R), axis = 1)
        
        fig, ax = plt.subplots(2,2, figsize = (12,10))

        def plot_network(ax_, p_):
            if(model_id == 0 or model_id == 2 or model_id == 4 or model_id == 6):
                print("2 layer")
                layer1 = p_[:T_in*D_int].reshape(D_int, T_in)
                scale = np.max(np.abs(layer1))

                ax_.set_ylim((-2,2))
                for i in range(T_in):
                    for j in range(D_int):
                        weight = layer1[j, i]/scale
                        ax_.plot([-i + (T_in-1)/2, -j + (D_int-1)/2], [1, 0], color = ((1 + weight)/2, 0.0 , (1 - weight)/2))

                layer2 = p_[T_in*D_int:T_in*D_int+ D_int]
                scale = np.max(np.abs(layer2))

                for j in range(D_int):
                    weight = layer2[j]/scale
                    ax_.plot([-j + (D_int-1)/2, 0], [0,-1], color = ((1 + weight)/2, 0.0 , (1 - weight)/2))
                

            if(model_id == 1 or model_id == 3 or model_id == 5 or model_id == 7):
                print("1 layer")
                layer = p_[:T_in]
                scale = np.max(np.abs(layer))

                ax_.set_ylim((-2,2))
                for i in range(T_in):
                    weight = layer[i]/scale
                    ax_.plot([-i + (T_in-1)/2, 0], [1,0], color = ((1 + weight)/2, 0.0 , (1 - weight)/2))
        
        ax[0,0].set_title("initial network")
        plot_network(ax[0,0], p[:,0])

        ax[1,0].set_title("final network")
        plot_network(ax[1,0], p[:,-1])

        plt.savefig(outpath + "/network_weights.png")
        plt.close()

        fig, ax = plt.subplots(2,2, figsize = (12,10))

        ax[0,0].set_xlabel("t")
        ax[0,0].set_ylabel("E")
        for i in range(min(E_hist.shape[1], 40)):
            ax[0,0].plot(range(E_hist.shape[0]), E_hist[:,i])
        

        ax[1,0].set_xlabel("t")
        ax[1,0].set_ylabel("E")
        ax[1,0].set_yscale("symlog")
        for i in range(min(E_hist.shape[1], 40)):
            ax[1,0].plot(range(E_hist.shape[0]), E_hist[:,i])
        
        ax[1,1].set_xlabel("t")
        ax[1,1].set_ylabel("E")
        ax[1,1].set_yscale("symlog")
        for i in range(min(E_hist.shape[1], 40)):
            ax[1,1].plot(range(E_hist.shape[0]), E_hist[:,i], color = "blue", alpha = 0.05)
        
        ax[1,1].plot(range(E_hist.shape[0]), E_hist.mean(axis = 1), color = "black")

        
        plt.savefig(outpath + "/E_traj_current.png")
        plt.close()

        plt.figure(figsize = (5,3.5))

        plt.xlabel("t", fontsize = 18)
        plt.ylabel("x", fontsize = 18)

        for i in range(10):
            plt.plot(range(traj_rec.shape[0]), traj_rec[:,i])
        
        plt.tight_layout()
        plt.savefig(outpath + "/x_traj_current.png")
        plt.close()

        plt.figure(figsize = (5,3.5))
        
        plt.xlabel("t", fontsize = 18)
        plt.ylabel("x", fontsize = 18)

        flip_vec = (-1)**np.array(range(traj_rec.shape[0]))
        #print(flip_vec)
        for i in range(10):
            plt.plot(range(traj_rec.shape[0]), flip_vec*traj_rec[:,i].numpy())
        
        plt.tight_layout()
        plt.savefig(outpath + "/x_traj_flipped_current.png")
        plt.close()

    if((epoch % stride == stride-1 and epoch <= 50) or epoch % 50 == 49):
        create_snapshot(epoch)
    
    
    
    if(epoch % 50 == 49):
        rew_current = np.mean(instance_setup.rew_inst)

        # if(rew_current < rew_base):
        #     R_base = int(R_base*1.4)
        
        
        rew_base = max(rew_base, rew_current)



