import os 
import numpy as np
import pandas as pd
import torch
import argparse
import sys
import os
from tqdm import tqdm

from rdkit.Chem.rdchem import ResonanceMolSupplier
from foundation_models import (
    MolFormerRegressor,
    T5Regressor,
)
from foundation_models import (
    get_molformer_tokenizer,
    get_t5_tokenizer,
)
from foundation_models.smi_ted import load_smi_ted
from problems.data_processor import (
    RedoxDataProcessor,
    SolvationDataProcessor,
    KinaseDockingDataProcessor,
    LaserEmitterDataProcessor,
    PhotovoltaicsPCEDataProcessor,
    PhotoswitchDataProcessor,
    AmpCDockingDataProcessor,
    D4DockingDataProcessor,
)
from problems.prompting import PromptBuilder
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from rdkit.Chem import AllChem


from utils import helpers
from utils.configs import LLMFeatureType

from mordred import Calculator, descriptors, error

parser = argparse.ArgumentParser()
parser.add_argument(
    "--feature_type",
    choices=[
        "fingerprints",
        "molformer",
        "t5-base-chem",
        "mordred",
        "degree_of_conjugation",
        "force_field",
        "dft", 
    ],
    default="t5-base-chem",
)
parser.add_argument(
    "--feature_reduction", choices=["default", "average"], default="average"
)
parser.add_argument(
    "--prompt_type",
    choices=["single-number", "just-smiles", "naive", "completion"],
    default="just-smiles",
)
parser.add_argument(
    "--problem",
    choices=["redox-mer", "solvation", "kinase", "laser", "pce", "photoswitch", "ampc", "d4"],
    default="redox-mer",
)
args = parser.parse_args()

if args.problem == "redox-mer":
    dataset = pd.read_csv("data/redox_mer_with_iupac.csv.gz")
    OBJ_COL = "Ered"
    SMILES_COL = "SMILES"
    MAXIMIZATION = False
elif args.problem == "solvation":
    dataset = pd.read_csv("data/redox_mer_with_iupac.csv.gz")
    OBJ_COL = "Gsol"
    SMILES_COL = "SMILES"
    MAXIMIZATION = False
elif args.problem == "kinase":
    dataset = pd.read_csv("data/enamine10k.csv.gz")
    OBJ_COL = "score"
    SMILES_COL = "SMILES"
    MAXIMIZATION = False
elif args.problem == "laser":
    dataset = pd.read_csv("data/laser_multi10k.csv.gz")
    OBJ_COL = "Fluorescence Oscillator Strength"
    SMILES_COL = "SMILES"
    MAXIMIZATION = True
elif args.problem == "pce":
    dataset = pd.read_csv("data/photovoltaics_pce10k.csv.gz")
    OBJ_COL = "pce"
    SMILES_COL = "SMILES"
    MAXIMIZATION = True
elif args.problem == "photoswitch":
    dataset = pd.read_csv("data/photoswitches.csv.gz")
    OBJ_COL = "Pi-Pi* Transition Wavelength"        
    SMILES_COL = "SMILES"
    MAXIMIZATION = True
elif args.problem == "ampc":
    dataset = pd.read_csv("data/Zinc_AmpC_Docking_filtered.csv.gz")
    OBJ_COL = "dockscore"        
    SMILES_COL = "SMILES"
    MAXIMIZATION = False
elif args.problem == "d4":
    dataset = pd.read_csv("data/Zinc_D4_Docking_filtered.csv.gz")
    OBJ_COL = "dockscore"        
    SMILES_COL = "SMILES"
    MAXIMIZATION = False
else:
    print("Invalid test function!")
    sys.exit(1)

if args.feature_type == "fingerprints":
    
    rdkgen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=1024) 
    features = [
        torch.tensor(
            rdkgen.GetFingerprintAsNumPy(Chem.MolFromSmiles(mol))
        ).float()
        for mol in tqdm(dataset[SMILES_COL], file=sys.stdout, dynamic_ncols=True, mininterval=0)
    ]
    targets = list(
        helpers.y_transform(torch.tensor(dataset[OBJ_COL].to_numpy()), MAXIMIZATION)
        .unsqueeze(-1)
        .float()
    )
