from typing import Optional, Callable, List

import os
import glob
import os.path as osp
import pandas as pd 
import torch
import numpy as np
from torch_geometric.data import (InMemoryDataset, Data, download_url,
                                  extract_tar, extract_zip)
from dirgt.loader.dataset.mcoli_dataset_utils import generate_pyg_data, add_labels, reaction_direction, Direction
from contrabass import CobraMetabolicModel
from sklearn.model_selection import train_test_split, StratifiedKFold
from torch_geometric.graphgym.config import cfg


class MColiSyntheticDataset(InMemoryDataset):
    r"""This dataset is used to store a single graph
    """

    def __init__(self, root: str, pkl_dir: str, transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None,
                 pre_filter: Optional[Callable] = None):
        self.pkl_dir = pkl_dir
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> List[str]:
        return ['iML1515_mfg_edges.csv', 'iML1515_mfg_nodes_ess-label_fba_pred.csv',
                    'mfg_essentialities_indeces.npy', 'iML1515_glucose.json']

    @property
    def processed_file_names(self) -> List[str]:
        return ['data.pt', 'split_dict.pt']

    def download(self):
        pass

    def len(self):
        return 1

    def process(self):

        model_file = osp.join(self.root, 'raw', self.raw_file_names[3])
        bass_model = CobraMetabolicModel(model_file)
        model = bass_model.model()
        model.objective = 'BIOMASS_Ec_iML1515_core_75p37M'

        data, dictionary_reactions, dictionary_metabolites, dictionary_aggr_rxn = \
            generate_pyg_data(model)
        data, er_ids = add_labels(data, bass_model, dictionary_reactions)
        data.y = torch.squeeze(data.y)

        if len(er_ids) == 0:
            raise ValueError("Essential reaction is empty. Check contrabass dependencies.")
            
        # ---- SPLITS -----
        node_info = pd.read_csv(osp.join(self.root, 'raw', self.raw_file_names[1])).set_index('id').rename(
            columns={'Unnamed: 0': 'index'})
        node_info[['train_mask', 'val_mask', 'test_mask']] = False

        onetoX_indx = np.load(osp.join(self.root, 'raw', self.raw_file_names[2]))
        onetoX_nodes = node_info.iloc[onetoX_indx]

        _, test_split = train_test_split(onetoX_indx, test_size=0.2, random_state=cfg.seed,
                                             stratify=onetoX_nodes['essentiality'])
        train_val_idx = set(onetoX_indx) - set(test_split)
        train_split, validation_split = train_test_split(list(train_val_idx), test_size=0.2,
                                                            random_state=cfg.seed,
                                                            stratify=node_info.iloc[list(train_val_idx)]['essentiality'])
        node_info.iloc[train_split, -3] = True
        node_info.iloc[validation_split, -2] = True
        node_info.iloc[test_split, -1] = True

        data.train_mask = torch.tensor(data.x.shape[0] * [False], dtype=torch.bool)
        data.val_mask = torch.tensor(data.x.shape[0] * [False], dtype=torch.bool)
        data.test_mask = torch.tensor(data.x.shape[0] * [False], dtype=torch.bool)

        # iterate over column 'label' of node_info
        for index, row in node_info.iterrows():
            # get value of column 'label'
            reaction_id = row['label']
            if row['train_mask']:
                data.train_mask[dictionary_reactions[reaction_id]] = True
            if row['val_mask']:
                data.val_mask[dictionary_reactions[reaction_id]] = True
            if row['test_mask']:
                data.test_mask[dictionary_reactions[reaction_id]] = True

        ###########################

        self.data = data
        slices = {}
        slices['train'] = data.train_mask
        slices['valid'] = data.val_mask
        slices['test'] = data.test_mask

        torch.save((data, None), self.processed_paths[0])
        torch.save(slices, self.processed_paths[1])


    def get_idx_split(self):
        return torch.load(self.processed_paths[1])
