from datasets.utils.base_dataset import BaseDataset, get_loader
import torch
from backbones.clevrcnn import CLEVRCNN, CLEVRDECODER
from datasets.utils.cub_creation import CUBDataset


class CUB(BaseDataset):
    NAME = "cub"
    DATADIR = "data/CUB/CUB_200_2011"

    def get_data_loaders(self):
        dataset_train = CUBDataset(
            base_path="cv_datasets/cub/", split='train'
        )
        dataset_val = CUBDataset(
            base_path="cv_datasets/cub/", split='val'
        )
        dataset_test = CUBDataset(
            base_path="cv_datasets/cub/", split='test'
        )
        
        # dataset_val, dataset_test = torch.utils.data.random_split(dataset_test, [0.4, 0.6])

        self.dataset_train = dataset_train
        self.dataset_val = dataset_val
        self.dataset_test = dataset_test
        
        self.train_loader = get_loader(
            dataset_train, self.args.batch_size, val_test=False
        )
        self.val_loader = get_loader(dataset_val, self.args.batch_size, val_test=True)
        self.test_loader = get_loader(dataset_test, self.args.batch_size, val_test=True)

        return self.train_loader, self.val_loader, self.test_loader

    def get_backbone(self):
        if self.args.joint:
            return NotImplementedError()
        else:
            return CLEVRCNN(num_classes=312), CLEVRDECODER()

    def get_split(self):
        if self.args.joint:
            return 1, (312)
        else:
            return 1, (312,)

    def get_concept_labels(self):
        return [str(i) for i in range(312)]

    def get_labels(self):
        return [str(i) for i in range(200)]

    def print_stats(self):
        print("## Statistics ##")
        print("Train samples", len(self.dataset_train.data))
        print("Validation samples", len(self.dataset_val))
        print("Test samples", len(self.dataset_test))