import os
import sys
import time
import argparse
import random
import shutil

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import timm
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models
from torchattacks import VNIFGSM
from collections import OrderedDict

import models
from utils import AverageMeter, time_file_str

parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", type=str, help="path to imagenet dataset")
parser.add_argument("--save_dir", type=str, help="path to save the sampled dataset")
args = parser.parse_args()

def main():
    normalize_mean = [0.485, 0.456, 0.406]
    normalize_std = [0.229, 0.224, 0.225]
    val_dataset = datasets.ImageNet(
        args.dataset_dir,
        split = 'val',
        transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=normalize_mean, std=normalize_std)
        ])
    )
    print(len(val_dataset.imgs))
    print(val_dataset.imgs[0])
    print(val_dataset.imgs[1])

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True
    )
    sample_img = []
    sample_target = {}
    model1 = models.resnet50(pretrained=True).to('cuda')
    model2 = models.resnet101(pretrained=True).to('cuda')
    model1.eval()
    model2.eval()
    for i, (input, target) in tqdm(enumerate(val_loader)):
        input = input.cuda()
        target = target.cuda()
        input_var = torch.autograd.Variable(input, volatile=True)
        target_var = torch.autograd.Variable(target, volatile=True)

        output = model1(input_var)
        prec1, _ = accuracy(output.data, target, topk=(1, 5))
        model1_prec1 = prec1[0].item()
    
        output = model2(input_var)
        prec1, _ = accuracy(output.data, target, topk=(1, 5))
        model2_prec1 = prec1[0].item()
        if model1_prec1 > 0 and model2_prec1 > 0:
            sample_img.append(val_dataset.imgs[i])
            sample_target[sample_img[-1]] = target.item()
    
    sample_img = sorted(sample_img, key=lambda img: sample_target[img])
    print(f"There are {len(sample_img)} samples!")
    sample_img = random_sample(sample_img, sample_target)
    print(f"There are {len(sample_img)} samples!")
    for img in sample_img:
        img_path = img[0]
        dir_name = img_path.split('/')[-2]
        file_name = img_path.split('/')[-1]
        mkdir_path = os.path.join(args.save_dir, dir_name)
        os.makedirs(mkdir_path, exist_ok=True)
        dest_img_path = os.path.join(mkdir_path, file_name)
        shutil.copy(img_path, dest_img_path)

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
    pred = pred.t() # (maxk, batch_size)
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    # print("size of the correct tensor: ", correct.size())

    res = []
    for k in topk:
        correct_k = correct[:k].reshape((-1,)).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def random_sample(sample_img, sample_target):
    final_sample_img = []
    cur_target = 0
    start_idx = 0
    for idx, img in enumerate(sample_img):
        if sample_target[img] > cur_target:
            cur_target = sample_target[img]
            if idx - start_idx > 2:
                final_sample_img += random.sample(sample_img[start_idx:idx], 2)
            else:
                final_sample_img += sample_img[start_idx:idx]
            start_idx = idx
    if len(sample_img) - start_idx > 2:
        final_sample_img += random.sample(sample_img[start_idx:len(sample_img)], 2)
    else:
        final_sample_img += sample_img[start_idx:len(sample_img)]
    return final_sample_img

if __name__ == '__main__':
    main()