import argparse

from click import argument
import wandb
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision
import copy


from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.loss import NegativeCosineSimilarity
from lightly.models.utils import update_momentum

from simsiam import SimSiam
from byol import BYOL, BYOL_IN100, BYOL_IN100_R50
from imagenette_dataset import ImageNetteDataset, ImageNetDataset

from utils import accuracy, AverageMeter, ProgressMeter, Summary
from eval import generate_embeddings, knn_predict

# Commenting out Wandb project name and entity name to preserve anonymity.
PROJECT_NAME=""
ENTITY_NAME=""

def argument_parser():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('-a', '--arch', default="resnet18", type=str,
                        help='Architecture of backbone')
    parser.add_argument( '--method', default="byol", type=str,
                        help='Architecture of backbone')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=200, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=256, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                        'batch size of all GPUs on the current node when '
                        'using Data Parallel or Distributed Data Parallel')
    parser.add_argument('--lr', '--learning-rate', default=0.06, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('-p', '--print-freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')
    parser.add_argument('--world-size', default=-1, type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str,
                        help='distributed backend')
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--gpu', default=0, type=int,
                        help='GPU id to use.')
    parser.add_argument('--multiprocessing-distributed', action='store_true',
                        help='Use multi-processing distributed training to launch '
                        'N processes per node, which has N GPUs. This is the '
                        'fastest way to use PyTorch for either single node or '
                        'multi node data parallel training')
    parser.add_argument('--root_path', default='../../imagenette2-160', type=str,
                        help='Name of checkpoint')
    parser.add_argument('--aug_root_path', default=None, type=str,
                        help='Name of checkpoint')
    parser.add_argument('--dataset', default='imagenette', type=str,
                        help='Dataset to be used')
    parser.add_argument('--simclr_transforms', action='store_true',
                        help='To Apply SimCLR transforms or not')
    parser.add_argument('--randomize', action='store_true',
                        help='Randomize the input to online and target networks')
    parser.add_argument('--ordered', action='store_true',
                        help='Views are ordered.')
    parser.add_argument('--val_transforms', action='store_true',
                        help='To Apply SimCLR transforms or not')

    parser.add_argument('--std_transforms', action='store_true',
                        help='To Apply SimCLR transforms or not')
    parser.add_argument('--image_size', default=128, type=int,
                        help='To Normalize Input')
    parser.add_argument('--num_views', default=2, type=int,
                        help='Number of views')
    parser.add_argument('--adampi_prob', default=1.0, type=float,
                        help='Number of views')
    parser.add_argument('--exp_str', default="0", type=str,
                        help='To Normalize Input')
    parser.add_argument('--use_amp', action='store_true',
                    help='Views are ordered.')

    parser.add_argument('--log_results', default=True, type=bool,
                        help='Log results')
    args = parser.parse_args()
    return args

def slurm_infos():
    return {
        'slurm/job_id': os.getenv('SLURM_JOB_ID'),
        'slurm/job_user': os.getenv('SLURM_JOB_USER'),
        'slurm/job_partition': os.getenv('SLURM_JOB_PARTITION'),
        'slurm/cpus_per_node': os.getenv('SLURM_JOB_CPUS_PER_NODE'),
        'slurm/num_nodes': os.getenv('SLURM_JOB_NUM_NODES'),
        'slurm/nodelist': os.getenv('SLURM_JOB_NODELIST'),
        'slurm/cluster_name': os.getenv('SLURM_CLUSTER_NAME'),
        'slurm/array_task_id': os.getenv('SLURM_ARRAY_TASK_ID')
    }

