import pickle
import os

import numpy as np
import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader
import torch.distributed as dist

from .sampler_ddp import RandomIdentitySampler_DDP
from .cifar import get_cifar_100_datasets, get_cifar_10_datasets
from .imagenet import get_imagenet_100_datasets
from utils.datasets import *

# OSR Split dir
osr_split_dir = './data/ssb_splits'

__factory = {
    'Cifar10': get_cifar_10_datasets,
    'Cifar100': get_cifar_100_datasets,
    'ImageNet100': get_imagenet_100_datasets,
}

def train_collate_fn(batch):
    imgs, pids, idx= zip(*batch)
    pids = torch.tensor(pids, dtype=torch.int64)
    return torch.stack(imgs, dim=0), pids, idx

def train_mask_collate_fn_cl(batch):
    images, pids, idx, mask_lab= zip(*batch)
    samples, transforms = [item[0] for item in images], [item[1] for item in images]
    imgs_sample, imgs_transform = torch.stack(samples, dim=0), torch.stack(transforms, dim=0)
    pids = torch.tensor(pids, dtype=torch.int64)
    mask_lab = torch.tensor(np.array(mask_lab), dtype=torch.int64)
    return [imgs_sample, imgs_transform], pids, idx, mask_lab

def val_collate_fn(batch):##### revised by luo
    imgs, pids, idx = zip(*batch)
    pids = torch.tensor(pids, dtype=torch.int64)
    return torch.stack(imgs, dim=0), pids, idx

