import torch
import time

model_id = 0

T_alt = 1

cached_t_modes = {}

def get_t_modes(T, T_D):
    if(T in cached_t_modes.keys()):
        return cached_t_modes[T]
    T_D_ = T_D//T_alt
    t_modes = torch.zeros(T_D, T)
    

    t_range = torch.tensor(range(T))/T
    for a in range(T_alt):
        alt_mask = (torch.tensor(range(T)) % T_alt == a)
        t_modes[T_D_*a + 0,:] = alt_mask*1
        for k in range((T_D_ - 1)//2):
            t_modes[T_D_*a + 1 + k*2,:] = alt_mask*1*torch.cos(t_range*torch.pi*(k+1))/(k+2)**1
            t_modes[T_D_*a + 2 + k*2,:] = alt_mask*1*torch.sin(t_range*torch.pi*(k+1))/(k+2)**1
    
    cached_t_modes[T] = t_modes
    return t_modes



def get_P(hyper_params):
    T_in = hyper_params["T_in"]
    D_int = hyper_params["D_int"]

    #total params
    if(model_id == -1):
        return 4

    if(model_id == -2):
        return 5

    if(model_id == 0):
        return T_in*D_int*1 + D_int*1

    if(model_id == 1):
        return T_in
    

    if(model_id == 2):
        return T_in*D_int*1 + D_int*1

    if(model_id == 3):
        return T_in
    
    if(model_id == 4):
        return T_in*D_int*1 + D_int*1 + 1

    if(model_id == 5):
        return T_in + 1
    
    if(model_id == 6):
        return T_in*D_int*1 + D_int*1 + 1

    if(model_id == 7):
        return T_in + 1
#h_hist shape (T_in, N, R)
#p shape (P, B, R)
#J shape (B, N, N)
def step(p, h_hist, J, N, R, B, t, hyper_params, device = "cpu"):
    T_in = hyper_params["T_in"]
    D_int = hyper_params["D_int"]
    norm = hyper_params["norm"]
    
    #total params
    P = get_P(hyper_params)

    p = p.swapaxes(0,1)
    #shape (B, P, R)

    if(model_id == 0 or model_id == 2 or model_id == 4 or model_id == 6):
        layer1 = p[:,:T_in*D_int, :]
        layer2 = p[:,T_in*D_int:T_in*D_int+ D_int, :]

        
        z1 = torch.sum(layer1.reshape(B, D_int, T_in, 1, R)*h_hist.reshape(B, 1, T_in, N, R), axis = 2)
        z1 = torch.tanh(z1) + z1
        z2 = torch.sum(layer2.reshape(B, D_int, 1, R)*z1.reshape(B, D_int, N, R), axis = 1)
        

    if(model_id == 1 or model_id == 3 or model_id == 5 or model_id == 7):
        layer = p[:, :T_in, :]
        z2 = torch.sum(layer.reshape(B, T_in, 1, R)*h_hist.reshape(B, T_in, N, R), axis = 1)
    
    

    #print(z2[0,0])

    # print(N, R)
    # print(z2.shape)

    if(model_id >= 0 and model_id < 2):
        z2 = torch.tanh(z2)
        z2 = z2*0.99

        s = torch.sign(z2 + 1 - torch.rand(B,N,R, device = device)*2)
    
    elif(model_id >= 2 and model_id < 4):

        s = torch.tanh(z2 + 0.01*torch.randn(B,N,R, device = device))

    elif(model_id >= 4 and model_id < 6):

        s = torch.tanh(z2 + p[:,-1, :].reshape(B,1,R)*torch.randn(B,N,R, device = device))
    
    elif(model_id >= 6 and model_id < 8):

        s = torch.sign(z2 + p[:,-1, :].reshape(B,1,R)*torch.randn(B,N,R, device = device))
    
    
    if(hyper_params["fix_aux"]):
        s[range(B),hyper_params["aux_spin"],:] = 1
    h_raw = torch.matmul(J, s)
    h = norm*h_raw

    h_hist = torch.roll(h_hist, shifts=(-1), dims=(1))
    h_hist[:,-1,:,:] = h

    #print(h_hist[:,0,0])

    s_greedy = torch.sign(z2)
    if(model_id >= 6 and model_id < 8):
        s_greedy = s
    
    if(hyper_params["fix_aux"]):
        s_greedy[range(B), hyper_params["aux_spin"],:] = 1
    
    if(model_id < 6):
        h_greedy = torch.matmul(J, s_greedy)
    else:
        h_greedy = h_raw

    E = torch.sum(s_greedy*h_greedy, axis = 1)

    #if(t >= 0.96):
        
        #print(h_hist[:,0,0], z2[0,0])
    #print(t, E[0])
    if(model_id >= 2 and model_id < 8):
        z2 = s

    return h_hist, z2, E


def step_CAC(p, h_hist, J, N, R, B, t, hyper_params, device = "cpu"):
    T_in = hyper_params["T_in"]
    D_int = hyper_params["D_int"]
    norm = hyper_params["norm"]
    
    #total params
    P = get_P(hyper_params)


    x = h_hist[1,:,:]
    e_log = h_hist[2,:,:]
    p1 = 0.1*torch.exp(p[0]/10)
    p2 = 0.1*torch.exp(p[1]/10)
    p3 = 0.1*torch.exp(p[2]/10)
    p4 = 0.9*torch.exp(p[3]/10)

    # p1 = 0.1
    # p2 = 0.1
    # p3 = 0.1
    # p4 = 0.5

    x_ = x + p1*(-x) + p2*(-x**3 - torch.exp(e_log)*h_hist[0,:,:])
    e_log_ = e_log + p3*(1 - x**2)
    x_ = torch.clamp(x_, -4,4)
    z2 = torch.tanh(p4*x_)

    #print(z2[0,0])

    # print(N, R)
    # print(z2.shape)
    z2 = z2*0.99

    s = torch.sign(z2 + 1 - torch.rand(N,R, device = device)*2)
    h = norm*torch.matmul(J, s)

    h_hist[0,:,:] = h
    h_hist[1,:,:] = x_
    h_hist[2,:,:] = e_log_
    

    s_greedy = torch.sign(z2)
    h_greedy = torch.matmul(J, s_greedy)

    E = torch.sum(s_greedy*h_greedy, axis = 0)

    #if(t >= 0.96):
        
        #print(h_hist[:,0,0], z2[0,0])
    #print(t, E[0])
    return h_hist, z2, E



def step_CAC2(p, h_hist, J, N, R, B,t, hyper_params, device = "cpu"):
    T_in = hyper_params["T_in"]
    D_int = hyper_params["D_int"]
    norm = hyper_params["norm"]
    
    #total params
    P = get_P(hyper_params)


    x = h_hist[:,1,:,:]
    e_log = h_hist[:,2,:,:]

    p = p.swapaxes(0,1)

    p1 = 0.1*torch.exp(p[:,0,:]/0.5).reshape(B, 1, R)
    p2 = 0.1*torch.exp(p[:,1,:]/0.5).reshape(B, 1, R)
    p3 = 0.1*torch.exp(p[:,2,:]/0.5).reshape(B, 1, R)
    p4 = 0.2*torch.exp(p[:,3,:]/0.5).reshape(B, 1, R)
    p5 = 0.9*torch.exp(p[:,4,:]/0.5).reshape(B, 1, R)

    # p1 = 0.1
    # p2 = 0.1
    # p3 = 0.1
    # p4 = 0.5


    
    x_ = x + p1*(-x) + p2*(-x**3 - p5*torch.exp(e_log)*h_hist[:,0,:,:])
    e_log_ = e_log + p3*(1 - x**2)
    x_ = torch.clamp(x_, -4,4)
    z2 = torch.tanh(p4*x_)

    #print(z2[0,0])

    # print(N, R)
    # print(z2.shape)
    z2 = z2*0.99

    s = z2 + torch.randn(B, N, R, device = device)*0.01

    if(hyper_params["fix_aux"]):
        s[hyper_params["aux_spin"],:] = 1

    
    h = norm*torch.matmul(J, s)

    h_hist[:,0,:,:] = h
    h_hist[:,1,:,:] = x_
    h_hist[:,2,:,:] = e_log_
    

    s_greedy = torch.sign(z2)
    h_greedy = torch.matmul(J, s_greedy)
    if(hyper_params["fix_aux"]):
        s_greedy[hyper_params["aux_spin"],:] = 1
    E = torch.sum(s_greedy*h_greedy, axis = 1)

    #if(t >= 0.96):
        
        #print(h_hist[:,0,0], z2[0,0])
    #print(t, E[0])
    return h_hist, z2, E


#T_D = number of t_modes
#t_modes shape = (T_D, T)

current_obj = 0
current_obj_top_10 = 0
current_obj_max = 0

def run_alg(p_global, t_modes, J, N, R, T, hyper_params, device = "cpu", rec_traj = False, B = 1):
    global current_obj, current_obj_top_10, current_obj_max
    
    J = J.to(device)
    T_D = hyper_params["T_D"]
    T_in = hyper_params["T_in"]
    #print(t_modes)
    #print(t_modes.shape)
    p_global = p_global.reshape(-1, T_D, B, R)
    p = torch.sum(p_global.reshape(-1, 1, T_D, B, R)*t_modes.T.reshape(1, -1, T_D, 1, 1), axis = 2)

    p = p.to(device)
    
    #print(p.shape)

    h_hist = torch.zeros(B, T_in, N, R, device = device)

    if(model_id == -2):
        h_hist[:,1,:,:] = torch.randn(B, N, R, device = device)

    E_opt = torch.zeros(B, R, device = device) + 10**10
    S_opt = torch.zeros(B, R, device = device)

    E_hist = torch.zeros(T, R, device = device)
    
    traj_rec = torch.zeros(T, N, device = device)
    if(hyper_params["is_MIS"]):
        hyper_params["aux_spin"] = torch.tensor(hyper_params["aux_spin"], device = device, dtype = torch.int32)
        aux_spin_idx = hyper_params["aux_spin"]
        s_mask = (torch.tensor(range(N), device = device).reshape(1,N) < aux_spin_idx.reshape(B,1)).reshape(B, N, 1)


    tstart = time.time()

    for ti in range(T):
        
        if(model_id >= 0):
            h_hist, z, E = step(p[:,ti,:], h_hist, J, N, R, B, float(ti)/T, hyper_params, device = device)
        if(model_id == -1):
            h_hist, z, E = step_CAC(p[:,ti,:], h_hist, J, N, R, B,float(ti)/T, hyper_params, device = device)
        if(model_id == -2):
            h_hist, z, E = step_CAC2(p[:,ti,:], h_hist, J, N, R, B, float(ti)/T, hyper_params, device = device)
        
        if(hyper_params["is_MIS"]):

            s = torch.sign(z)
            
            aux_spin_idx = hyper_params["aux_spin"]
            
            s = s*s[range(B), aux_spin_idx, :].reshape(B, 1, R)

            J_graph = J + 0
            J_graph[range(B), :, aux_spin_idx] = 0
            J_graph[range(B), aux_spin_idx, :] = 0

            s_plus = (1 + s)/2
            s_plus = s_plus*s_mask

            vio = torch.sum(s_plus*torch.matmul(J_graph, s_plus), axis = 1)
            #print(matmul)
            indep_flag = vio <= 0

            S_current = torch.sum(s_plus, axis = 1)
            S_current = S_current*indep_flag
            S_opt = torch.maximum(S_opt, S_current)

        if(rec_traj):
            traj_rec[ti, :] = z[0, :,0]

        E_opt = torch.minimum(E_opt, E)
        E_hist[ti, :] = E[0,:]
    print("alg time", time.time() - tstart, "s")
    current_obj = torch.mean(-E_opt, axis = 1)
    current_obj_max = torch.max(-E_opt, axis = 1).values
    if(hyper_params["is_MC"]):
        offset = torch.sum(J, axis = (1,2))
        obj = (offset.reshape(B,1) - E_opt)/4
        current_obj = torch.mean(obj, axis = 1) 
        current_obj_max = torch.max(obj, axis = 1).values
        print("mcut", current_obj)
    if(hyper_params["is_MIS"]):
        print("mean set size", torch.mean(S_opt, axis = 1))
        print("max set size", torch.max(S_opt, axis = 1).values)
        current_obj = torch.mean(S_opt, axis = 1)
        # w = (torch.tensor(range(S_opt.shape[1]), device = device)/S_opt.shape[1])**9
        # w = w/torch.sum(w)
        #current_obj_top_10 = torch.sum(torch.tensor(sorted(S_opt, reverse = False), device = device)*w)
        #print("top 10 set size", current_obj_top_10)
        current_obj_max = torch.max(S_opt, axis = 1).values
    if(rec_traj):
        return E_opt.cpu(), E_hist.cpu(), traj_rec.cpu()

    

    return E_opt.cpu(), E_hist.cpu()



    


    


