from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time
import numpy as np
import pandas as pd
import torch.multiprocessing as mp
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.distributed as dist
import random
from model import SparseCL
from util import Logger, LARS, LossManager
import config
import util
from dataloader import ImageNetTransform

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if math.isinf(val) or math.isnan(val):
            pass
        else:
            self.val = val
            self.sum += val * n
            self.count += n
            self.avg = self.sum / self.count

    def __str__(self):
        #fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        fmtstr = '{name} {avg' + self.fmt + '}'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

def main_worker(args):
    torch.cuda.set_device(args.local_rank)
    args.world_size = int(os.environ["WORLD_SIZE"])
    torch.distributed.init_process_group(backend='nccl')

    model = SparseCL(args)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    model = model.cuda(args.local_rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])

    args.lr = args.lr * args.batch_size / 256
    optimizer = LARS(model.parameters(), args.lr, weight_decay=args.weight_decay, momentum=args.momentum, trust_coefficient=args.eta)
    scaler = torch.cuda.amp.GradScaler(growth_interval=100)

    train_dataset = torchvision.datasets.ImageFolder(args.train_data_dir, ImageNetTransform(args))
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)

    assert args.batch_size % args.world_size == 0
    per_device_batch_size = args.batch_size // args.world_size
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=per_device_batch_size, num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)

    start_time = time.time()
    start_epoch = 1
    print("Start training...")
    results = {'total_loss':[], 'alignment_loss': [], 'sparsity_loss':[]}
    momentum_rate = args.base_momentum
    min_scale = 1024
    for epoch in range(start_epoch, args.epochs+1):
        train_sampler.set_epoch(epoch)
        print("epoch:%d, lr:%4f"%(epoch, optimizer.param_groups[0]["lr"]))
        total_loss, alignment_loss, sparsity_loss, total_num = 0.0, 0.0, 0.0, 0

        batch_time = AverageMeter('Time', ':6.3f')
        learning_rates = AverageMeter('LR', ':2.4f')
        losses = AverageMeter('Loss', ':6.3f')
        alignments = AverageMeter('Alignment', ':6.3f')
        sparsities = AverageMeter('Sparsity', ':6.3f')
        progress = ProgressMeter(len(train_loader), [batch_time, learning_rates, losses, alignments, sparsities], prefix="Epoch: [{}/{}]".format(epoch, args.epochs))
        model.train()
        end = time.time()
        iters_per_epoch = len(train_loader)
        for step, ((x1, x2), _) in enumerate(train_loader):
            lr = util.adjust_learning_rate(optimizer, (epoch-1) + step / iters_per_epoch, args)
            learning_rates.update(lr)
            momentum_rate = util.adjust_moco_momentum((epoch-1) + step / iters_per_epoch, args)

            x1 = x1.cuda(args.local_rank, non_blocking=True)
            x2 = x2.cuda(args.local_rank, non_blocking=True)

            #with torch.cuda.amp.autocast(True):
            loss, loss_pack = model(x1, x2, momentum_rate)

            losses.update(loss.item(), x1.size(0))
            alignments.update(loss_pack.alignment_loss.item(), x1.size(0))
            sparsities.update(loss_pack.sparsity_loss.item(), x1.size(0))

            optimizer.zero_grad()
            #scaler.scale(loss).backward()
            #scaler.unscale_(optimizer)
            #scaler.step(optimizer)
            #scaler.update()
            if args.local_rank == 0:
                print("alignment_loss:"+str(loss_pack.alignment_loss.item())+"  sparsity_loss:"+str(loss_pack.sparsity_loss.item()))
            #    print("scale:"+str((scaler._scale).item())+"  alignment_loss:"+str(loss_pack.alignment_loss.item())+"  sparsity_loss:"+str(loss_pack.sparsity_loss.item()))

            #if scaler._scale < min_scale:
            #    scaler._scale = torch.tensor(min_scale).to(scaler._scale)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            if args.local_rank == 0:
                total_num += args.batch_size
                total_loss += loss.item() * args.batch_size
                alignment_loss += loss_pack.alignment_loss.item() * args.batch_size
                sparsity_loss += loss_pack.sparsity_loss.item() * args.batch_size

            if (step+1) % 10 == 0 and args.local_rank==0:
                progress.display(step+1)

        if args.local_rank ==0:
            results['total_loss'].append(total_loss/total_num)
            results['alignment_loss'].append(alignment_loss/total_num)
            results['sparsity_loss'].append(sparsity_loss/total_num)
            data_frame = pd.DataFrame(data=results, index=range(1, epoch + 1))
            data_frame.to_csv('{}/{}_statistics.csv'.format(args.saver_dir, args.save_name_pre), index_label='epoch')

            checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint-'+args.save_name_pre+'.pth.tar')

            save_checkpoint({
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
            }, is_best=False, filename=checkpoint_path)
 
if __name__ == '__main__':
    args = config.parse_arg()
    sys.stdout = Logger(args)
    dict_args = vars(args)

    if args.local_rank == 0:
        for k, v in zip(dict_args.keys(), dict_args.values()):
            print("{0}: {1}".format(k, v))
    
    main_worker(args)

