# from pytorchcv.model_provider import get_model as ptcv_get_model
import torch
from torch.autograd import Variable
import numpy as np
from tqdm import trange
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import tensorflow_datasets as tfds
import tensorflow as tf
import sys
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
from pathlib import Path
import os
import gc

def eval_acc(config, eval_dir):
    Path(os.path.join(eval_dir, 'wrong_pred')).mkdir(parents=True, exist_ok=True)
    acc = []
    device='cuda'
    # device='cpu'

    feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10')
    model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10').to(device)
    bs = 256

    with torch.no_grad():
        for i in range(config.data.num_classes):
            qqpr = []
            count, total = 0, 0
            samples = tf.io.gfile.glob(os.path.join(eval_dir, str(i), f"sample_*.npz"))
            if len(samples) == 0:
                samples = tf.io.gfile.glob(os.path.join(eval_dir, str(i), f"samples_*.npz"))
            for sample in samples:
                gc.collect()
                np_samples = np.load(sample)['samples']
                samples = []
                for ele in np_samples:
                    samples.append(Image.fromarray(ele))
                total += len(samples)
                for idx in range(0, len(samples), bs):
                    inputs = feature_extractor(images=samples[idx:idx+bs], return_tensors="pt")
                    y = model(**inputs.to(device))['logits']
                    count += (y.argmax(dim=-1).cpu() == i).cpu().sum().item()
                    for j in range(y.shape[0]):
                        if y.argmax(dim=-1)[j].cpu().item() != i:
                            qqpr.append(torch.from_numpy(np.transpose(np_samples[j+idx], (2, 0, 1))) / 255)
                if len(qqpr):
                    save_image(make_grid(torch.stack(qqpr), int(len(qqpr) ** 0.5), padding=2), os.path.join(eval_dir, 'wrong_pred', f'vit_class_{i}.png'))
            print(f"Class {i} acc = {count / total}")
            acc.append(count / total)
        print(f"Average acc = {sum(acc) / config.data.num_classes}")
    print()