import os
import json
import torch
import pickle
import os.path
import numpy as np
import torchvision
from PIL import Image
import torch.utils.data as data
from pycocotools.coco import COCO
import torchvision.transforms as transforms

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

VISUAL_DECATHLON_WDS = {
    "aircraft" : 0.0005,
    "cifar100" : 0.0,
    "daimlerpedcls" : 0.0005,
    "dtd" : 0.0,
    "gtsrb" : 0.0,
    "omniglot" : 0.0005,
    "svhn" : 0.0,
    "ucf101" : 0.0005,
    "vgg-flowers" : 0.0001,
    "imagenet12" : 0.0001 
}
                    
                    

def pil_loader(path):
    return Image.open(path).convert('RGB')

class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, target_transform=None, index=None,
            labels=None ,imgs=None,loader=pil_loader,skip_label_indexing=0):
        
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        if index is not None:
            imgs = [imgs[i] for i in index]
        self.imgs = imgs
        if index is not None:
            if skip_label_indexing == 0:
                labels = [labels[i] for i in index]
        self.labels = labels
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index][0]
        target = None if self.labels is None else self.labels[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, (target, self.imgs[index][1])

    def __len__(self):
        return len(self.imgs)

def prepare_data_loaders(dataset_names, data_dir, imdb_dir, shuffle_train=True, index=None):
    train_loaders = []
    val_loaders = []
    num_classes = []
    train = [0]
    val = [1]

    imdb_names_train = [os.path.join(imdb_dir, dataset_names[i] + '_train.json') for i in range(len(dataset_names))]
    imdb_names_val   = [os.path.join(imdb_dir, dataset_names[i] + '_val.json') for i in range(len(dataset_names))]
    imdb_names = [imdb_names_train, imdb_names_val]
   	
    with open(os.path.join(data_dir , 'decathlon_mean_std.pickle'), 'rb') as handle:
        dict_mean_std = pickle.load(handle, encoding="iso-8859-1")
    
    for i in range(len(dataset_names)):
        imgnames_train = []
        imgnames_val = []
        labels_train = []
        labels_val = []

        for itera1 in train+val:
            annFile = imdb_names[itera1][i]
            coco = COCO(annFile)
            print(f"Load {annFile} as {'test' if itera1 else 'train'}_set dataset of {dataset_names[i]}.")
            imgIds = coco.getImgIds()
            annIds = coco.getAnnIds(imgIds=imgIds)
            anno = coco.loadAnns(annIds)
            images = coco.loadImgs(imgIds) 
            timgnames = [img['file_name'] for img in images]
            timgnames_id = [img['id'] for img in images]
            labels = [int(ann['category_id'])-1 for ann in anno]
            min_lab = min(labels)
            labels = [lab - min_lab for lab in labels]
            max_lab = max(labels)

            imgnames = []
            for j in range(len(timgnames)):
                imgnames.append((data_dir + '/' + timgnames[j],timgnames_id[j]))

            if itera1 in train:
                imgnames_train += imgnames
                labels_train += labels
			
            if itera1 in val:
                # imgnames_train += imgnames
                # labels_train += labels

                imgnames_val += imgnames
                labels_val += labels
	
        num_classes.append(int(max_lab+1))
        means = dict_mean_std[dataset_names[i] + 'mean']
        stds = dict_mean_std[dataset_names[i] + 'std']

        if dataset_names[i] in ['gtsrb', 'omniglot','svhn']: 
            transform_train = transforms.Compose([
            transforms.Resize(72),
            transforms.CenterCrop(72),
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])
        elif dataset_names[i] in ['daimlerpedcls']:
            transform_train = transforms.Compose([
            transforms.Resize(72),            
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])  
        else:
            transform_train = transforms.Compose([
            transforms.Resize(72),            
            transforms.RandomCrop(64),
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])  

        if dataset_names[i] in ['gtsrb', 'omniglot','svhn']: 
            transform_test = transforms.Compose([
            transforms.Resize(72),
            transforms.CenterCrop(72),
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])
        elif dataset_names[i] in ['daimlerpedcls']:
            transform_test = transforms.Compose([
            transforms.Resize(72),            
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])  
        else:
            transform_test = transforms.Compose([
                transforms.Resize(72),
                transforms.CenterCrop(72),
                transforms.ToTensor(),
                transforms.Normalize(means, stds),
            ])
		
        img_path = data_dir

        trainloader = ImageFolder(data_dir, transform_train, None, index, labels_train, imgnames_train)
        valloader = ImageFolder(data_dir, transform_test, None, None, labels_val, imgnames_val)

        train_loaders.append(trainloader)
        val_loaders.append(valloader)
        print("loader len: ", len(trainloader), len(valloader))
    
    return train_loaders, val_loaders, num_classes, dataset_names


def prepare_val_data_loaders(dataset_names, data_dir, imdb_dir):
    val_loaders = []
    val = [0]

    imdb_names_val = [imdb_dir + '/' + dataset_names[i] + '_test_stripped.json' for i in range(len(dataset_names))]
    imdb_names = [imdb_names_val]

    with open(os.path.join(data_dir, 'decathlon_mean_std.pickle'), 'rb') as handle:
        dict_mean_std = pickle.load(handle, encoding="iso-8859-1")
    
    for i in range(len(dataset_names)):
        imgnames_val = []
        for itera1 in val:
            annFile = imdb_names[itera1][i]
            coco = COCO(annFile)
            imgIds = coco.getImgIds()
            annIds = coco.getAnnIds(imgIds=imgIds)
            anno = coco.loadAnns(annIds)
            images = coco.loadImgs(imgIds) 
            timgnames = [img['file_name'] for img in images]
            timgnames_id = [img['id'] for img in images]

            imgnames = []
            for j in range(len(timgnames)):
                imgnames.append((data_dir + '/' + timgnames[j],timgnames_id[j]))

            if itera1 in val:
                imgnames_val += imgnames

        means = dict_mean_std[dataset_names[i] + 'mean']
        stds = dict_mean_std[dataset_names[i] + 'std']

        if dataset_names[i] in ['gtsrb', 'omniglot','svhn']: 
            transform_test = transforms.Compose([
            transforms.Resize(72),
            transforms.CenterCrop(72),
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])
        elif dataset_names[i] in ['daimlerpedcls']:
            transform_test = transforms.Compose([
            transforms.Resize(72),
            transforms.ToTensor(),
            transforms.Normalize(means, stds),
            ])
        else:
            transform_test = transforms.Compose([
                transforms.Resize(72),
            	transforms.CenterCrop(72),
                transforms.ToTensor(),
                transforms.Normalize(means, stds),
            ])
        
        img_path = data_dir
        valloader = torch.utils.data.DataLoader(ImageFolder(data_dir, transform_test, None, None, None, imgnames_val), batch_size=100, shuffle=False, num_workers=4, pin_memory=True)
        val_loaders.append(valloader) 
    
    return val_loaders 

def get_submit_json(args, model_union):
    datasets = [
        "aircraft",
        "cifar100",
        "daimlerpedcls",
        "dtd",
        "gtsrb",
        "omniglot",
        "svhn",
        "ucf101",
        "vgg-flowers"
        ]
    datasets2num = {
        "aircraft" : 0,
        "cifar100" : 1,
        "daimlerpedcls" : 2,
        "dtd" : 3,
        "gtsrb" : 4,
        "omniglot" : 6,
        "svhn" : 7,
        "ucf101" : 8,
        "vgg-flowers" : 9
    }
    val_loaders = prepare_val_data_loaders(
        datasets,
        args.dataset_name,
        os.path.join(args.dataset_name, "annotations/")
        )
    results = []
    for idx, model in enumerate(model_union.models[1:]):
        if model is None:
            continue
        model.eval()
        val_loader = val_loaders[idx]
        for img, (target, image_ids) in val_loader:
            img, target = img.cuda(args.local_rank), target.cuda(args.local_rank)
            pre = model(img)
            
            _, predicted = torch.max(pre, 1)
            for image_idx, prediction in enumerate(predicted.data.cpu().numpy()):
                res_dict = {}
                res_dict['category_id'] = int(10e6 * (datasets2num[datasets[idx]]+1) + prediction + 1)
                res_dict['image_id'] = image_ids.data.cpu().numpy()[image_idx]
                results.append(res_dict)
    f =  "./results.json"
    with open(f, 'wb') as fh:
        json.dump(results, fh)  