elif args.feature_type == "degree_of_conjugation":
    def are_neighbor_bonds_conjugated(
        suppl, conj_dict: dict, bond, traversed_bonds=[]
    ) -> dict:
        bond_idx = bond.GetIdx()
        conj_grp_idx = suppl.GetBondConjGrpIdx(bond_idx)
        if (
            bond.GetIdx() not in traversed_bonds and bond.GetIsConjugated()
        ):  # does not add to conjugated dictionary if bond is not conjugated
            traversed_bonds.append(bond.GetIdx())
            if conj_grp_idx in conj_dict:
                conj_dict[conj_grp_idx] += 1
            else:
                conj_dict[conj_grp_idx] = 1
        else:
            traversed_bonds.append(bond.GetIdx())
        end_atom = bond.GetEndAtom()
        next_bonds = end_atom.GetBonds()
        # remove the bond that was used to get to the next bond
        next_bonds = [bond for bond in next_bonds if bond.GetIdx() not in traversed_bonds]

        if len(next_bonds) == 0:
            return conj_dict
        else:
            for bond in next_bonds:
                conj_dict = are_neighbor_bonds_conjugated(
                    suppl, conj_dict, bond, traversed_bonds
                )
            return conj_dict
        
    def max_num_of_bonds_with_conjugation(smile):
        mol = Chem.MolFromSmiles(smile)
        suppl = ResonanceMolSupplier(mol)
        start_bond = mol.GetBondWithIdx(0)
        conj_dict = {}
        conj_dict = are_neighbor_bonds_conjugated(
            suppl, conj_dict, start_bond, traversed_bonds=[]
        )
        try:
            max_degree = max(conj_dict.values())
        except:
            max_degree = 0

        return max_degree
    
    features =[
        torch.tensor(max_num_of_bonds_with_conjugation(smile)).float().reshape(-1)
        for smile in tqdm(dataset[SMILES_COL], file=sys.stdout, dynamic_ncols=True, mininterval=0)
    ]
    
    targets = list(
        helpers.y_transform(torch.tensor(dataset[OBJ_COL].to_numpy()), MAXIMIZATION)
        .unsqueeze(-1)
        .float()
    )
elif args.feature_type == "force_field":
    def get_force_field_features(smiles):

        mol = Chem.MolFromSmiles(smiles)
        mol = Chem.AddHs(mol)

        # Generate a 3D conformer
        # The argument is the number of conformers to generate
        status = AllChem.EmbedMolecule(mol, randomSeed=42, maxAttempts=100, useRandomCoords=True, enforceChirality=False)
        
        if status != 0:
            print(smiles)
            print("Failed to generate conformer.")
            
        # Assign stereochemistry
        Chem.AssignStereochemistry(mol, force=True, cleanIt=True)

        # Optionally, you can optimize the geometry
        try:
            AllChem.MMFFOptimizeMolecule(mol, maxIters=1000)
        except:
            print("Failed to optimize molecule.")
    
        # Correctly get molecule properties
        mol_properties = AllChem.MMFFGetMoleculeProperties(mol)
        
        if mol_properties is None:
            print("Failed to get molecule properties.")
            force_field = AllChem.UFFGetMoleculeForceField(mol) 
        else:
            # Create the force field
            force_field = AllChem.MMFFGetMoleculeForceField(mol, mol_properties)

        # Opimize 
        force_field.Minimize() 
        energy = force_field.CalcEnergy()
        #gradients = force_field.CalcGrad()
        return energy 
    features = [
        torch.tensor(
            [get_force_field_features(mol)]
        ).float()
        for mol in tqdm(dataset[SMILES_COL], file=sys.stdout, dynamic_ncols=True, mininterval=0)
    ]
    targets = list(
        helpers.y_transform(torch.tensor(dataset[OBJ_COL].to_numpy()), MAXIMIZATION)
        .unsqueeze(-1)
        .float()
    )
# force field failed on: all
elif args.feature_type == "mordred":
    # Initialize Mordred calculator
    calc = Calculator(descriptors, ignore_3D=True)

    # Process molecules
    features = []
    for mol_smiles in tqdm(dataset[SMILES_COL], file=sys.stdout, dynamic_ncols=True, mininterval=0):
        mol = Chem.MolFromSmiles(mol_smiles)
        descriptors = []
        for desc in calc(mol):
            if isinstance(desc, error.Error):
                descriptors.append(np.nan)  # Replace errors with NaN
            else:
                descriptors.append(desc)
        features.append(descriptors)

    features = np.array(features).astype(np.float64)

    # Convert to torch tensors, dropping NaN-containing columns
    valid_columns = ~np.isnan(features).any(axis=0)  # Mask for columns without NaN
    features = features[:, valid_columns]
    features = [torch.tensor(feature).float() for feature in features]
    
    targets = list(
        helpers.y_transform(torch.tensor(dataset[OBJ_COL].to_numpy()), MAXIMIZATION)
        .unsqueeze(-1)
        .float()
    )
