import argparse
import numpy as np
import torchvision.datasets as dset

from click import argument
import wandb
import os
import random
import shutil
import time
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision
import copy


from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.models.utils import deactivate_requires_grad

from byol import BYOL, BYOL_IN100
from simsiam import SimSiam
from imagenette_dataset import ImageNetteDataset, ImageNetDataset, ImageNetDatasetDepth

from utils import accuracy, AverageMeter, ProgressMeter, Summary
from eval import generate_embeddings, knn_predict

# Commenting out Wandb project name and entity name to preserve anonymity.
PROJECT_NAME=""
ENTITY_NAME=""
def argument_parser():
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('-a', '--arch', default="resnet18", type=str,
                        help='Architecture of backbone')
    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')
    parser.add_argument('--epochs', default=100, type=int, metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b', '--batch-size', default=256, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                        'batch size of all GPUs on the current node when '
                        'using Data Parallel or Distributed Data Parallel')
    parser.add_argument('--lr', '--learning-rate', default=0.5, type=float,
                        metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                        help='momentum')
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                        metavar='W', help='weight decay (default: 1e-4)',
                        dest='weight_decay')
    parser.add_argument('-p', '--print-freq', default=100, type=int,
                        metavar='N', help='print frequency (default: 10)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                        help='use pre-trained model')
    parser.add_argument('--world-size', default=-1, type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=-1, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str,
                        help='distributed backend')
    parser.add_argument('--seed', default=1, type=int,
                        help='seed for initializing training. ')
    parser.add_argument('--gpu', default=0, type=int,
                        help='GPU id to use.')
    parser.add_argument('--multiprocessing-distributed', action='store_true',
                        help='Use multi-processing distributed training to launch '
                        'N processes per node, which has N GPUs. This is the '
                        'fastest way to use PyTorch for either single node or '
                        'multi node data parallel training')
    parser.add_argument('--root_path', default='../../imagenette2-160', type=str,
                        help='Name of checkpoint')
    parser.add_argument('--aug_root_path', default=None, type=str,
                        help='Name of checkpoint')
    parser.add_argument('--checkpoint_path', default=None, type=str,
                        help='Name of checkpoint')

    parser.add_argument('--drop_depth', default=0.0, type=float,
                        help='To Normalize Input')
    parser.add_argument('--use_log_depth', action='store_true',
                        help='To Normalize Input')
    parser.add_argument('--resize_depth', action='store_true',
                        help='Resize depth image to 160')
    parser.add_argument('--depth_0_255', action='store_true',
                        help='Resize depth image to 160')
    parser.add_argument('--use_pfm', action='store_true',
                        help='Resize depth image to 160')
    parser.add_argument('--renormalize_depth', action='store_true',
                        help='Resize depth image to 160')

    parser.add_argument('--depth_path', default="../../DPT/imagenet-100-depth-map", type=str,
                        help='Path of depth map')
    parser.add_argument('--dataset', default='imagenette', type=str,
                        help='Dataset to be used')
    parser.add_argument('--num_classes', default=10, type=int,
                        help='To Normalize Input')
    parser.add_argument('--simclr_transforms', action='store_true',
                        help='To Apply SimCLR transforms or not')
    parser.add_argument('--val_transforms', action='store_true',
                        help='To Apply SimCLR transforms or not')
    parser.add_argument('--image_size', default=224, type=int,
                        help='To Normalize Input')
    parser.add_argument('--exp_str', default="0", type=str,
                        help='To Normalize Input')
    parser.add_argument('--method', default='byol', type=str,
                        help='Dataset to be used')
    parser.add_argument('--finetune', action='store_true',
                        help='finetune')
    parser.add_argument('--use_amp', action='store_true',
                    help='Views are ordered.')
    parser.add_argument('--zero_depth', action='store_true',
                        help='finetune')
    parser.add_argument('--log_results', default=True, type=bool,
                        help='Log results')
    args = parser.parse_args()
    return args

class Classifier(torch.nn.Module):
    def __init__(self, backbone, args):
        super().__init__()
        # use the pretrained ResNet backbone
        self.backbone = backbone

        # freeze the backbone
        if not args.finetune:
            print("Freezing backbone")
            deactivate_requires_grad(backbone)

        # create a linear layer for our downstream classification model
        if args.arch=="resnet18":
            self.fc = nn.Linear(512, args.num_classes)
        elif args.arch=="resnet50":
            self.fc = nn.Linear(2048, args.num_classes)

    def forward(self, x):
        y_hat = self.backbone(x).flatten(start_dim=1)
        y_hat = self.fc(y_hat)
        return y_hat


def main():
    args = argument_parser()
    if args.dataset == "imagenette":
        args.num_classes = 10
    if args.dataset == "imagenet-100":
        args.num_classes = 100
        args.image_size = 224

    args.eval_method = "linear" if not args.finetune else "finetune"
    head, tail = os.path.split(args.checkpoint_path)
    epoch_num = os.path.splitext(tail)[0].rstrip(".pth").lstrip("model_")
    run_name = os.path.basename(head)

    args.store_name = '_'.join([run_name, str(epoch_num), args.eval_method, str(
        args.epochs), str(args.batch_size), str(args.lr), str(args.seed), args.exp_str])
    print("Store Name:", args.store_name)
    if args.log_results:
        wandb.init(project=PROJECT_NAME, entity=ENTITY_NAME,
                   name=args.store_name)
        wandb.config.update(args)
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    device = "cuda" if torch.cuda.is_available() else "cpu"
    if args.arch=='resnet18':
        resnet = torchvision.models.resnet18()
    elif args.arch=="resnet50":
        resnet = torchvision.models.resnet50()
    
    if "depth" in args.checkpoint_path:
        args.depth_eval = True
        resnet.conv1 =  nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
    backbone = nn.Sequential(*list(resnet.children())[:-1], 
            nn.AdaptiveAvgPool2d(1))
    if args.method == "byol":
        if args.arch=="resnet18":
            old_model = BYOL_IN100(backbone).to(device)
        elif args.arch=="resnet50":
            old_model = BYOL_IN100_R50(backbone).to(device)
    elif args.method=="simsiam":
        old_model = SimSiam(backbone).to(device)
    
    ckpt = torch.load(args.checkpoint_path)
    old_model.load_state_dict(ckpt['state_dict'])
    print("Loaded weights from {}".format(args.checkpoint_path))
    print(f"kNN Acc with that model :", ckpt['best_acc1'], "at epoch:", ckpt['epoch'])

    model = Classifier(old_model.backbone, args).to(device)
    params = model.parameters() if args.finetune else model.fc.parameters()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    val_transforms = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])
    print("Standard train transforms")
    train_transforms = transforms.Compose([
               transforms.RandomResizedCrop(224),
               transforms.RandomHorizontalFlip(),
               transforms.ToTensor(),
               normalize,
            ])
 
    train_path = os.path.join(args.root_path, "train")
    val_path = os.path.join(args.root_path, "val")


    depth_train_path = os.path.join(args.depth_path, "train")
    depth_val_path = os.path.join(args.depth_path, "val")
    
    if "depth" in args.checkpoint_path:
        depth = True
        args.depth_eval = True
        train_dataset = ImageNetDatasetDepth(
            train_path, num_views=0, depth_path = depth_train_path, transform=train_transforms, train=True, args=args)

        val_dataset = ImageNetDatasetDepth(
           val_path,  depth_path = depth_val_path, transform=val_transforms, train=False, args=args)
    else:
        depth = False
        args.depth_eval = False
     
        train_dataset = ImageNetDataset(
             train_path, transform=train_transforms)
        val_dataset = ImageNetDataset(
             val_path, transform=val_transforms)

    print(f"Number of examples in training dataset: {len(train_dataset)}")
    print(f"Number of examples in validation dataset: {len(val_dataset)}")
    # print(train_dataset[0])

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, args.epochs)
    if not os.path.isdir(f"runs/{args.store_name}"):
        os.makedirs(f"runs/{args.store_name}", exist_ok=True)

    save_path = f"runs/{args.store_name}"
    best_acc1 = 0
    zero_depth = False
    if args.depth_eval:
        if args.zero_depth:
            zero_depth = True
    print("Zero Depth: ", zero_depth)

    scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
    for epoch in range(args.epochs):

        # model.train()
        # train for one epoch
        train(train_loader, model, criterion, optimizer, scaler, epoch, args, zero_depth)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args, zero_depth)
        print(f"Epoch {epoch} Val Acc: {acc1}")
        if args.log_results:
            wandb.log({'epoch': epoch, 'acc': acc1})

        scheduler.step()

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if (epoch+1) % 50 == 0 or is_best:
            save_checkpoint({
                'epoch': epoch + 1,
                # 'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }, is_best, filename=f"{save_path}/{str(epoch+1)}.pth.tar")
    print("Best Val Acc: ", best_acc1)
    if args.log_results:
        wandb.log({"best_acc": best_acc1})
    model.load_state_dict(torch.load(f"{save_path}/model_best.pth.tar")['state_dict'])
    print("Loaded Best Model")
    model.eval()
    if args.depth_eval:
        acc_without_depth = validate(val_loader, model, criterion, args, zero_depth=True)
        acc_with_depth = validate(val_loader, model, criterion, args, zero_depth=False)
        print("Accc without depth: ",acc_without_depth)
        print("Accc with depth: ",acc_with_depth)

    #return
    # Evaluation on ImageNet-C best model.
    distortions = [
    'gaussian_noise', 'shot_noise', 'impulse_noise',
    'defocus_blur', 'glass_blur', 'motion_blur', 'zoom_blur',
    'snow', 'frost', 'fog', 'brightness',
    'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression',
    'speckle_noise', 'gaussian_blur', 'spatter', 'saturate'
    ]
    base_folder_inc = "../../imagenet-100-c"
    in_c_accs = calculate_corruption_acc(distortions, model, base_folder_inc, args)

    print(in_c_accs)
    if args.log_results:
        for distortion_name, acc in zip(distortions, in_c_accs):
            wandb.log({distortion_name:acc})
        wandb.log({'avg_corruption_acc':np.mean(in_c_accs)})
    base_folder_3dcc = "../../imagenet-100-3dcc"
    distortions_3d = sorted(os.listdir(base_folder_3dcc))
    in_3dcc_accs = calculate_corruption_acc(distortions_3d, model, base_folder_3dcc, args)
    print(in_3dcc_accs)
    if args.log_results:
        for distortion_name, acc in zip(distortions_3d, in_3dcc_accs):
            wandb.log({distortion_name+"3d":acc})
        wandb.log({'avg_corruption3dcc_acc':np.mean(in_3dcc_accs)})
    if args.depth_eval:
        depth_path_2d = "../../DPT/imagenet-100-c-depth-map" 
        depth_path_3d = "../../DPT/imagenet-100-3dcc-depth-map" 
        in_3dcc_accs_with_depth = calculate_corruption_acc(distortions_3d, model, base_folder_3dcc, args, depth_path_3d)
        in_2d_accs_with_depth = calculate_corruption_acc(distortions, model, base_folder_inc, args, depth_path_2d)
        if args.log_results:
            for distortion_name, acc in zip(distortions_3d, in_3dcc_accs_with_depth):
                wandb.log({distortion_name+"3d"+"depth":acc})
            wandb.log({'avg_corruption3dcc_acc_depth':np.mean(in_3dcc_accs_with_depth)})
            for distortion_name, acc in zip(distortions, in_2d_accs_with_depth):
                wandb.log({distortion_name+"depth":acc})
            wandb.log({'avg_corruption_acc_depth':np.mean(in_2d_accs_with_depth)})

    
def calculate_corruption_acc(distortions, model, base_folder, args, depth_path=None):

    error_rates = []
    accs =[]
    for distortion_name in distortions:
        rate, acc = show_performance(distortion_name, model, base_folder, args, depth_path)
        error_rates.append(rate)
        accs.append(acc)
        print('Distortion: {:15s}  | CE (unnormalized) (%): {:.2f} | Acc {:.2f}'.format(distortion_name, 100 * rate, 100*acc))
    print('mCE (unnormalized by AlexNet errors) (%): {:.2f}'.format(100 * np.mean(error_rates)))
    return accs


def show_performance(distortion_name, model, base_folder, args, depth_path=None):
    errs = []
    accs =[]
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    base_path = os.path.join(base_folder, distortion_name)
    if depth_path:
        base_path_depth = os.path.join(depth_path, distortion_name)
    for severity in range(1, 6):
        if not os.path.isdir(os.path.join(base_path, str(severity))):
            print(f"Skipping {severity}")
            continue
        if depth_path:   
            val_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
            distorted_dataset = ImageNetDatasetDepth(
               os.path.join(base_path, str(severity)),  depth_path =os.path.join(base_path_depth, str(severity)), transform=val_transforms, train=False, args=args) 
        else:
            distorted_dataset = dset.ImageFolder(
                root=os.path.join(base_path, str(severity)),
                transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), 
                    transforms.ToTensor(), transforms.Normalize(mean, std)]))

        distorted_dataset_loader = torch.utils.data.DataLoader(
            distorted_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

        correct = 0
        for batch_idx, batch in enumerate(distorted_dataset_loader):
            if depth_path:
                data, target, _ = batch
            else:
                data, target = batch
            zeros_ = torch.zeros(data.size(0), 1, 224, 224)
            if args.depth_eval:
                if depth_path is None:
                    data = torch.cat((data, zeros_), dim=1)
            data = data.cuda()

            output = model(data)

            pred = output.data.max(1)[1]
            correct += pred.eq(target.cuda()).sum().item()

        errs.append(1 - 1.*correct / len(distorted_dataset))
        accs.append(correct/len(distorted_dataset))

    print('\n=Average', tuple(errs))
    print('\n=Average Accuracy ', tuple(accs))
    return np.mean(errs), np.mean(accs)



def train(train_loader, model, criterion, optimizer, scaler, epoch, args, zero_depth=False):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    losses = AverageMeter('Loss', ':.4e')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    if args.finetune:
        model.train()
    else:
        model.eval()
    total_loss = 0

    end = time.time()
    for i, (images, target, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if zero_depth:
            images[:,3,:,:]= torch.zeros(target.size(0), 224, 224)
        if args.gpu is not None:
            images = images.cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast(args.use_amp):
            output = model(images)
            loss = criterion(output, target)
        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        if args.log_results:
            wandb.log({'loss':loss,'train_acc':acc1})

        # compute gradient and do SGD step
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)
    avg_loss = total_loss / len(train_loader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")


def validate(val_loader, model, criterion, args, zero_depth=False):
    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1],
        prefix='Test: ')

    # switch to evaluate mode

    # model = model.backbone
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        end = time.time()
        for i, (images, target, _) in enumerate(val_loader):

            if zero_depth:
                images[:,3,:,:]= torch.zeros(target.size(0), 224, 224)
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)
            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        progress.display_summary()
           
    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(
            os.path.dirname(filename), 'model_best.pth.tar'))


if __name__ == '__main__':
    main()
