import numpy as np
import random
import os

import torch
from torchvision import datasets, transforms

import sys
sys.path.append('..')
from models.resnet18_32x32 import ResNet18_32x32


def get_reference_set(sample_num_per_class=10):


    SEED = 100
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    random.seed(SEED)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform_train = transforms.Compose([
        transforms.Resize(32),   
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])

    train_dataset = datasets.CIFAR10(root='../../datasets/',
                                                train=True, 
                                                transform=transform_train,
                                                download=True)
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=1000,
                                            shuffle=True,
                                            num_workers=2)
    
    model = ResNet18_32x32()
    model.load_state_dict(torch.load('../weights/resnet18_9554.pth'))
    model = model.cuda()
    model.eval()
    reference_set = []
    with torch.no_grad():
        
        for samples, labels in train_loader:
            
            samples, labels = samples.cuda(), labels.cuda()
            preds = model(samples)
            preds = preds.max(-1)[1]

            mask = preds == labels
            print(mask.shape, mask.sum())

            samples = samples[mask, :, :, :]
            labels = labels[mask]

            print(len(samples))
            reference_set = [samples[labels==i][:sample_num_per_class] for i in range(10)]
            break

        try:
            reference_set = torch.stack(reference_set, dim=1)
        except Exception as e:
            print("Reference Set Error")
            
        print(reference_set.shape)
    return reference_set

def get_no_normlize_reference(sample_num_per_class=10):

    SEED = 100
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    random.seed(SEED)

    os.environ['CUDA_VISIBLE_DEVICES'] = '0,2'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform_train = transforms.Compose([
        transforms.Resize(32),   
        transforms.ToTensor(),
    ])

    train_dataset = datasets.CIFAR10(root='../../datasets/',
                                                train=True, 
                                                transform=transform_train,
                                                download=True)
    
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                            batch_size=1000,
                                            shuffle=True,
                                            num_workers=2)
    
    reference_set = []
    with torch.no_grad():
        
        for samples, labels in train_loader:
            
            reference_set = [samples[labels==i][:sample_num_per_class] for i in range(10)]
            break

        try:
            reference_set = torch.cat(reference_set, dim=0)
        except Exception as e:
            print("Reference Set Error")
            
        # print(reference_set.shape)
    return reference_set

if __name__ == "__main__":

    get_reference_set()