import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
import tensorflow_datasets as tfds


import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

import datasets
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

from jax import numpy as jnp

from .config import DATASETS_CONFIG
from .utils import DictWrapper

import numpy as np


class UnifiedDataLoader:
    def __init__(self, dataset, batch_size, data_dir, preprocessor, framework='jax'):
        """
        dataset_name: as in DATASETS_CONFIG
        framework: "jax" or "torch"
        """
        self.batch_size = batch_size
        self.framework = framework
        self.data_dir = data_dir
        self.preprocessor = preprocessor

        data_source = DATASETS_CONFIG[dataset]['source']

        if data_source == 'tfds':
            self.dataset, self.loader = self._build_tfds_loader(dataset)
            self.loader_type = 'tfds'
        elif data_source == 'image_folder':
            self.dataset, self.loader = self._build_image_folder_loader(dataset)
            self.loader_type = 'torch'
        elif data_source == 'hf':
            self.dataset, self.loader = self._build_hf_loader(dataset)
            self.loader_type = 'torch'
        else:
            raise ValueError("Unsupported loader_type.")


    def _build_tfds_loader(self,dataset):
        tfds_name = DATASETS_CONFIG[dataset]['name']
        split = DATASETS_CONFIG[dataset]['split']

        ds = tfds.load(tfds_name, split=split, data_dir=self.data_dir)
            
        def ds_map(sample):
            return {'image': self.preprocessor.tfds_pp(sample['image']),               
                    'label': sample['label']}
        ds = ds.map(ds_map)                 

        return ds, ds.batch(self.batch_size, drop_remainder=True)

    def _build_image_folder_loader(self,dataset):
        root_folder = self.data_dir + DATASETS_CONFIG[dataset]['root_folder']
        ds = ImageFolder(root=root_folder, transform=self.preprocessor.torch_pp)
        ds = DictWrapper(ds)
        _loader = DataLoader(ds,batch_size=self.batch_size, drop_last=True)

        return ds, _loader
    
    def _build_hf_loader(self,dataset):
        hf_name = DATASETS_CONFIG[dataset]['name']
        split = DATASETS_CONFIG[dataset]['split']

        ds = datasets.load_dataset(hf_name, split=split, trust_remote_code=True,cache_dir = self.data_dir) 

        if 'label' in ds.features:
            def pp_func(x):
                return {'image': self.preprocessor.torch_pp(x['image']),
                        'label': x['label']}
        else:
            def pp_func(x):
                return {'image': self.preprocessor.torch_pp(x['image']),
                        'label': 0}
            
        ds = ds.map(pp_func)
        ds = ds.with_format("torch", columns=["image", "label"], dtype=torch.float32)
        ds.targets = [int(label) for label in ds["label"]]

        return ds, DataLoader(ds, batch_size=self.batch_size, drop_last=True)  

    def get_iterator(self,framework=None,as_dict=True):
        "Get dataset loader as framework-friendly iterator."
        if framework is None:
            framework = self.framework
        if self.loader_type == 'tfds':
            return self._tfds_iterator(self.loader,framework,as_dict)
        elif self.loader_type == 'torch':
            return self._torch_iterator(self.loader,framework,as_dict)
    
    def set_framework(self,framework):
        self.framework = framework
        return self 
        
    def __iter__(self):
        '''Loader iterator in Torch DataLoader style:
            - iterates directly over the class object
            - yields batch in x,y format
        '''
        return self.get_iterator(framework=self.framework,as_dict=False)
    
    def __len__(self):
        if self.loader_type == 'tfds':
            cardinality = tf.data.experimental.cardinality(self.loader)
            if cardinality == tf.data.experimental.INFINITE_CARDINALITY:
                raise ValueError("Dataset has infinite cardinality.")
            if cardinality == tf.data.experimental.UNKNOWN_CARDINALITY:
                raise ValueError("Dataset has unknown cardinality.")
            return cardinality.numpy() 
        else:
            return len(self.loader)

    def get_loader_by_class(self,framework=None):
        "Get loader_by_class function, such that loader_by_class(c) gives the loader for class c as iterator."
        if framework is None:
            framework = self.framework

        if self.loader_type == 'tfds':
            def loader_by_class(c):
                def filter_fn(x):
                    return x['label'] == c
                _ds_c = self.dataset.filter(filter_fn)
                return self._tfds_iterator(_ds_c.batch(self.batch_size, drop_remainder=True),framework)
        
        elif self.loader_type == 'torch':
            targets = torch.tensor(self.dataset.targets)
            def loader_by_class(c):
                target_idx = (targets==c).nonzero()
                sampler = torch.utils.data.sampler.SubsetRandomSampler(target_idx)
                return self._torch_iterator(DataLoader(self.dataset, sampler=sampler, 
                                                       batch_size=self.batch_size, drop_last=True),framework) 
        
        return loader_by_class
        
    
    def get_random_batch(self, _batch_size, framework=None):
        if framework is None:
            framework = self.framework

        if self.loader_type == 'tfds':
            shuffled_loader = self.dataset.shuffle(buffer_size=10*_batch_size).batch(_batch_size, drop_remainder=True)
            _iter = self._tfds_iterator(shuffled_loader,framework)
        elif self.loader_type == 'torch':
            targets = torch.tensor(self.dataset.targets)
            target_idx = torch.as_tensor(np.random.randint(0,targets.shape[0],size=_batch_size)) 
            sampler = torch.utils.data.sampler.SubsetRandomSampler(target_idx)
            shuffled_loader = DataLoader(self.dataset, sampler=sampler, batch_size=_batch_size, drop_last=True)
            _iter = self._torch_iterator(shuffled_loader,framework)
        return next(_iter)
    

    @staticmethod
    def _tfds_iterator(loader,framework,as_dict=True):
        "Transform a TFDS loader into a framework-friendly iterator"
        for batch in loader:
            x = batch['image']
            y = batch['label']
            if framework == 'jax':
                x = jnp.array(x,dtype=jnp.float32)
                y = jnp.array(y,dtype=jnp.uint8)
                yield {'image': x, 'label': y}
            elif framework == 'torch':
                x = torch.as_tensor(np.transpose(x.numpy(),axes=(0,3,1,2)),dtype=torch.float32)
                y = torch.as_tensor(y.numpy(),dtype=torch.uint8)
                if as_dict:
                    yield {'image': x, 'label': y}
                else:
                    yield x,y
            else:
                raise ValueError("Unknown framework. Supported options: 'jax' or 'torch'.")

    @staticmethod
    def _torch_iterator(loader,framework,as_dict=True):
         "Transform a Torch loader into a framework-friendly iterator"
         for batch in loader:
            x = batch['image']
            y = batch['label']
            if framework == 'jax':
                x = jnp.transpose(jnp.array(x.numpy(),dtype=jnp.float32), axes=(0,2,3,1))
                y = jnp.array(y.numpy(),dtype=jnp.uint8)
                yield {'image': x, 'label': y}
            elif framework == 'torch':
                x = torch.as_tensor(x.numpy(),dtype=torch.float32)
                y = torch.as_tensor(y.numpy(),dtype=torch.float32)
                if as_dict:
                    yield {'image': x, 'label': y}
                else:
                    yield x,y
            else:
                raise ValueError("Unknown framework. Supported options: 'jax' or 'torch'.")



