import argparse

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import DataSet
from models import model


def test(fusion_model, loader):
    fusion_model.eval()
    running_acc = 0
    count = 0
    for i, data in enumerate(loader):
        count += 1
        video, audio, label = data['video'].cuda(), data['audio'].cuda(), data['label'].long().cuda()
        _, _, y_hat = fusion_model(audio, video)
        running_acc += (y_hat.argmax(1) == label).float().item()
    print("ACC:{}".format(running_acc / count))
    return running_acc /count

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--nb-frames', type=int, default=32, help='frames of each video')
    parser.add_argument('--train_dataset', type=str, default='Train/')
    parser.add_argument('--test_dataset', type=str, default='Test/')

    parser.add_argument('--batch-size', type=int, default=24)
    parser.add_argument('--epoch', type=int, default=400)
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate, defaults to 1e-3')
    parser.add_argument('--lam', type=float, default=50)

    parser.add_argument('--nb-workers', type=int, default=16, help='Number of workers for dataloader.')
    parser.add_argument('--nb-class', type=int, default=309, help='Number of class for dataset.')
    args = parser.parse_args()

    fusion_model = model.FusionModel(args.nb_class).cuda()
    audio_model = model.AudioModel(args.nb_class).cuda()
    video_model = model.VideoModel(args.nb_class).cuda()
    audio_model.load_state_dict(torch.load("audio.pt"))
    video_model.load_state_dict(torch.load("video.pt"))

    audio_model.eval()
    video_model.eval()

    AVDataset = DataSet.AudioVideoDataset(args.train_dataset, num_frames=args.nb_frames)

    AVDataloader = DataLoader(AVDataset, batch_size=args.batch_size, num_workers=args.nb_workers, shuffle=True, drop_last=False, pin_memory=True)
    print("DataSet:{}, DataLoader:{}".format(len(AVDataset), len(AVDataloader)))
    print("lambda:{}".format(args.lam))
    test_ADataset = DataSet.AudioVideoTestDataset(args.test_dataset, num_frames=args.nb_frames)
    test_ADataloader = DataLoader(test_ADataset, batch_size=1, num_workers=args.nb_workers, shuffle=True,
                                  drop_last=False, pin_memory=True)

    optimizer = torch.optim.Adam(fusion_model.parameters(), lr=args.lr)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

    ce_loss = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()

    lam = args.lam

    for epoch in range(args.epoch):
        if epoch == 0:
            test(fusion_model, test_ADataloader)
        fusion_model.train()
        running_ce_loss = 0.
        running_mse_loss = 0.

        count = 0
        for i, data in enumerate(AVDataloader):
            video, audio, label = data['video'].cuda(), data['audio'].cuda(), data['label'].long().cuda()
            with torch.no_grad():
                v_f_target, v_hat = video_model(video)
                a_f_target, a_hat = audio_model(audio)

            a_f, v_f, y_hat = fusion_model(audio, video)
            ce = ce_loss(y_hat, label)
            mse = lam * (mse_loss(a_f, a_f_target) + mse_loss(v_f, v_f_target))
            # mse = lam * (mse_loss(v_f, v_f_target))
            # mse = lam * (mse_loss(a_f, a_f_target))
            loss = ce + mse
            running_ce_loss += ce.item()
            running_mse_loss += mse.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            count += 1

        print("full distill, step_size:{},lam:{}, epoch:{}, ce_loss:{}, mse_loss:{}".format(
            5, lam, epoch, running_ce_loss/count, running_mse_loss/count)
        )

        scheduler.step()
        test_acc = test(fusion_model, loader=test_ADataloader)
