import torch
import torchvision
import numpy
import time

import tqdm
from apex import amp
from einops import einops
from torch import nn
# import matplotlib
# import matplotlib.pyplot as plt
from torch.autograd import Variable

from models.Baseline import MelSpectrogramResNet, preprocess
from models.R2plus1D import R2Plus1DClassifier
# from transformer_resnet import ResNet20
from dataset import read_vggsound_a
from TinyImageNet import load_tinyimagenet
from logger import MetricsLogger
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'


def main():
    batch_size = 64
    n_epochs = 300
    Lr = 0.001
    momentum = 0.9
    weight_delay = 1e-4
    num_classes = 309

    data_loader_train, data_loader_test = read_vggsound_a(batch_size)

    print(torch.cuda.is_available())

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger = MetricsLogger(log_file="logs/ResNet_Baseline_B.log")
    model = R2Plus1DClassifier(num_classes=num_classes, channel=1, layer_sizes=(2, 2, 2, 2)).to(device)
    cost = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(), Lr,
                                momentum=momentum,
                                weight_decay=weight_delay)

    model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1000, factor=0.1)

    print(model)
    iter = 0
    model.train()
    since = time.time()
    for epoch in range(n_epochs):
        torch.save(model.state_dict(),
                   r'Path' + str(epoch) + '.ckpt')
        training_loss = 0.0
        training_acc = 0.0
        print("Epoch {}/{}".format(epoch + 1, n_epochs))
        total_train = 0
        for i, data in tqdm.tqdm(enumerate(data_loader_train)):
            iter = iter + 1
            x, labels = data
            x, labels = x.to(device), labels.to(device)
            x = preprocess(x)
            x = einops.rearrange(x, '(C B) T H W -> B C T H W', C=1)
            outputs = model(x)
            loss = cost(outputs, labels)
            training_loss += loss.item()
            _, pred = torch.max(outputs, 1)
            total_train += labels.size(0)
            num_correct = (pred == labels).sum()
            training_acc += num_correct.item()

            if iter % 100 == 0:
                train_acc = 100 * training_acc / total_train
                test_acc = eval(model, data_loader_test, device)
                model.train()
                logger.log_metrics(train_acc, test_acc, iter, training_loss / 100, epoch + 1)
                training_loss = 0.0
            optimizer.zero_grad()
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            optimizer.step()

        scheduler.step(training_acc)

    time_used = time.time() - since
    logger.log_direct('Time: {:.0f}m {:.0f}s'.format(time_used // 60, time_used % 60))
    logger.close() 
    del logger 


def eval(model, data_loader_test, device):
    model.eval()
    testing_correct = 0
    total = 0
    with torch.no_grad():
        for data in data_loader_test:
            x_test, label_test = data
            x_test, label_test = x_test.to(device), label_test.to(device)
            x_test = preprocess(x_test)
            x_test = einops.rearrange(x_test, '(C B) T H W -> B C T H W', C=1)
            outputs = model(x_test)
            _, pred = torch.max(outputs.data, 1)
            # print(pred,label_test)
            total += label_test.size(0)
            testing_correct += (pred == label_test).sum().item()
    return 100 * testing_correct / total


if __name__ == '__main__':
    main()
