import rdkit.Chem as Chem
from typing import Dict
from torch.utils.data import Dataset, DataLoader
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig
from datasets import Dataset
from torch.utils.data import TensorDataset, DataLoader
import random
import os
from tqdm import tqdm
import pandas as pd
from sklearn.utils import shuffle
import time
from torch.nn.utils import clip_grad_norm_
import json
import datetime
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp

from deepspeed import get_accelerator
import logging
import deepspeed
import subprocess
import wandb
from sklearn.metrics import r2_score
import numpy as np

import pathlib

import os
import re
import shutil
def obtain_functional_group_changes(graph_knowledge):
    reactant_functional_groups = [
        entry for entry in graph_knowledge if "Reactant" in entry
    ]
    product_functional_groups = [
        entry for entry in graph_knowledge if "Product" in entry
    ]

    # print(product_functional_groups)
    def extract_functional_group_details(entries):
        details = {}
        for entry in entries:
            # Check if the entry contains a functional group and count
            if "Functional Group:" in entry and "Count:" in entry:
                # Extract the functional group type
                fg_type = entry.split("Functional Group: ")[1].split(",")[0]
                # Extract the count
                count = int(entry.split("Count: ")[1])
                # Accumulate counts for each functional group type
                if fg_type in details:
                    details[fg_type] += count
                else:
                    details[fg_type] = count
        return details

    reactant_fg_details = extract_functional_group_details(reactant_functional_groups)
    product_fg_details = extract_functional_group_details(product_functional_groups)

    # Determine new and lost functional groups
    new_functional_groups = []
    lost_functional_groups = []

    for fg, count in product_fg_details.items():
        if fg in reactant_fg_details:
            if count > reactant_fg_details[fg]:
                new_functional_groups.append(
                    f"Functional Group: {fg}. Increased by: {count - reactant_fg_details[fg]}"
                )
        elif count > 0:
            new_functional_groups.append(f"Functional Group: {fg}. New: {count}")

    for fg, count in reactant_fg_details.items():
        if fg in product_fg_details:
            if count > product_fg_details[fg]:
                lost_functional_groups.append(
                    f"Functional Group: {fg}. Decreased by: {count - product_fg_details[fg]}"
                )
        elif count > 0:
            lost_functional_groups.append(f"Functional Group: {fg}. Lost: {count}")

    # Append new and lost functional groups to graph_knowledge
    graph_knowledge_changes = {
        "New Functional Groups": (
            new_functional_groups if new_functional_groups else ["None"]
        ),
        "Lost Functional Groups": (
            lost_functional_groups if lost_functional_groups else ["None"]
        ),
    }

    return graph_knowledge_changes


