import os

import torch
from torchvision.datasets import CIFAR10

from avalanche.benchmarks import nc_benchmark
from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.benchmarks.datasets.multi_label_dataset.voc import MultiLabelVOC,MultiLabelVOC2
from avalanche.benchmarks.datasets.multi_label_dataset.nus_wide import NUS_WIDE
from avalanche.benchmarks.datasets.multi_label_dataset.coco.mycoco import MultiLabelCOCO
from torchvision.transforms import transforms
from avalanche.benchmarks.scenarios.generic_benchmark_creation import create_multi_dataset_generic_benchmark
from avalanche.benchmarks.utils import make_classification_dataset
import random
from pycocotools.coco import COCO


usr_root = os.path.expanduser("~")


dataset_root = usr_root + "/data/Datasets/coco2017/"

traindata_root = dataset_root + 'train/data/'
train_annFile = dataset_root + 'annotations/instances_train2017.json'

valdata_root = dataset_root + 'validation/data/'
val_annFile = dataset_root + 'annotations/instances_val2017.json'
global coco_train
global coco_test


def MultiLabelBenchmark():
    train_trans = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize([300, 300]),
    ])
    voc_train_dataset = MultiLabelVOC(
        root= usr_root + "/data/Datasets/VOC/",
        year="2012",
        image_set="train",
        transform=train_trans
    )
    voc_test_dataset = MultiLabelVOC(
        root= usr_root + "/data/Datasets/VOC/",
        year="2012",
        image_set="val",
        transform=train_trans
    )

    nus_train_dataset = NUS_WIDE(transforms=train_trans, train=True)
    nus_test_dataset = NUS_WIDE(transforms=train_trans, train=False)

    coco_train_dataset = MultiLabelCOCO(train=True, transform=train_trans)
    coco_test_dataset = MultiLabelCOCO(train=False, transform=train_trans)

    benchmark = create_multi_dataset_generic_benchmark(
        train_datasets=(voc_train_dataset, nus_train_dataset, coco_train_dataset),
        test_datasets=(voc_test_dataset, nus_test_dataset, coco_test_dataset),
        task_labels=[0,1,2]
    )

    return benchmark


def multi_label_batchlearning_benchmark(dataset_name,seed):
    train_trans = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Resize([300,300]),
        transforms.Resize([224, 224]),
        ]
    )
    if dataset_name == "voc":
        train_dataset = MultiLabelVOC(
            root= usr_root + "/data/Datasets/VOC/",
            year="2012",
            image_set="train",
            transform=train_trans
        )
        val_dataset = MultiLabelVOC(
            root=usr_root + "/data/Datasets/VOC/",
            year="2012",
            image_set="test",
            transform=train_trans
        )
        test_dataset = MultiLabelVOC(
            root= usr_root + "/data/Datasets/VOC/",
            year="2012",
            image_set="val",
            transform=train_trans
        )
    elif dataset_name == "nus-wide":
        train_dataset = NUS_WIDE(imageset="train",transforms=train_trans)
        val_dataset = NUS_WIDE(imageset="val", transforms=train_trans)
        test_dataset = NUS_WIDE(imageset="test", transforms=train_trans)
    elif dataset_name == "coco":
        train_dataset = MultiLabelCOCO(image_set="train", transform=train_trans)
        val_dataset = MultiLabelCOCO(image_set="val", transform=train_trans)
        test_dataset = MultiLabelCOCO(image_set="test", transform=train_trans)
    else:
        raise NotImplementedError

    print(len(train_dataset),len(val_dataset),len(test_dataset))
    benchmark = create_multi_dataset_generic_benchmark(
        train_datasets=(train_dataset,),
        test_datasets=(test_dataset,),
        other_streams_datasets={"val":(val_dataset,)},
        task_labels=[0,]
    )

    return benchmark


def get_onehot(target):
    one_hot = torch.zeros(10)
    one_hot[target] = 1
    return one_hot


class OneCIFAR10(CIFAR10):
    def __init__(self,
                 root,
                 train=True,
                 transform=None,
                 target_transform=None,
                 download=False,
                 ):
        super().__init__(root, train, transform, target_transform, download)
        self.targets = [get_onehot(e) for e in self.targets]


def cifar10_batchlearning_benchmark():
    train_trans = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize([224, 224]),
        ])

    train_dataset = OneCIFAR10(default_dataset_location("cifar10"),train=True, transform=train_trans,download=True)
    test_dataset = OneCIFAR10(default_dataset_location("cifar10"),train=False, transform=train_trans,download=True)


    print(len(train_dataset),len(test_dataset))
    benchmark = create_multi_dataset_generic_benchmark(
        train_datasets=(train_dataset,),
        test_datasets=(test_dataset,),
        task_labels=[0,]
    )
    return benchmark
