import sys
import os
from torch.utils import data
import torchvision
from torchvision import datasets
from torchvision import transforms, utils
import torch
import torchvision
from pathlib import Path

import numpy as np
from sklearn.model_selection import train_test_split
from termcolor import colored
from tqdm.auto import tqdm
from termcolor import colored
from metrics import ranking
import socket
from model.stylegan2.swish import Swish

def hashing_loss(b, cls, m=None, alpha=0.01):
    """
    compute hashing loss
    automatically consider all n^2 pairs
    """
    if m is None:
        m = 2
    #y = (cls.unsqueeze(0) != cls.unsqueeze(1)).float().view(-1)
    if len(cls.shape) == 1:
        y = (cls.unsqueeze(0) != cls.unsqueeze(1)).float().view(-1)
    elif len(cls.shape) == 2: #multi-label support
        y = ((cls.unsqueeze(0) * cls.unsqueeze(1)).sum(-1) == 0).float().view(-1)    
    dist = ((b.unsqueeze(0) - b.unsqueeze(1)) ** 2).mean(dim=2).view(-1)
    loss = (1 - y) / 2 * dist + y / 2 * (m - dist).clamp(min=0)

    loss = loss.mean() + alpha * (b.abs() - 1).abs().mean()

    return loss

def get_hostip():
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    s.connect(("8.8.8.8", 80))
    hostip = s.getsockname()[0]
    s.close()
    return hostip

class ExternalLogger(object):
    def __init__(self, args, run_name=None):
        if args.external_logger == 'neptune':
            project = args.external_logger_args
            if not os.path.exists(os.path.expanduser('~/.neptune')):
                raise Exception('Please create .neptune file to store your credential!')
            api_token = open(os.path.expanduser('~/.neptune')).readline().strip()
            if project is None:
                #backward comp
                project = 'khoadoan106/coophash'
            import neptune.new as neptune
            self.run = neptune.init(
                project=project,
                 api_token=api_token,
                name=run_name
            ) # your credentials
            self.set_val('IP', get_hostip())
            self.set_val('CMD', " ".join(sys.argv[:]))
        else:
            self.run = None

    def log_val(self, key, val, step=None):
        if self.run is not None:
            self.run[key].log(val, step=step)
    def set_val(self, key, val):
        if self.run is not None:
            self.run[key] = val
    def log_img(self, key, img, step=None):
        if self.run is not None:
            from neptune.new.types import File
            if type(img) == str:
                self.run[key].log(File(img), step=step)
    def cleanup(self):
        if self.run is not None:
            self.run.stop()

def get_activation(activation):
    if activation == 'swish':
        return Swish()
    elif activation == 'lrelu':
        return torch.nn.LeakyReLU()
    elif activation == 'lrelu1':
        return torch.nn.LeakyReLU(0.1)
    elif activation == 'lrelu2':
        return torch.nn.LeakyReLU(0.2)
    elif activation == 'relu':
        return torch.nn.ReLU()
    elif activation == 'tanh':
        return torch.nn.Tanh()
    elif activation is None:
        return None
    
class AddSaltAndPepperNoise(object):
    def __init__(self, prob=0.1, noise_prob=0.1):
        self.noise_prob = noise_prob
        self.prob = prob
        
    def __call__(self, tensor):
        if np.random.rand() < self.prob:
            rnd = torch.rand(tensor.shape)
            tensor[rnd < self.noise_prob/2] = 0.
            tensor[rnd > 1 - self.noise_prob/2] = 1.
        return tensor
    def __repr__(self):
        return self.__class__.__name__ + '(prob={0}, noise_prob={1})'.format(self.prob, self.noise_prob)

