import os
import torch

import numpy as np
import pandas as pd
from rdkit import Chem
import json
from sklearn.utils import shuffle
from rdkit.Chem import FragmentCatalog
from rdkit.Chem import RDConfig
from datasets import Dataset
from torch.utils.data import DataLoader
ood_13_1 = [ ('1a, 6-Cl-Q', '2a, Boronic Acid'),
('1a, 6-Cl-Q', '2b, Boronic Ester'),
('1a, 6-Cl-Q', '2c, Trifluoroborate'),
('1b, 6-Br-Q', '2a, Boronic Acid'),
('1b, 6-Br-Q', '2b, Boronic Ester'),
('1b, 6-Br-Q', '2c, Trifluoroborate'),
('1c, 6-OTf-Q', '2a, Boronic Acid'),
('1c, 6-OTf-Q', '2b, Boronic Ester'),
('1c, 6-OTf-Q', '2c, Trifluoroborate'),
('1d, 6-I-Q', '2a, Boronic Acid'),
('1d, 6-I-Q', '2b, Boronic Ester'),
('1d, 6-I-Q', '2c, Trifluoroborate'),
('1e, 6-BOH2-Q', '2d, Bromide'),
('1f, 6-BPin-Q', '2d, Bromide'),
('1g, 6-BF3K-Q', '2d, Bromide') 
 ]

def load_functional_groups_from_csv(csv_file):
    """Load functional group information from a CSV file, returning a dictionary of group names and SMARTS patterns."""
    fg_df = pd.read_csv(csv_file, header=None, names=['SMARTS', 'GroupName'])
    functional_groups = {row['SMARTS']: row['GroupName'] for _, row in fg_df.iterrows()}
    return functional_groups


