import torch
from dgl.data import DGLDataset
import pandas as pd
from dgllife.utils import CanonicalAtomFeaturizer, smiles_to_bigraph
from utils import integer_label_protein
import numpy as np
import random
from utils import setup_seed


class DTIDataset(DGLDataset):
    def __init__(self, task_path='./dataset/mol_seq.csv'):
        self.csv_file = task_path
        super().__init__(name="synthetic")

    def process(self):

        data = pd.read_csv(self.csv_file)

        # Ligand Input
        self.smiles = data['SMILES'].tolist()
        
        # Target Input
        self.target_seq = data['sequence'].tolist()

        # Interaction GT
        self.labels = data['Action Label'].tolist()

        atom_featurizer = CanonicalAtomFeaturizer()

        self.graphs = []
        for sl in self.smiles:
            self.graphs.append(smiles_to_bigraph(sl, add_self_loop=True, node_featurizer=atom_featurizer))

        self.target_emd = []
        for seq in self.target_seq:
            self.target_emd.append(integer_label_protein(seq))

        self.labels = torch.LongTensor(self.labels)


    def __getitem__(self, i):
        return self.graphs[i], self.target_emd[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)



class AgentDTIDataset(DGLDataset):
    def __init__(self, version="binary", task_path='./dataset/mol_seq.csv', first_label=0, second_label=1):
        self.version = version
        self.csv_file = task_path
        self.first_label = first_label
        self.second_label = second_label
        super().__init__(name="synthetic")
        

    def process(self):

        data = pd.read_csv(self.csv_file)

        if self.version == "binary":

            frt_idx = data[data['Action Label']==self.first_label].index
            data["Action Label"][frt_idx] = 0

            sec_idx = data[data['Action Label']==self.second_label].index
            data["Action Label"][sec_idx] = 1

            self.masks = np.zeros(len(data), dtype=int)
            self.masks[frt_idx] = 1
            self.masks[sec_idx] = 1
        else:

            frt_idx = data[data['Action Label']==self.first_label].index
            sec_idx = data[data['Action Label']==self.second_label].index

            other_idx = data[(data['Action Label']!=self.first_label) & (data['Action Label']!=self.second_label)].index
            other_idx = other_idx.tolist()
            third_idx = random.sample(other_idx, min(len(frt_idx)+len(sec_idx), len(other_idx)))

            data["Action Label"][frt_idx] = 0
            data["Action Label"][sec_idx] = 1
            data["Action Label"][third_idx] = 2

            self.masks = np.zeros(len(data), dtype=int)
            self.masks[frt_idx] = 1
            self.masks[sec_idx] = 1
            self.masks[third_idx] = 1

        # Ligand Input
        self.smiles = data['SMILES'].tolist()
        
        # Target Input
        self.target_seq = data['sequence'].tolist()

        # Interaction GT
        self.labels = data['Action Label'].tolist()

        atom_featurizer = CanonicalAtomFeaturizer()

        self.graphs = []
        for sl in self.smiles:
            self.graphs.append(smiles_to_bigraph(sl, add_self_loop=True, node_featurizer=atom_featurizer))

        self.target_emd = []
        for seq in self.target_seq:
            self.target_emd.append(integer_label_protein(seq))

        self.labels = torch.LongTensor(self.labels)


    def __getitem__(self, i):
        return self.graphs[i], self.target_emd[i], self.labels[i]

    def __len__(self):
        return len(self.graphs)
