import torch

# def update_vis(counter="total"):
#     if "init_complete" not in state or not state["init_complete"]:
#         return False
#     step_every = state["visualization_step_"+counter] if "visualization_step_"+counter in state else state["visualization_step_total"]
#     counter_i = state[counter+"_i"] if counter+"_i" in state else state["total_i"]
#     return step_every == 1 or not counter_i % int(step_every)

def plot_pca(state, event, input, bias, plot_id):
    k = 2
    X = event.flatten(input)
    U,_,V = torch.svd(torch.t(X))
    C = torch.mm(X,U[:,:k])
    b = torch.mm(bias[None].detach(),U[:,:k])
    labels = (C[:,0]<b[:,0]).type(torch.int)*2 + (C[:,1]<b[:,1]).type(torch.int) + 1

    # add bias as a point
    C = torch.cat((C,b))
    labels = torch.cat((labels,torch.tensor([5],device=labels.device,dtype=labels.dtype)))

    # add Δbias as a point
    # H.backward(retain_graph=True)
    # b2 = torch.mm((self.bias.detach()+0.1*self.bias.grad)[None],U[:,:k])
    # C = torch.cat((C,b2))
    # labels = torch.cat((labels,torch.tensor([6],device=labels.device,dtype=labels.dtype)))

    # size = int(12-np.log(X.shape[0]))
    size = 10
    state["vis"].scatter(C, labels.cpu().numpy(), win="pca_layer"+str(plot_id), opts={"title":"pca_"+str(plot_id), "markersize":size})


