import lightning as L
import torch
import torchvision.transforms as T
import os
import numpy as np
from astropy.table import Table
from torch.utils.data import DataLoader, Dataset
from collections import Counter

__all__ = ["GZ10DatasetModule"]


class GZ10Dataset(Dataset):
    def __init__(self, data, input_fields, output_fields, pretokenized=True, mode="train"):
        if input_fields != ["tok_image"] or output_fields != ["tok_label"]:
            raise NotImplementedError(
                "Only 'tok_image' is supported for input fields and 'tok_label' is supported for output fields"
            )
        self.data = data
        self.input_fields = input_fields
        self.output_fields = output_fields
        self.pretokenized = pretokenized

        if not self.pretokenized:
            self.channel_indices = [5, 6, 7, 8]
            self.channel_means, self.channel_stds = self._compute_channel_stats(self.channel_indices)
            if mode == "train":
                print("Using train transform with random augmentations")
                self.transform = T.Compose([
                    T.Resize((224, 224), interpolation=T.InterpolationMode.BILINEAR),
                    T.RandomHorizontalFlip(),
                    T.RandomVerticalFlip(),
                    T.RandomAffine(0, translate=(0.2, 0.2)),
                    T.RandomRotation(180),
                    T.Normalize(mean=self.channel_means.tolist(), std=self.channel_stds.tolist()),
                ])
            else:
                print("Using val transform without augmentations")
                self.transform = T.Compose([
                    T.Resize((224, 224), interpolation=T.InterpolationMode.BILINEAR),
                    T.Normalize(mean=self.channel_means.tolist(), std=self.channel_stds.tolist()),
                ])

    def _compute_channel_stats(self, channel_indices):
        """Compute mean and std on original 96x96 images before resizing."""
        images = np.array([self.data["tok_image"][i] for i in range(len(self.data))])
        
        # Compute mean/std over (batch, height, width)
        channel_means = np.mean(images[:, channel_indices, :, :], axis=(0, 2, 3))  
        channel_stds = np.std(images[:, channel_indices, :, :], axis=(0, 2, 3))  
        return torch.tensor(channel_means, dtype=torch.float32), torch.tensor(channel_stds, dtype=torch.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if self.pretokenized:
            inputs = {
                key: torch.tensor(self.data[key][idx].astype("float32")).flatten()
                for key in self.input_fields
            }
        else:
            image = self.data["tok_image"][idx][self.channel_indices, ...].astype("float32")
            image = torch.tensor(image, dtype=torch.float32)  # Convert to tensor
            image = self.transform(image)
            inputs = {"tok_image": image}

        output = [torch.tensor(self.data[key][idx].astype("int64")) for key in self.output_fields]
        output = torch.stack(output, dim=0).squeeze()
        return inputs, output


class GZ10DatasetModule(L.LightningDataModule):
    """This module assumes the data prepared for the provabgs task described
    in the scripts/data_provabgs_xmatch.py script.
    """

    def __init__(
        self,
        data_dir: str = "data",
        version: str = "1",
        batch_size: int = 256,
        input_fields=["tok_image"],
        output_fields=["tok_label"],
        num_workers: int = 10,
        split_ratio: float = 0.9,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.version = version
        self.input_fields = input_fields
        if output_fields != ["tok_label"]:
            raise ValueError("Only 'tok_label' is supported for output fields")
        self.output_fields = output_fields
        self.num_workers = num_workers
        # We define two split ratios:
        # - split_ratio: the ratio of the dataset to use for training
        # - MAX_SPLIT_RATIO: the maximum ratio of the dataset to use for training
        #   (this is used to ensure that the validation set is not too small)
        self.split_ratio = split_ratio
        self.MAX_SPLIT_RATIO = 0.9
        # Version 3 is the raw image version
        self.pretokenized = version != "3"

    def setup(self, stage=None):
        dset_file = os.path.join(
            self.data_dir, f"gz10_legacysurvey_v{self.version}.fits"
        )

        data = Table.read(dset_file)
        print(f"Loaded dataset from {dset_file} with {len(data)} rows")
        # Our pretokenized data has been split into train and validation sets, so we can just use the indices
        if self.pretokenized:
            train_idx = np.arange(len(data))[: int(len(data) * self.split_ratio)]
            val_idx = np.arange(len(data))[int(len(data) * self.MAX_SPLIT_RATIO) :]
        else:
            shuffled_idx = np.random.permutation(len(data))
            train_idx = shuffled_idx[: int(len(data) * self.split_ratio)]
            val_idx = shuffled_idx[int(len(data) * self.MAX_SPLIT_RATIO) :]
        self.train_data = data[train_idx]
        self.val_data = data[val_idx]

        self.train_dataset = GZ10Dataset(
            self.train_data, self.input_fields, self.output_fields, self.pretokenized, mode="train"
        )
        self.val_dataset = GZ10Dataset(
            self.val_data, self.input_fields, self.output_fields, self.pretokenized, mode="val"
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True,
            num_workers=self.num_workers,
        )
