'''
For evaluation
'''
import argparse
import pandas as pd
from utils.evaluation import mol_prop, calculate_novelty, calculate_similarity
from tqdm import tqdm
import logging

SLICE_NUM = 100

parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, default="llama3.1-8B")

# dataset settings
parser.add_argument("--benchmark", type=str, default="open_generation")
parser.add_argument("--task", type=str, default="MolCustom")
parser.add_argument("--subtask", type=str, default="AtomNum")

parser.add_argument("--output_dir", type=str, default="./predictions/")
parser.add_argument("--calc_novelty", action="store_true", default=True)
parser.add_argument("--log_file", type=str, required=True)

args = parser.parse_args()

log_file = args.log_file


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 保存log到指定文件

file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

raw_file = "../datasets/TOMG-Bench/benchmarks/{}/{}/{}/test.csv".format(args.benchmark, args.task, args.subtask)
target_file = args.output_dir + args.subtask + ".csv"

data = pd.read_csv(raw_file)
data = data[:SLICE_NUM]
try:
    target = pd.read_csv(target_file)
except:
    target = pd.read_csv(target_file, engine='python')

if args.benchmark == "open_generation":
    if args.task == "MolCustom":
        if args.subtask == "AtomNum":
            # accuracy
            atom_type = ['carbon', 'oxygen', 'nitrogen', 'sulfur', 'fluorine', 'chlorine', 'bromine', 'iodine', 'phosphorus', 'boron', 'silicon', 'selenium', 'tellurium', 'arsenic', 'antimony', 'bismuth', 'polonium']
            flags = []
            valid_molecules = []
            
            # use tqdm to show the progress
            for idx in tqdm(range(len(data))):
                if mol_prop(target["outputs"][idx], "validity"):
                    valid_molecules.append(target["outputs"][idx])
                    flag = 1
                    for atom in atom_type:
                        if mol_prop(target["outputs"][idx], "num_" + atom) != int(data[atom][idx]):
                            flag = 0
                            break
                    flags.append(flag)
                else:
                    flags.append(0)
                # Novelty
                # novelty = mol_prop(target["outputs"][idx], "novelty")
                # if novelty is not None:
                #     novelties.append(novelty)
                
            
            
            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Accuracy: %f", sum(flags) / len(flags))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(flags) if flag == 1]
            logger.info("index_true: %s", index_true)
            
            if args.calc_novelty:
                novelties = calculate_novelty(valid_molecules)
                logger.info("Novelty: %f", sum(novelties) / len(novelties))
            logger.info("Validty: %f", len(valid_molecules) / len(flags))
                
        elif args.subtask == "FunctionalGroup":
            functional_groups = ['benzene rings', 'hydroxyl', 'anhydride', 'aldehyde', 'ketone', 'carboxyl', 'ester', 'amide', 'amine', 'nitro', 'halo', 'nitrile', 'thiol', 'sulfide', 'disulfide', 'sulfoxide', 'sulfone', 'borane']
            flags = []
            valid_molecules = []
            for idx in tqdm(range(len(data))):
                if mol_prop(target["outputs"][idx], "validity"):
                    valid_molecules.append(target["outputs"][idx])
                    flag = 1
                    for group in functional_groups:
                        if group == "benzene rings":
                            if mol_prop(target["outputs"][idx], "num_benzene_ring") != int(data[group][idx]):
                                flag = 0
                                break
                        else:
                            if mol_prop(target["outputs"][idx], "num_" + group) != int(data[group][idx]):
                                flag = 0
                                break
                    flags.append(flag)
                else:
                    flags.append(0)
                
                
            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Accuracy: %f", sum(flags) / len(flags))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(flags) if flag == 1]
            logger.info("index_true: %s", index_true)
            if args.calc_novelty:
                novelties = calculate_novelty(valid_molecules)
                logger.info("Novelty: %f", sum(novelties) / len(novelties))
            logger.info("Validty: %f", len(valid_molecules) / len(flags))

        elif args.subtask == "BondNum":
            bonds_type = ['single', 'double', 'triple', 'rotatable', 'aromatic']
            flags = []
            valid_molecules = []
            for idx in tqdm(range(len(data))):
                if mol_prop(target["outputs"][idx], "validity"):
                    valid_molecules.append(target["outputs"][idx])
                    flag = 1
                    for bond in bonds_type:
                        if bond == "rotatable":
                            if int(data[bond][idx]) == 0:
                                continue
                            elif mol_prop(target["outputs"][idx], "rot_bonds") != int(data[bond][idx]):
                                flag = 0
                                break
                        else:
                            if int(data[bond][idx]) == 0:
                                continue
                            elif mol_prop(target["outputs"][idx], "num_" + bond + "_bonds") != int(data[bond][idx]):
                                flag = 0
                                break
                    flags.append(flag)
                else:
                    flags.append(0)
                
            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Accuracy: %f", sum(flags) / len(flags))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(flags) if flag == 1]
            logger.info("index_true: %s", index_true)
            if args.calc_novelty:
                novelties = calculate_novelty(valid_molecules)
                logger.info("Novelty: %f", sum(novelties) / len(novelties))
            logger.info("Validty: %f", len(valid_molecules) / len(flags))

    elif args.task == "MolEdit":
        if args.subtask == "AddComponent":
            valid_molecules = []
            successed = []
            similarities = []
            for idx in tqdm(range(len(data))):
                raw = data["molecule"][idx]
                group = data["added_group"][idx]
                if group == "benzene ring":
                    group = "benzene_ring"
                target_mol = target["outputs"][idx]
                if mol_prop(target_mol, "validity"):
                    valid_molecules.append(target_mol)

                    if mol_prop(target_mol, "num_" + group) == mol_prop(raw, "num_" + group) + 1:
                        successed.append(1)
                    else:
                        successed.append(0)

                    similarities.append(calculate_similarity(raw, target_mol))
                else:
                    successed.append(0)

            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Success Rate: %f", sum(successed) / len(successed))
            logger.info("Similarity: %f", sum(similarities) / len(similarities))
            logger.info("Validty: %f", len(valid_molecules) / len(data))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(successed) if flag == 1]
            logger.info("index_true: %s", index_true)
        elif args.subtask == "DelComponent":
            valid_molecules = []
            successed = []
            similarities = []
            for idx in tqdm(range(len(data))):
                raw = data["molecule"][idx]
                group = data["removed_group"][idx]
                if group == "benzene ring":
                    group = "benzene_ring"
                target_mol = target["outputs"][idx]
                if mol_prop(target_mol, "validity"):
                    valid_molecules.append(target_mol)

                    if mol_prop(target_mol, "num_" + group) == mol_prop(raw, "num_" + group) - 1:
                        successed.append(1)
                    else:
                        successed.append(0)

                    similarities.append(calculate_similarity(raw, target_mol))
                else:
                    successed.append(0)

            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Success Rate: %f", sum(successed) / len(successed))
            logger.info("Similarity: %f", sum(similarities) / len(similarities))
            logger.info("Validty: %f", len(valid_molecules) / len(data))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(successed) if flag == 1]
            logger.info("index_true: %s", index_true)
            
        elif args.subtask == "SubComponent":
            valid_molecules = []
            successed = []
            similarities = []
            for idx in tqdm(range(len(data))):
                raw = data["molecule"][idx]
                added_group = data["added_group"][idx]
                removed_group = data["removed_group"][idx]
                if added_group == "benzene ring":
                    added_group = "benzene_ring"
                if removed_group == "benzene ring":
                    removed_group = "benzene_ring"

                target_mol = target["outputs"][idx]

                if mol_prop(target_mol, "validity"):
                    valid_molecules.append(target_mol)

                    if mol_prop(target_mol, "num_" + removed_group) == mol_prop(raw, "num_" + removed_group) - 1 and mol_prop(target_mol, "num_" + added_group) == mol_prop(raw, "num_" + added_group) + 1:
                        successed.append(1)
                    else:
                        successed.append(0)

                    similarities.append(calculate_similarity(raw, target_mol))
                else:
                    successed.append(0)

            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Success Rate: %f", sum(successed) / len(successed))
            logger.info("Similarity: %f", sum(similarities) / len(similarities))
            logger.info("Validty: %f", len(valid_molecules) / len(data))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(successed) if flag == 1]
            logger.info("index_true: %s", index_true)
            

    elif args.task == "MolOpt":
        if args.subtask == "LogP":
            valid_molecules = []
            successed = []
            similarities = []
            for idx in tqdm(range(len(data))):
                raw = data["molecule"][idx]
                target_mol = target["outputs"][idx]
                instruction = data["Instruction"][idx]
                if mol_prop(target_mol, "validity"):
                    valid_molecules.append(target_mol)
                    similarities.append(calculate_similarity(raw, target_mol))
                    if "lower" in instruction or "decrease" in instruction:
                        if mol_prop(target_mol, "logP") < mol_prop(raw, "logP"):
                            successed.append(1)
                        else:
                            successed.append(0)
                    else:
                        if mol_prop(target_mol, "logP") > mol_prop(raw, "logP"):
                            successed.append(1)
                        else:
                            successed.append(0)
                else:
                    successed.append(0)
            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Success Rate: %f", sum(successed) / len(successed))
            logger.info("Similarity: %f", sum(similarities) / len(similarities))
            logger.info("Validty: %f", len(valid_molecules) / len(data))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(successed) if flag == 1]
            logger.info("index_true: %s", index_true)

        elif args.subtask == "MR":
            valid_molecules = []
            successed = []
            similarities = []
            for idx in tqdm(range(len(data))):
                raw = data["molecule"][idx]
                target_mol = target["outputs"][idx]
                instruction = data["Instruction"][idx]
                if mol_prop(target_mol, "validity"):
                    valid_molecules.append(target_mol)
                    similarities.append(calculate_similarity(raw, target_mol))
                    if "lower" in instruction or "decrease" in instruction:
                        if mol_prop(target_mol, "MR") < mol_prop(raw, "MR"):
                            successed.append(1)
                        else:
                            successed.append(0)
                    else:
                        if mol_prop(target_mol, "MR") > mol_prop(raw, "MR"):
                            successed.append(1)
                        else:
                            successed.append(0)
                else:
                    successed.append(0)
            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Success Rate: %f", sum(successed) / len(successed))
            logger.info("Similarity: %f", sum(similarities) / len(similarities))
            logger.info("Validty: %f", len(valid_molecules) / len(data))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(successed) if flag == 1]
            logger.info("index_true: %s", index_true)
        elif args.subtask == "QED":
            valid_molecules = []
            successed = []
            similarities = []
            for idx in tqdm(range(len(data))):
                raw = data["molecule"][idx]
                target_mol = target["outputs"][idx]
                instruction = data["Instruction"][idx]
                if mol_prop(target_mol, "validity"):
                    valid_molecules.append(target_mol)
                    similarities.append(calculate_similarity(raw, target_mol))
                    if "lower" in instruction or "decrease" in instruction:
                        if mol_prop(target_mol, "qed") < mol_prop(raw, "qed"):
                            successed.append(1)
                        else:
                            successed.append(0)
                    else:
                        if mol_prop(target_mol, "qed") > mol_prop(raw, "qed"):
                            successed.append(1)
                        else:
                            successed.append(0)
                else:
                    successed.append(0)
            logger.info("=== Evaluating %s ===", args.subtask)
            logger.info("Success Rate: %f", sum(successed) / len(successed))
            logger.info("Similarity: %f", sum(similarities) / len(similarities))
            logger.info("Validty: %f", len(valid_molecules) / len(data))
            #获取flags中为1的index
            index_true = [i for i, flag in enumerate(successed) if flag == 1]
            logger.info("index_true: %s", index_true)
elif args.benchmark == "targeted_generation":
    pass
else:
    raise ValueError("Invalid Benchmark Type")