"""
conditional trellis
* Gives x0 control AND sample of x0 culture ()


"""
import time
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import pickle
import yaml as yml
import numpy as np
import random
import torch
from sklearn.decomposition import PCA
import warnings


class trellis_dataset(Dataset):
    def __init__(
        self,
        split,
        data,
        encoded_conditional,
        num_prior=100,
        cells_per_sample=1000,
        num_conditions=1430,
        num_components=5,
        use_small_exp_num=False,
        control=set(["DMSO", "AH", "H2O"]),
        treatment=['O', 'S', 'VS', 'L', 'V', 'F', 'C', 'SF', 'CS', 'CF', 'CSF'],
        culture=["PDO", "PDOF", "F"],
        cell_type=["PDOs", "Fibs"],  # only PDO, can do 'PDOs' and/or 'Fibs'
        prefix="train",
        pca=None,
        seed=0,
    ):
        #np.random.seed(seed)
        self.rng = np.random.default_rng(seed)
        
        self.data = data
        
        self.control = control  # identify x0
        self.treatment = treatment #
        self.culture = culture
        self.cell_type = cell_type
        
        self.num_components = num_components
        self.encode_conditional = encoded_conditional
        self.num_prior = (
            num_prior  # number of background cells chosen (include PDO, Fib etc)
        )
        self.cells_per_sample = cells_per_sample
        self.num_conditions = num_conditions
        self.use_small_exp_num = use_small_exp_num
        self.prefix = prefix

        self.split = self.__filter_control__(split)
        self.exp_idx = self.__lst_exp__(self.split)

        self.pca = pca

        print("Constructing {} data ...".format(prefix))
        
        # construct dataset
        start = time.time()
        self.construct_data()
        end = time.time()
        print("done. Time (s):", print(end - start))

        print("... Data loaded!")

    def construct_data(self):
        self.samples_tmp, self.culture, self.x0, self.x1, self.cell_cond, self.treat_cond = self.select_experiments()
        
        if self.prefix == "train":
            if self.num_components is None:
                self.samples = self.samples_tmp
            else:
                print("Fitting PCA for low-dim representation ...")
                x0_train = np.concatenate(self.x0, axis=0)
                x1_train = np.concatenate(self.x1, axis=0)
                x_train = np.concatenate([x0_train, x1_train], axis=0)
                print(x_train.shape)

                self.pca = PCA(n_components=self.num_components)
                self.pca.fit(x_train)

                print("... PCA fit done!")

                self.samples = self.pca_embed_samples(self.samples_tmp)

            if self.use_small_exp_num:
                idcs = self.rng.choice(
                    np.arange(len(self.samples)), size=6, replace=False
                )
                new_samples = []
                for i in idcs:
                    new_samples.append(self.samples[i])
                self.samples = new_samples

        elif self.prefix == "val":
            if self.num_components is None:
                self.samples = self.samples_tmp
            else:
                self.samples = self.pca_embed_samples(self.samples_tmp)

            if self.use_small_exp_num:
                idcs = self.rng.choice(
                    np.arange(len(self.samples)), size=1, replace=False
                )
                new_samples = []
                for i in idcs:
                    new_samples.append(self.samples[i])
                self.samples = new_samples

        elif self.prefix == "test":
            if self.num_components is None:
                self.samples = self.samples_tmp
            else:
                self.samples = self.pca_embed_samples(self.samples_tmp)

            if self.use_small_exp_num:
                idcs = self.rng.choice(
                    np.arange(len(self.samples)), size=1, replace=False
                )
                new_samples = []
                for i in idcs:
                    new_samples.append(self.samples[i])
                self.samples = new_samples

        else:
            raise ValueError("prefix not recognized")

    def select_experiments(self):
        samples_tmp, cultures, sources, targets, cell_conds, treat_conds = [], [], [], [], [], []
        
        for i in range(len(self.split)):
            exp = self.split[i]

            x0_treatment = list(set(exp.keys()).intersection(self.control))[0]
            treatkeys = [key for key in exp.keys() if key not in self.control]
            for t in treatkeys:
                concentration = list(exp[t].keys())
                max_conc = str(max(map(int, concentration)))

                cultures_keys = list(exp[t][max_conc].keys())
                for culture in cultures_keys:
                    
                    x0_pdos_idx, x1_pdos_idx, x0_fibs_idx, x1_fibs_idx = [], [], [], []
                    if culture in ["PDOF", "PDO"]:
                        x0_pdos_idx = exp[x0_treatment]["0"][culture][self.cell_type[0]].copy().tolist()
                        x1_pdos_idx = exp[t][max_conc][culture][self.cell_type[0]].copy().tolist()

                    if culture in ["PDOF", "F"]:
                        x0_fibs_idx = exp[x0_treatment]["0"][culture][self.cell_type[1]].copy().tolist()
                        x1_fibs_idx = exp[t][max_conc][culture][self.cell_type[1]].copy().tolist()
                        
                    # concat x0 and x1 idcs
                    x0_idx = x0_pdos_idx + x0_fibs_idx
                    x1_idx = x1_pdos_idx + x1_fibs_idx

                    # create data
                    x0 = np.array(self.data[x0_idx])
                    x1 = np.array(self.data[x1_idx])
                    
                    # get cell type one-hot encoding for x0 populations
                    x0_cell_pdos_idx = range(0, len(x0_pdos_idx))
                    x0_cell_fibs_idx = range(len(x0_pdos_idx), len(x0_idx))
                    cond_cell = np.zeros((x0.shape[0], len(self.cell_type)))
                    cond_cell[x0_cell_pdos_idx, 0] = 1
                    cond_cell[x0_cell_fibs_idx, 1] = 1

                    # get treatment one-hot encoding 
                    treat_idx = self.treatment.index(t)
                    cond_treat = torch.nn.functional.one_hot(
                        torch.tensor(treat_idx).long(), num_classes=len(self.treatment)
                    )
                    cond_treat = cond_treat.expand(x0.shape[0], -1).detach().numpy()

                    samples_tmp.append((culture, x0, x1, cond_cell, cond_treat))
                    cultures.append(culture)
                    targets.append(x1)
                    cell_conds.append(cond_cell)
                    treat_conds.append(cond_treat)
            sources.append(x0)

        print("{} {} data samples".format(len(samples_tmp), self.prefix))
        return samples_tmp, cultures, sources, targets, cell_conds, treat_conds

    def pca_embed_samples(self, samples_tmp):
        samples = []
        for sample in samples_tmp:
            culture, x0, x1, cell_cond, treat_cond = sample
            x0_pca = self.pca.transform(x0)
            x1_pca = self.pca.transform(x1)
            samples.append((culture, x0_pca, x1_pca, x0, x1, cell_cond, treat_cond))
        return samples

    def __filter_control__(self, split):
        split_lst = []
        for ls in split:
            keyset = set(ls.keys())
            if self.has_empty_element(ls):
                continue
            split_lst.append(ls)
        return split_lst

    def has_empty_element(self, nested_dict):
        """
        Recursively checks if there is any empty dictionary in the nested structure.

        Parameters:
            nested_dict (dict): The nested dictionary to check.

        Returns:
            bool: True if any empty dictionary is found, otherwise False.
        """
        for key, value in nested_dict.items():
            if isinstance(value, dict):  # Check if the item is a dictionary
                if not value:  # Check if the dictionary is empty
                    return (
                        True  # Return True immediately upon finding an empty dictionary
                    )
                else:
                    # Recursively check further in the dictionary
                    if self.has_empty_element(value):
                        return True
        return False  # Return False if no empty dictionary is found after checking all items

    def __lst_exp__(self, idx):
        # for v2 using only experiments as one output (to be squeezed)
        split_idx = {}
        count = 0
        for num, ls in enumerate(self.split):
            keyset = set(ls.keys())
            ctrl_key = list(keyset.intersection(self.control))
            if len(ctrl_key) == 0 or not ("0" in ls[ctrl_key[0]]):
                continue
            for key, value in ls.items():
                for conc, ids in value.items():
                    count += 1
                    split_idx[count] = (num, key, conc)
        return split_idx

    def __len__(self):
        return len(self.samples)

    def __get_background__(self, exp_num, x0_treatment):
        # get the background cells
        # find x0
        return self.split[exp_num][x0_treatment]["0"][
            "background"
        ].copy()  # already list .tolist()

    def __getitem__(self, idx):
        """
        return x0, x1 (for each replicate)

        """
        if self.pca is not None:
            culture, x0, x1, _, x1_full, cell_cond, treat_cond = self.samples[idx]
            return (
                idx,
                culture,
                x0,
                x1,
                x1_full,
                cell_cond,
                treat_cond,
            )
        else:
            culture, x0, x1, cell_cond, treat_cond = self.samples[idx]
            return (
                idx,
                culture,
                x0,
                x1,
                torch.tensor(0), # dummy tensor
                cell_cond,
                treat_cond,
            )


