import os
import random
import numpy as np

import sys
from pathlib import Path
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
path_root = Path(__file__).parents[1]
print(path_root)
sys.path.append(str(path_root))

sys.path.append('../exlib/src')
import exlib
from exlib.datasets.abdomen_organs import *
from datasets import load_dataset


import torch
from torch.utils.data import random_split
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
from transformers import ViTImageProcessor
from .imagenet import ImageNetDatasetSmall, MaskedImageNetDatasetSmall, \
        ImageNetDataset, MaskedImageNetDataset, \
        ImageNetDatasetMedium, MaskedImageNetDatasetMedium
from .fgvc import MaskedFgvcDataset
from .cosmogrid import CosmogridDataset
from .abdomen import AbdomenDataset
from .pascal_voc07 import PascalVoc07Dataset
from exlib.datasets.eraser_movies import EraserMovies

from utils import convert_idx_masks_to_bool


# filter imagenet classes
def filter_classes(dataset, selected_classes):
    indices = []
    for idx, (_, class_label) in enumerate(dataset.imgs):
        if class_label in selected_classes:
            indices.append(idx)
    return Subset(dataset, indices)

def get_datasets(dataset_name, processor=None, debug=False, transform=None, 
                 need_cls=True, train_size=-1, val_size=-1, label2id=None, 
                 class_type='variants', val_only=False, video_split=False, mode='train'):
    if transform is None:
        def transform(image):
            # Preprocess the image using the ViTImageProcessor
            image = image.convert("RGB")
            if processor is not None:
                inputs = processor(image, return_tensors='pt')
                return inputs['pixel_values'].squeeze(0)
            else:
                return np.asarray(image)

    if dataset_name == 'multirc':
        def transform(batch):
            # Preprocess the image using the ViTImageProcessor
            if processor is not None:
                inputs = processor(batch['passage'], 
                                   batch['query_and_answer'], 
                                   padding='max_length', 
                                   truncation=True, 
                                   max_length=512)
                return {k: torch.tensor(v) for k, v in inputs.items()}
            else:
                return batch
        if not val_only:
            train_dataset = load_dataset('eraser_multi_rc', split='train')
            if processor is not None:
                train_dataset = train_dataset.map(transform, batched=True,
                                            remove_columns=['passage', 
                                                            'query_and_answer',
                                                            'evidences'])
        else:
            train_dataset = None
        val_dataset = load_dataset('eraser_multi_rc', split='validation')

        if processor is not None:
            val_dataset = val_dataset.map(transform, batched=True,
                                        remove_columns=['passage', 
                                                        'query_and_answer',
                                                        'evidences'])

        if train_size != -1:
            train_dataset = torch.utils.data.Subset(train_dataset, 
                                                    list(range(train_size)))
        if val_size != -1:
            val_dataset = torch.utils.data.Subset(val_dataset, 
                                                  list(range(val_size)))
    elif dataset_name == 'movies':
        def transform(batch):
            # Preprocess the image using the ViTImageProcessor
            if processor is not None:
                inputs = processor(batch['passage'], 
                                   batch['query'], 
                                   padding='max_length', 
                                   truncation=True, 
                                   max_length=512)
                batch_new = {k: torch.tensor(v) for k, v in inputs.items()}
                batch_new['label'] = batch['label']
                # batch_new['evidences'] = batch['evidences']
                return batch_new
            else:
                return batch
        if not val_only:
            train_dataset = EraserMovies('../datasets', split='train',
                                         transform=transform, data_size=train_size)
        else:
            train_dataset = None
        val_dataset = EraserMovies('../datasets', split='val',
                                         transform=transform, data_size=val_size)
        
    elif dataset_name == 'voc':
        if not val_only:
            train_dataset = PascalVoc07Dataset(root_dir='../datasets/VOCdevkit/VOC2007', 
                                            split='train', transform=transform, data_size=train_size, mode=mode)
        if mode == 'train':
            val_dataset = PascalVoc07Dataset(root_dir='../datasets/VOCdevkit/VOC2007', 
                                            split='val', transform=transform, data_size=val_size, mode=mode)
        else:
            val_dataset = PascalVoc07Dataset(root_dir='../datasets/VOCdevkit/VOC2007', 
                                            split='test', transform=transform, data_size=val_size, mode=mode)
    elif dataset_name == 'cosmogrid':
        data_dir = '../datasets/cosmogrid'
        if not val_only:
            train_dataset = CosmogridDataset(data_dir, 
                                             split='train',
                                             data_size=train_size)
        else:
            train_dataset = None
        val_dataset = CosmogridDataset(data_dir, 
                                       split='val',
                                       data_size=val_size)
    elif dataset_name == 'abdomen_organ':
        ABDX_PATH = "../datasets/abdomen_exlib"
        if not video_split:
            train_dataset = AbdomenOrgans(ABDX_PATH, 
                                          image_height=352,
                                          image_width=640,
                                          split="train",
                                          return_idx=True)
            val_dataset = AbdomenOrgans(ABDX_PATH, 
                                        image_height=352,
                                        image_width=640,
                                        split="test",
                                        return_idx=True)
        else:
            train_dataset = AbdomenOrgans(ABDX_PATH, split="train_video")
            val_dataset = AbdomenOrgans(ABDX_PATH, split="test_video")
    elif dataset_name == 'abdomen':
        # Some relevant loaders
        images_path = "../datasets/gonogo/images"
        labels_path = '../datasets/gonogo/masks'

        full_dataset = AbdomenDataset(images_path, labels_path)
        train_size_full = int(0.8 * len(full_dataset))
        val_size_full = len(full_dataset) - train_size_full
        train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, 
                                                                   [train_size_full, 
                                                                    val_size_full])
        if train_size != -1:
            train_dataset = torch.utils.data.Subset(train_dataset, 
                                                    list(range(train_size)))
        if val_size != -1:
            val_dataset = torch.utils.data.Subset(val_dataset, 
                                                  list(range(val_size)))
        print('train_size', train_size)
        print('val_size', val_size)
        # import pdb
        # pdb.set_trace()
    elif dataset_name == 'mnist':    
        dataset = datasets.MNIST(root='/scratch/tmp/data', train=True, 
                                download=True, 
                                transform=transform)
        train_dataset, val_dataset = random_split(dataset, [50000, 10000])
    elif dataset_name == 'imagenet':
        data_dir = '/scratch/vision_datasets/ImageNet_old/'
        if not val_only:
            train_dataset = ImageNetDataset(data_dir, split='train', 
                                            transform=transform, 
                                            debug=debug, 
                                            need_cls=need_cls, 
                                            data_size=train_size,
                                            label2id=label2id)
        else:
            train_dataset = None
        val_dataset = ImageNetDataset(data_dir, split='val', 
                                      transform=transform, 
                                      debug=debug, 
                                      need_cls=need_cls,
                                      data_size=val_size,
                                      label2id=label2id)
    elif dataset_name == 'imagenet_m':
        data_dir = '/scratch/vision_datasets/ImageNet_old/'
        if not val_only:
            train_dataset = ImageNetDatasetMedium(data_dir, split='train', 
                                            transform=transform, 
                                            debug=debug, 
                                            need_cls=need_cls, 
                                            data_size=train_size,
                                            label2id=label2id)
        else:
            train_dataset = None
        val_dataset = ImageNetDatasetMedium(data_dir, split='val', 
                                      transform=transform, 
                                      debug=debug, 
                                      need_cls=need_cls,
                                      data_size=val_size,
                                      label2id=label2id)
    elif dataset_name == 'imagenet_s':
        data_dir = '/scratch/vision_datasets/ImageNet_old/imagenet1k_select10_vit'

        if not val_only:
            train_dataset = ImageNetDatasetSmall(data_dir, split='train', 
                                            transform=transform, 
                                            debug=debug, 
                                            need_cls=need_cls, 
                                            data_size=train_size)
        else:
            train_dataset = None
        val_dataset = ImageNetDatasetSmall(data_dir, split='val', 
                                      transform=transform, 
                                      debug=debug, 
                                      need_cls=need_cls,
                                      data_size=val_size)
    elif dataset_name == 'fgvc':
        data_dir = '../datasets/fgvc-aircraft-2013b'

        if not val_only:
            train_dataset = MaskedFgvcDataset(data_dir, split='train', 
                                            transform=transform, 
                                            debug=debug, 
                                            need_cls=need_cls, 
                                            data_size=train_size,
                                            class_type=class_type)
        else:
            train_dataset = None
        val_dataset = MaskedFgvcDataset(data_dir, split='val', 
                                      transform=transform, 
                                      debug=debug, 
                                      need_cls=need_cls,
                                      data_size=val_size,
                                      class_type=class_type)
    else:
        raise ValueError(f'Unrecognized dataset {dataset_name}')
    
    if not val_only:
        return train_dataset, val_dataset
    else:
        return val_dataset


