import argparse
import os
import time
import blobfile as bf
import torch as th

import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW, SGD

import logging
logging.level = logging.INFO

from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    add_dict_to_argparser,
    args_to_dict,
    classifier_and_diffusion_defaults,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
)

import random
from guided_diffusion.rn50_classifiers import RN50Classifier
from torchvision.utils import save_image
import torch
from torch import nn
import numpy as np
from torchvision import datasets, transforms, models
import augmentations
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.utils import save_image
augmentations.IMAGE_SIZE = 224
torch.backends.cudnn.benchmark=True


defaults = dict()
defaults.update(classifier_and_diffusion_defaults())
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
parser.add_argument('--clean_data', type=str, default="", help='path to clean ImageNet dataset')

parser.add_argument(
    '--resume',
    '-r',
    type=str,
    default='workdirs/rn50_diffbase_scratch.pt',
    help='Checkpoint path for resume / test.')
parser.add_argument(
    '--print-freq',
    type=int,
    default=10,
    help='Training loss print frequency (batches).')

# Acceleration
parser.add_argument(
    '--num-workers',
    type=int,
    default=4,
    help='Number of pre-fetching threads.')

MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
import sys
args = parser.parse_args(MODEL_FLAGS.split()+sys.argv[1:])
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
RANK = int(os.environ['LOCAL_RANK']) if USE_PYTORCH_DDP else 0
WORLD_SIZE = int(os.environ['WORLD_SIZE']) if USE_PYTORCH_DDP else 1
args.batch_size = 256//WORLD_SIZE if WORLD_SIZE>1 else 64
if USE_PYTORCH_DDP:
    dist.init_process_group('nccl')
    torch.cuda.set_device(RANK)
    print_fn = print 
    def print(*args,**kwargs):
       if RANK==0:
        print_fn(RANK,*args,**kwargs)


class ConcatDataset(torch.utils.data.Dataset):
   def __init__(self, datasets, train_transform, da_transform):
      super().__init__()
      lens = [len(d) for d in datasets]
      assert len(set(lens)) == 1
      self.datasets = datasets 
      self.all_idxs = list(range(lens[0]))
      self.train_transform = train_transform
      self.da_transform = da_transform
      if USE_PYTORCH_DDP:
        random.Random(6158).shuffle(self.all_idxs)
        per_device = len(self.all_idxs)//WORLD_SIZE
        self.all_idxs = self.all_idxs[RANK*per_device:(RANK+1)*per_device]
        print(RANK, WORLD_SIZE, len(self.all_idxs))
   
   def __getitem__(self, i):
      i = self.all_idxs[i]
      xys = [d[i] for d in self.datasets]
      return [self.train_transform(xy[0]) for xy in xys],[self.da_transform(xy[0]) for xy in xys],xys[0][1] 
   
   def __len__(self):
      return len(self.all_idxs)


def accuracy(output, target, topk=(1,)):
  """Computes the accuracy over the k top predictions for the specified values of k."""
  with torch.no_grad():
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
      correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
      res.append(correct_k.mul_(100.0 / batch_size))
    return res

