import os
import torch.nn as nn

from continuum import ClassIncremental
from continuum.datasets import PyTorchDataset
from continuum.datasets import (
    CIFAR100
)

from PIL import Image

try:
    from torchvision.transforms import InterpolationMode

    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

import numpy as np
import json

from dataset.cifar_lt.imbalancedcifar100 import IMBALANCECIFAR100


def get_dataset(cfg, is_train=False):
    if cfg.dataset == "cifar100-lt":
        data_path = os.path.join(cfg.dataset_root, cfg.dataset)
        dataset = PyTorchDataset(
            data_path=data_path,
            dataset_type=IMBALANCECIFAR100,
            train=is_train,
            download=True,
            imb_factor=cfg.imb_factor,
            imb_type=cfg.imb_type,
        )
        classes_names = dataset.dataset.classes
        
    elif cfg.dataset == "cifar100":
        data_path = os.path.join(cfg.dataset_root, cfg.dataset)
        dataset = CIFAR100(
            data_path=data_path,
            download=True,
            train=is_train,
            # transforms=transforms
        )
        classes_names = dataset.dataset.classes
    
    else:
        ValueError(f"'{cfg.dataset}' is not implemented.")

    return dataset, classes_names


def build_cl_scenarios(cfg, is_train, transforms) -> nn.Module:
    dataset, classes_names = get_dataset(cfg, is_train)
    
    if cfg.scenario == "class":
        scenario = ClassIncremental(
            dataset,
            initial_increment=cfg.initial_increment,
            increment=cfg.increment,
            transformations=transforms.transforms,  # Convert Compose into list
            class_order=cfg.class_order
        )
    else:
        ValueError(f"You have entered `{cfg.scenario}` which is not a defined scenario, "
                   "please choose from {{'class', 'domain', 'task-agnostic'}}.")
    
    all_class_names = np.array(classes_names)
    order = np.array(cfg.class_order)
    classes_names = all_class_names[order].tolist()
  
    return scenario, classes_names