class TrellisDatamodule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size=1,
        split_source="/data_splits.pickle", 
        marker_cols="/trellis_marker_col.yaml",
        data_path="/data_normalized.npy",  
        full_metadata_data_path="/data_full",
        metadata_path="/reduced_metadata.pickle",  
        val_test_path=None, 
        eval_frac_val=1.0,
        eval_frac_test=1.0,
        control=set(["DMSO", "AH", "H2O"]),
        treatment=["O", "S", "VS", "L", "V", "F", "C", "SF", "CS", "CF", "CSF"],
        id_by_replicate=True,
        culture=["PDO", "PDOF", "F"],
        cell_type=["PDOs", "Fibs"],
        separation=[
            "Patient",
            "Culture",
            "Plate",
        ],  # values used to distinguish replicates
        min_data_pts=1000,
        treatment_test=None,
        name="trellis_replicates",
        num_prior=100,  # number of samples co-cultured with the same background
        num_components=None,
        use_small_exp_num=False,
        seed=0,
    ):
        """
        Processing steps:
        *

        """
        super().__init__()
        self.batch_size = batch_size
        # self.data_source = data_source
        self.split_source = split_source
        self.cell_type = cell_type  # filtered cell type
        self.separation = separation  # values used to distinguish replicates
        self.min_data_pts = min_data_pts  # values used to select how much for metrics (validation and test)
        self.treatment_test = treatment_test  # to filter for specific treatment
        self.name = name
        self.num_prior = num_prior
        self.num_components = num_components
        self.use_small_exp_num = use_small_exp_num
        self.seed = seed
        with open(self.split_source, "rb") as handle:
            self.data_splits = pickle.load(handle)
            # from these split we will get the background cells

        self.marker_cols = list(
            yml.safe_load(open(marker_cols))["marker"]
        )  # dropping one col (because it is missing half the data)
        self.input_dim = len(self.marker_cols)
        self.non_marker_cols = [
            "Treatment",
            "Culture",
            "Date",
            "Patient",
            "Concentration",
            "Replicate",
            "Cell_type",
            "Plate",
            "Batch",
        ]
        #self.normalized_data = np.load(normalized_data_path)[:, :-1]
        self.data = np.load(data_path)[:, :-1]

        self.metadata = pd.read_pickle(metadata_path)  # dropped all marker columns
        self.metadata["idx"] = self.metadata.index
        
        self.eval_frac_val = eval_frac_val
        self.eval_frac_test = eval_frac_test
        self.control = control
        self.treatment = treatment
        self.cell_type = cell_type
        self.culture = culture
        self.id_by_replicate = id_by_replicate
        self.unique_treatments = self.metadata["Treatment"].unique()

        self.__prepare_data__()
        self.x_idx = {}
        # self.divide_x0_x1('train') # not doing fixed x0, x1 for train

        if val_test_path is not None:
            with open(val_test_path, "rb") as handle:
                self.x_idx = pickle.load(handle)
        else:
            self.divide_x0_x1("val", frac=self.eval_frac_val)
            self.divide_x0_x1("test", frac=self.eval_frac_test)
        
        self.train_dataset = trellis_dataset(
            split=self.data_splits["train"],
            data=self.data,
            encoded_conditional=self.encode_conditional,
            num_prior=self.num_prior,
            num_components=self.num_components,
            use_small_exp_num=self.use_small_exp_num,
            control=self.control,
            treatment=self.treatment,
            culture=self.culture,
            cell_type=self.cell_type,  
            prefix="train",
            seed=self.seed,
        )
        
        self.val_dataset = trellis_dataset(
            split=self.data_splits["val"],
            data=self.data,
            encoded_conditional=self.encode_conditional,
            num_prior=self.num_prior,
            num_components=self.num_components,
            use_small_exp_num=self.use_small_exp_num,
            control=self.control,
            treatment=self.treatment,
            culture=self.culture,
            cell_type=self.cell_type,  
            prefix="val",
            pca=self.train_dataset.pca if self.num_components is not None else None,
            seed=self.seed,
        )
        
        self.test_dataset = trellis_dataset(
            split=self.data_splits["test"],
            data=self.data,
            encoded_conditional=self.encode_conditional,
            num_prior=self.num_prior,
            num_components=self.num_components,
            use_small_exp_num=self.use_small_exp_num,
            control=self.control,
            treatment=self.treatment,
            culture=self.culture,
            cell_type=self.cell_type,  
            prefix="test",
            pca=self.train_dataset.pca if self.num_components is not None else None,
            seed=self.seed,
        )
        
        print("DataModule initialized")
        
    def __prepare_data__(self):
        """
        data split format:
        [{'treatment1':{'concentration 1':[ID, ID,...],''}, }
        ,{},...,{}]
        """
        print("Preparing data")
        self.treatments, self.concentrations = self.get_conditional_arrangement()
        self.conditionals = len(self.treatments) + len(
            self.concentrations
        )  # this will be used for specifying the model
        print("Data prepared")

    def get_conditional_arrangement(self):
        """keep track of the possible treatments and concentrations

        Returns:
            dictionary (treatment: [concentration1, concentration2, ...]), list (concentration1, concentration2, ...
        """
        # return conditionals (treatment, concentration)
        self.no_treatment = "H2O"
        possible_concentrations = []
        treatments = {}
        for ls in self.data_splits["train"]:
            for key, value in ls.items():
                if key not in treatments and key != self.no_treatment:
                    treatments[key] = []

                    for conc, __ in value.items():
                        if conc not in possible_concentrations:
                            possible_concentrations.append(conc)
                        if conc not in treatments[key]:
                            treatments[key].append(conc)
        return treatments, possible_concentrations

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=1,
            num_workers=4,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=1,
            num_workers=4,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=1,
            num_workers=4,
            shuffle=False,
        )

    def encode_conditional(self, treatment, concentration):
        """encode the treatment and concentration into a one-hot vector

        Args:
            treatment (str): treatment
            concentration (str): concentration

        Returns:
            np.array: one-hot vector of the treatment and concentration
        """
        treat = np.zeros(len(self.treatments))
        conc = np.zeros(len(self.concentrations))
        if treatment != self.no_treatment:
            treat[list(self.treatments.keys()).index(treatment)] = 1
            conc[self.concentrations.index(concentration)] = 1
        return np.concatenate([treat, conc])

    def decode_conditional(self, encoded):
        """decode the one-hot vector into treatment and concentration

        Args:
            encoded (np.array): one-hot vector of the treatment and concentration

        Returns:
            tuple: treatment, concentration
        """
        treat = list(self.treatments.keys())[np.argmax(encoded[: len(self.treatments)])]
        conc = self.concentrations[np.argmax(encoded[len(self.treatments) :])]
        return treat, conc
    