def create_corrupted_train_dataset(args):
  """Create the training dataset"""

  if args.train_corruption is not None:
    if 'salt' in args.train_corruption:
        noise_prob = args.train_corruption.split('_')[1]
        corruption = AddSaltAndPepperNoise(1, float(noise_prob)) #100% will be corrupted, since we split the dataset, this is ok
    elif 'erase' in args.train_corruption:
        params = args.train_corruption.split('_')
        if len(params) > 1:
            scale = (float(params[1]), float(params[2]))
        else:
            scale = (0.1, 0.2)
        if len(params) > 3:
            ratio = (float(params[3]), float(params[4]))
        else:
            ratio = (0.3, 0.3)
            
        corruption = transforms.RandomErasing(p=1, scale=scale, ratio=ratio, value=1, inplace=False)#100% will be corrupted, since we split the dataset, this is ok
    print(f'Using train corrupt_transform {corruption}')

  if args.dataset == 'cifar10':
    clean_transform = transforms.Compose(
      [
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
      ]
    )
    
    corrupt_transform = transforms.Compose(
      [
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        corruption,
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
      ]
    )
    
    clean_indices, corrupt_indices = train_test_split(np.arange(50000), train_size=40000)
    clean_dataset = torchvision.datasets.CIFAR10(args.root, train=True, transform=clean_transform, download=True)
    corrupt_dataset = torchvision.datasets.CIFAR10(args.root, train=True, transform=corrupt_transform, download=True)
    
    clean_dataset = torch.utils.data.Subset(clean_dataset, clean_indices)
    corrupt_dataset = torch.utils.data.Subset(corrupt_dataset, corrupt_indices)
    
  else:
    raise Exception(f'Invalid dataset {args.dataset}')

  return clean_dataset, corrupt_dataset

def create_corrupted_test_datasets(args):    
    if args.train_corruption is not None:
        if 'salt' in args.train_corruption:
            noise_prob = args.train_corruption.split('_')[1]
            corruption = AddSaltAndPepperNoise(1, float(noise_prob)) #100% will be corrupted, since we split the dataset, this is ok
        elif 'erase' in args.train_corruption:
            params = args.train_corruption.split('_')
            if len(params) > 1:
                scale = (float(params[1]), float(params[2]))
            else:
                scale = (0.1, 0.2)
            if len(params) > 3:
                ratio = (float(params[3]), float(params[4]))
            else:
                ratio = (0.3, 0.3)

            corruption = transforms.RandomErasing(p=1, scale=scale, ratio=ratio, value=1, inplace=False)#100% will be corrupted, since we split the dataset, this is ok
        print(f'Using train corrupt_transform {corruption}')

    
    if 'sampled_db_size' not in args:
        db_size=5000
        query_size=200
    else:
        db_size=args.sampled_db_size
        query_size=args.sampled_query_size
        
    print(colored('Using {} for sample db, {} sample query in testing'.format(args.sampled_db_size, args.sampled_query_size), 'red'))
        
    if args.dataset == 'cifar10':
        test_data_transforms = transforms.Compose([
                              transforms.Resize(args.size),
                              transforms.CenterCrop(args.size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                          ])        
        clean_transform = transforms.Compose(
            [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
            ]
        )

        corrupt_transform = transforms.Compose(
            [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            corruption,
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
            ]
        )
        
        #db_dataset = datasets.CIFAR10('../../data', download=True, train=True, transform=db_data_transforms)
        #query_dataset = datasets.CIFAR10('../../data', download=True, train=False, transform=test_data_transforms)
        
        clean_db_indices, corrupt_db_indices = train_test_split(np.arange(50000), train_size=40000)
        db_clean_dataset = torchvision.datasets.CIFAR10(args.root, train=True, transform=clean_transform, download=True)
        db_corrupt_dataset = torchvision.datasets.CIFAR10(args.root, train=True, transform=corrupt_transform, download=True)
        db_clean_dataset = torch.utils.data.Subset(db_clean_dataset, clean_db_indices)
        db_corrupt_dataset = torch.utils.data.Subset(db_corrupt_dataset, corrupt_db_indices)
        small_clean_db_indices, _ = train_test_split(np.arange(40000), train_size=4000)
        small_corrupt_db_indices, _ = train_test_split(np.arange(10000), train_size=1000)
                
        query_dataset = datasets.CIFAR10('../../data', download=True, train=False, transform=test_data_transforms)
        small_query_indices, _ = train_test_split(np.arange(10000), train_size=200)
                
            
    #FULL EVALUATION
    db_clean_loader = data.DataLoader(db_clean_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    db_corrupt_loader = data.DataLoader(db_corrupt_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    query_loader = data.DataLoader(query_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    
    sampled_db_clean_dataset = torch.utils.data.Subset(db_clean_dataset, small_clean_db_indices)
    sampled_db_corrupt_dataset = torch.utils.data.Subset(db_corrupt_dataset, small_corrupt_db_indices)
    sampled_query_dataset = torch.utils.data.Subset(query_dataset, small_query_indices)

    sampled_db_clean_loader = data.DataLoader(sampled_db_clean_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    sampled_db_corrupt_loader = data.DataLoader(sampled_db_clean_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    sampled_query_loader = data.DataLoader(sampled_query_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    
    
    return db_clean_loader, db_corrupt_loader, query_loader, sampled_db_clean_loader, sampled_db_corrupt_loader, sampled_query_loader
    

def create_train_dataset(args):
    """Create the training dataset"""
    if args.dataset == 'cifar10':
        transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
          ]
        )
        dataset = torchvision.datasets.CIFAR10(args.root, train=True, transform=transform, download=True)

        if args.ood_training:
            remove0_indices = np.nonzero(np.array(dataset.targets) != 0)[0]
            print(colored('Removing examples for OOD Experiments: {} vs {} examples'.format(len(remove0_indices), len(dataset)), 'red'))
            dataset = torch.utils.data.Subset(dataset, remove0_indices)
    elif args.dataset == 'mnist_1c':
        transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ), inplace=True),
          ]
        )
        dataset = torchvision.datasets.MNIST('../../data', train=True, transform=transform, download=True)
    elif args.dataset == 'nuswide':
        from data.tensor_dataset import TensorDataset
        transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
          ]
        )
        dataset = TensorDataset('../../data/nuswide_isize32_TOP10', 'nuswide', ds_type='db', transform=transform, 
                                data_path=None)
    elif args.dataset == 'coco':
        from data.tensor_dataset import TensorDataset
        transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
          ]
        )
        dataset = TensorDataset('../../data/coco_isize32_TOP10', 'coco', ds_type='db', transform=transform, 
                                data_path=None)
    elif args.dataset == 'coco80':
        from data.tensor_dataset import TensorDataset
        transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
          ]
        )
        dataset = TensorDataset('../../data/coco_isize32_ORIGINAL', 'coco', ds_type='db', transform=transform, 
                                    data_path=None)
    elif args.dataset in ['nuswide_hashgan']:
        from data.tensor_dataset import NumpyDataset
        transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
          ]
        )
        dataset = NumpyDataset('../../data/', args.dataset, ds_type='db', transform=transform)
    elif args.dataset == 'coco_hashgan':
        from data.tensor_dataset import FilenameDataset
        transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
          ]
        )
        dataset = FilenameDataset('../../data/coco', '../../data', args.dataset, ds_type='db', transform=transform)
    elif args.dataset == 'cmnist':
        from data.color_mnist import ColoredMNIST
        transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
          ]
        )
        dataset = ColoredMNIST(root='../../data', env='train', transform=transform)
    else:
        dataset = MultiResolutionDataset(dataset_path, transform, args.size)

    return dataset

