

import torch
from contextlib import nullcontext
from torch.utils.data import DataLoader, TensorDataset, Dataset

class TransformedDataset(Dataset):
    """A wrapper dataset to apply a transform to an existing dataset or subset."""
    
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):

        row = self.dataset[index]

        if self.transform:
            transformed_img = self.transform(row[0])
            return (transformed_img, ) + row[1:]

        return row

    def __len__(self):
        return len(self.dataset)

def collate_train(batch):
    if isinstance(batch, dict):
        X = batch['image'].unsqueeze(0)
        t = batch['time'].unsqueeze(0)
        e = batch['event'].unsqueeze(0)
    else:
        X = torch.stack([b['image'] for b in batch], dim=0)
        t = torch.stack([b['time'] for b in batch], dim=0)
        e = torch.stack([b['event'] for b in batch], dim=0)
    return X, t, e

def collate_test(batch):
    if isinstance(batch, dict):
        X = batch['image'].unsqueeze(0)
    else:
        X = torch.stack([b['image'] for b in batch], dim=0)
    return X

class Base:
    def __init__(self,
                 net,
                 opt,
                 sch=None,
                 mixup=None,
                 discretizer=None,
                 train_transform=None,
                 test_transform=None,
                 epochs=100,
                 batch_size=128,
                 device=None):
        
        self.net = net
        self.opt = opt
        self.sch = sch
        self.mixup = mixup
        self.discretizer = discretizer
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.epochs = epochs
        self.batch_size = batch_size
        self.device = ('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
        self.non_blocking = True
        
        self.use_amp = (self.device == 'cuda') and (torch.__version__ >= '2.0.0')
        if self.use_amp:
            self.amp_ctx = torch.amp.autocast(device_type='cuda', dtype=torch.float16)
        else:
            self.amp_ctx = nullcontext()

    def fit(self, train_dataset, valid_dataset=None):
        if self.train_transform is not None:
            train_dataset = TransformedDataset(train_dataset, transform=self.train_transform)

        if self.test_transform is not None:
            valid_dataset = TransformedDataset(valid_dataset, transform=self.test_transform)

        train_loader = DataLoader(
                train_dataset,
                batch_size=self.batch_size,
                shuffle=True,
                num_workers=4 if self.device=='cuda' else 0,
                pin_memory=True if self.device=='cuda' else False,
            )
        valid_loader = DataLoader(
                valid_dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=4 if self.device=='cuda' else 0,
                pin_memory=True if self.device=='cuda' else False,
            ) if valid_dataset is not None else None

        return self._fit(train_loader, valid_loader)
    
    def _fit(self, train_loader, val_loader=None):
        raise NotImplementedError("Subclasses should implement this method.")

    def survival_probability_at_times(self, X, times=None):
        if not isinstance(X, Dataset):
            X = TensorDataset(X)
        
        X = TransformedDataset(X, transform=self.test_transform)
        dataloader = DataLoader(
            X,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            pin_memory=True if self.device=='cuda' else False,
        )
        
        return self._survival_probability_at_times(dataloader, times)

    def _survival_probability_at_times(self, dataloader, times):
        raise NotImplementedError("Subclasses should implement this method.")
    