from datetime import datetime
import argparse
import os

import torch
from torchvision import models
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import *

from datasets import *
from scratch_models import VGG16


def train(gpu, args):
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=args.world_size,
        rank=rank,
    )

    torch.manual_seed(0)

    model = models.resnet34()
    torch.cuda.set_device(gpu)
    torch.cuda.current_device()
    model.cuda(gpu)
    batch_size = args.batch_size 
    criterion = nn.CrossEntropyLoss().cuda(gpu)
    optimizer = torch.optim.Adam(model.parameters(), 1e-4)

    model = nn.parallel.DistributedDataParallel(model,
                                                device_ids=[gpu])

    # Data loading
    if args.imagenette:
        dataset_name = 'imagenette'
        if args.gpus > 1:
            train_dataset = Imagenette(image_dir='/data/strategicfreeze/imagenette/train', pre_shuffle=True)
        else:
            train_dataset = Imagenette(image_dir='/data/strategicfreeze/imagenette/train')

        validation_dataset = Imagenette(image_dir='/data/strategicfreeze/imagenette/val')
    else:
        raise NotImplementedError('Please select a dataset flag.')
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=args.world_size,
        rank=rank
    )
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        shuffle=False,
        batch_size=batch_size,
        num_workers=args.num_workers,
        pin_memory=True,
        sampler=train_sampler
    )
    validation_dataloader = torch.utils.data.DataLoader(
                                                        dataset=validation_dataset,
                                                        shuffle=False,
                                                        batch_size=batch_size,
                                                        num_workers=args.num_workers,
                                                        pin_memory=True
    )
     

    for epoch in range(args.epochs):
        # Training start
        model.train()
        total = 0
        correct = 0
        training_loss = 0
        for i, (images, labels) in enumerate(train_loader):
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()

            # Loss
            loss.backward()
            optimizer.step()
            training_loss += loss.item() * labels.size(0)

            # Accuracy
            predicted = outputs.argmax(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        # Training end

        # Validation start
        model.eval()
        validation_loss = 0
        validation_total = 0
        validation_correct = 0
        with torch.no_grad():
            for i, (images, labels) in enumerate(validation_dataloader):
                images = images.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                # Loss
                validation_loss += loss.item() * labels.size(0)

                # Accuracy
                predicted = outputs.argmax(1)
                validation_total += labels.size(0)
                validation_correct += predicted.eq(labels).sum().item()

                if i == (len(validation_dataloader) - 1):
                    mean_loss = training_loss / total
                    accuracy = 100 * correct / total
                    validation_mean_loss = validation_loss / validation_total 
                    validation_accuracy = 100 * validation_correct / validation_total 
                    if rank == 0:
                        print(f"Epoch [{epoch + 1}/{args.epochs}], Training loss: {mean_loss:.4f}, Training accuracy: {accuracy:.4f}, Validation loss: {validation_mean_loss:.4f}, Validation accuracy: {validation_accuracy:.4f}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')
    parser.add_argument('-nr', '--nr', default=0, type=int,
                        help='ranking within the nodes')
    parser.add_argument('--epochs', default=0, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--imagenette', action = 'store_true', help = 'Imagenette Dataset')
    parser.add_argument('--num_workers', required = False, default = 1, type = int, help = 'Number of workers for dataloader')
    parser.add_argument('--batch_size', required = False, default = 32, type = int, help = 'Batch size of data')
    args = parser.parse_args()

    args.world_size = args.gpus * args.nodes
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12353'
    mp.spawn(train, nprocs=args.gpus, args=(args,))