def data_sampler(dataset, shuffle):
    if shuffle:
        return data.RandomSampler(dataset)
    else:
        return data.SequentialSampler(dataset)

def create_datasets(args):
    if 'sampled_db_size' not in args:
        sampled_db_size = 10000
        sampled_query_size = 200
    else:
        sampled_db_size = args.sampled_db_size
        sampled_query_size = args.sampled_query_size
        
    print(colored('Using {} for sample db, {} sample query in testing'.format(args.sampled_db_size, args.sampled_query_size), 'red'))
        
    if args.dataset == 'mnist_1c':
        train_transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            #transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ), inplace=True),
          ]
        )
        
        test_transform = transforms.Compose([
                              transforms.Resize(args.size),
                              transforms.CenterCrop(args.size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, ), (0.5, ))
                          ])
        
        # Create orig dataset for getting indices
        orig_train_dataset = datasets.MNIST('../../data', download=True, train=True)
        orig_test_dataset = datasets.MNIST('../../data', download=True, train=False)
        
        if args.full_fid_evaluation: #only ebm
            train_dataset = datasets.MNIST('../../data', download=True, train=True, transform=train_transform)
            db_dataset = datasets.MNIST('../../data', download=True, train=True, transform=test_transform)
            query_dataset = datasets.MNIST('../../data', download=True, train=False, transform=test_transform)
            # Create sampled dataset
            sampled_db_indices, _ = train_test_split(np.arange(len(db_dataset)), train_size=sampled_db_size, 
                                                     random_state=args.data_seed)
            sampled_query_indices, _ = train_test_split(np.arange(len(query_dataset)), train_size=sampled_query_size,
                                                       random_state=args.data_seed)
                                                     
            sampled_db_dataset = torch.utils.data.Subset(
                db_dataset, sampled_db_indices)
            sampled_query_dataset = torch.utils.data.Subset(
                query_dataset, sampled_query_indices)
        else:
            # Sample 10000
            train_indices, db_indices_1 = train_test_split(np.arange(len(orig_train_dataset)), 
                                             train_size=10000, stratify=orig_train_dataset.targets, random_state=args.data_seed)
            query_indices, db_indices_2 = train_test_split(np.arange(len(orig_test_dataset)), 
                                             train_size=1000, stratify=orig_test_dataset.targets, random_state=args.data_seed)
            sampled_db_indices, _ = train_test_split(db_indices_1, 
                                             train_size=sampled_db_size, stratify=orig_train_dataset.targets[db_indices_1],
                                             random_state=args.data_seed)
            sampled_query_indices, _ = train_test_split(query_indices, 
                                             train_size=sampled_query_size, stratify=orig_test_dataset.targets[query_indices],
                                             random_state=args.data_seed)

            train_dataset = torch.utils.data.Subset(
                datasets.MNIST('../../data', download=True, train=True, transform=train_transform), train_indices)

            query_dataset = torch.utils.data.Subset(
                datasets.MNIST('../../data', download=True, train=False, transform=test_transform), query_indices)

            # DB is a mix of remaining datasets
            db_dataset_1 = torch.utils.data.Subset(
                datasets.MNIST('../../data', download=True, train=True, transform=test_transform), db_indices_1)

            db_dataset_2 = torch.utils.data.Subset(
                datasets.MNIST('../../data', download=True, train=False, transform=test_transform), db_indices_2)

            db_dataset = torch.utils.data.ConcatDataset([db_dataset_1, db_dataset_2])

            # Create sampled dataset
            sampled_db_dataset = torch.utils.data.Subset(
                datasets.MNIST('../../data', download=True, train=True, transform=test_transform), sampled_db_indices)
            sampled_query_dataset = torch.utils.data.Subset(
                datasets.MNIST('../../data', download=True, train=False, transform=test_transform), sampled_query_indices)
            
            
            if 'train_db_evaluation' in args and args.train_db_evaluation:
                train_dataset = torch.utils.data.ConcatDataset([train_dataset, db_dataset])
                
    elif args.dataset == 'cifar10':
        train_transform = transforms.Compose(
          [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ), inplace=True),
          ]
        )
        
        test_transform = transforms.Compose([
                              transforms.Resize(args.size),
                              transforms.CenterCrop(args.size),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, ), (0.5, ))
                          ])
        
        # Create orig dataset for getting indices
        orig_train_dataset = datasets.CIFAR10('../../data', download=True, train=True)
        orig_test_dataset = datasets.CIFAR10('../../data', download=True, train=False)
        
        if args.full_fid_evaluation:
            train_dataset = datasets.CIFAR10('../../data', download=True, train=True, transform=train_transform)
            db_dataset = datasets.CIFAR10('../../data', download=True, train=True, transform=test_transform)
            query_dataset = datasets.CIFAR10('../../data', download=True, train=False, transform=test_transform)
            # Create sampled dataset
            sampled_db_indices, _ = train_test_split(np.arange(len(db_dataset)), train_size=sampled_db_size,
                                                    random_state=args.data_seed)
            sampled_query_indices, _ = train_test_split(np.arange(len(query_dataset)), train_size=sampled_query_size,
                                                       random_state=args.data_seed)
                                                     
            sampled_db_dataset = torch.utils.data.Subset(
                db_dataset, sampled_db_indices)
            sampled_query_dataset = torch.utils.data.Subset(
                query_dataset, sampled_query_indices)
        else:
            # Sample 10000
            train_indices, db_indices_1 = train_test_split(np.arange(len(orig_train_dataset)), 
                                             train_size=5000, stratify=orig_train_dataset.targets,
                                             random_state=args.data_seed)
            query_indices, db_indices_2 = train_test_split(np.arange(len(orig_test_dataset)), 
                                             train_size=1000, stratify=orig_test_dataset.targets,
                                             random_state=args.data_seed)
            sampled_db_indices, _ = train_test_split(db_indices_1, 
                                             train_size=sampled_db_size, 
                                             stratify=np.array(orig_train_dataset.targets)[db_indices_1], 
                                             random_state=args.data_seed)
            sampled_query_indices, _ = train_test_split(query_indices, 
                                             train_size=sampled_query_size, 
                                             stratify=np.array(orig_test_dataset.targets)[query_indices],
                                             random_state=args.data_seed)

            train_dataset = torch.utils.data.Subset(
                datasets.CIFAR10('../../data', download=True, train=True, transform=train_transform), train_indices)

            query_dataset = torch.utils.data.Subset(
                datasets.CIFAR10('../../data', download=True, train=False, transform=test_transform), query_indices)

            # DB is a mix of remaining datasets
            db_dataset_1 = torch.utils.data.Subset(
                datasets.CIFAR10('../../data', download=True, train=True, transform=test_transform), db_indices_1)

            db_dataset_2 = torch.utils.data.Subset(
                datasets.CIFAR10('../../data', download=True, train=False, transform=test_transform), db_indices_2)

            db_dataset = torch.utils.data.ConcatDataset([db_dataset_1, db_dataset_2])

            # Create sampled dataset
            sampled_db_dataset = torch.utils.data.Subset(
                datasets.CIFAR10('../../data', download=True, train=True, transform=test_transform), sampled_db_indices)
            sampled_query_dataset = torch.utils.data.Subset(
                datasets.CIFAR10('../../data', download=True, train=False, transform=test_transform), sampled_query_indices)
            
            if 'train_db_evaluation' in args and args.train_db_evaluation:
                train_dataset = torch.utils.data.ConcatDataset([train_dataset, db_dataset])
            
        
    print(colored(f'Data sizes: Train {len(train_dataset)} DB {len(db_dataset)}/{len(sampled_db_dataset)} Query {len(query_dataset)}/{len(sampled_query_dataset)}', 'blue'))
    #FULL EVALUATION
    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
                                   sampler=data_sampler(train_dataset, shuffle=True),
                                   drop_last=True, pin_memory=True, num_workers=2)
    db_loader = data.DataLoader(db_dataset, batch_size=args.batch_size, shuffle=False)
    query_loader = data.DataLoader(query_dataset, batch_size=args.batch_size, shuffle=False)
    
    sampled_db_loader = data.DataLoader(sampled_db_dataset, batch_size=args.batch_size, shuffle=False)
    sampled_query_loader = data.DataLoader(sampled_query_dataset, batch_size=args.batch_size, shuffle=False)
    
    return train_loader, db_loader, query_loader, sampled_db_loader, sampled_query_loader
    