elif args.feature_type == "dft":
    dft_descriptors = ["alpha", "cv", "g298", "gap", "h298", "homo", "lumo", "mu", "r2", "u0", "u298", "zpve"]
    features = []
    for dft in tqdm(dft_descriptors, file=sys.stdout, dynamic_ncols=True, mininterval=0):
        model = load_smi_ted(
            folder='smi_ted',
            ckpt_filename=f'smi-ted-Light-Finetune_seed0_qm9_{dft}.pt',
        ).cuda()
        dft_feature = []
        with torch.no_grad():
            for smile in tqdm(dataset[SMILES_COL].to_list(), file=sys.stdout, dynamic_ncols=True, mininterval=0):
                embedding = model.extract_embeddings(smile).cuda().reshape(-1)
                outputs = model.net(embedding).detach().float() # (16, 1)
                dft_feature.append(outputs.cpu())
        features.append(torch.concatenate(dft_feature, dim=0)) # (n_samples)
    features = torch.stack(features, dim=-1) # (n_samples, n_dfts)
    features = [feature for feature in features]

    targets = list(
        helpers.y_transform(torch.tensor(dataset[OBJ_COL].to_numpy()), MAXIMIZATION)
        .unsqueeze(-1)
        .float()
    )
else:  # LLM & MolFormer features
    if args.feature_type == "molformer":
        tokenizer = get_molformer_tokenizer()
    elif args.feature_type == "t5-base-chem":
        foundation_model_real = "GT4SD/multitask-text-and-chemistry-t5-base-augm"
        tokenizer = get_t5_tokenizer(foundation_model_real)
    
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        # print(tokenizer.pad_token, tokenizer.pad_token_id)

    DEFAULT_REDUCTIONS = {
        "t5-base-chem": LLMFeatureType.LAST_TOKEN,
        "molformer": None,
    }
    reduction = (
        DEFAULT_REDUCTIONS[args.feature_type]
        if args.feature_reduction == "default"
        else LLMFeatureType.AVERAGE
    )

    if args.feature_type == "molformer":
        llm_feat_extractor = MolFormerRegressor(tokenizer)
    elif args.feature_type == "t5-base-chem":
        llm_feat_extractor = T5Regressor(
            kind="GT4SD/multitask-text-and-chemistry-t5-base-augm",
            tokenizer=tokenizer,
            reduction=reduction,
        )
    else:
        raise NotImplementedError  # TO-DO!

    llm_feat_extractor.cuda()
    llm_feat_extractor.eval()
    llm_feat_extractor.freeze_params()

    prompt_builder = PromptBuilder(kind=args.prompt_type)
    DATA_PROCESSORS = {
        "redox-mer": RedoxDataProcessor,
        "solvation": SolvationDataProcessor,
        "kinase": KinaseDockingDataProcessor,
        "laser": LaserEmitterDataProcessor,
        "pce": PhotovoltaicsPCEDataProcessor,
        "photoswitch": PhotoswitchDataProcessor,
        "ampc": AmpCDockingDataProcessor, #
        "d4": D4DockingDataProcessor,
    }
    data_processor = DATA_PROCESSORS[args.problem](prompt_builder, tokenizer)
    append_eos = args.feature_type != "molformer" and ("t5" not in args.feature_type)
    dataloader = data_processor.get_dataloader(
        dataset, shuffle=False, append_eos=append_eos
    )

    features, targets = [], []
    for data in tqdm(dataloader, file=sys.stdout, dynamic_ncols=True, mininterval=0):
        with torch.no_grad():
            feat = llm_feat_extractor.forward_features(data)

        features += list(feat.cpu())
        targets += list(helpers.y_transform(data["labels"], MAXIMIZATION))

# Save to files
cache_path = f"data/cache/{args.problem}/"
if not os.path.exists(cache_path):
    os.makedirs(cache_path)

torch.save(features, cache_path + f"{args.feature_type}_feats.bin")
torch.save(targets, cache_path + f"{args.feature_type}_targets.bin")
