# from pytorchcv.model_provider import get_model as ptcv_get_model
import torch
from torch.autograd import Variable
import numpy as np
from tqdm import trange
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import tensorflow_datasets as tfds
import tensorflow as tf
import sys
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
from pathlib import Path
import os
import gc
import losses_classifier as losses
# from models import utils as mutils
# from models import ddpm, ncsnv2, ncsnpp, unet_classifier
# from models.ema import ExponentialMovingAverage
from utils import save_checkpoint, restore_checkpoint
import sde_lib
from torch.nn import functional as F

def eval_calib_error(config, eval_dir):
    if hasattr(config.model, 'classifier'):
        config.model.name = config.model.classifier
    class_model = mutils.create_model(config)
    ema = ExponentialMovingAverage(class_model.parameters(), decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, class_model.parameters())
    state = dict(optimizer=optimizer, model=class_model, ema=ema, step=0)
    state = restore_checkpoint(config.model.class_path, state, device=config.device)

    def resize_op(img):
      img = tf.image.convert_image_dtype(img, tf.float32)
      return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)

    def preprocess_fn(d):
      """Basic preprocessing function scales data to [0, 1) and randomly flips."""
      img = resize_op(d['image'])
      if config.data.uniform_dequantization:
        img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256.

      return dict(image=img, label=d.get('label', None))

    def create_dataset(dataset_builder, split):
        dataset_options = tf.data.Options()
        dataset_options.experimental_optimization.map_parallelization = True
        dataset_options.experimental_threading.private_threadpool_size = 48
        dataset_options.experimental_threading.max_intra_op_parallelism = 1
        read_config = tfds.ReadConfig(options=dataset_options)
        if isinstance(dataset_builder, tfds.core.DatasetBuilder):
            dataset_builder.download_and_prepare()
            ds = dataset_builder.as_dataset(
                split=split, shuffle_files=True, read_config=read_config)
        else:
            ds = dataset_builder.with_options(dataset_options)
        ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        # ds = ds.batch(400, drop_remainder=False)
        return ds.batch(400, drop_remainder=False)


    if config.data.dataset == 'CIFAR10' or config.data.dataset.startswith('CIFAR10_'):
        dataset_builder = tfds.builder('cifar10')
        train_split_name = 'train'
        eval_split_name = 'test'
    elif config.data.dataset == 'CIFAR100' or config.data.dataset.startswith('CIFAR100_'):
        dataset_builder = tfds.builder('cifar100')
        train_split_name = 'train'
        eval_split_name = 'test'
    else:
        NotImplementedError(f'Calibration error evaluation for dataset {config.data.dataset} not yet supported.')

    train_ds = create_dataset(dataset_builder, train_split_name)
    eval_ds = create_dataset(dataset_builder, eval_split_name)

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    bucket_num = 20
    buckets_conf = [[] for _ in range(bucket_num)]
    buckets_correct = [[] for _ in range(bucket_num)]
    with torch.no_grad():
        for next_eval in iter(eval_ds):
            batch = torch.from_numpy(next_eval['image']._numpy()).to(config.device).float().permute(0, 3, 1, 2)
            targets = torch.from_numpy(next_eval['label']._numpy()).to(config.device)
            ori_logits = F.softmax(mutils.get_score_fn(sde, class_model, train=False, continuous=config.training.continuous)(batch, torch.tensor([1e-5] * int(batch.shape[0]), device=batch.device)), dim=1)

            confidence = torch.max(ori_logits, dim=-1).values
            bucket_idx = confidence // 0.05
            correct = ori_logits.argmax(dim=-1) == targets
            for i in range(len(bucket_idx)):
                idx = int(bucket_idx[i].item())
                if idx == 20:
                    idx = 19
                buckets_conf[idx].append(confidence[i].item())
                buckets_correct[idx].append(correct[i].item())
    buckets_acc = [(sum(ele) / len(ele) if len(ele) else 0.0) for ele in buckets_correct]
    buckets_conf = [(sum(ele) / len(ele) if len(ele) else 0.0) for ele in buckets_conf]
    n = 0
    total = 0.0
    for i in range(20):
        print(buckets_acc[i], buckets_conf[i])
        if buckets_acc[i] > buckets_conf[i]:
            total += (buckets_acc[i] - buckets_conf[i]) * len(buckets_correct[i])
        else:
            total += (buckets_conf[i] - buckets_acc[i]) * len(buckets_correct[i])
        n += len(buckets_correct[i])
    print(total / n)