def make_dataloader(cfg, args, train_transforms=None, test_transforms=None):
    if train_transforms is None:
        train_transforms = T.Compose([
            T.Resize((256, 256)),
            T.RandomCrop((224, 224)),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        test_transforms = T.Compose([
            T.Resize((256, 256)),
            T.CenterCrop((224, 224)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    num_workers = cfg.DATALOADER.NUM_WORKERS
    dataset = __factory[cfg.DATASETS.NAMES](
        train_transforms, 
        test_transforms, 
        common_classes=args.common_classes, 
        private_classes=args.private_classes, 
        prop_train_labels=cfg.DATASETS.PROP_TRAIN_LABELS, 
        split_train_val=False, 
        seed=0)

    if cfg.MODEL.UNINCD_STAGE.startswith('UNINCD'):
        dataset.train_labelled.transform = test_transforms
        img_num1 = len(dataset.train_labelled)
        dataset.train_unlabelled.transform = test_transforms
        img_num2 = len(dataset.train_unlabelled)
    elif cfg.MODEL.UNINCD_STAGE.endswith('_MASK'):
        dataset.train_mask.transform = test_transforms
        dataset.train_unlabelled.transform = test_transforms

    num_known_classes, num_unknown_classes = dataset.num_known_classes, dataset.num_unknown_classes
    all_classes = set(dataset.common_classes).union(set(dataset.private_classes)).union(set(dataset.unlabeled_classes))
    num_classes = len(all_classes)
    if cfg.MODEL.DIST_TRAIN:
        print('DIST_TRAIN START')
        mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size()
        data_sampler = RandomIdentitySampler_DDP(dataset.train_labelled, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE)
        batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True)
        train_loader = DataLoader(
            dataset.train_labelled,
            num_workers=num_workers,
            batch_sampler=batch_sampler,
            collate_fn=train_collate_fn,
            pin_memory=True,
        )
    elif cfg.MODEL.UNINCD_STAGE.startswith('UNINCD'):
        train_loader1 = DataLoader(
            dataset.train_labelled,
            num_workers=num_workers,
            batch_size=cfg.SOLVER.IMS_PER_BATCH,
            shuffle=False,
            collate_fn=train_collate_fn,
            pin_memory=True,
            persistent_workers=True,
        )
        train_loader2 = DataLoader(
            dataset.train_unlabelled,
            num_workers=num_workers,
            batch_size=cfg.SOLVER.IMS_PER_BATCH,
            shuffle=False,
            collate_fn=train_collate_fn,
            pin_memory=True,
            persistent_workers=True,
        )
    elif cfg.MODEL.UNINCD_STAGE == 'pretrain_MASK':
        label_len = len(dataset.train_labelled)
        unlabelled_len = len(dataset.train_unlabelled)
        sample_weights = [1 if i < label_len else label_len / unlabelled_len for i in range(len(dataset.train_mask))]
        sample_weights = torch.DoubleTensor(sample_weights)
        weighted_sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(dataset.train_mask))
        train_loader = DataLoader(
            dataset.train_mask,
            num_workers=num_workers,
            batch_size=cfg.SOLVER.IMS_PER_BATCH,
            sampler=weighted_sampler,
            shuffle=False,
            collate_fn=train_mask_collate_fn_cl,
            drop_last=True
        )
        train_loader_unlabelled = DataLoader(
            dataset.train_unlabelled,
            num_workers=num_workers,
            batch_size=cfg.SOLVER.IMS_PER_BATCH,
            shuffle=False,
            collate_fn=train_collate_fn
        )

    # val_set = dataset.test if cfg.MODEL.UNINCD_STAGE.startswith('UNINCD') else dataset.train_labelled
    test_loader = DataLoader(
        dataset.test, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers,
        collate_fn=val_collate_fn
    )
    if cfg.MODEL.UNINCD_STAGE.startswith('UNINCD'):
        return train_loader1, train_loader2, test_loader,\
            num_known_classes, num_unknown_classes, num_classes,\
            img_num1, img_num2
    else:
        return train_loader, train_loader_unlabelled, test_loader, num_known_classes, num_unknown_classes, num_classes



def get_class_splits(args):

    # For FGVC datasets, optionally return bespoke splits
    if args.dataset_name in ('scars', 'cub', 'aircraft'):
        if hasattr(args, 'use_ssb_splits'):
            use_ssb_splits = args.use_ssb_splits
        else:
            use_ssb_splits = False

    # -------------
    # GET CLASS SPLITS
    # -------------
    if args.dataset_name == 'Cifar10':

        args.private_classes = range(2)
        args.common_classes = range(2, 8)

    elif args.dataset_name == 'Cifar100':

        args.private_classes = range(20)
        args.common_classes = range(20, 80)

    elif args.dataset_name == 'ImageNet100':

        args.private_classes = range(20)
        args.common_classes = range(20, 80)

    elif args.dataset_name == 'tinyimagenet':

        args.private_classes = range(40)
        args.common_classes = range(40, 160)

    elif args.dataset_name == 'herbarium_19':

        herb_path_splits = os.path.join(osr_split_dir, 'herbarium_19_class_splits.pkl')

        with open(herb_path_splits, 'rb') as handle:
            class_splits = pickle.load(handle)

        args.private_classes = class_splits['Old']
        args.common_classes = class_splits['New']

    
    elif args.dataset_name == 'scars':

        if use_ssb_splits:

            split_path = os.path.join(osr_split_dir, 'scars_osr_splits.pkl')
            with open(split_path, 'rb') as handle:
                class_info = pickle.load(handle)

            args.private_classes = class_info['known_classes']
            open_set_classes = class_info['unknown_classes']
            args.common_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']

        else:

            args.private_classes = range(98)
            args.common_classes = range(98, 196)

    elif args.dataset_name == 'aircraft':

        if use_ssb_splits:

            split_path = os.path.join(osr_split_dir, 'aircraft_osr_splits.pkl')
            with open(split_path, 'rb') as handle:
                class_info = pickle.load(handle)

            args.private_classes = class_info['known_classes']
            open_set_classes = class_info['unknown_classes']
            args.common_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']

        else:

            args.private_classes = range(50)
            args.common_classes = range(50, 100)

    elif args.dataset_name == 'cub':

        if use_ssb_splits:

            split_path = os.path.join(osr_split_dir, 'cub_osr_splits.pkl')
            with open(split_path, 'rb') as handle:
                class_info = pickle.load(handle)

            args.private_classes = class_info['known_classes']
            open_set_classes = class_info['unknown_classes']
            args.common_classes = open_set_classes['Hard'] + open_set_classes['Medium'] + open_set_classes['Easy']

        else:

            args.private_classes = range(100)
            args.common_classes = range(100, 200)

    elif args.dataset_name == 'chinese_traffic_signs':

        args.private_classes = range(28)
        args.common_classes = range(28, 56)

    else:
        raise NotImplementedError

    return args