def find_functional_groups(smiles: str, functional_groups: dict):
    """Extract functional group information from a SMILES string."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES string: {smiles}")

    detected_groups = {}

    for group_smarts, group_name in functional_groups.items():
        patt = Chem.MolFromSmarts(group_smarts)
        if mol.HasSubstructMatch(patt):
            matches = mol.GetSubstructMatches(patt)
            detected_groups[group_name] = len(matches)

    return detected_groups



def extract_latents_of_percond_inputs(raw_df, save_json_path):
    reactant_1_smiles = {
        '6-chloroquinoline': 'C1=C(Cl)C=CC2=NC=CC=C12.CCC1=CC(=CC=C1)CC',
        '6-Bromoquinoline': 'C1=C(Br)C=CC2=NC=CC=C12.CCC1=CC(=CC=C1)CC',
        '6-triflatequinoline': 'C1C2C(=NC=CC=2)C=CC=1OS(C(F)(F)F)(=O)=O.CCC1=CC(=CC=C1)CC',
        '6-Iodoquinoline': 'C1=C(I)C=CC2=NC=CC=C12.CCC1=CC(=CC=C1)CC',
        '6-quinoline-boronic acid hydrochloride': 'C1C(B(O)O)=CC=C2N=CC=CC=12.Cl.O',
        'Potassium quinoline-6-trifluoroborate': '[B-](C1=CC2=C(C=C1)N=CC=C2)(F)(F)F.[K+].O',
        '6-Quinolineboronic acid pinacol ester': 'B1(OC(C(O1)(C)C)(C)C)C2=CC3=C(C=C2)N=CC=C3.O'
    }

    reactant_2_smiles = {
        '2a, Boronic Acid': 'CC1=CC=C2C(C=NN2C3OCCCC3)=C1B(O)O',
        '2b, Boronic Ester': 'CC1=CC=C2C(C=NN2C3OCCCC3)=C1B4OC(C)(C)C(C)(C)O4',
        '2c, Trifluoroborate': 'CC1=CC=C2C(C=NN2C3OCCCC3)=C1[B-](F)(F)F.[K+]',
        '2d, Bromide': 'CC1=CC=C2C(C=NN2C3OCCCC3)=C1Br'
    }

    catalyst_smiles = {
        'Pd(OAc)2': 'CC(=O)O~CC(=O)O~[Pd]'
    }

    ligand_smiles = {
        'P(tBu)3': 'CC(C)(C)P(C(C)(C)C)C(C)(C)C',
        'P(Ph)3 ': 'c3c(P(c1ccccc1)c2ccccc2)cccc3',
        'AmPhos': 'CC(C)(C)P(C1=CC=C(C=C1)N(C)C)C(C)(C)C',
        'P(Cy)3': 'C1(CCCCC1)P(C2CCCCC2)C3CCCCC3',
        'P(o-Tol)3': 'CC1=CC=CC=C1P(C2=CC=CC=C2C)C3=CC=CC=C3C',
        'CataCXium A': 'CCCCP(C12CC3CC(C1)CC(C3)C2)C45CC6CC(C4)CC(C6)C5',
        'SPhos': 'COc1cccc(c1c2ccccc2P(C3CCCCC3)C4CCCCC4)OC',
        'dtbpf': 'CC(C)(C)P(C1=CC=C[CH]1)C(C)(C)C.CC(C)(C)P(C1=CC=C[CH]1)C(C)(C)C.[Fe]',
        'XPhos': 'P(c2ccccc2c1c(cc(cc1C(C)C)C(C)C)C(C)C)(C3CCCCC3)C4CCCCC4',
        'dppf': 'C1=CC=C(C=C1)P([C-]2C=CC=C2)C3=CC=CC=C3.C1=CC=C(C=C1)P([C-]2C=CC=C2)C3=CC=CC=C3.[Fe+2]',
        'Xantphos': 'O6c1c(cccc1P(c2ccccc2)c3ccccc3)C(c7cccc(P(c4ccccc4)c5ccccc5)c67)(C)C',
        'None': ''
    }

    reagent_1_smiles = {
        'NaOH': '[OH-].[Na+]',
        'NaHCO3': '[Na+].OC([O-])=O',
        'CsF': '[F-].[Cs+]',
        'K3PO4': '[K+].[K+].[K+].[O-]P([O-])([O-])=O',
        'KOH': '[K+].[OH-]',
        'LiOtBu': '[Li+].[O-]C(C)(C)C',
        'Et3N': 'CCN(CC)CC',
        'None': ''
    }

    solvent_1_smiles = {
        'MeCN': 'CC#N.O',
        'THF': 'C1CCOC1.O',
        'DMF': 'CN(C)C=O.O',
        'MeOH': 'CO.O',
        'MeOH/H2O_V2 9:1': 'CO.O',
        'THF_V2': 'C1CCOC1.O'
    }

    file_path = './functional_groups_smiles_codes.csv'
    func_groups = load_functional_groups_from_csv(file_path)

    results = []
    results_raw_prompt = []
    for i in range(len(raw_df)):
        prompt_dict = {}
        r_dict = {'instruction': '', 'input': '', 'output': "", 'history':[]}
        row = raw_df.iloc[i]
        yields = row['Product_Yield_PCT_Area_UV']

        catalyst = Chem.MolToSmiles(Chem.MolFromSmiles(catalyst_smiles[row['Catalyst_1_Short_Hand']]))
        reactant_1 = Chem.MolToSmiles(Chem.MolFromSmiles(reactant_1_smiles[row['Reactant_1_Name']]))
        reactant_2 = Chem.MolToSmiles(Chem.MolFromSmiles(reactant_2_smiles[row['Reactant_2_Name']]))
        ligand = Chem.MolToSmiles(Chem.MolFromSmiles(ligand_smiles[row['Ligand_Short_Hand']]))
        reagent = Chem.MolToSmiles(Chem.MolFromSmiles(reagent_1_smiles[row['Reagent_1_Short_Hand']]))
        solvent = Chem.MolToSmiles(Chem.MolFromSmiles(solvent_1_smiles[row['Solvent_1_Short_Hand']]))
        # print(row)
        # yields = row['yield']
        # catalyst = Chem.MolToSmiles(Chem.MolFromSmiles(catalyst_smiles[row['catalyst']]))
        # reactant_1 = Chem.MolToSmiles(Chem.MolFromSmiles(reactant_1_smiles[row['electrophile']]))
        # reactant_2 = Chem.MolToSmiles(Chem.MolFromSmiles(reactant_2_smiles[row['nuceophile']]))
        # ligand = Chem.MolToSmiles(Chem.MolFromSmiles(ligand_smiles[row['ligand']]))
        # reagent = Chem.MolToSmiles(Chem.MolFromSmiles(reagent_1_smiles[row['base']]))
        # solvent = Chem.MolToSmiles(Chem.MolFromSmiles(solvent_1_smiles[row['solvent']]))

        # catalyst = catalyst_smiles[row['Catalyst_1_Short_Hand']]
        # reactant_1 = reactant_1_smiles[row['Reactant_1_Name']]
        # reactant_2 = reactant_2_smiles[row['Reactant_2_Name']]
        # ligand = ligand_smiles[row['Ligand_Short_Hand']]
        # reagent = reagent_1_smiles[row['Reagent_1_Short_Hand']]
        # solvent = solvent_1_smiles[row['Solvent_1_Short_Hand']]

        product = 'C1=C(C2=C(C)C=CC3N(C4OCCCC4)N=CC2=3)C=CC2=NC=CC=C12'
        can_product = Chem.MolToSmiles(Chem.MolFromSmiles(product))
        # can_product = product
        '''
        Here is a chemical reaction. Reactants are: ClC=1SC2=C(C=NC(=C2)Cl)N1,C1(CC1)B(O)O. 
        Product is: ClC1=CC2=C(C=N1)N=C(S2)C2CC2. Reaction type is Chloro Suzuki coupling.
        The reaction conditions of this reaction are: Solvent: O,C1(=CC=CC=C1)C. 
        Catalyst: C1=CC=C(C=C1)P(C1(C=CC=C1)[Fe]C1(P(C2=CC=CC=C2)C2=CC=CC=C2)C=CC=C1)C1=CC=CC=C1.Cl[Pd]Cl. 
        Atmosphere: N#N. 
        Additive: [Cs]OC(=O)O[Cs].
        Functional Group: Alkene([R]C([R])=C([R])[R]). Number: 2. Reactant:[ClC=1SC2=C(C=NC(=C2)Cl)N1]
        '''
        prompt_dict['reaction'] = f'Here is a chemical reaction. Reactants are: {reactant_1}, {reactant_2}. Product is: {can_product}.'
        prompt_dict['reaction_type'] = 'Reaction type is Suzuki Miyaura.'
        prompt_dict['condition'] = f'The reaction conditions of this reaction are: Solvent: {solvent}. Catalyst: {catalyst}. Ligand: {ligand}. Base: {reagent}.'
        prompt_dict['graph_knowledge'] = []

        # 将分子添加到目录中，这会自动检测并记录所有匹配的官能团
        smiles = [reactant_1, reactant_2, can_product, solvent, catalyst, ligand, reagent]
        roles = ['Reactant', 'Reactant', "Product", 'Solvent', 'Catalyst', 'Ligand', 'Base']
        for smile, role in zip(smiles, roles):
            # detected_groups = find_functional_groups(smiles, functional_groups)
            func_group_info = get_func_group(func_groups, smile, role)
            prompt_dict['graph_knowledge'] = prompt_dict['graph_knowledge'] + func_group_info

        prompt_dict['graph_knowledge_changes'] = obtain_functional_group_changes(prompt_dict['graph_knowledge'])
        results_raw_prompt.append(prompt_dict)

        fg_graph = '. '.join(prompt_dict['graph_knowledge'])
        fg_new = '. '.join(prompt_dict['graph_knowledge_changes']['New Functional Groups'])
        fg_lost = '. '.join(prompt_dict['graph_knowledge_changes']['Lost Functional Groups'])
        fg_prompt = f"{fg_graph}. New Functional Groups: {fg_new}. Lost Functional Groups: {fg_lost}"


        instruct =   (f"{prompt_dict['reaction']} {prompt_dict['reaction_type']} {prompt_dict['condition']}"
                      f" Functional groups information: {fg_prompt}.")

        instruct = instruct + " What is the yield of this reaction?"

        r_dict['instruction'] = instruct
        # r_dict['output'] = str(int(yields))
        r_dict['output'] = yields

        results.append(r_dict)
    json.dump(results, open(save_json_path, 'w'))
    return results, results_raw_prompt




def obtain_functional_group_changes(graph_knowledge):
    reactant_functional_groups = [entry for entry in graph_knowledge if "Reactant" in entry]
    product_functional_groups = [entry for entry in graph_knowledge if "Product" in entry]
    # print(product_functional_groups)
    def extract_functional_group_details(entries):
        details = {}
        for entry in entries:
            # Check if the entry contains a functional group and count
            if "Functional Group:" in entry and "Count:" in entry:
                # Extract the functional group type
                fg_type = entry.split("Functional Group: ")[1].split(",")[0]
                # Extract the count
                count = int(entry.split("Count: ")[1])
                # Accumulate counts for each functional group type
                if fg_type in details:
                    details[fg_type] += count
                else:
                    details[fg_type] = count
        return details

    reactant_fg_details = extract_functional_group_details(reactant_functional_groups)
    product_fg_details = extract_functional_group_details(product_functional_groups)



    # Determine new and lost functional groups
    new_functional_groups = []
    lost_functional_groups = []

    for fg, count in product_fg_details.items():
        if fg in reactant_fg_details:
            if count > reactant_fg_details[fg]:
                new_functional_groups.append(f"Functional Group: {fg}. Increased by: {count - reactant_fg_details[fg]}")
        elif count > 0:
            new_functional_groups.append(f"Functional Group: {fg}. New: {count}")

    for fg, count in reactant_fg_details.items():
        if fg in product_fg_details:
            if count > product_fg_details[fg]:
                lost_functional_groups.append(f"Functional Group: {fg}. Decreased by: {count - product_fg_details[fg]}")
        elif count > 0:
            lost_functional_groups.append(f"Functional Group: {fg}. Lost: {count}")

    # Append new and lost functional groups to graph_knowledge
    graph_knowledge_changes = {"New Functional Groups": new_functional_groups if new_functional_groups else ["None"],
                                "Lost Functional Groups": lost_functional_groups if lost_functional_groups else ["None"]}

    return graph_knowledge_changes



def get_func_group(func_groups_df, smiles, Role):
    func_group_info = []
    func_group = find_functional_groups(smiles, func_groups_df)
    name_to_smarts = dict(zip(func_groups_df.values(), func_groups_df.keys()))
    # 打印出找到的官能团信息
    for func_group_name, number in func_group.items():
        func_group_smarts = name_to_smarts[func_group_name]
        func_group_info.append(f'{Role}: [{smiles}]. Functional Group: {func_group_name}({func_group_smarts}). Count: {number}')
    return func_group_info

from sklearn.cluster import KMeans

def select_representatives(df, n_clusters=10, samples_per_cluster=10, random_state=42):
    """
    将DataFrame中的分类变量转化为one-hot编码，并聚类为指定数量的类别。
    然后从每个类别中选出指定数量的代表样本，并返回这些样本的原始索引。

    参数:
    df (pd.DataFrame): 输入的数据框，包含需要转换和聚类的数据。
    n_clusters (int): 聚类的数量，默认是10。
    samples_per_cluster (int): 每个类别中选择的代表样本数量，默认是10。
    random_state (int): 随机状态种子，用于保证结果的可重复性。

    返回:
    list: 选出来的代表数据的原始索引列表。
    """

    # 检查是否有缺失值，如果有，可以考虑填充或删除。
    # 这里简单地跳过此步骤，假设数据已经清理好。

    # 将分类变量转换为one-hot编码
    df_encoded = pd.get_dummies(df)

    # 使用KMeans聚类
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
    cluster_labels = kmeans.fit_predict(df_encoded)

    # 创建一个新的DataFrame来保存聚类标签
    df_with_clusters = df_encoded.copy()
    df_with_clusters['cluster'] = cluster_labels

    representatives_indices = []
    for cluster_id in range(n_clusters):
        # 获取属于当前类的所有数据的索引
        indices_in_cluster = df_with_clusters[df_with_clusters['cluster'] == cluster_id].index

        if len(indices_in_cluster) >= samples_per_cluster:
            # 如果该类的数据量大于等于所需样本数，则随机选取指定数量的样本作为代表
            selected_indices = pd.Series(indices_in_cluster).sample(n=samples_per_cluster,
                                                                    random_state=random_state).tolist()
        else:
            # 如果少于所需样本数，则全部选取
            selected_indices = indices_in_cluster.tolist()

        representatives_indices.extend(selected_indices)

    return representatives_indices


if __name__ == '__main__':
    # df = pd.read_excel('D:\\Workspace\\PhD_workspace\\rxn_yields\\data\\Suzuki-Miyaura\\aap9112_Data_File_S1.xlsx')
    print(os.curdir)
    df = pd.read_excel('/Users/donghan/code/yield_prediction_preprocessing/data/suzuki/aap9112_data_file_s1.xlsx')
    train_size = 30
    tag = 'suzuki_miyaura_fg_changes'
    select_method = 'reactants_ood'
    if select_method == 'random':
        ids = shuffle(np.arange(len(df)), random_state=42)
        train_idx = ids[:train_size]
        test_idx = ids[train_size:]
    elif select_method == 'cluster':
        new_df = df[['Reactant_1_Name', 'Reactant_2_Name', 'Ligand_Short_Hand', 'Catalyst_1_Short_Hand', 'Reagent_1_Short_Hand', 'Solvent_1_Short_Hand']]
        selected_indices = select_representatives(new_df, n_clusters=round(train_size/10), samples_per_cluster=10, random_state=42)
        train_idx = selected_indices

        all_indices = np.arange(len(new_df))
        mask = ~np.isin(all_indices, selected_indices)
        test_idx = all_indices[mask]
        test_idx = np.array(test_idx)
        train_idx = np.array(train_idx)
        print(train_idx)
        print(test_idx)

    elif select_method == 'reactants_ood':
        raw_df = pd.read_csv('/Users/donghan/code/yield_prediction_preprocessing/data/suzuki/experiment_index.csv')
        print(raw_df.head())
        train_reactants = [tuple(i) for i in ood_13_1]
        test_df = raw_df[~raw_df[["electrophile","nucleophile"]].apply(tuple, 1).isin(train_reactants)]
        train_idx = raw_df[raw_df[["electrophile","nucleophile"]].apply(tuple, 1).isin(train_reactants)].index
        test_idx = test_df.index
        print(len(train_idx))
        print(len(test_idx))

    data_info_dict = {
        "train_idx": train_idx,
        "val_idx": test_idx
    }
    train_size = len(train_idx)

    if not os.path.exists(f'data/data4regression/{tag}_{train_size}'):
        os.mkdir(f'data/data4regression/{tag}_{train_size}')
    torch.save(data_info_dict, f'./data/data4regression/{tag}_{train_size}/split_idx.pt')


    results, raw_dict = extract_latents_of_percond_inputs(df, save_json_path='./suzuki_miyaura_func.json')
    res_lens = [len(i['instruction']) for i in results]
    print('max lens', max(res_lens), '. mean lens', np.mean(res_lens))


    train_set = [results[i] for i in train_idx]
    test_set = [results[i] for i in test_idx]

    train_df = pd.DataFrame(train_set)
    test_df = pd.DataFrame(test_set)
    all_df = pd.DataFrame(results)

    all_df.to_csv(f'data/data4regression/{tag}_{train_size}/all.csv', index=False)
    train_df.to_csv(f'data/data4regression/{tag}_{train_size}/train.csv', index=False)
    test_df.to_csv(f'data/data4regression/{tag}_{train_size}/test.csv', index=False)
    # json.dump(train_set, open('./data/suzuki_miyaura_func_train_full5700.json', 'w'))
    # json.dump(test_set, open('./data/suzuki_miyaura_func_test_full5700.json', 'w'))


