import random 
import cv2
import albumentations as A
import random 
from torch.utils.data import DataLoader, ConcatDataset, Subset
from transformers import AutoTokenizer
from .dataset_registry import DatasetRegistry
import numpy as np 
import torch
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import random
import torch.distributed as dist
import lightning.pytorch as pl


import random 
import cv2
# import albumentations as A
import random 
from torch.utils.data import DataLoader, ConcatDataset, Subset
from transformers import AutoTokenizer
from files.datasets.dataset_registry import DatasetRegistry
from files.datasets.dataset_registry_instruct import DatasetRegistryInstruct

import files.datasets.dataset_omnimed_vqa
import files.datasets.dataset_path_vqa
import files.datasets.dataset_rad_vqa
import files.datasets.dataset_pmc_vqa
import files.datasets.dataset_vqa_med
import files.datasets.dataset_ubench_vqa
import files.datasets.dataset_slake_vqa


VL_MODEL_NAMES = {
    "Qwen/Qwen2-VL-7B-Instruct",
}

class Transforms():
    def __init__(self, config):
        self.train_transforms = A.Compose([A.augmentations.geometric.resize.Resize(height=config.image_size[0], width=config.image_size[1])])
        self.valid_transforms = A.Compose([A.augmentations.geometric.resize.Resize(height=config.image_size[0], width=config.image_size[1])])



class MultiDataLoader:
    def __init__(self, dataloaders):
        self.dataloaders = dataloaders

    def __iter__(self):
        for loader in self.dataloaders:
            for batch in loader:
                yield batch

    def __len__(self):
        return sum(len(loader) for loader in self.dataloaders)


def get_sequential_dataloaders(cfg, phase):
    dataloaders = []
    name_list = cfg.dataset_list

    # Create DataLoader for each dataset separately
    for name in name_list:

        transforms = Transforms(cfg)
        if phase == 'train':
            transform = transforms.train_transforms
        elif phase == 'valid':
            transform = transforms.valid_transforms
        else:
            raise ValueError 

        tokenizer = AutoTokenizer.from_pretrained(
            cfg.LLM.model_name,
            cache_dir=cfg.LLM.cache_dir,
            force_download=False,
            unk_token="<unk>",
            bos_token="<s>",
            eos_token="</s>"
        ) if cfg.LLM.model_name == 'chaoyi-wu/PMC_LLAMA_7B' else AutoTokenizer.from_pretrained(
            cfg.LLM.model_name,
            cache_dir=cfg.LLM.cache_dir,
            force_download=False,
        )

        img_padding = [-100 for _ in range(cfg.img_token_num)]
        if ("Qwen2-VL" in cfg.LLM.model_name) or ("gemma-3" in  cfg.LLM.model_name.lower()):
            dataset = DatasetRegistryInstruct.get_dataset(name, cfg, name, phase)
        else:
            dataset = DatasetRegistry.get_dataset(name, cfg, transform, tokenizer, img_padding, name, phase)

        print(f'---- DATASET LENGTH ----')
        print(f'Name of the dataset : {name} and phase : {phase}')
        print(len(dataset))

        # Optional subsampling for inference mode

        num_samples = cfg.n_inference_lim
        if cfg.is_inference and len(dataset)>num_samples:
            seed = 42
            np.random.seed(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            pl.seed_everything(seed)
            random.seed(seed)

            indices = random.sample(range(len(dataset)), num_samples)
            dataset = Subset(dataset, indices)

        # Use DistributedSampler for multi-GPU
        if dist.is_initialized():
            sampler = DistributedSampler(dataset, shuffle=True)
        else:
            sampler = None

        # Create DataLoader for each individual dataset
        loader = DataLoader(
            dataset,
            batch_size=cfg.batch_size,
            sampler=sampler,
            num_workers=cfg.num_workers,
            pin_memory=True,
            shuffle=(sampler is None),  # Only shuffle if no sampler is used
            # shuffle=True,
            prefetch_factor=cfg.prefetch_factor,
        )

        dataloaders.append(loader)

    dataloader = MultiDataLoader(dataloaders)
    return dataloader

