import argparse
import os
import time
import blobfile as bf
import torch as th

import torch.distributed as dist
import torch.nn.functional as F

import logging
logging.level = logging.INFO


from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
    add_dict_to_argparser,
    args_to_dict,
    classifier_and_diffusion_defaults,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
)

import random
from torchvision.utils import save_image
import torch
from torch import nn
import numpy as np
from torchvision import datasets, transforms, models

from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.utils import save_image
from collections import defaultdict
from vit_models import deit_base_patch16_LS



defaults = dict()
defaults.update(classifier_and_diffusion_defaults())
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)

parser.add_argument('--clean_data', type=str, default="", help='path to clean ImageNet dataset')
parser.add_argument(
    '--corrupted_data', type=str, default="", help='path to ImageNet-C dataset')


parser.add_argument(
   "--worker_id",
   type=int
)
parser.add_argument(
   "--num_workers",
   type=int
)

MODEL_FLAGS = "--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
import sys
args = parser.parse_args(MODEL_FLAGS.split()+sys.argv[1:])

TIMES = list(range(0,451,50))
args.eval_batch_size = 64


CORRUPTIONS = [
   'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 
   'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 
   'frost',  'fog', 'brightness', 'contrast', 
   'elastic_transform', 'pixelate', 'jpeg_compression']

def chunk(l, num_chunks):
   import math
   per_chunk = len(l)//num_chunks
   ret = []
   i = 0
   while len(ret)!=num_chunks:
      if len(ret)==num_chunks-1:
         ret.append(l[i:])
      else:
        ret.append(l[i:i+per_chunk])
      i += per_chunk
   return ret

CORRUPTIONS = chunk(CORRUPTIONS, args.num_workers)[args.worker_id]
RANK = 0
print(CORRUPTIONS)
def accuracy(output, target, topk=(1,)):
  """Computes the accuracy over the k top predictions for the specified values of k."""
  with torch.no_grad():
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
      correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
      res.append(correct_k.mul_(100.0 / batch_size))
    return res

def denoise_augment(batch, t):
  t = torch.ones(batch.shape[0],device=batch.device).long()*t
  z = torch.randn_like(batch)
  xt = diffusion.q_sample(x_start=batch, t=t, noise=z)
  with torch.autocast("cuda"):
      xstart = diffusion.p_sample(
                  score_model,
                  xt,
                  t,
                  clip_denoised=False
              )['pred_xstart']
  return preprocess_denoised(xstart)


def test(net, test_loader):
  """Evaluate network on given dataset."""
  net.eval()
  total_loss = 0.
  all_logits = defaultdict(list)
  all_orig_logits = defaultdict(list)
  all_targets = None
  nums = 0
  num_lines = 0
  headers = list(net.nets.keys())
  with torch.no_grad():
    for images, targets in test_loader:
      images, targets = images.cuda(RANK), targets
      nums += images.shape[0]
      if all_targets is None:
         all_targets = targets
      else:
         all_targets = torch.cat([all_targets,targets],dim=0)
      for _ in range(num_lines):
        print("\033[F",end="")
      num_lines = 1
      print(f"{str(nums):^6s}"+"".join(list(map(lambda x:f"{x:^16s}",headers))))
      start_time = time.time()
      for idx,t in enumerate(TIMES):
        if t==0:
           logits = net(preprocess(images))
           orig_logits = orig_net(preprocess(images))
          #  loss = F.cross_entropy(logits, targets)
           total_loss += 0#float(loss.data)
        else:
           da_images = denoise_augment(2*images-1,t)
           logits = net(da_images)
           orig_logits = orig_net(da_images)
        for k in logits:
          if len(all_logits[k]) == idx:
            all_logits[k].append(logits[k])
            all_orig_logits[k].append(orig_logits[k])
          else:
            all_logits[k][idx] = torch.cat([all_logits[k][idx],logits[k]],dim=0)
            all_orig_logits[k][idx] = torch.cat([all_orig_logits[k][idx],orig_logits[k]],dim=0)
        
        print_str = f"{str(t):^6s}"
        for k in headers:
          print_str += f"{(all_orig_logits[k][idx].max(dim=1)[1]==all_targets.data).float().mean():^8.4f}{(all_logits[k][idx].max(dim=1)[1]==all_targets.data).float().mean():^8.4f}"
        print(print_str)
        num_lines += 1
      print_str = f"{'MN':^6s}"
      # print(time.time()-start_time)
      # num_lines += 1
      for k in headers:
            print_str += f"{(torch.softmax(torch.stack(all_orig_logits[k]),dim=-1).mean(dim=0).max(dim=1)[1]==all_targets.data).float().mean():^8.4f}{(torch.softmax(torch.stack(all_logits[k]),dim=-1).mean(dim=0).max(dim=1)[1]==all_targets.data).float().mean():^8.4f}"
      print(print_str)
      num_lines += 1
      
      total_correct = get_cumm_accs(all_logits,all_targets)
      total_correct_orig = get_cumm_accs(all_orig_logits,all_targets)
      for i in range(len(all_logits[k])):
         print_str = f"{f'MN{i}':^6s}"
         for k in headers:
            print_str += f"{total_correct_orig[k][0][i]:^8.4f}{total_correct[k][0][i]:^8.4f}"
         print(print_str)
         num_lines += 1      

  print()
  
  total_correct = get_cumm_accs(all_logits,all_targets)
  total_correct_orig = get_cumm_accs(all_orig_logits,all_targets)
  return total_loss / nums, total_correct, total_correct_orig, {"all_logits": all_logits, "all_orig_logits":all_orig_logits,"targets":all_targets}

