
import sys
import os
import pandas as pd
import numpy as np
import torch
from torch.utils.data import WeightedRandomSampler
from utils.utils import to_torch_dataloader



class ActiveLearningDataset():
    def __init__(self, args):
        self.args = args
        
        self.smiles_index, self.index_smiles, self.smiles, self.fp, self.y, self.graph = self.load()

        self.idx_train, self.idx_screen = None, None
        self.x_train, self.y_train, self.smiles_train = None, None, None
        self.x_screen, self.y_screen, self.smiles_screen = None, None, None

    def load(self):
        print('Loading Data!')
        if self.args.mode == "a":
            path = os.path.join("data", self.args.dataset, "screen")
        elif self.args.mode == "d":
            path = os.path.join("data_pcba", self.args.dataset, "screen")

        index_smiles = torch.load(os.path.join(path, 'index_smiles'))
        smiles_index = torch.load(os.path.join(path, 'smiles_index'))
        smiles = torch.load(os.path.join(path, 'smiles'))
        fp = torch.load(os.path.join(path, 'x'))
        y = torch.load(os.path.join(path, 'y'))
        graph = torch.load(os.path.join(path, 'graphs2'))

        return smiles_index, index_smiles, smiles, fp, y, graph

    def construct_dataloader(self):
        class_weights = [1 - sum(self.y_train == 0) / len(self.y_train), 1 - sum(self.y_train == 1) / len(self.y_train)]   # 0.02，0.98
        weights = [class_weights[i] for i in self.y_train]   # Get class weight to build a weighted random sampler to balance out this data
        sampler = WeightedRandomSampler(weights, num_samples=len(self.y_train), replacement=True)

        train_loader = to_torch_dataloader(self.x_train, self.y_train, batch_size=self.args.infer_batch_size, shuffle=False, pin_memory=True)
        train_loader_balanced = to_torch_dataloader(self.x_train, self.y_train, batch_size=self.args.train_batch_size, sampler=sampler, shuffle=False, pin_memory=True)
        screen_loader = to_torch_dataloader(self.x_screen, self.y_screen, batch_size=self.args.infer_batch_size, shuffle=False, pin_memory=True)

        return train_loader, train_loader_balanced, screen_loader
    

    def get_start_data(self):
        self.idx_train, self.idx_screen = self.get_start_idx()
        self.x_train, self.y_train, self.smiles_train = self.idx_to_data(self.idx_train)
        self.x_screen, self.y_screen, self.smiles_screen = self.idx_to_data(self.idx_screen)

    def add_data(self, smiles_pick):
        self.idx_train, self.idx_screen = self.add_idx(smiles_pick)
        self.x_train, self.y_train, self.smiles_train = self.idx_to_data(self.idx_train)
        self.x_screen, self.y_screen, self.smiles_screen = self.idx_to_data(self.idx_screen)


    def get_start_idx(self):
        if self.args.mode == "a":
            # get random active mol to start
            hit_idx = np.where(self.y == 1)[0]
            select_hit_idx = hit_idx[np.random.choice(len(hit_idx), size=self.args.start_active_num, replace=False)]

            # get other random mol
            remain_other_idx = np.array([i for i in range(len(self.y)) if i not in select_hit_idx])
            select_other_idx = np.random.choice(remain_other_idx, size=self.args.start_num-self.args.start_active_num)

            idx_train = np.concatenate((select_hit_idx, select_other_idx))
            idx_train = np.random.permutation(idx_train)

        elif self.args.mode == "e":
            select_hit_and_other_idx = np.arange(0, self.args.start_num, 1)
            idx_train = select_hit_and_other_idx
            idx_train = np.random.permutation(idx_train)

        elif self.args.mode == "d":
            # get random active mol to start
            hit_idx = np.where(self.y == 1)[0]
            select_hit_idx = hit_idx[np.random.choice(len(hit_idx), size=self.args.start_active_num, replace=False)]

            # get other random inactive mol
            remain_inactive_idx = np.array([i for i in range(len(self.y)) if i not in hit_idx])
            select_inactive_idx = np.random.choice(remain_inactive_idx, size=self.args.start_num-self.args.start_active_num)

            idx_train = np.concatenate((select_hit_idx, select_inactive_idx))
            idx_train = np.random.permutation(idx_train)

        # idx left are screen_idx
        idx_screen = np.array([i for i in range(len(self.y)) if i not in idx_train])
        assert len(np.intersect1d(idx_screen, idx_train)) == 0, "Something went wrong selecting train/screen samples"

        return idx_train, idx_screen
    
    def add_idx(self, smiles_pick):
        pick_idx = np.array([self.smiles_index[smi] for smi in smiles_pick])

        idx_train = np.concatenate((self.idx_train, pick_idx))
        idx_screen = np.array([i for i in range(len(self.y)) if i not in idx_train])

        return idx_train, idx_screen
    
    def idx_to_data(self, idx):
        if type(idx) is int:
            idx = [idx]
        if self.args.architecture == 'mlp':
            return self.fp[idx], self.y[idx], self.smiles[idx]
        else:
            return [self.graph[i] for i in idx], self.y[idx], self.smiles[idx]
        