def create_models(args):
  if args.model.startswith('sngan'):
      from model.sngan.model_resnet import ConditionalGenerator, HashDiscriminator
      discriminator = HashDiscriminator(args.n_classes, 3, args.size, args.hash_dim, multi_labels=args.multi_labels)
      generator = ConditionalGenerator(args.n_classes, 3, args.latent, multi_labels=args.multi_labels)
      g_ema = ConditionalGenerator(args.n_classes, 3, args.latent, multi_labels=args.multi_labels)
  elif args.model.startswith('original'):
      from model.original import ConditionalGenerator
      if args.model == 'original':
          from model.original import HashDiscriminator
          discriminator = HashDiscriminator(args.n_classes, 3, args.size, args.hash_dim, multi_labels=args.multi_labels)
          generator = ConditionalGenerator(args.n_classes, 3, args.latent, multi_labels=args.multi_labels)
          g_ema = ConditionalGenerator(args.n_classes, 3, args.latent, multi_labels=args.multi_labels)
      elif args.model == 'original_v2':
          from model.original import HashDiscriminatorV2
          discriminator = HashDiscriminatorV2(args.n_classes, 3, args.size, args.hash_dim, multi_labels=args.multi_labels, ndf=args.ngf)
          generator = ConditionalGenerator(args.n_classes, 3, args.latent, multi_labels=args.multi_labels, ngf=args.ngf)
          g_ema = ConditionalGenerator(args.n_classes, 3, args.latent, multi_labels=args.multi_labels, ngf=args.ngf)            
      elif args.model == 'original_1c':
          from model.original import HashDiscriminator
          discriminator = HashDiscriminator(args.n_classes, 1, args.size, args.hash_dim, multi_labels=args.multi_labels)
          generator = ConditionalGenerator(args.n_classes, 1, args.latent, multi_labels=args.multi_labels)
          g_ema = ConditionalGenerator(args.n_classes, 1, args.latent, multi_labels=args.multi_labels)
      elif args.model == 'original_1c_v2':
          from model.original import HashDiscriminatorV2
          discriminator = HashDiscriminatorV2(args.n_classes, 1, args.size, args.hash_dim, multi_labels=args.multi_labels)
          generator = ConditionalGenerator(args.n_classes, 1, args.latent, multi_labels=args.multi_labels)
          g_ema = ConditionalGenerator(args.n_classes, 1, args.latent, multi_labels=args.multi_labels)          
  elif 'style2' in args.model:
      from model.stylegan2.stylegan2 import ConditionalGenerator, HashDiscriminator
      generator = ConditionalGenerator(
        args.n_classes, args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
      )
      if args.model == 'style2':
        discriminator = HashDiscriminator(
          args.n_classes, args.hash_dim, args.size, channel_multiplier=args.channel_multiplier
        )
      elif args.model == 'style2_leaky':
        from model.stylegan2.stylegan2  import HashDiscriminatorLeaky
        discriminator = HashDiscriminatorLeaky(
          args.n_classes, args.hash_dim, args.size, channel_multiplier=args.channel_multiplier
        )
      g_ema = ConditionalGenerator(
        args.n_classes, args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
      )
  return generator, g_ema, discriminator


