import torch
from torch import nn


from Utils.utils import cal_dist_gaussian
from Utils.metrics import LS_loss_new


# %% set default args
def deploy_args(fname, plot_everthing = False, plot_flag = False, save_model = False, save_period = None, \
          peroid_loss = 1, period_graph = 50, num_epochs = 1000, num_neighbors = 5, top_k_iter = 100, \
            manual_flag = True, Batch_flag = False, pretrain_flag = False, pretrained_model = ""):
    arg_dict = {}
    arg_dict['fname'] = fname
    # visualization parameter
    arg_dict['plot_everthing'] = plot_everthing
    arg_dict['plot_flag'] = plot_flag
    arg_dict['period_loss'] = peroid_loss
    arg_dict['period_graph'] = period_graph
    arg_dict['save_model'] = save_model
    arg_dict['save_period'] = save_period

    # learning parameter
    arg_dict['num_epochs']= num_epochs
    arg_dict['num_neighbours'] = num_neighbors
    arg_dict['num_iter'] = top_k_iter
    arg_dict['manual_flag'] = manual_flag
    arg_dict['Batch_flag'] = Batch_flag

    # pretrain parameter
    arg_dict['pretrain_flag'] = pretrain_flag
    arg_dict['pretrained_model'] = pretrained_model
    return arg_dict


# %% Training: FS with top-k
def train(net, train_data,devices, args, X_ori = None, toy = False):
    # extract data
    X,y = train_data
    # extract parameters
    fname = args['fname']
    dim = args['fea_dim']
    pretrained_flag = args['pretrain_flag']
    pretrained_model = args['pretrained_model']

    num_epochs = args['num_epochs']
    lr = args['lr']

    if not pretrained_flag:
        # initialize net parameters
        def init_weights(m):
            if type(m) in [nn.Linear, nn.Conv2d]:
                nn.init.normal_(m.weight, std=0.01)
        net.apply(init_weights)
    else:
        net.load_state_dict(torch.load(pretrained_model))
        
    net = net.to(devices[0])


    # set optimizer & loss function
    trainer = torch.optim.Adam(net.parameters(), lr)

        

    objs = []
    # start training

    num_neighbour = args['num_neighbours']
    S_ori = cal_dist_gaussian(X,num_neighbours = num_neighbour)
    S_ori = torch.tensor(S_ori).to(devices[0])
    X = X.to(devices[0]).to(torch.float32)   

    for epoch in range(num_epochs):
        net.train()
        trainer.zero_grad()           
        temperature, FS_mat, S, X_new = net(X)
        I = (FS_mat**2).sum(dim = -1).unsqueeze(-1)
        
        if epoch>1:
            l_pre = l
        if S==None:
            S = S_ori
        l = LS_loss_new(X_new, S)
        objs.append(l.cpu().data)
        if epoch>1 and l>l_pre:
            pass
        l.backward()

        trainer.step()
        
                
    if toy:
        return I, S, FS_mat
    else:
        return I, S