def main():
    score_model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    schedule_sampler = create_named_schedule_sampler("uniform", diffusion)
     
    score_model = score_model.cuda(RANK)
    
    score_model.load_state_dict(torch.load('workdirs/256x256_diffusion_uncond.pt',map_location='cpu'))
    score_model.convert_to_fp16()
    
    score_model = (score_model).eval()
    print("loaded score-model")

    # Load datasets
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    
    train_transform = transforms.Compose(
        [transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),])
    da_transform = transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
        transforms.ToTensor(),])
    normalize = transforms.Normalize(mean, std)
    mean_t = torch.FloatTensor(mean).reshape(1,3,1,1)
    std_t = torch.FloatTensor(std).reshape(1,3,1,1)
    normalize_denoised = transforms.Compose([transforms.Normalize([-1,-1,-1],[2,2,2]), 
                                              transforms.Normalize(mean, std)])
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])

    traindir = os.path.join(args.clean_data, 'train')
    valdir = os.path.join(args.clean_data, 'val')
    train_dataset = datasets.ImageFolder(
        traindir,
        )

    concat_dataset = ConcatDataset([train_dataset],
                                   train_transform=train_transform, 
                                   da_transform=da_transform)
    print(args.batch_size)
    train_loader = torch.utils.data.DataLoader(
        concat_dataset,
        batch_size=args.batch_size,pin_memory=True,
        shuffle=True,
        num_workers=1)
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, test_transform),
        batch_size=args.eval_batch_size,
        shuffle=True,
        num_workers=1)
    
    def denoise_augment(batch):
        t, _ = schedule_sampler.sample(batch.shape[0], batch.device)
        z = torch.randn_like(batch)
        xt = diffusion.q_sample(x_start=batch, t=t, noise=z)
        with torch.autocast("cuda"):
            xstart = diffusion.p_sample(
                        score_model,
                        xt,
                        t,
                        clip_denoised=True
                    )['pred_xstart']
        
        return normalize_denoised(xstart)

    def train_1epoch(net, train_loader, optimizer, scheduler, num_steps):
        """Train for one epoch."""
        
        data_ema = 0.
        batch_ema = 0.
        loss_ema = 0.
        acc1_ema = 0.
        acc5_ema = 0.

        end = time.time()
        
        for i, (images, da_images, targets) in enumerate(train_loader):
            net.train()
               
            # Compute data loading time
            data_time = time.time() - end
            optimizer.zero_grad()
            
            clean = images[0] 
            da_clean = da_images[0]
            with torch.no_grad():
                da_clean = denoise_augment(2*da_clean.cuda(RANK)-1)
                
            images_all = torch.cat([normalize(clean).cuda(RANK),
                                    da_clean,
                                  ], dim=0)
            targets = targets.cuda(RANK)
            
            logits_all = net(images_all)

            (
               logits_clean, 
               logits_da, 
            ) = torch.split(logits_all, clean.size(0))

            # Cross-entropy is only computed on clean/denoised images
            loss = F.cross_entropy(logits_clean, targets) \
              + F.cross_entropy(logits_da, targets) #+ \
            acc1, acc5 = accuracy(logits_clean, targets, topk=(1, 5))  # pylint: disable=unbalanced-tuple-unpacking
        
            loss.backward()
            optimizer.step()
            scheduler.step()
            # Compute batch computation time and update moving averages.
            batch_time = time.time() - end
            end = time.time()

            data_ema = data_ema * 0.1 + float(data_time) * 0.9
            batch_ema = batch_ema * 0.1 + float(batch_time) * 0.9
            loss_ema = loss_ema * 0.1 + float(loss) * 0.9
            acc1_ema = acc1_ema * 0.1 + float(acc1) * 0.9
            acc5_ema = acc5_ema * 0.1 + float(acc5) * 0.9

            if i % args.print_freq == 0:
                print(
                    'Batch {}: Data Time {:.3f} | Batch Time {:.3f} | Train Loss {:.3f} | Train Acc1 '
                    '{:.3f} | Train Acc5 {:.3f}'.format(num_steps, data_ema,
                                                        batch_ema, loss_ema, acc1_ema,
                                                        acc5_ema))
            num_steps += 1
            if RANK==0 and num_steps%100 == 0:
                checkpoint = {
                    "steps":num_steps,
                    "state_dict":net.state_dict(),
                    "optimizer":optimizer.state_dict(),
                    "scheduler":scheduler.state_dict()
                }
                torch.save(checkpoint,args.resume)
            
        checkpoint = {
                    "steps":num_steps,
                    "state_dict":net.state_dict(),
                    "optimizer":optimizer.state_dict(),
                    "scheduler":scheduler.state_dict()
                }
        return loss_ema, acc1_ema, batch_ema, checkpoint
    
    net = models.resnet50(weights=None).cuda(RANK)
    if USE_PYTORCH_DDP:
      net = DDP(net, device_ids=[RANK], output_device=RANK)
    else:
      net = torch.nn.DataParallel(net)
    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    num_steps = 0
    optimizer = torch.optim.SGD(
      net.parameters(),
      0.1,
      momentum=0.9,nesterov=False,
      weight_decay=0.0001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                                step_size=30*5000, 
                                                gamma=0.1)
    print(optimizer.param_groups[0]['lr'])
    print(optimizer.param_groups[0]['weight_decay'])

    if os.path.exists(args.resume):
       checkpoint = torch.load(args.resume)
       net.load_state_dict(checkpoint['state_dict'])
       optimizer.load_state_dict(checkpoint['optimizer'])
       scheduler.load_state_dict(checkpoint['scheduler'])
       num_steps = checkpoint.get("steps",0)
       print(f"Loaded model finetuned for {num_steps} steps")
    
    
    while True:
       if num_steps>90*5000:
          break
       train_loss_ema, train_acc1_ema, batch_ema, checkpoint = train_1epoch(net, train_loader,
                                                      optimizer, scheduler, num_steps)
       num_steps = checkpoint["steps"]

if __name__ == "__main__":
   main()
    
