from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time
import numpy as np
import pandas as pd
import torch.multiprocessing as mp
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.distributed as dist
import random
from torchvision import models, datasets, transforms
from util import Logger, LARS, LossManager, Pack
import config
import util
import torchvision.models as torchvision_models

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def adjust_learning_rate(optimizer, init_lr, epoch, args):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.linear_epochs))
    for param_group in optimizer.param_groups:
        param_group['lr'] = cur_lr

def main_worker(args):
    args.batch_size = 1024
    torch.cuda.set_device(args.local_rank)
    args.world_size = int(os.environ["WORLD_SIZE"])
    torch.distributed.init_process_group(backend='nccl')

    model = torchvision_models.__dict__["resnet50"]()
    for name, param in model.named_parameters():
        if name not in ['%s.weight' % 'fc', '%s.bias' % 'fc']:
            param.requires_grad = False

    getattr(model, 'fc').weight.data.normal_(mean=0.0, std=0.01)
    getattr(model, 'fc').bias.data.zero_()

    checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint-'+args.save_name_pre+'.pth.tar')
    print("checkpoint_path:", checkpoint_path)
    state_dict = torch.load(checkpoint_path, map_location="cpu")['state_dict']

    for k in list(state_dict.keys()):
        # retain only base_encoder up to before the embedding layer
        if k.startswith('module.online_encoder') and not k.startswith('module.online_encoder.%s' % 'fc'):
            # remove prefix
            state_dict[k[len("module.online_encoder."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]

    msg = model.load_state_dict(state_dict, strict=False)
    assert set(msg.missing_keys) == {"%s.weight" % 'fc', "%s.bias" % 'fc'}

    init_lr = args.linear_lr * args.batch_size / 256

    model.cuda(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
    criterion = nn.CrossEntropyLoss().cuda(args.local_rank)

    parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    assert len(parameters) == 2  # weight, bias
    optimizer = torch.optim.SGD(parameters, init_lr, momentum=args.momentum, weight_decay=0)

    assert args.batch_size % args.world_size == 0
    per_device_batch_size = args.batch_size // args.world_size
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = torchvision.datasets.ImageFolder(args.train_data_dir, transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = torchvision.datasets.ImageFolder(args.val_data_dir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=per_device_batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=args.workers, pin_memory=True)

    best_acc = argparse.Namespace(top1=0, top5=0)
    print("Start evaluation...")
    train_loss = LossManager()
    results = {'Acc@1': [], 'Acc@5': []}
    save_name = os.path.join(args.results, args.save_name_pre+"_"+str(args.linear_lr)+'_linear.csv')
    for epoch in range(1, args.linear_epochs+1):
        model.eval()
        train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, init_lr, epoch-1, args)
        if args.local_rank ==0:
            print("epoch:%d, lr:%4f"%(epoch, optimizer.param_groups[0]["lr"]))
        for step, (x, target) in enumerate(train_loader):
            output = model(x.cuda(args.local_rank, non_blocking=True))
            loss = criterion(output, target.cuda(args.local_rank, non_blocking=True))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss.add_loss(Pack(loss=loss))
            
            if args.local_rank==0 and (step+1)%10==0:
                print(train_loss.pprint(window=20, prefix='Train Epoch: [{}/{}] Iters:[{}/{}]'.format(epoch, args.linear_epochs, step+1, len(train_loader))))
                     
        train_loss.clear()
        
        ##evaluation
        model.eval()
        if args.local_rank == 0:
            top1 = AverageMeter('Acc@1')
            top5 = AverageMeter('Acc@5')
            with torch.no_grad():
                for images, target in val_loader:
                    output = model(images.cuda(non_blocking=True))
                    acc1, acc5 = accuracy(output, target.cuda(args.local_rank, non_blocking=True), topk=(1, 5))
                    top1.update(acc1[0].item(), images.size(0))
                    top5.update(acc5[0].item(), images.size(0))
                    
            results['Acc@1'].append(top1.avg)
            results['Acc@5'].append(top5.avg)  
            data_frame = pd.DataFrame(data=results, index=range(1, epoch+1))
            data_frame.to_csv(save_name, index_label='epoch')
            best_acc.top1 = max(best_acc.top1, top1.avg)
            best_acc.top5 = max(best_acc.top5, top5.avg)
            print("acc1:"+str(top1.avg)+"  acc5:"+str(top5.avg)+"  best_top1:"+str(best_acc.top1)+"  best_top5:"+str(best_acc.top5))

    if args.local_rank==0:
        results['Acc@1'].append(best_acc.top1)
        results['Acc@5'].append(best_acc.top5)    
        data_frame = pd.DataFrame(data=results, index=range(1, args.linear_epochs + 2))
        data_frame.to_csv(save_name, index_label='epoch')
 
if __name__ == '__main__':
    args = config.parse_arg()
    dict_args = vars(args)

    if args.local_rank == 0:
        for k, v in zip(dict_args.keys(), dict_args.values()):
            print("{0}: {1}".format(k, v))
    
    main_worker(args)
