import sys
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import data
import functools
import PIL
from torchvision import transforms

from classifier import Classifier
from logger import Logger

from time import time


device = "cuda:0"



class video_transforms(object):
    def __init__(self, augmentation=False):
        self.augmentation = augmentation
    
    def __call__(self, video):
        vid = []
        if self.augmentation:
            i, j, h, w = transforms.RandomSizedCrop.get_params(PIL.Image.fromarray(video[0]), scale=(0.8, 1.0), ratio=(3. / 4., 4. / 3.))
        for im in video:
            im = PIL.Image.fromarray(im)
            if self.augmentation:
                im = transforms.functional.resized_crop(im, i, j, h, w, im.size)  # Augmentation
            im = transforms.functional.to_tensor(im)
            im = transforms.functional.normalize(im, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            vid.append(im)

        vid = torch.stack(vid).permute(1, 0, 2, 3)

        return vid


def load_data(data_folder, video_batch, video_length=10, shuffle=True, augmentation=True):
    dataset = data.VideoFolderDataset(data_folder, cache=os.path.join(data_folder, 'local.db'))
    video_dataset = data.VideoDataset(dataset, video_length, 2, video_transforms(augmentation=augmentation))
    video_loader = DataLoader(video_dataset, batch_size=video_batch, drop_last=True, num_workers=2, shuffle=shuffle)
    return video_loader


def train(logger):
    num_epochs = 500
    batch_size = 32
    batch_size_test = 30
    learning_rate = 1e-4
    weight_decay = 1e-5

    # Get mode
    net = Classifier(n_channels=3, num_class=6)
    net = net.to(device)

    # Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) 
    
    # Data
    logger.log_string("Loading dataset...")
    train_loader = load_data('mug64_split/train', video_batch=batch_size, video_length=10, shuffle=True, augmentation=True)
    test_loader = load_data('mug64_split/test', video_batch=batch_size_test, video_length=10, shuffle=False, augmentation=False)
    logger.log_string('Train samples %d'  % len(train_loader))
    logger.log_string('Test samples %d' % len(test_loader))
    
    for epoch in range(1, num_epochs + 1):
        tic = time()
        net.train()
        correct = 0
        total = 0
        running_loss = 0.0
        num_iters = 0

        # Train
        # scheduler.step()  # Adjust learning rate
        # logger.log_scalar_train('Learning rate', scheduler.get_lr()[0], epoch)
        for data in train_loader:
            inputs, labels = data['images'], data['categories']
            inputs, labels = inputs.to(device), labels.to(device)
            # print(inputs.shape, torch.min(inputs), torch.max(inputs))
            # print(labels)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            num_iters = num_iters + 1

        # Eval and metrics
        acc_train = correct / total
        loss_train = running_loss / num_iters
        acc_eval, loss_eval = evaluate(net, test_loader, criterion)
        logger.log_string('Epoch %d: train loss: %.3f, eval loss: %.3f, train acc: %.3f, eval acc: %.3f, time: %.3f' % (epoch, 
              loss_train, loss_eval, acc_train, acc_eval, time() - tic))
        logger.log_scalar_train('Loss', loss_train, epoch)
        logger.log_scalar_train('Accuracy', acc_train, epoch)
        logger.log_scalar_eval('Loss', loss_eval, epoch)
        logger.log_scalar_eval('Accuracy', acc_eval, epoch)

        # Save trained model
        torch.save(net.state_dict(), os.path.join(logger.logdir, 'model.pkl'))
    
    logger.log_string('Finished Training')
    logger.close()
    

def evaluate(net, dataloader, criterion):
    num_iters = 0
    correct = 0
    total = 0
    running_loss = 0.0
    net.eval()
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data['images'], data['categories']
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            num_iters += 1
    return correct / total, running_loss / num_iters


if __name__ == "__main__":
    if len(sys.argv) < 2:
        logdir = 'logs/default'
    else:
        logdir = sys.argv[1]
    logger = Logger(logdir)
    train(logger)