"""Dataset and utility functions."""

import os
import json
import random
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader

from prompt_strategies import PromptSelector


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)


def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def get_dataloader(dataset, batch_size, shuffle=True, num_workers=4, seed=42):
    generator = torch.Generator()
    generator.manual_seed(seed)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=generator if shuffle else None
    )


def preprocess_batch(images, prompts, processor, device, config):
    numpy_images = []
    for img in images:
        if hasattr(img, 'cpu'):
            numpy_img = img.cpu().numpy().transpose(1, 2, 0)
        else:
            numpy_img = np.array(img).transpose(1, 2, 0) if len(np.array(img).shape) == 3 else np.array(img)
        numpy_images.append(np.clip(numpy_img, 0.0, 1.0))

    image_inputs = processor.image_processor(images=numpy_images, return_tensors="pt", do_rescale=False)
    text_inputs = processor.tokenizer(text=prompts, return_tensors="pt", padding=True, truncation=True, max_length=config['text_max_length'])
    inputs = {**image_inputs, **text_inputs}
    for k in inputs:
        if torch.is_tensor(inputs[k]):
            inputs[k] = inputs[k].to(device, non_blocking=True)
    return inputs


def save_model_checkpoint(model, processor, save_path, task_info=None, config=None):
    save_path = Path(save_path)
    save_path.mkdir(parents=True, exist_ok=True)
    model.clipseg_backbone.save_pretrained(save_path / "clipseg_backbone")
    processor.save_pretrained(save_path / "processor")
    torch.save(model.state_dict(), save_path / "peft_model.pth")
    with open(save_path / "checkpoint_info.json", 'w') as f:
        json.dump({"model_type": "CLIPModalityLoRAModel", "task_info": task_info, "config": config}, f, indent=2, default=str)
    return save_path


def serialize_config(config):
    serializable = {}
    for k, v in config.items():
        if k == 'device':
            serializable[k] = str(v)
        elif isinstance(v, torch.dtype):
            serializable[k] = str(v)
        elif isinstance(v, (int, float, str, bool, list, dict)):
            serializable[k] = v
        else:
            serializable[k] = str(v)
    return serializable


class MedVLSMDataset(Dataset):
    
    def __init__(self, data_root, dataset_name, split, prompt_strategy, config):
        self.data_root = Path(data_root)
        self.dataset_path = self.data_root / dataset_name
        self.split = split
        self.prompt_selector = PromptSelector(prompt_strategy)
        self.config = config
        with open(self.dataset_path / 'anns' / f'{split}.json', 'r') as f:
            self.annotations = json.load(f)

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        ann = self.annotations[idx]
        img = Image.open(self.dataset_path / 'images' / ann['img_name']).convert('RGB')
        mask = Image.open(self.dataset_path / 'masks' / ann['mask_name']).convert('L')
        prompt = self.prompt_selector.select_prompt(ann)

        img_t = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
        mask_t = torch.from_numpy(np.array(mask)).float()
        sz = self.config['image_size']
        img_t = F.interpolate(img_t.unsqueeze(0), size=(sz, sz), mode='bilinear')[0]
        mask_t = F.interpolate(mask_t.unsqueeze(0).unsqueeze(0), size=(sz, sz), mode='nearest')[0, 0]

        return {
            'image': torch.clamp(img_t, 0.0, 1.0),
            'mask': (mask_t > 0.5).long(),
            'prompt': prompt,
            'image_path': str(self.dataset_path / 'images' / ann['img_name']),
            'img_name': ann['img_name']
        }