import torch
import torch.nn.functional as F
import torch.utils.cpp_extension
import numpy as np
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Union, Sequence, Mapping, Any, Optional
import lightning.pytorch as pl

from mcu.gallary.dataset import (
    DummyDataset, 
    RawDataset, 
    InstructionFollowingDataset, 
    ExtensionDataset, 
    ExtConcatDataset,
    PreferenceDataset, 
    AuxilaryFunctions, 
    write_video
)

class MineRLDataModule(pl.LightningDataModule):
    
    def __init__(
        self, 
        mode: Union['raw', 'conditioned'] = 'raw',
        batch_size: int = 8, 
        num_workers: int = 8,
        train_shuffle: bool = True,
        prefetch_factor: int = 4,
        # below are parameters for dataset manager
        dataset_dirs: List[str] = [], 
        enable_video: bool = True,
        enable_action: bool = True,
        enable_clip: bool = False,
        enable_contractor_info: bool = False,
        frame_width: int = 128, 
        frame_height: int = 128,
        decode_library: Union['pyav', 'opencv'] = 'pyav',
        # below are parameters for extension dataset
        extension_dirs: List[str] = [],
        sample_mode: Union['balance', 'uniform'] = 'uniform',
        samples_per_goal: int = 10000, 
        padding_left: int = 128, 
        padding_right: int = 20, 
        win_len: int = 128,
        skip_frame: int = 1,
        split_ratio: int = 0.9, 
        split_method: str = 'episode',
        goal_list: List[str] = [],
        fixed_start: bool = False,
        # below are parameters for preference dataset
        positive_goals: Union[str, List] = [], 
        negative_goals: Union[str, List] = [],
        num_preferences: int = 100,
        # below are parameters for instruction-following dataset
        instruction_following_args: Optional[Dict] = None,
        **unused_kwargs,
    ) -> None:
        
        super().__init__()
        self.mode = mode
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_shuffle = train_shuffle
        self.prefetch_factor = prefetch_factor
        self.manager_kwargs = dict(
            dataset_dirs = dataset_dirs,
            enable_video = enable_video,
            enable_action = enable_action,
            enable_clip = enable_clip,
            enable_contractor_info = enable_contractor_info,
            frame_width = frame_width,
            frame_height = frame_height,
            decode_library = decode_library,
        )
        self.raw_dataset_kwargs = dict(
            win_len = win_len,
            skip_frame = skip_frame,
            split_ratio = split_ratio,
            split_method = split_method,
        )
        self.extension_dataset_kwargs = dict(
            extension_dirs = extension_dirs,
            sample_mode = sample_mode,
            samples_per_goal = samples_per_goal,
            padding_left = padding_left,
            padding_right = padding_right,
            win_len = win_len,
            skip_frame = skip_frame,
            split_ratio = split_ratio,
            split_method = split_method,
            goal_list = goal_list,
            fixed_start = fixed_start,
        )
        
        if mode in ['conditioned', 'ext_concat']:
            self.dataset_kwargs = self.extension_dataset_kwargs
        elif mode == 'raw':
            self.dataset_kwargs = self.raw_dataset_kwargs
        elif mode == 'dummy':
            self.dataset_kwargs = dict(
                win_len = win_len,
                **unused_kwargs, 
                num_samples=10000,
            )
        elif mode == 'preference':
            self.dataset_kwargs = dict(
                positive_goals = positive_goals,
                negative_goals = negative_goals,
                num_preferences = num_preferences,
                **self.extension_dataset_kwargs,
            )
        
        self.instruction_following_args = instruction_following_args

    def prepare_data(self):
        pass

    def setup(self, stage: Optional[str] = None):
        print("Setup the MineRL Data Module. ")
        if self.mode in ['conditioned']:
            if stage == 'fit':
                self.train_dataset = ExtensionDataset(**self.dataset_kwargs, split_type='train', **self.manager_kwargs)
            self.val_dataset = ExtensionDataset(**self.dataset_kwargs, split_type='val', **self.manager_kwargs)
            self.collate_fn = AuxilaryFunctions.collate_fn
        elif self.mode == 'raw':
            if stage == 'fit':
                self.train_dataset = RawDataset(**self.dataset_kwargs, split_type='train', **self.manager_kwargs)
            self.val_dataset = RawDataset(**self.dataset_kwargs, split_type='val', **self.manager_kwargs)
            self.collate_fn = AuxilaryFunctions.collate_fn
        elif self.mode == 'dummy':
            self.train_dataset = DummyDataset(**self.manager_kwargs, **self.dataset_kwargs)
            self.val_dataset = DummyDataset(**self.manager_kwargs, **self.dataset_kwargs)
            self.collate_fn = AuxilaryFunctions.collate_fn
        elif self.mode == 'ext_concat':
            if stage == 'fit':
                self.train_dataset = ExtConcatDataset(**self.dataset_kwargs, split_type='train', **self.manager_kwargs)
            self.val_dataset = ExtConcatDataset(**self.dataset_kwargs, split_type='val', **self.manager_kwargs)
            self.collate_fn = AuxilaryFunctions.collate_fn
        elif self.mode == 'preference':
            if stage == 'fit':
                self.train_dataset = PreferenceDataset(**self.dataset_kwargs, split_type='train', **self.manager_kwargs)
            self.val_dataset = PreferenceDataset(**self.dataset_kwargs, split_type='val', **self.manager_kwargs)
            self.collate_fn = None
        else:
             raise ValueError(f"Unknown mode: {self.mode}")
        
        if self.instruction_following_args is not None:
            self.if_val_dataset = InstructionFollowingDataset(**self.instruction_following_args)

    
    def train_dataloader(self):
        train_dataloader = DataLoader(
            dataset = self.train_dataset,
            batch_size = self.batch_size,
            num_workers = self.num_workers,
            shuffle = self.train_shuffle,
            collate_fn = self.collate_fn,
            pin_memory = True,
            prefetch_factor = self.prefetch_factor,
        )
        return train_dataloader

    def val_dataloader(self):
        res_dataloaders = []
        val_dataloader = DataLoader(
            dataset = self.val_dataset,
            batch_size = self.batch_size,
            num_workers = self.num_workers,
            shuffle = False,
            collate_fn = self.collate_fn,
            pin_memory = True,
            prefetch_factor = self.prefetch_factor,
        )
        res_dataloaders.append(val_dataloader)
        if hasattr(self, 'if_val_dataset'):
            if_val_dataloader = DataLoader(
                dataset = self.if_val_dataset,
                batch_size = self.batch_size,
                num_workers = self.num_workers,
                shuffle = False,
                collate_fn = self.collate_fn,
                pin_memory = True,
                prefetch_factor = self.prefetch_factor,
            )
            res_dataloaders.append(if_val_dataloader)
        return res_dataloaders
