# from pytorchcv.model_provider import get_model as ptcv_get_model
import torch
from torch.autograd import Variable
import numpy as np
from tqdm import trange
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import tensorflow_datasets as tfds
import tensorflow as tf
import sys
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
from pathlib import Path
import os
import gc
import losses_classifier as losses
# from models import utils as mutils
# from models import ddpm, ncsnv2, ncsnpp, unet_classifier
# from models.ema import ExponentialMovingAverage
from utils import save_checkpoint, restore_checkpoint
import sde_lib

def eval_cls(config, eval_dir):
    class_model = mutils.create_model(config)
    ema = ExponentialMovingAverage(class_model.parameters(), decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, class_model.parameters())
    state = dict(optimizer=optimizer, model=class_model, ema=ema, step=0)
    state = restore_checkpoint(config.model.class_path, state, device=config.device)

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)

    def preprocess_fn(d):
      """Basic preprocessing function scales data to [0, 1) and randomly flips."""
      img = resize_op(d['image'])
      if config.data.uniform_dequantization:
        img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.

      return dict(image=img, label=d.get('label', None))

    def create_dataset(dataset_builder, split):
        dataset_options = tf.data.Options()
        dataset_options.experimental_optimization.map_parallelization = True
        dataset_options.experimental_threading.private_threadpool_size = 48
        dataset_options.experimental_threading.max_intra_op_parallelism = 1
        read_config = tfds.ReadConfig(options=dataset_options)
        if isinstance(dataset_builder, tfds.core.DatasetBuilder):
            dataset_builder.download_and_prepare()
            ds = dataset_builder.as_dataset(
                split=split, shuffle_files=True, read_config=read_config)
        else:
            ds = dataset_builder.with_options(dataset_options)
        ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        # ds = ds.batch(400, drop_remainder=False)
        return ds.batch(400, drop_remainder=False)


    if 'CIFAR10' in config.data.dataset:
        dataset_builder = tfds.builder('cifar10')
        train_split_name = 'train'
        eval_split_name = 'test'
    elif 'CIFAR100' in config.data.dataset:
        dataset_builder = tfds.builder('cifar100')
        train_split_name = 'train'
        eval_split_name = 'test'
    else:
        NotImplementedError(f'Calibration error evaluation for dataset {config.data.dataset} not yet supported.')

    train_ds = create_dataset(dataset_builder, train_split_name)
    eval_ds = create_dataset(dataset_builder, eval_split_name)

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    bucket_num = 20
    buckets_conf = [[] for _ in range(bucket_num)]
    buckets_correct = [[] for _ in range(bucket_num)]
    with torch.no_grad():
        for next_eval in iter(eval_ds):
            batch = torch.from_numpy(next_eval['image']._numpy()).to(config.device).float().permute(0, 3, 1, 2)
            targets = torch.from_numpy(next_eval['label']._numpy()).to(config.device)
            ori_logits = mutils.get_score_fn(sde, class_model, train=False, continuous=config.training.continuous)(batch, torch.tensor([1e-5] * int(batch.shape[0]), device=batch.device))

            confidence = torch.max(ori_logits, dim=1)
            bucket_idx = confidence // 0.05
            correct = ori_logits.argmax(dim=-1) == targets
            for i in range(len(bucket_idx)):
                buckets_conf[int(bucket_idx[i].item())].append(confidence[i])
                buckets_correct[int(bucket_idx[i].item())].append(correct[i].item())
            print(buckets_conf)
            print(buckets_correct)
            exit(0)