def get_cumm_accs(all_logits, all_targets):
   result = {}
   for k in all_logits:
     probs = 0
     cum_accs = []
     instant_accs = []
     for l in all_logits[k]:
        preds = torch.softmax(l,dim=1)
        probs += preds
        cum_accs.append((probs.max(dim=1)[1]==all_targets).float().mean())
        instant_accs.append((preds.max(dim=1)[1]==all_targets).float().mean())
     result[k] = (cum_accs,instant_accs)
   return result

def test_c(net, test_transform):
  """Evaluate network on given corrupted dataset."""
  corruption_accs = {}
  
  for c in CORRUPTIONS:
    ckpt = f"workdirs/corr_{c}.pt"
    if ckpt is not None and os.path.exists(ckpt):
      corruption_accs = torch.load(ckpt)
    print(c)
    if c in corruption_accs:
       s = 5-len(corruption_accs[c])
    else:
       s = 5
       corruption_accs[c] = []
    while s>=1:
      valdir = os.path.join(args.corrupted_data, c, str(s))
      val_loader = torch.utils.data.DataLoader(
          datasets.ImageFolder(valdir, test_transform),
          batch_size=args.eval_batch_size,
          shuffle=True,
          num_workers=1,
          pin_memory=True)

      loss, acc1, acc1_orig, results = test(net, val_loader)
      corruption_accs[c].append({'ft':acc1,'orig':acc1_orig})

      print('\ts={}: {} {}'.format(
          s, acc1_orig, acc1))
      s -= 1
      if ckpt is not None:
        torch.save(corruption_accs,ckpt)


      
def main():
    global orig_net, score_model, diffusion, schedule_sampler, preprocess, preprocess_denoised, mean_t, std_t 
    from nets_inference import get_nets
    orig_net, net = get_nets()

    score_model, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )
    schedule_sampler = create_named_schedule_sampler("uniform", diffusion)
    
    score_model = score_model.cuda(RANK)
    
    score_model.load_state_dict(torch.load('workdirs/256x256_diffusion_uncond.pt',map_location='cpu'))
    score_model.convert_to_fp16()
    
    score_model = (score_model).eval()
    print("loaded score-model")

    # Load datasets
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    mean_t = torch.FloatTensor(mean).reshape(1,3,1,1)
    std_t = torch.FloatTensor(std).reshape(1,3,1,1)
    preprocess = transforms.Normalize(mean, std)
    preprocess_denoised = transforms.Compose([transforms.Normalize([-1,-1,-1],[2,2,2]), 
                                              transforms.Normalize(mean, std)])
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

    valdir = os.path.join(args.clean_data, 'val')
    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, test_transform),
        batch_size=args.eval_batch_size,
        shuffle=True,
        num_workers=1)
    
    test_c(net, test_transform)
    ckpt = "workdirs/corr_test.pt"
    loss, acc1, acc1_orig, results = test(net, val_loader)
    torch.save({"test":[{'ft':acc1,'orig':acc1_orig}]},ckpt)
    return
    

if __name__ == "__main__":
   main()
    
