import torch.nn as nn
import torch

def train_dynamic_embedder_single_letter(args):
    device = args.device    
    assert args.train_generator is not None
    train_generator = args.train_generator
    assert args.valid_generator is not None
    valid_generator = args.valid_generator
    assert args.model is not None
    model = args.model
    model.to(device)

    # training process
    global_step = 0

    # set optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    global_step = 0
    test_per_epoch = 20
    
    best_accuracy = -1.0
    while global_step < args.training_step:
        model.train()
        # feedforward
        train_batch_x, train_batch_y, train_batch_n = train_generator.__next__()
        train_batch_x = torch.tensor(train_batch_x, dtype=torch.float32).to(device)
        train_batch_y = torch.tensor(train_batch_y, dtype=torch.long).to(device)

        if args.invert_flag:
            train_batch_x = train_batch_x.permute(0, 2, 1)

        _, cond_vec = model(
            x_enc=train_batch_x,
            x_mark_enc=None,
        )
        cond_vec_cla = model.classifier(cond_vec)

        # loss function
        ce_loss_func = nn.CrossEntropyLoss()
        ce_loss_train = ce_loss_func(cond_vec_cla, train_batch_y)

        if global_step % test_per_epoch == 0:
            model.eval()

            with torch.no_grad():
                # compute r2_score on test spikes
                x, y, batch_num = valid_generator[0], valid_generator[1], valid_generator[0].shape[0]
                valid_batch_x = torch.tensor(x).to(device)
                valid_batch_y = torch.tensor(y).to(device)

                if args.invert_flag:
                    valid_batch_x = valid_batch_x.permute(0, 2, 1)
                
                _, cond_vec_test = model(                 
                    x_enc=valid_batch_x,
                    x_mark_enc=None
                )
                cond_vec_cla_test = model.classifier(cond_vec_test)

                # calculate accuracy
                correct = (cond_vec_test.argmax(dim=-1) == valid_batch_y).type(torch.float).sum().item()
                accuracy_tmp = correct / batch_num
                print('current accuracy: %.4f' % accuracy_tmp)

                if accuracy_tmp > best_accuracy:
                    best_accuracy = accuracy_tmp
                    valid_best_model = model
                    # torch.save(transformer_model, './pre_train/dynamic_transformer_model.pkl')         

        if args.update_flag:
            optimizer.zero_grad()
            ce_loss_train.backward()
            optimizer.step()

        global_step += 1

    print("best valid accuracy: %.4f" % best_accuracy)

    return