import torch
from torchvision.models import resnet50
from torch import nn
from tqdm import tqdm, trange
import torchvision.transforms as transforms
import tensorflow as tf
import tensorflow_datasets as tfds
import gc
import sys
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import datasets
import os

def eval_CAS(config, eval_dir):
    classifier = resnet50(num_classes=10).to('cuda')
    optimizer = torch.optim.Adam(
                    classifier.parameters(),
                    lr=2e-4,
                    betas=(0.9, 0.999),
                    eps=1e-8,
                    weight_decay=0.
                )
    
    num_epoch = 5
    bs = 128

    all_samples = []
    all_targets = []

    for class_idx in range(config.data.num_classes):
        gc.collect()
        samples = []
        stats = tf.io.gfile.glob(os.path.join(eval_dir, str(class_idx), 'sample_*.npz'))
        for stat_file in stats:
            with tf.io.gfile.GFile(stat_file, "rb") as fin:
                samples.append(np.load(fin)['samples'])
        
        all_samples.append(np.concatenate(samples, axis=0)[:5000])
        all_targets.append(np.array([class_idx for _ in range(5000)]))
        if len(all_samples[-1]) < 5000:
            raise ValueError(f'Not enough samples for class {i}.')

    all_samples = np.concatenate(all_samples)
    all_targets = np.concatenate(all_targets)

    data = torch.tensor(all_samples, dtype=torch.float32)
    target = torch.tensor(all_targets, dtype=torch.long)
    train_ds = TensorDataset(data, target)
    train_loader = DataLoader(dataset=train_ds, batch_size=bs, shuffle=True, drop_last=True)
    ce_loss = nn.CrossEntropyLoss()

    for epoch in trange(num_epoch, desc='Total'):
        train_loss, train_acc = [], []
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
        for batch in pbar:
            inp, targets = batch
            inp = inp.permute(0, 3, 1, 2)
            inp, targets = inp.to('cuda') / 255., targets.to('cuda')

            classifier.train()
            optimizer.zero_grad()
            preds = classifier(inp)
            cur_loss = ce_loss(preds, targets)
            cur_loss.backward()
            optimizer.step()

            cur_acc = (torch.argmax(preds, dim=-1) == targets).float().mean()
            train_loss.append(cur_loss.item())
            train_acc.append(cur_acc.item())
            pbar.set_postfix({
                'loss': f'{sum(train_loss) / len(train_loss):.6e}',
                'acc': f'{sum(train_acc) / len(train_acc):.6e}'}
            )
    classifier.eval()

    if config.data.dataset == 'CIFAR10' or 'CIFAR10_SEMI' in config.data.dataset:
        dataset_builder = tfds.builder('cifar10')

    def resize_op(img):
        img = tf.image.convert_image_dtype(img, tf.float32)
        return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True)
    def preprocess_fn(d):
      img = resize_op(d['image'])
      return dict(image=img, label=d.get('label', None))

    def create_dataset(dataset_builder, split):
        dataset_options = tf.data.Options()
        dataset_options.experimental_optimization.map_parallelization = True
        dataset_options.experimental_threading.private_threadpool_size = 48
        dataset_options.experimental_threading.max_intra_op_parallelism = 1
        read_config = tfds.ReadConfig(options=dataset_options)
        if isinstance(dataset_builder, tfds.core.DatasetBuilder):
            dataset_builder.download_and_prepare()
            ds = dataset_builder.as_dataset(
                split=split, shuffle_files=True, read_config=read_config)
        else:
            ds = dataset_builder.with_options(dataset_options)
        ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
        ds = ds.batch(128, drop_remainder=False)
        return ds.prefetch(tf.data.experimental.AUTOTUNE)
    
    test_ds = create_dataset(dataset_builder, 'test')
    test_preds, test_targets = [], []
    for batch in tqdm(test_ds, desc='Test'):
        inp, targets = torch.from_numpy(batch['image']._numpy()).to('cuda'), torch.from_numpy(batch['label']._numpy()).to('cuda')
        inp = inp.permute(0, 3, 1, 2)
        preds = classifier(inp)
        test_preds.append(torch.argmax(preds, dim=-1).detach())
        test_targets.append(targets.detach())

    print(f'CAS: {(torch.concat(test_preds) == torch.concat(test_targets)).float().mean():.6e}')