import argparse

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import DataSet
from models import model


def test(audio_model, video_model, linear_cls, loader):
    linear_cls.eval()
    running_acc = 0
    count = 0
    for i, data in enumerate(loader):
        count += 1
        video, audio, label = data['video'].cuda(), data['audio'].cuda(), data['label'].long().cuda()
        v_f, _ = video_model(video)
        a_f, _ = audio_model(audio)

        y_hat = linear_cls(torch.cat((v_f, a_f), 1))
        running_acc += (y_hat.argmax(1) == label).float().item()
    print("Test ACC:{}".format(running_acc / count))
    return running_acc / count

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--nb-frames', type=int, default=32, help='frames of each video')
    parser.add_argument('--train_dataset', type=str, default='Train/')
    parser.add_argument('--test_dataset', type=str, default='Test/')

    parser.add_argument('--batch-size', type=int, default=24)
    parser.add_argument('--step_size', type=int, default=1)

    parser.add_argument('--epoch', type=int, default=20)
    parser.add_argument('--lr', type=float,default=0.001, help='learning rate, defaults to 1e-3')
    parser.add_argument('--nb-workers', type=int, default=16, help='Number of workers for dataloader.')
    parser.add_argument('--nb-class', type=int, default=309, help='Number of class for dataset.')
    args = parser.parse_args()

    # net = model.FusionModel(args.nb_class).cuda()
    linear_cls = nn.Sequential(
        nn.Linear(1024, args.nb_class)
    ).cuda()

    audio_model = model.AudioModel(args.nb_class).cuda()
    video_model = model.VideoModel(args.nb_class).cuda()

    audio_model.load_state_dict(torch.load("audio.pt"))
    video_model.load_state_dict(torch.load("video.pt"))

    audio_model.eval()
    video_model.eval()

    AVDataset = DataSet.AudioVideoDataset(args.train_dataset, num_frames=args.nb_frames)
    AVDataloader = DataLoader(AVDataset, batch_size=args.batch_size, num_workers=args.nb_workers, shuffle=True, drop_last=False, pin_memory=True)

    print("DataSet:{}, DataLoader:{}".format(len(AVDataset), len(AVDataloader)))
    test_ADataset = DataSet.AudioVideoTestDataset(args.test_dataset, num_frames=args.nb_frames)
    test_ADataloader = DataLoader(test_ADataset, batch_size=1, num_workers=args.nb_workers, shuffle=True,
                                  drop_last=False, pin_memory=True)

    optimizer = torch.optim.Adam(linear_cls.parameters(), lr=args.lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.1)

    criterion = nn.CrossEntropyLoss()

    for epoch in range(args.epoch):
        if epoch==0:
            test(audio_model, video_model, linear_cls, test_ADataloader)
        linear_cls.train()
        running_loss = 0.
        count = 0
        correct = 0
        for i, data in enumerate(AVDataloader):
            video, audio, label = data['video'].cuda(), data['audio'].cuda(), data['label'].long().cuda()
            with torch.no_grad():
                v_f, _ = video_model(video)
                a_f, _ = audio_model(audio)
            y_hat = linear_cls(torch.cat((v_f, a_f), 1))
            correct += y_hat.max(1)[1].eq(label).sum().item()

            loss = criterion(y_hat, label)

            running_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            count += 1

        print("full set finetune step_size:{}, epoch:{}, loss:{}, running_acc:{}".format(args.step_size, epoch, running_loss/count, correct/len(AVDataset)))

        test_acc = test(audio_model, video_model, linear_cls, loader=test_ADataloader)
        scheduler.step()


