import lightning as L
import torch
import os
from astropy.table import Table
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomVerticalFlip

__all__ = ["PROVABGSDatasetModule"]


class PROVABGSDataset(Dataset):
    def __init__(self, data, input_fields, output_fields, mean, std, limit_train_size=None, transforms=None):
        self.data = data
        self.input_fields = input_fields
        self.output_fields = output_fields
        self.mean = mean
        self.std = std
        self.limit_train_size = limit_train_size
        self.transforms = transforms

    def __len__(self):
        length = len(self.data) if self.limit_train_size is None else min(len(self.data), self.limit_train_size)
        return length

    def __getitem__(self, idx):
        inputs = {}
        for key in self.input_fields:
            data = torch.tensor(self.data[key][idx].astype("float32"))
            inputs[key] = data.flatten() if key != "image" else data

        if "image" in inputs.keys() and self.transforms is not None:
            inputs["image"] = self.transforms(inputs["image"])

        output = [torch.tensor((self.data[key][idx].astype('float32') - self.mean[key]) / self.std[key]) for key in self.output_fields]
        output = torch.stack(output, dim=-1)
        return inputs, output


class PROVABGSDatasetModule(L.LightningDataModule):
    """ This module assumes the data prepared for the provabgs task described 
    in the scripts/data_provabgs_xmatch.py script.
    """
    def __init__(self, 
                 data_dir: str = "data", 
                 survey: str = "legacysurvey",
                 version: str = "1",
                 batch_size: int = 256,
                 input_fields=['tok_image'],
                 output_fields=['Z_HP','LOG_MSTAR', 'TAGE_MW', 'LOG_Z_MW', 'sSFR'],
                 num_workers: int = 10,
                 limit_train_size = None,
                 ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.survey = survey
        self.version = version
        self.input_fields = input_fields
        self.output_fields = output_fields
        self.num_workers = num_workers
        self.limit_train_size = limit_train_size
        self.transforms = Compose([RandomHorizontalFlip(), RandomVerticalFlip()]) if "image" in input_fields else None

    def setup(self, stage=None):            
        train_file = os.path.join(self.data_dir, f'provabgs_{self.survey}_train_v{self.version}_w_image.fits')
        val_file = os.path.join(self.data_dir, f'provabgs_{self.survey}_eval_v{self.version}_w_image.fits')

        self.train_data = Table.read(train_file)
        self.val_data = Table.read(val_file)

        # Removing rows with failed provabgs fits, this shouldn't do anything, but better safe than sorry 
        self.train_data = self.train_data[self.train_data['LOG_MSTAR'] > 0]
        self.val_data = self.val_data[self.val_data['LOG_MSTAR'] > 0]

        total_len = len(self.train_data[self.output_fields[0]])
        total_len = min(total_len, self.limit_train_size) if self.limit_train_size is not None else total_len

        # Compute normalization 
        self.mean = {k: self.train_data[k].astype('float32')[:total_len].mean() for k in self.output_fields}
        self.std = {k: self.train_data[k].astype('float32')[:total_len].std() for k in self.output_fields}

        self.train_dataset = PROVABGSDataset(self.train_data, self.input_fields, self.output_fields, self.mean, self.std, self.limit_train_size, transforms=self.transforms)
        self.val_dataset = PROVABGSDataset(self.val_data, self.input_fields, self.output_fields, self.mean, self.std)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=False, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, drop_last=True, num_workers=self.num_workers)