def main():
    args = argument_parser()
    if args.dataset == "imagenette":
        args.num_classes = 10
    if args.dataset == "imagenet-100":
        args.num_classes = 100
        args.image_size = 224
    simclr = "simclr" if args.simclr_transforms else "std"
    if args.val_transforms:
        simclr = "val"
    if args.aug_root_path is not None:
        if "adampi" in args.aug_root_path:
            aug_root_folder = "adampi"
            #imagenette2-160-full-adampi_50_0.4_0.2_0.1
            aug_root_angle = "_".join(os.path.basename(args.aug_root_path).split("_")[1:])
            aug_root_name = aug_root_folder + aug_root_angle
        else:
            aug_root_angle, aug_root_folder = os.path.split(args.aug_root_path)
            aug_root_name = aug_root_folder + aug_root_angle[-4:]
        print(aug_root_folder, aug_root_angle)
    else:
        aug_root_name = "None"
    if args.randomize:
        args.exp_str = "randomize_"+args.exp_str
    if args.ordered:
        args.exp_str = "ordered_"+args.exp_str
    if args.adampi_prob!=1.0:
        args.exp_str = str(args.adampi_prob) + args.exp_str
    args.store_name = '_'.join([args.dataset, str(args.arch) ,args.method, "views", str(args.num_views), simclr, aug_root_name, str(
        args.epochs), str(args.batch_size), str(args.lr), str(args.seed), args.exp_str])
    
    if os.path.isdir(f"runs/{args.store_name}"):
        print("Folder already exists")
        if os.path.exists(f"runs/{args.store_name}/latest_model.pth.tar"):
            args.resume = f"runs/{args.store_name}/latest_model.pth.tar"
        args.store_name+="pre-1"
    print("Store Name:", args.store_name)
    if args.log_results:
        wandb.init(project=PROJECT_NAME, entity=ENTITY_NAME,
                   name=args.store_name)
        wandb.config.update(args)
        wandb.run.summary.update(slurm_infos())
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if args.dataset == "imagenet-100" and args.image_size!=224:
        print("Training ImageNet-100 without 224!!! Cross-Verify")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if args.arch=='resnet18':
        resnet = torchvision.models.resnet18()
    elif args.arch=="resnet50":
        resnet = torchvision.models.resnet50()
    
    backbone = nn.Sequential(*list(resnet.children())[:-1], 
            nn.AdaptiveAvgPool2d(1))
    if args.method == "byol":
        if args.arch=="resnet18":
            model = BYOL_IN100(backbone).to(device)
        elif args.arch=="resnet50":
            model = BYOL_IN100_R50(backbone).to(device)
        params = list(model.backbone.parameters()) \
            + list(model.projection_head.parameters()) \
            + list(model.prediction_head.parameters())
        criterion = NegativeCosineSimilarity()
        train = train_byol
    elif args.method=="simsiam":
        model = SimSiam(backbone).to(device)
        params = model.parameters()
        criterion = NegativeCosineSimilarity()
        train = train_simsiam
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    
    if args.dataset=="imagenet-100":
        val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        normalize,
    ])
    elif args.dataset=="imagenette":
        val_transforms = transforms.Compose([
        transforms.Resize(128),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        normalize,
    ])
    if args.simclr_transforms:
        print("Using SimCLR transforms")
        #train_transforms = get_simclr_data_transforms()
        train_transforms= SimCLRCollateFunction(input_size=args.image_size).transform
    else:
        if args.val_transforms:
            print("Using Val Transforms")
            train_transforms = val_transforms
        else:
            print("Standard train transforms")
            train_transforms = transforms.Compose([
               transforms.RandomResizedCrop(args.image_size),
               transforms.RandomHorizontalFlip(),
               transforms.ToTensor(),
               normalize,
            ])
 
    train_path = os.path.join(args.root_path, "train")
    val_path = os.path.join(args.root_path, "val")
    
    if args.dataset == "imagenette":
        train_dataset = ImageNetteDataset(
            train_path, num_views=args.num_views, aug_root_path=args.aug_root_path, transform=train_transforms, randomize = args.randomize, ordered = args.ordered)
        std_train_dataset = ImageNetteDataset(
            train_path, transform=val_transforms)
        val_dataset = ImageNetteDataset(
            val_path, transform=val_transforms)
    elif args.dataset == "imagenet-100":
        # Note transform = none uses the solo-learn transforms
        train_dataset = ImageNetDataset(train_path, aug_root_path=args.aug_root_path, num_views = args.num_views, transform=None, args=args)
        std_train_dataset = ImageNetDataset(train_path, transform=val_transforms, args=args)
        val_dataset = ImageNetDataset(val_path, transform=val_transforms,args=args)

    print(f"Number of examples in training dataset: {len(train_dataset)}")
    print(f"Number of examples in validation dataset: {len(val_dataset)}")
    # print(train_dataset[0])

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    std_train_loader = torch.utils.data.DataLoader(
        std_train_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    args.knn_k = 200
    args.knn_t = 0.1
    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            scaler.load_state_dict(checkpoint['scaler'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
    if not os.path.isdir(f"runs/{args.store_name}"):
        os.makedirs(f"runs/{args.store_name}", exist_ok=True)

    save_path = f"runs/{args.store_name}"
    best_acc1 = 0
 
    for epoch in range(args.start_epoch, args.epochs):
        print(f"Epoch : {epoch}")
        if args.log_results:
            wandb.log({'epoch':epoch})

        # model.train()
        # train for one epoch
        train(train_loader, model, criterion, optimizer, scaler, epoch, args)

        scheduler.step()
        state = {
                'epoch': epoch + 1,
                # 'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'scaler': scaler.state_dict()
                }
        torch.save(state, f"{save_path}/latest_model.pth.tar")

        # evaluate on validation set
        if (epoch+1)%20==0 or (args.epochs-epoch)<50:
            acc1 = validate(val_loader, std_train_loader, model, criterion, args)
            print(f"Epoch {epoch} kNN Acc: {acc1}")
            if args.log_results:
                wandb.log({'epoch': epoch, 'acc': acc1})


        # remember best acc@1 and save checkpoint
            is_best = acc1 > best_acc1
            best_acc1 = max(acc1, best_acc1)

            if (epoch+1) % 50 == 0 or is_best:
                save_checkpoint({
                'epoch': epoch + 1,
                # 'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'scaler': scaler.state_dict()
            }, is_best, filename=f"{save_path}/{str(epoch+1)}.pth.tar")
    print("Best kNN Acc: ", best_acc1)
    if args.log_results:
        wandb.log({"best_acc": best_acc1})


def train_byol(train_loader, model, criterion, optimizer, scaler, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    total_loss = 0

    end = time.time()
    for i, (images, target, _) in enumerate(train_loader):
        # measure data loading time
        x0 = images[0]
        x1 = images[1]
        if args.gpu is not None:
            x0 = x0.cuda(args.gpu, non_blocking=True)
            x1 = x1.cuda(args.gpu, non_blocking=True)
        data_time.update(time.time() - end)
        #for view in range(1, args.num_views+1):
        optimizer.zero_grad()
        update_momentum(model.backbone, model.backbone_momentum, m=0.99)
        update_momentum(model.projection_head,
                            model.projection_head_momentum, m=0.99)

            # compute output
        with torch.cuda.amp.autocast(args.use_amp):
            p0 = model(x0)
            z0 = model.forward_momentum(x0)
            p1 = model(x1)
            z1 = model.forward_momentum(x1)
            loss = 0.5 * (criterion(p0, z1) + criterion(p1, z0))
        total_loss += loss.detach()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

            # measure accuracy and record loss
        losses.update(loss.item(), x0.size(0))

        if args.log_results and i% args.print_freq==0:
            wandb.log({'loss': loss})

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
    avg_loss = total_loss / len(train_loader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")


def train_simsiam(train_loader, model, criterion, optimizer, scaler, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    total_loss = 0

    end = time.time()
    for i, (images, target, _) in enumerate(train_loader):
        # measure data loading time
        optimizer.zero_grad()
        x0 = images[0]
        x1 = images[1]
        if args.gpu is not None:
            x0 = x0.cuda(args.gpu, non_blocking=True)
            x1 = x1.cuda(args.gpu, non_blocking=True)
        data_time.update(time.time() - end)
        with torch.cuda.amp.autocast(args.use_amp):
            z0, p0 = model(x0)
            z1, p1 = model(x1)
            loss = 0.5 * (criterion(z0, p1) + criterion(z1, p0))
        total_loss += loss.detach()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # measure accuracy and record loss
        losses.update(loss.item(), x0.size(0))

        if args.log_results and i% args.print_freq==0:
            wandb.log({'loss': loss})

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
    avg_loss = total_loss / len(train_loader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")


def validate(val_loader, train_loader, model, criterion, args):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1],
        prefix='Test: ')

    # switch to evaluate mode

    # model = model.backbone
    model.eval()
    embeddings, _, feature_targets = generate_embeddings(
        model, train_loader)
    correct = 0
    total = 0

    with torch.no_grad():
        for i, (images, target, _) in enumerate(val_loader):

            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
            # compute output
            feature = model.backbone(images).squeeze()
            feature = F.normalize(feature, dim=1).cpu()
            pred_labels = knn_predict(
                feature, embeddings, feature_targets, args.num_classes, args.knn_k, args.knn_t)
            #print(pred_labels, target)
            correct += ((pred_labels[:, 0] == target).float().sum().item())
            total += images.size(0)
            # print(acc)

            # measure accuracy and record loss
            # top1.update(, images.size(0))

            # measure elapsed time
            # batch_time.update(time.time() - end)
            # end = time.time()

            # if i % args.print_freq == 0:
            #    progress.display(i)

        # progress.display_summary()

    return (correct/total)*100

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    if is_best:
        torch.save(state, filename)
        torch.save(state, os.path.join(os.path.dirname(filename), 'model_best.pth.tar'))
        #shutil.copyfile(filename, os.path.join(
        #    os.path.dirname(filename), 'model_best.pth.tar'))
    else:
        torch.save(state, filename)


if __name__ == '__main__':
    main()
