import argparse, os, sys, datetime, glob, importlib
from torch.utils.data import random_split, DataLoader, Dataset

import lightning as L
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from lightning import seed_everything

from torch.utils.data.dataloader import default_collate as custom_collate

import torch
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.deterministic = True #True
torch.backends.cudnn.benchmark = False #False

def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def instantiate_from_config(config):
    # Convert NestedArg/Namespace to dict if needed (for jsonargparse compatibility)
    if not isinstance(config, dict):
        # Try to convert to dict using vars() or manual extraction
        config_dict = {}
        try:
            # Try vars() first (works for Namespace-like objects)
            config_dict = vars(config)
        except (TypeError, AttributeError):
            pass
        
        # If vars() didn't work or returned empty, try manual extraction
        if not config_dict or "target" not in config_dict:
            if hasattr(config, 'target'):
                config_dict['target'] = config.target
            if hasattr(config, 'params'):
                config_dict['params'] = config.params
            elif 'params' not in config_dict:
                config_dict['params'] = {}
        
        config = config_dict
    
    if not "target" in config:
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


class WrappedDataset(Dataset):
    """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
    def __init__(self, dataset):
        self.data = dataset

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class DataModuleFromConfig(L.LightningDataModule):
    def __init__(self, batch_size, train=None, validation=None, test=None,
                 wrap=False, num_workers=None):
        super().__init__()
        self.batch_size = batch_size
        self.dataset_configs = dict()
        self.num_workers = num_workers if num_workers is not None else batch_size*2
        if train is not None:
            self.dataset_configs["train"] = train
            self.train_dataloader = self._train_dataloader
        if validation is not None:
            self.dataset_configs["validation"] = validation
            self.val_dataloader = self._val_dataloader
        if test is not None:
            self.dataset_configs["test"] = test
            self.test_dataloader = self._test_dataloader
        self.wrap = wrap

    def prepare_data(self):
        for data_cfg in self.dataset_configs.values():
            instantiate_from_config(data_cfg)

    def setup(self, stage=None):
        self.datasets = dict(
            (k, instantiate_from_config(self.dataset_configs[k]))
            for k in self.dataset_configs)
        if self.wrap:
            for k in self.datasets:
                self.datasets[k] = WrappedDataset(self.datasets[k])

    def _train_dataloader(self):
        return DataLoader(self.datasets["train"], batch_size=self.batch_size,
                          num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate, pin_memory=True)

    def _val_dataloader(self):
        return DataLoader(self.datasets["validation"],
                          batch_size=self.batch_size,
                          num_workers=self.num_workers, collate_fn=custom_collate, shuffle=False, pin_memory=True)

    def _test_dataloader(self):
        return DataLoader(self.datasets["test"], batch_size=self.batch_size,
                          num_workers=self.num_workers, collate_fn=custom_collate, shuffle=False, pin_memory=True)


def pad_collate_fn(batch):
    return {
        "waveform": torch.nn.utils.rnn.pad_sequence([x["waveform"].transpose(0,1) for x in batch], batch_first=True, padding_value=0.).permute(0, 2, 1),
        "audio_path": [x["audio_path"] for x in batch]
        }

class PadDataModuleFromConfig(DataModuleFromConfig):

    def _train_dataloader(self):
        return DataLoader(self.datasets["train"], batch_size=self.batch_size,
                          num_workers=self.num_workers, shuffle=True, collate_fn=pad_collate_fn, pin_memory=True)

    def _val_dataloader(self):
        return DataLoader(self.datasets["validation"],
                          batch_size=self.batch_size,
                          num_workers=self.num_workers, collate_fn=pad_collate_fn, shuffle=False, pin_memory=True)
    def _test_dataloader(self):
        return DataLoader(self.datasets["test"],
                          batch_size=self.batch_size,
                          num_workers=self.num_workers, collate_fn=pad_collate_fn, shuffle=False, pin_memory=True)

def main():
    cli = LightningCLI(
        save_config_kwargs={"overwrite": True}
    )

if __name__ == "__main__":
    main()