import argparse

import DataSet
from models import model
import torch
import torch.nn as nn
from torch.utils.data import DataLoader


def test(net, loader):
    net.eval()
    running_acc = 0
    count = 0
    for i, data in enumerate(loader):
        count += 1
        video, audio, label = data['video'].cuda(), data['audio'].cuda(), data['label'].long().cuda()
        _, _, y_hat = net(audio, video)
        running_acc += (y_hat.argmax(1) == label).float().item()
    print("Test ACC:{}".format(running_acc / count))
    return running_acc / count


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--nb-frames', type=int, default=32, help='frames of each video')
    parser.add_argument('--train_dataset', type=str, default='Train/')
    parser.add_argument('--test_dataset', type=str, default='Test/')

    parser.add_argument('--batch-size', type=int, default=24)
    parser.add_argument('--epoch', type=int, default=400)
    parser.add_argument('--lr', type=float,default=0.001, help='learning rate, defaults to 1e-3')
    parser.add_argument('--nb-workers', type=int, default=16, help='Number of workers for dataloader.')
    parser.add_argument('--nb-class', type=int, default=309, help='Number of class for dataset.')
    args = parser.parse_args()
    net = model.FusionModel(args.nb_class).cuda()
    AVDataset = DataSet.AudioVideoDataset(args.train_dataset, num_frames=args.nb_frames)
    AVDataloader = DataLoader(AVDataset, batch_size=args.batch_size, num_workers=args.nb_workers, shuffle=True, drop_last=False, pin_memory=True)
    print("DataSet:{}, DataLoader:{}".format(len(AVDataset), len(AVDataloader)))
    test_ADataset = DataSet.AudioVideoTestDataset(args.test_dataset, num_frames=args.nb_frames)
    test_ADataloader = DataLoader(test_ADataset, batch_size=1, num_workers=args.nb_workers, shuffle=True,
                                  drop_last=False, pin_memory=True)


    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    criterion = nn.CrossEntropyLoss()

    for epoch in range(args.epoch):
        net.train()
        running_loss = 0.
        count = 0
        correct = 0
        for i, data in enumerate(AVDataloader):
            video, audio, label = data['video'].cuda(), data['audio'].cuda(), data['label'].long().cuda()
            _, _, y_hat = net(audio, video, modal_drop=False)
            loss = criterion(y_hat, label)

            running_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            count += 1
            correct += y_hat.max(1)[1].eq(label).sum().item()

        print("dropout baseline epoch:{}, loss:{}, running_acc:{}".format(epoch, running_loss/count, correct/len(AVDataset)))


        test_acc = test(net, loader=test_ADataloader)
        scheduler.step()