def get_func_group(smiles, Role):
    func_groups_df = {
        "[R]O": "S2",
        "[R]C([R])=C([R])[R]": "S2",
        "[R]C#C[R]": "S2",
        "[R]C=O": "S2",
        "[R]C=S": "S2",
        "[R]C(=O)[R]": "S2",
        "[R]C(=S)[R]": "S2",
        "[R]OC(=O)O[R]": "S2",
        "[R]C(=O)[O-]": "S2",
        "[R]C(=O)O[R]": "S2",
        "[R]C(=O)S[R]": "S2",
        "[R]OS(=O)O[R]": "S2",
        "[R]OS(=O)(=O)O[R]": "S2",
        "OP(O)O": "S2",
        "OP(O)(=O)O[R]": "S2",
        "OP(O[R])(=O)O[R]": "S2",
        "[R]OP(O[R])(=O)O[R]": "S2",
        "[R]S(=O)[R]": "S2",
        "[R]C(=O)O": "S2",
        "[R]SO": "S2",
        "[R]S(=O)O": "S2",
        "[R]S(=O)(=O)O": "S2",
        "[R]S(=O)(=O)[R]": "S2",
        "[R]C(O)O[R]": "S2",
        "[R]C([R])(O)O[R]": "S2",
        "[R]C(O[R])O[R]": "S2",
        "[R]OO": "S2",
        "[R]OO[R]": "S2",
        "[R]O[R]": "S2",
        "[R]S[R]": "S2",
        "[R]SS[R]": "S2",
        "[R]C(=O)N([R])[R]": "S2",
        "[R]P(=O)([R])N([R])[R]": "S2",
        "[R]N": "S2",
        "[R]N[R]": "S2",
        "[R]N([R])[R]": "S2",
        "[R]C(=N)[R]": "S2",
        "[R]C([R])=N[R]": "S2",
        "[R]C=N": "S2",
        "[R]C=N[R]": "S2",
        "[R]C(=O)N([R])C(=O)[R]": "S2",
        "[R]N=[N+]=[N-]": "S2",
        "[R]N=N[R]": "S2",
        "[R]OC#N": "S2",
        "[R]N=C=O": "S2",
        "[R]O[N+](=O)[O-]": "S2",
        "[R][N+]#[C-]": "S2",
        "[R]ON=O": "S2",
        "[R]N(=O)=O": "S2",
        "[R]N=O": "S2",
        "[R]C=NO": "S2",
        "[R]C([R])=NO": "S2",
        "[R]OC(=O)N([R])[R]": "S2",
        "[R]C#N": "S2",
        "[R]F": "S2",
        "[R]Cl": "S2",
        "[R]Br": "S2",
        "[R]I": "S2",
        "[R]B(O)O": "S2",
    }

    func_group_info = []
    func_group = find_functional_groups(smiles, func_groups_df)
    name_to_smarts = dict(zip(func_groups_df.values(), func_groups_df.keys()))
    # 打印出找到的官能团信息
    for func_group_name, number in func_group.items():
        func_group_smarts = name_to_smarts[func_group_name]
        func_group_info.append(
            f"{Role}: [{smiles}]. Functional Group: {func_group_name}({func_group_smarts}). Count: {number}"
        )
    return func_group_info


def find_functional_groups(smiles: str, functional_groups: dict):
    """Extract functional group information from a SMILES string."""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES string: {smiles}")

    detected_groups = {}

    for group_smarts, group_name in functional_groups.items():
        patt = Chem.MolFromSmarts(group_smarts)
        if mol.HasSubstructMatch(patt):
            matches = mol.GetSubstructMatches(patt)
            detected_groups[group_name] = len(matches)

    return detected_groups


def convert_sample_to_prompt(data: Dict[str,str]) -> str:
    prompt_dict = {}
    r_dict = {"instruction": "", "input": "", "output": "", "history": []}
    product = Chem.MolToSmiles(
        Chem.MolFromSmiles(data['product'])
    )
    catalyst = Chem.MolToSmiles(
        Chem.MolFromSmiles(data['catalyst'])
    )
    reactant_1 = Chem.MolToSmiles(
        Chem.MolFromSmiles(data['reactant_1'])
    )
    reactant_2 = Chem.MolToSmiles(
        Chem.MolFromSmiles(data['reactant_2'])
    )
    ligand = Chem.MolToSmiles(
        Chem.MolFromSmiles(data['ligand'])
    )
    base = Chem.MolToSmiles(
        Chem.MolFromSmiles(data['base'])
    )
    solvent = Chem.MolToSmiles(
        Chem.MolFromSmiles(data['solvent'])
    )
    additive = Chem.MolToSmiles(
        Chem.MolFromSmiles(data['additive'])
    )

    prompt_dict["reaction"] = (
        f"Here is a chemical reaction. Reactants are: {reactant_1}, {reactant_2}. Product is: {product}."
    )
    prompt_dict["reaction_type"] = f"Reaction type is {data['reaction_type']}."
    prompt_dict["condition"] = (
        f"The reaction conditions of this reaction are: Solvent: {solvent}. Catalyst: {catalyst}. Ligand: {ligand}. Base: {base}. Additive: {additive}."
    )
    prompt_dict["graph_knowledge"] = []

    # 将分子添加到目录中，这会自动检测并记录所有匹配的官能团
    smiles = [reactant_1, reactant_2, product, solvent, catalyst, ligand, base, additive]
    roles = ["Reactant", "Reactant", "Product", "Solvent", "Catalyst", "Ligand", "Base","Additive"]
    for smile, role in zip(smiles, roles):
        # detected_groups = find_functional_groups(smiles, functional_groups)
        func_group_info = get_func_group(smile, role)
        prompt_dict["graph_knowledge"] = (
            prompt_dict["graph_knowledge"] + func_group_info
        )

    prompt_dict["graph_knowledge_changes"] = obtain_functional_group_changes(
        prompt_dict["graph_knowledge"]
    )

    fg_graph = ". ".join(prompt_dict["graph_knowledge"])
    fg_new = ". ".join(prompt_dict["graph_knowledge_changes"]["New Functional Groups"])
    fg_lost = ". ".join(
        prompt_dict["graph_knowledge_changes"]["Lost Functional Groups"]
    )
    fg_prompt = f"{fg_graph}. New Functional Groups: {fg_new}. Lost Functional Groups: {fg_lost}"

    instruct = (
        f"{prompt_dict['reaction']} {prompt_dict['reaction_type']} {prompt_dict['condition']}"
        f" Functional groups information: {fg_prompt}."
    )

    instruct = instruct + " What is the yield of this reaction?"
    return instruct


