from datasets import Dataset
import torch
from sklearn.model_selection import train_test_split
from datasets import DatasetDict
from rdkit import RDLogger
import gzip
import os
import pandas as pd
from tqdm import tqdm
from rdkit import Chem


def random_data_split(dataset, task):
    y = [dt['prop_labels'] for dt in dataset]
    y = torch.tensor(y)
    nan_count = torch.isnan(y[:, 0]).sum().item()
    labeled_len = len(dataset) - nan_count
    full_idx = list(range(labeled_len))
    train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
    train_index, test_index = train_test_split(full_idx, test_size=test_ratio, random_state=42)
    train_index, val_index = train_test_split(train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
    unlabeled_index = list(range(labeled_len, len(dataset)))
    print(task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
    return train_index, val_index, test_index, unlabeled_index
    
    
def fixed_split(dataset, task):
    if task == 'O2-N2':
        test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
    else:
        raise ValueError('Invalid task name: {}'.format(task))
    full_idx = list(range(len(dataset)))
    full_idx = list(set(full_idx) - set(test_index))
    train_ratio = 0.8
    train_index, val_index = train_test_split(full_idx, test_size=1-train_ratio, random_state=42)
    print(task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
    return train_index, val_index, test_index, []


def preprocess_dataset(args):
    RDLogger.DisableLog('rdApp.*')
    data_path = os.path.join(args.root_path, f"{args.data_name}.csv.gz")
    # unzip the file
    with gzip.open(data_path, 'rb') as f:
        data = pd.read_csv(f)
    
    target_prop = args.target.split('-')
    data_list = []
    with tqdm(total=len(data)) as pbar:
        for i, row in data.iterrows():
            
            # import pdb; pdb.set_trace()

            smiles = row["smiles"]
            mol = Chem.MolFromSmiles(smiles, sanitize=False)
            if mol is None:
                continue

            entry = {
                "smiles": Chem.MolToSmiles(mol)
            }
            print(f'smiles: {entry["smiles"]}')
            y = []
            for target in target_prop:
                entry[target] = float(row[target])
                y.append(float(row[target]))
            entry["prop_labels"] = y
            entry["sas"] = float(row["SA"])
            entry["scs"] = float(row["SC"])
            data_list.append(entry)
            pbar.update(1)

    if len(args.data_name.split('-')) == 2:
        train_index, val_index, test_index, unlabeled_index = fixed_split(data_list, args.data_name)
    else:
        train_index, val_index, test_index, unlabeled_index = random_data_split(data_list, args.data_name)

    train_index = torch.LongTensor(train_index)
    val_index = torch.LongTensor(val_index)
    test_index = torch.LongTensor(test_index)
    unlabeled_index = torch.LongTensor(unlabeled_index)

    if len(unlabeled_index) > 0:
        train_index = torch.cat([train_index, unlabeled_index], dim=0)
    
    train_data = {k: [d[k] for d in [data_list[i.item()] for i in train_index]] for k in data_list[0].keys()}
    val_data = {k: [d[k] for d in [data_list[i.item()] for i in val_index]] for k in data_list[0].keys()}
    test_data = {k: [d[k] for d in [data_list[i.item()] for i in test_index]] for k in data_list[0].keys()}

    train_data["indices"] = train_index.tolist()
    val_data["indices"] = val_index.tolist()
    test_data["indices"] = test_index.tolist()

    train_dataset = Dataset.from_dict(train_data)
    val_dataset = Dataset.from_dict(val_data)
    test_dataset = Dataset.from_dict(test_data)
    
    dataset = DatasetDict({
        "train": train_dataset,
        "validation": val_dataset,
        "test": test_dataset
    })
    
    dataset.save_to_disk(f"data/{args.data_name}_processed")
    return dataset
        