def get_masked_datasets(dataset_name, processor, mask_dir, 
                        seg_mask_cut_off=2,
                        debug=False, 
                        transform=None, mask_transform=None, need_cls=True, 
                        train_size=-1, val_size=-1, label2id=None,
                        class_type='variants', val_only=False
                        ):
    seg_mask_cut_off = int(seg_mask_cut_off)
    if seg_mask_cut_off <= 0:
        raise ValueError('input_mask_cut_off needs to be at least 1')
    if transform is None:
        def transform(image):
            # Preprocess the image using the ViTImageProcessor
            image = image.convert("RGB")
            if processor is not None:
                inputs = processor(image, return_tensors='pt')
                return inputs['pixel_values'].squeeze(0)
            else:
                return np.asarray(image)
    
    if mask_transform is None:
        def mask_transform(mask):
            # Preprocess the mask using the ViTImageProcessor
            if len(mask.shape) == 2 and mask.dtype == torch.bool:
                # print('mask2')
                mask_dim1, mask_dim2 = mask.shape
                # mask = torch.tensor(mask)
                mask = mask.unsqueeze(0).expand(3, 
                                                mask_dim1, 
                                                mask_dim2).float()
                if processor is not None:
                    inputs = processor(mask, 
                                    do_rescale=False, 
                                    do_normalize=False,
                                    return_tensors='pt')
                    # (1, 3, 224, 224)
                    return inputs['pixel_values'][0][0]
                else:
                    return mask
            else: # len(mask.shape) == 3
                # print('mask3')
                if mask.dtype != torch.bool:
                    if len(mask.shape) == 2:
                        mask = mask.unsqueeze(0)
                    mask = convert_idx_masks_to_bool(mask)
                bsz, mask_dim1, mask_dim2 = mask.shape
                # mask = torch.tensor(mask)

                
                # import pdb
                # pdb.set_trace()
                mask = mask.unsqueeze(1).expand(bsz, 
                                                3, 
                                                mask_dim1, 
                                                mask_dim2).float()
                
                # mask = mask
                # import pdb
                # pdb.set_trace()
                if bsz < seg_mask_cut_off:
                    repeat_count = seg_mask_cut_off // bsz + 1
                    mask = torch.cat([mask] * repeat_count, dim=0)
                
                # add additional mask afterwards
                mask_sum = torch.sum(mask[:seg_mask_cut_off - 1], dim=0, keepdim=True).bool()
                if False in mask_sum:
                    mask = mask[:seg_mask_cut_off - 1]
                    compensation_mask = (1 - mask_sum.int()).bool()
                    mask = torch.cat([mask, compensation_mask])
                else:
                    mask = mask[:seg_mask_cut_off]

                if processor is not None:
                    inputs = processor(mask, 
                                    do_rescale=False, 
                                    do_normalize=False,
                                    return_tensors='pt')
                # (bsz, 3, 224, 224)
                # import pdb
                # pdb.set_trace()
                
                
                # print('masks[:input_mask_cut_off,0]', masks[:input_mask_cut_off,0].shape)
                    return inputs['pixel_values'][:,0]
                else:
                    return mask[:,0]
        
    if dataset_name == 'cosmogrid':
        data_dir = '../datasets/cosmogrid'
        if not val_only:
            train_dataset = CosmogridDataset(data_dir, 
                                             split='train',
                                             data_size=train_size,
                                             mask_suffix=mask_dir,
                                             mask_transform=mask_transform)
        else:
            train_dataset = None
        val_dataset = CosmogridDataset(data_dir, 
                                       split='val',
                                       data_size=val_size,
                                       mask_suffix=mask_dir,
                                       mask_transform=mask_transform)
    elif dataset_name == 'imagenet':
        data_dir = '/scratch/vision_datasets/ImageNet_old'

        if not val_only:
            train_dataset = MaskedImageNetDataset(data_dir, mask_dir,
                                                split='train', 
                                                transform=transform, 
                                                mask_transform=mask_transform,
                                                debug=debug, 
                                                need_cls=True,
                                                data_size=train_size,
                                                label2id=label2id)
        else:
            train_dataset = None
        val_dataset = MaskedImageNetDataset(data_dir, mask_dir,
                                            split='val', 
                                            transform=transform, 
                                            mask_transform=mask_transform,
                                            debug=debug, 
                                            need_cls=True,
                                            data_size=val_size,
                                            label2id=label2id)
    elif dataset_name == 'imagenet_m':
        data_dir = '/scratch/vision_datasets/ImageNet_old'

        if not val_only:
            train_dataset = MaskedImageNetDatasetMedium(data_dir, mask_dir,
                                                    split='train', 
                                                    transform=transform, 
                                                    mask_transform=mask_transform,
                                                    debug=debug, 
                                                    need_cls=True,
                                                    data_size=train_size,
                                                    label2id=label2id)
        else:
            train_dataset = None
        val_dataset = MaskedImageNetDatasetMedium(data_dir, mask_dir,
                                                split='val', 
                                                transform=transform, 
                                                mask_transform=mask_transform,
                                                debug=debug, 
                                                need_cls=True,
                                                data_size=val_size,
                                                label2id=label2id)
    elif dataset_name == 'imagenet_s':
        data_dir = '/scratch/vision_datasets/ImageNet_old/imagenet1k_select10_vit'

        if not val_only:
            train_dataset = MaskedImageNetDatasetSmall(data_dir, mask_dir,
                                                split='train', 
                                                transform=transform, 
                                                mask_transform=mask_transform,
                                                debug=debug, 
                                                need_cls=True,
                                                data_size=train_size)
        else:
            train_dataset = None
        val_dataset = MaskedImageNetDatasetSmall(data_dir, mask_dir,
                                            split='val', 
                                            transform=transform, 
                                            mask_transform=mask_transform,
                                            debug=debug, 
                                            need_cls=True,
                                            data_size=val_size)
    elif dataset_name == 'fgvc':
        data_dir = '../datasets/fgvc-aircraft-2013b'

        if not val_only:
            train_dataset = MaskedFgvcDataset(data_dir, mask_dir,
                                            split='train', 
                                            transform=transform, 
                                            mask_transform=mask_transform,
                                            debug=debug, 
                                            need_cls=need_cls, 
                                            data_size=train_size,
                                            class_type=class_type)
        else:
            train_dataset = None
        val_dataset = MaskedFgvcDataset(data_dir, mask_dir, 
                                        split='val', 
                                        transform=transform, 
                                        mask_transform=mask_transform,
                                        debug=debug, 
                                        need_cls=need_cls,
                                        data_size=val_size,
                                        class_type=class_type)
    else:
        raise ValueError(f'Unrecognized dataset {dataset_name}')

    if not val_only:
        return train_dataset, val_dataset
    else:
        return val_dataset