import os
import sys
import numpy as np
import yaml
from similarity_clustering import cluster
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import QED
from rdkit.Chem import DataStructs
from scipy.stats import ttest_ind
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score

from rdkit.Chem import RDConfig
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
        
def analyze_results(run_name, limit=False, llm_only=True):
    print("RUN: " + run_name)
    
    llm_ligands = []
    num_errors = 0
    if llm_only:
        with open(f"single_objective/logs/{run_name}.txt", 'r') as log_file:
            for line in log_file:
                if "1000/1000" in line:
                    break
                if "LLM-GENERATED:" in line:
                    ligand = line.split()[1].strip()
                    llm_ligands.append(ligand)
                if "NUM LLM ERRORS" in line:
                    num_errors += 1
   
    cmet = []
    ligands = {}
    with open(f"single_objective/results/{run_name}.yaml", 'r') as file:
        data = yaml.safe_load(file)
        for ligand, values in data.items():
            if limit and int(values[1]) > 250:
                continue
            if int(values[1]) <= 120:
                cmet.append(ligand)
            else:
                if (llm_only is False or ligand in llm_ligands) and float(values[0])!=0:
                    ligands[ligand] = -float(values[0])
                
    sorted_ligands = sorted(ligands, key=ligands.get)
    print(len(cmet))
    print(len(sorted_ligands))
    best_10 = []
    for i in sorted_ligands[:10]:
        best_10.append(ligands[i])
    
    c = cluster(sorted_ligands)
    c = sorted(c, key=ligands.get)
    best_10_cluster = []
    for i in c[:10]:
        best_10_cluster.append(ligands[i])
    print("AVG TOP TEN: " + str(np.mean(best_10)))
    print("AVG TOP TEN (CLUSTERED): " + str(np.mean(best_10_cluster)))
    print("BEST: " + str(min(best_10_cluster)))
    print("STDEV TOP 10 (CLUSTERED): " + str(np.std(best_10_cluster)))
    print("BEST 10 LIGANDS (CLUSTERED):")
    qed = []
    sim = []
    sa = []
    num_better_than_threshold = 0
    threshold = -11
    unique = []
    for idx, ligand in enumerate(c):
        if idx < 10:
            mol = Chem.MolFromSmiles(ligand)
            qed_score = QED.qed(mol)
            qed.append(qed_score)
            sa_score = sascorer.calculateScore(mol)
            sa.append(sa_score)
            
            morgan = AllChem.GetMorganGenerator(radius=2, fpSize=512)
            fingerprint = morgan.GetFingerprint(mol)
            max_sim = 0
            sim_ligand = ""
            for cmet_ligand in cmet:
                cmet_mol = Chem.MolFromSmiles(cmet_ligand)
                cmet_fingerprint = morgan.GetFingerprint(cmet_mol)
                similarity = DataStructs.TanimotoSimilarity(fingerprint, cmet_fingerprint)
                
                if similarity > max_sim:
                    max_sim = similarity
                    sim_ligand = cmet_ligand
            sim.append(max_sim)
            if ligands[ligand] < threshold:
                num_better_than_threshold += 1
            if max_sim < 0.5:
                unique.append(ligands[ligand])
            print(ligand)
    print("AVG QED (clustered): " + str(np.mean(qed)))
    print("AVG SA (clustered): " + str(np.mean(sa)))
    print("STDEV QED: " + str(np.std(qed)))
    print("AVG MAX SIM: " + str(np.mean(sim)))
    print("NUMBER OF LLM ERRORS: " + str(num_errors))
    print("NUM BETTER THAN THRESHOLD: " + str(num_better_than_threshold))
    print("UNIQUE GENERATION MEAN: " + str(np.mean(sorted(unique)[:10])))
    print("UNIQUE GENERATION STD: " + str(np.std(sorted(unique)[:10])))
    
    return best_10_cluster

values1 = analyze_results("GPT-4_c-met_docking", limit=False, llm_only=False)
values2 = analyze_results("custom_c-met_original_prompt_30k", limit=False, llm_only=True)
_, p = ttest_ind(values1, values2, alternative="two-sided", equal_var=False)
print(p)