def get_codes(device, discriminator, dataloader):
    discriminator.eval()
    num_batches = len(dataloader)
    codes, labels = [], []
    for i, (obs_data, obs_label) in tqdm(enumerate(dataloader), total=num_batches):
        out_h = discriminator(None, obs_data.to(device), return_code_only=True)
        codes.append(out_h.detach().cpu().numpy())
        labels.append(obs_label.detach().cpu().numpy())
    return np.vstack(codes), np.concatenate(labels)

def get_predictions(device, discriminator, dataloader):
    discriminator.eval()
    num_batches = len(dataloader)
    predictions, labels = [], []
    for i, (obs_data, obs_label) in tqdm(enumerate(dataloader), total=num_batches):
        out_class = discriminator(None, obs_data.to(device), return_prediction_only=True)
        predictions.append(out_class.detach().cpu().numpy())
        labels.append(obs_label.detach().cpu().numpy())
    return np.vstack(predictions), np.concatenate(labels)

def hashing_evaluate(args, device, discriminator, db_loader, query_loader, Rs=[1, 5, 10, 100, 1000]):
    db_features, db_labels = get_codes(device, discriminator, db_loader)
    query_features, query_labels = get_codes(device, discriminator, query_loader)
    
    N = db_features.shape[0]
    if N not in Rs:
        Rs = Rs.copy() + [N]
    
    threshold = 0
    db_b = np.sign(db_features-threshold)
    query_b = np.sign(query_features-threshold)

    db_b[db_b == 0] = 1
    query_b[query_b == 0] = 1

    if len(db_labels.shape) > 1: #multi-labels
        precisions, mAPs = ranking.calculate_all_metrics(db_b, db_labels,
                              query_b, query_labels, Rs=Rs)
    else: #1 label
        precisions, mAPs = ranking.calculate_all_metrics(db_b, ranking.one_hot_label(db_labels),
                              query_b, ranking.one_hot_label(query_labels, args.n_classes), Rs=Rs)
    print(db_features.min(), db_features.max())
    return precisions, mAPs

