import torch
import torch.utils.data as data
import pandas as pd
import math
import numpy as np

from utils import OneHotEncoder
from data.loaders.utils import convert_to_tuple

class CDDataset(data.Dataset):
    def __init__(
        self,
        file_paths,
        exp_config,
        metadata_path,
        biom_one_hot_embedder=None,
        unit_one_hot_embedder=None,
        lab_code_one_hot_embedder=None,
        save_path=None,
        exclude_biomarkers = [],
        anonimise_sample=False,
    ):
        self.file_paths = file_paths
        self.metadata_path = metadata_path
        
        self.unique_analysiscode = pd.read_csv(f"{metadata_path}/unique_analysiscode.csv")["0"].tolist()
        self.unique_units = pd.read_csv(f"{metadata_path}/unique_units.csv")["0"].tolist()
        self.unique_lab_codes = pd.read_csv(f"{metadata_path}/unique_lab_ids.csv")["0"].tolist()

        self.unique_analysiscode = [x for x in self.unique_analysiscode if x not in exclude_biomarkers]
        self.biom_one_hot_embedder = biom_one_hot_embedder
        self.biom_encoder = OneHotEncoder(self.unique_analysiscode)
        
        self.unit_one_hot_embedder = unit_one_hot_embedder
        self.unit_encoder = OneHotEncoder(self.unique_units)
        
        self.lab_code_one_hot_embedder = lab_code_one_hot_embedder
        self.lab_encoder = OneHotEncoder(self.unique_lab_codes)

        self.exp_config = exp_config
        
        self.label_encoder = exp_config.label_encoder
        self.exclude_biomarkers = exclude_biomarkers

        self.save_path = save_path
        self.anonimise_sample = anonimise_sample

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        return self._build_set_of_graphs(file_path)

    def _build_set_of_graphs(
        self,
        file_path
        ):
        dataframe = pd.read_csv(file_path)

        dataframe["samplingdate"] = pd.to_datetime(dataframe["samplingdate"])
        dataframe['prediag_token_date'] = pd.to_datetime(dataframe['prediag_token_date'])

        graph_dict = {}

        dataframe = dataframe.sort_values('samplingdate')

        raw_label = dataframe['dataset'].iloc[0]
        graph_label = self.label_encoder[raw_label]


        graph_id = dataframe['lbnr'].iloc[0]
        norm_age_at_sampling = dataframe["age_at_sampling_in_days_norm"]
        sex = dataframe["sex"].iloc[0]

        biomarkers = [col for col in dataframe.columns if col.startswith("npu") or col.startswith("dnk")]

        for biomarker in biomarkers:
            nodes = []
            sampling_dates_indices = []

            for idx, vals in dataframe[biomarker].items():
                biom_data = convert_to_tuple(vals)

                if math.isnan(biom_data[0]):
                    continue

                value = biom_data[0]
                unit = biom_data[1]
                lab_code = biom_data[2]

                biomarker_encoded, unit_encoded, lab_code_encoded = self._get_categorical_embeddings(biomarker, unit, lab_code)

                if self.anonimise_sample:
                    value += np.random.normal(loc=0.0, scale=0.03 * abs(value))

                value_feature_group = [value] + unit_encoded
                norm_age = norm_age_at_sampling.iloc[idx]
                if math.isnan(norm_age):
                    raise Exception

                nodes.append(value_feature_group + lab_code_encoded + [sex] + biomarker_encoded)
                sampling_dates_indices.append(idx)

            sampling_dates = pd.to_datetime(dataframe['samplingdate'][sampling_dates_indices])

            if self.anonimise_sample:
                valid_shifts = np.concatenate([
                    np.arange(-60, -29),  
                    np.arange(30, 61)
                ])
                shift_days = np.random.choice(valid_shifts, size=len(sampling_dates))
                sampling_dates += pd.to_timedelta(shift_days, unit="D")

            node_distances = self._get_node_distances(sampling_dates)            
            x = torch.tensor(nodes, dtype=torch.float)
            
            # graph_dict[biomarker] = (x, torch.tensor(node_distances, dtype=torch.float), torch.tensor([graph_label], dtype=torch.float), sampling_dates, dataframe["prediag_token_date"].iloc[0].strftime("%Y-%m-%d")) # FOR CD NODE-LEVEL GRAPHS
            graph_dict[biomarker] = (x, torch.tensor(node_distances, dtype=torch.float), torch.tensor([graph_label], dtype=torch.float))

        return graph_dict, torch.tensor([graph_label], dtype=torch.float), graph_id

    def _get_categorical_embeddings(self, bm, unit, lab_code):
        biomarker_encoded = self.biom_one_hot_embedder(
            torch.tensor(self.biom_encoder.encode(bm).tolist()).to(torch.float)
        ).tolist()

        unit_encoded = self.unit_one_hot_embedder(
            torch.tensor(self.unit_encoder.encode(unit).tolist()).to(torch.float)
        ).tolist()
        
        lab_code_encoded = self.lab_code_one_hot_embedder(
            torch.tensor(self.lab_encoder.encode(lab_code).tolist()).to(torch.float)
        ).tolist()

        return biomarker_encoded, unit_encoded, lab_code_encoded

    def _get_node_distances(self, sampling_dates):
        sampling_dates_np = sampling_dates.values.astype("datetime64[D]")
        node_distances = (sampling_dates_np[:, None] - sampling_dates_np[None, :]).astype(float)
        np.fill_diagonal(node_distances, 0)
        return node_distances
