import argparse
import os
import random
import warnings

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode
import torchvision      
from tqdm import tqdm
import torch.nn as nn       
from PIL import Image   
from models import ConvNet      
from resnet import ResNet18
from utils import get_dataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')       

parser = argparse.ArgumentParser(description='FKD Soft Label Generation on ImageNet-1K w/ Mix Augmentation')

parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=1024, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')

##############################################################################
parser.add_argument('--dataset_name', default='imagenet', type=str)
parser.add_argument('--model_save_name', type=str)
parser.add_argument('--train_folder_root', type=str, required=True)
parser.add_argument('--file_save_name', type=str, required=True)
parser.add_argument('--init_acc', action='store_true', default=False)       
parser.add_argument('--root_dir', type=str)
parser.add_argument('--subset', type=str)
parser.add_argument('--dataset_name_dict', type=str)
parser.add_argument('--init_resize', type=int)
parser.add_argument('--input_size', type=int)


# parser.add_argument('--ipc', type=int, required=True)   



##############################################################################

@torch.no_grad()            
def eval(model, dl):

    model.eval()    
    num_samples = 0
    crrct = 0   

    for b_ind, (x, y) in tqdm(enumerate(dl), total=len(dl)):        
        x, y = x.to(device), y.to(device)   
        out = model(x)  
        crrct += (out.argmax(1) == y).sum().item()      
        num_samples += x.shape[0]       

    return crrct / num_samples * 100.


def main():
    args = parser.parse_args()
    args.subset =  args.dataset_name

    if args.dataset_name == 'imagenet':
        model = models.__dict__[args.arch](pretrained=True)
        model = torch.nn.DataParallel(model).cuda()     
        print(f'pretrained model: {args.arch}')         
        im_size = 224       
        normalzie = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     
        val_folder = os.path.join(args.root_dir, 'val')     
        val_trans = transforms.Compose([transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(), normalzie])
        
        val_ds = torchvision.datasets.ImageFolder(val_folder, transform=val_trans)  


    elif args.dataset_name == 'imagenet100': 
        args.init_resize = 256      
        args.input_size = 224       
        model = ResNet18(num_cls=100)   
        model = torch.nn.DataParallel(model).cuda()     
        model.load_state_dict(torch.load(args.model_save_name))     
        im_size = 224       
        normalzie = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     
        
        ____, val_ds, folder_names, class_names, classes = get_dataset(args, None)

    elif args.dataset_name == 'imagenet-woof' or args.dataset_name == 'imagenette':   
        args.init_resize = 256  
        args.input_size = 224   
        model = ResNet18(num_cls=10)    
        model = torch.nn.DataParallel(model).cuda() 
        model.load_state_dict(torch.load(args.model_save_name)) 
        im_size = 224   
        normalzie = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])     
        ____, val_ds, folder_names, class_names, classes = get_dataset(args, None)      


    elif args.dataset_name == 'tiny':
        if args.arch == 'resnet18': 
            model = models.__dict__[args.arch](num_classes=200)
            model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            model.maxpool = nn.Identity()
        elif args.arch == 'convnet-4':
            model = ConvNet(num_classes=200, net_depth=4, im_size=(64, 64))    

        model = torch.nn.DataParallel(model).cuda()
        model.load_state_dict(torch.load(args.model_save_name))      
        normalzie = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
        im_size = 64     

        val_ds = torchvision.datasets.ImageFolder(os.path.join(args.root_dir, 'val', 'images'),
                                                    transform=transforms.Compose([transforms.ToTensor(), normalzie]))      

    elif args.dataset_name == 'cifar10':    
        model = ConvNet(num_classes=10)
        model = torch.nn.DataParallel(model).cuda()         
        model.load_state_dict(torch.load(args.model_save_name))     
        normalzie = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])       
        im_size = 32        
        val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, 
                                              transform=transforms.Compose([transforms.ToTensor(), normalzie]))      
        
    elif args.dataset_name == 'cifar100':       
        model = ConvNet(num_classes=100)
        model = torch.nn.DataParallel(model).cuda()         
        model.load_state_dict(torch.load(args.model_save_name))     
        normalzie = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])       
        im_size = 32        
        val_ds = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, 
                                               transform=transforms.Compose([transforms.ToTensor(), normalzie]))

    train_trans=transforms.Compose([transforms.Resize((im_size, im_size), interpolation=InterpolationMode.BICUBIC),
                                    transforms.ToTensor(), normalzie])          

    model.eval()    

    if args.init_acc:
        val_dl = torch.utils.data.DataLoader(val_ds, batch_size=args.batch_size, 
                                             shuffle=False, num_workers=args.workers, pin_memory=True)  

        acc = eval(model, val_dl)  
        print(f'Initial Accuracy: {acc:.2f}%')       

    save_soft_lbls(model, args.train_folder_root, train_trans, args.file_save_name, 
                   args.batch_size)

@torch.no_grad()    
def save_soft_lbls(model, train_folder_root, train_trans, file_save_name, batch_size):

    ds_train = torchvision.datasets.ImageFolder(train_folder_root, transform=train_trans)           
    train_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=False,
                                                num_workers=8, pin_memory=True)   
    
    class_names = sorted(os.listdir(train_folder_root))     
    int_targets = []        
    
    soft_targets = []
    imgs = []   

    for i, (input, target) in enumerate(tqdm(train_loader)):
        input = input.cuda()
        target = target.cuda()

        output = model(input)

        soft_targets.append(output.cpu())   

    soft_targets = torch.cat(soft_targets, dim=0)     
    print(soft_targets.shape)       

    torch.save(soft_targets, file_save_name)


if __name__ == '__main__':
    main()