def calculate_multilabel_accuracy(prediction, label):
    n, d = prediction.size()
    pred_output = (prediction >= 0).to(torch.float32)
    a = torch.matmul(pred_output.view(n, 1, d), label.view(n, d, 1)).sum(dim=1)
    accuracy = torch.mean((a > 0).to(torch.float32))
    return accuracy

def generate_random_multi_labels(size, n_classes):
    labels = torch.zeros(size, n_classes)
    indices = (torch.arange(size) % n_classes).view(-1, 1)
    labels.scatter_(1, indices, 1.0)
    return labels.to(torch.float32)

def create_multi_label_v(indices, n_classes):
    v = np.zeros(n_classes, dtype=np.float32)
    v[indices] = 1.0
    return v

def sample_negative_labels(label, n_classes, multi_labels=False):
    label_cpu = label.detach().cpu().numpy()
    if multi_labels:
        #find all positions which are 0, and randomly sample 1 to len(num_zero_positions) labels
        neg_label = [create_multi_label_v(np.random.choice(np.nonzero(l == 0)[0], np.random.randint(1, len(np.nonzero(l == 0)[0]) + 1), replace=False), n_classes) for l in label_cpu]
    else:
        neg_label = [np.random.choice([e for e in range(n_classes) if e != l], 1)[0] for l in label_cpu]
    neg_label = torch.tensor(np.array(neg_label))
    return neg_label