import os
import pickle
from typing import Optional

import lightning as L
import torch
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader, Dataset

from aion_eval.utils import flatten_dict, index_collated

__all__ = ["DESIDDPayneDatasetModule"]


class DESIDDPayneDataset(Dataset):
    def __init__(
        self, data, input_fields, output_fields, mean, std, limit_train_size=None
    ):
        self.data = data
        self.input_fields = input_fields
        self.output_fields = output_fields
        self.mean = mean
        self.std = std
        self.limit_train_size = limit_train_size

    def __len__(self):
        x = self.data[self.input_fields[0]]
        if isinstance(x, dict):
            total_len = len(x[list(x.keys())[0]])
        else:
            total_len = len(x)

        if self.limit_train_size is not None:
            return min(total_len, self.limit_train_size)
        else:
            return total_len

    def __getitem__(self, idx):
        inputs = {key: index_collated(self.data[key], idx) for key in self.input_fields}
        inputs = tree_map(torch.tensor, inputs)
        inputs = flatten_dict(inputs)
        output = [
            torch.tensor(
                (self.data[key][idx].astype("float32") - self.mean[key]) / self.std[key]
            )
            for key in self.output_fields
        ]
        output = torch.stack(output, dim=-1)
        return inputs, output


class DESIDDPayneDatasetModule(L.LightningDataModule):
    """This module assumes the data prepared for the provabgs task described
    in the scripts/data_desiddpayne_xmatch.py script.
    """

    def __init__(
        self,
        data_dir: str = "data",
        survey: str = "gaia",
        survey_dataset_name: str = "parallax_sample",
        version: str = "1",
        batch_size: int = 256,
        input_fields=[
            "tok_xp_bp",
            "tok_xp_rp",
            "tok_flux_g_gaia",
            "tok_flux_bp_gaia",
            "tok_flux_rp_gaia",
            "tok_parallax",
            "tok_ra",
            "tok_dec",
        ],
        output_fields=[
            "TEFF",
            "LOGG",
            "FEH",
            "C_FE",
            "N_FE",
            "MG_FE",
            "O_FE",
            "AL_FE",
            "SI_FE",
            "CA_FE",
            "TI_FE",
            "CR_FE",
            "MN_FE",
            "NI_FE",
            "RV",
            "VMIC",
        ],
        num_workers: int = 10,
        limit_train_size: Optional[int] = None,
    ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.survey = survey
        self.survey_dataset_name = survey_dataset_name
        self.version = version
        self.input_fields = input_fields
        self.output_fields = output_fields
        self.num_workers = num_workers
        self.full_survey_name = self.survey + (
            "-" + self.survey_dataset_name
            if self.survey_dataset_name is not None
            else ""
        )
        self.limit_train_size = limit_train_size

    def setup(self, stage=None):
        train_file = os.path.join(
            self.data_dir,
            f"desiddpayne_{self.full_survey_name}_train_v{self.version}.pkl",
        )
        val_file = os.path.join(
            self.data_dir,
            f"desiddpayne_{self.full_survey_name}_eval_v{self.version}.pkl",
        )

        with open(train_file, "rb") as f:
            self.train_data = pickle.load(f)

        with open(val_file, "rb") as f:
            self.val_data = pickle.load(f)

        if self.limit_train_size is not None:
            total_len = len(self.train_data[self.output_fields[0]])
        else:
            total_len = self.limit_train_size

        # Compute normalization
        self.mean = {
            k: self.train_data[k].astype("float32")[:total_len].mean()
            for k in self.output_fields
        }
        self.std = {
            k: self.train_data[k].astype("float32")[:total_len].std()
            for k in self.output_fields
        }

        self.train_dataset = DESIDDPayneDataset(
            self.train_data,
            self.input_fields,
            self.output_fields,
            self.mean,
            self.std,
            self.limit_train_size,
        )
        self.val_dataset = DESIDDPayneDataset(
            self.val_data, self.input_fields, self.output_fields, self.mean, self.std
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=False,
            num_workers=self.num_workers,
        )
