import json
import os
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import einops
import h5pickle
import hydra
import lightning as L
import numpy as np
import scipy.spatial
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader as torch_Dataloader


from torch.utils.data import Sampler


@dataclass
class LitDataModule(L.LightningDataModule):
    sampler: Sampler
    dataset: Dataset
    
    mode: str
    n_steps: int
    
    seed: int = None
    batch_size: int = 4
    num_workers: int = 1
    pin_memory: bool = True
    persistent_workers: bool = True
    include_full_data: bool = False
    cache_dataset: bool = True
    split: Optional[str] = None
    jump_by_sequence: bool = False


    follow_batch: Optional[List[str]] = field(
        default_factory=lambda: ["enc_pos", "supernode_index"]
    )


    def __post_init__(self):
        super().__init__()

    def setup(self, stage: Optional[str] = None) -> None:
        ctor = partial(self.dataset,
                include_full_data=self.include_full_data,
                do_cache=self.cache_dataset,
                n_steps=self.n_steps,
                jump_by_sequence=self.jump_by_sequence
            )
        
        if self.split is None:
            # initialize both
            self.train_dataset = ctor(split='train')
            self.val_dataset = ctor(split='val')
            self.test_dataset = ctor(split='test')
        elif self.split == 'train':
            self.train_dataset = ctor(split='train')
        elif self.split == 'val':
            self.val_dataset = ctor(split='val')
        elif self.split == 'test':
            self.val_dataset = ctor(split='test')
        else:
            raise ValueError
        
        
    def train_dataloader(self):
        sampler : Sampler = None if self.sampler is None else self.sampler(dataset=self.train_dataset, shuffle=True, seed=self.seed)
        
        return torch_Dataloader(
                self.train_dataset,
                sampler=sampler,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                persistent_workers=self.persistent_workers,
            )

    def val_dataloader(self):
        sampler : Sampler = None if self.sampler is None else self.sampler(dataset=self.val_dataset, shuffle=False, seed=self.seed)
        
        return torch_Dataloader(
                self.val_dataset,
                sampler=sampler,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                persistent_workers=self.persistent_workers,
            )

    def test_dataloader(self):
        sampler : Sampler = None if self.sampler is None else self.sampler(dataset=self.test_dataset, shuffle=False, seed=self.seed)
        
        return torch_Dataloader(
                self.test_dataset,
                sampler=sampler,
                batch_size=self.batch_size,
                num_workers=self.num_workers,
                pin_memory=self.pin_memory,
                persistent_workers=self.persistent_workers,
            )