def read_json(path):
    import json
    with open(path, "r") as f:
        data = json.load(f)
    return data
def create_iterator(data):
    for item in data:
        yield item
class json_dataloader(Dataset):
    def __init__(self, json):
        self.json = json
        self.reaction = json["Reaction"]
        self.condition = json["Condition"]
        self.catalyst = "CC(=O)[O-].CC(=O)[O-].[Pd+2]"
        self.solvent = self.condition["Solvent"]
        self.ligand = self.condition["Ligand"]
        self.base = self.condition["Base"]
        self.reaction_type = json["ReactionType"]
        self.additive = self.condition["Additive"]
        self.reaction_1 = self.reaction["Reactant1"]
        self.reaction_2 = self.reaction["Reactant2"]
        self.product = self.reaction["Product"]
    def __len__(self):
        return len(self.reaction_1)*len(self.reaction_2)*len(self.solvent)*len(self.ligand)*len(self.base)*len(self.additive)
    def __getitem__(self, idx):
        # Handle both single index and list of indices
        if isinstance(idx, (list, range)):
            batch_prompts = []
            batch_indices = []
            for i in idx:
                item = self._get_single_item(i)
                batch_prompts.append(item['prompt'])
                batch_indices.append(item['index'])
            return {
                'prompts': batch_prompts,
                'indices': batch_indices
            }
        else:
            return self._get_single_item(idx)
    def _get_single_item(self, idx):
        r_idx = idx % 2
        s_idx = (idx // 2) % len(self.solvent)
        l_idx = (idx // 2 // len(self.solvent)) % len(self.ligand)
        b_idx = (idx // 2 // len(self.solvent) // len(self.ligand)) % len(self.base)
        a_idx = (idx // 2 // len(self.solvent) // len(self.ligand) // len(self.base)) % len(self.additive)
        
        prompt = convert_sample_to_prompt({
            "reactant_1": self.reaction_1[r_idx],
            "reactant_2": self.reaction_2[r_idx],
            "solvent": self.solvent[s_idx],
            "ligand": self.ligand[l_idx],
            "base": self.base[b_idx],
            "additive": self.additive[a_idx],
            "catalyst": self.catalyst,
            "product": self.product[r_idx],
            "reaction_type": self.reaction_type
        })
        return {
            'prompt': prompt,
            'index': idx
        }
    @staticmethod
    def get_a_idx(json,idx):
        r1_idx = idx % len(json["Reaction"]["Reactant1"])
        r2_idx = (idx // len(json["Reaction"]["Reactant1"])) % len(json["Reaction"]["Reactant2"])
        s_idx = (idx // len(json["Reaction"]["Reactant1"]) // len(json["Reaction"]["Reactant2"])) % len(json["Condition"]["Solvent"])
        l_idx = (idx // len(json["Reaction"]["Reactant1"]) // len(json["Reaction"]["Reactant2"]) // len(json["Condition"]["Solvent"])) % len(json["Condition"]["Ligand"])
        b_idx = (idx // len(json["Reaction"]["Reactant1"]) // len(json["Reaction"]["Reactant2"]) // len(json["Condition"]["Solvent"]) // len(json["Condition"]["Ligand"])) % len(json["Condition"]["Base"])
        a_idx = (idx // len(json["Reaction"]["Reactant1"]) // len(json["Reaction"]["Reactant2"]) // len(json["Condition"]["Solvent"]) // len(json["Condition"]["Ligand"]) // len(json["Condition"]["Base"])) % len(json["Condition"]["Additive"])
        return r1_idx,r2_idx,s_idx,l_idx,b_idx,a_idx
def get_device():
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def read_data_from_csv(path):
    data_df = pd.read_csv(path)
    dataset = Dataset.from_dict(data_df)
    return dataset

lora_config = LoraConfig(
    r=8,  # Rank of the low-rank matrix
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj",],
    #target_modules="all-linear",
    # bias="none",
    # modules_to_save=["classifier"]
    )
class YieldPredLayer(nn.Module):
    def __init__(self, input_size, hidden_size, output_size=1):
        super(YieldPredLayer, self).__init__()
        self.act = nn.SiLU()
        self.predictor = nn.Sequential(
                            nn.Linear(input_size, hidden_size),
                            # self.act,
                            # nn.Linear(hidden_size, hidden_size//4),
                            # # self.act,
                            nn.Linear(hidden_size, 1),
                        )
    def forward(self, x):
        pred = self.predictor(x)
        # print(f'pred:{pred.view(-1)}')
        # print(f'y: {y.view(-1)}')
        return pred
class LlamaWithLoss(nn.Module):
    def __init__(self, llama, predictor):
        super(LlamaWithLoss, self).__init__()
        self.llama = llama
        self.loss_func = torch.nn.MSELoss()

        self.predictor = predictor
        
    def forward(self, inputs, y, pooling_method='last_token', return_loss=True):
        outputs = self.llama(**inputs, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        if pooling_method == 'mean':
            embeddings = last_hidden_state.mean(dim=1)  # Mean pooling to get sentence-level embeddings
        elif pooling_method=='last_token':
            embeddings = last_hidden_state[:,-1,:]
        if return_loss:
            pred = self.predictor(embeddings)
            loss = self.loss_func(pred.view(-1),y.view(-1))
            return embeddings, loss
        else:
            pred = self.predictor(embeddings)
            return embeddings, pred
    def eval(self):
        return super().eval() 
    
def main(args): 

    device = get_device()
    output_path = args.output_path
    os.makedirs(output_path, exist_ok=True)
    # Construct the model and load the checkpoint
    model = AutoModel.from_pretrained(args.pretrained_model_path, device_map='auto')
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path)
    if args.lora:
        model = get_peft_model(model,lora_config)
    state_dict = get_fp32_state_dict_from_zero_checkpoint(args.checkpoint_dir)
    predictor = YieldPredLayer(4096,1024,1).cuda()
    model = LlamaWithLoss(model, predictor)
    model = model.cpu()
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    print("Model loaded")
    # Load the dataset
    data = read_json(args.json_path)
    dataset = json_dataloader(data)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    print("Dataset loaded")
    # Inference
    results = []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            prompts,idx = batch['prompts'], batch['indices']
            inputs = tokenizer(prompts, max_length=3000, padding='longest', truncation=True, return_tensors="pt").to(device)
            _ , pred = model(inputs, None, return_loss=False)
            results.append((idx, pred.detach().cpu().numpy()))
            if i % 1000 == 0:
                print(f"Processed {i} samples")
                # Save the results
                results_df = pd.DataFrame(results, columns=["idx", "yield"])
                results_df.to_csv(os.path.join(output_path, f"results_{i}.csv"), index=False)
                results = []


    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pretrained_model_path", type=str, default='bert-base-uncased', help="The path to the pretrained model")
    parser.add_argument("--lora", type=int, help="Use LoRA")
    parser.add_argument("--checkpoint_dir", type=str, help="The path to the checkpoint directory")
    parser.add_argument("--json_path", type=str, help="The path to the json file")
    parser.add_argument("--output_path", type=str, help="The path to the output directory")
    args = parser.parse_args()
    main(args)
   