import math
import os
import pickle
import time
from pathlib import Path

import fsspec
import hydra
import lightning as L
import omegaconf
import rich.syntax
import rich.tree
import torch
from rdkit import Chem
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import shutil

import src.graph_diffusion as graph_diffusion
from src import dataloader
from src.constants import COORDS_STD
from src.utils import builder_utils, indexing_utils, misc_utils, spatial_utils
from shepherd_score.score.gaussian_overlap import get_overlap, get_overlap_batch
from shepherd_score.alignment import crippen_align
from rdkit.Chem.rdShapeHelpers import ShapeTverskyIndex, ShapeTanimotoDist, ShapeProtrudeDist

omegaconf.OmegaConf.register_new_resolver("cwd", os.getcwd)
omegaconf.OmegaConf.register_new_resolver("device_count", torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver("eval", eval)
omegaconf.OmegaConf.register_new_resolver("div_up", lambda x, y: (x + y - 1) // y)

## profiling
from lightning.pytorch.profilers import PyTorchProfiler

torch._dynamo.config.suppress_errors = True


def _load_from_checkpoint(config: omegaconf.DictConfig):
    if "hf" in config.backbone:
        return graph_diffusion.Diffusion(config).to("cuda")

    return graph_diffusion.Diffusion.load_from_checkpoint(
        config.eval.checkpoint_path, config=config
    )


@L.pytorch.utilities.rank_zero_only
def _print_config(
    config: omegaconf.DictConfig, resolve: bool = True, save_cfg: bool = True
) -> None:
    """Prints content of DictConfig using Rich library and its tree structure.

    Args:
      config (DictConfig): Configuration composed by Hydra.
      resolve (bool): Whether to resolve reference fields of DictConfig.
      save_cfg (bool): Whether to save the configuration tree to a file.
    """

    style = "dim"
    tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)

    fields = config.keys()
    for field in fields:
        branch = tree.add(field, style=style, guide_style=style)

        config_section = config.get(field)
        branch_content = str(config_section)
        if isinstance(config_section, omegaconf.DictConfig):
            branch_content = omegaconf.OmegaConf.to_yaml(config_section, resolve=resolve)

        branch.add(rich.syntax.Syntax(branch_content, "yaml"))
    rich.print(tree)
    if save_cfg:
        with fsspec.open("{}/config_tree.txt".format(config.checkpointing.save_dir), "w") as fp:
            rich.print(tree, file=fp)


def generate_samples(config, logger):
    logger.info("Generating samples.")
    model = _load_from_checkpoint(config=config)
    if config.eval.disable_ema:
        logger.info("Disabling EMA.")
        model.ema = None

    # Initialize lists to store all data
    all_nodes = []
    all_edges = []
    all_smiles = []
    all_lengths = []
    all_energies = []
    if config.sampling.spatial.guidance != "None" or config.spatial.pharmacophore_conditioning:
        all_gaussian_overlaps = []
        all_shape_tanimoto_dists = []
        all_shape_protrude_dists = []

    pharm_cond = None
    if config.spatial.pharmacophore_conditioning:
        pharm_cond_mol = Chem.MolFromMolFile(config.spatial.pharm_cond_mol)
        types, pos, pharm_padding_mask = spatial_utils.mol_to_pharm_cond(pharm_cond_mol, n_subset=config.spatial.pharmacophore_subset, bs=config.loader.eval_batch_size)
        pharm_cond = (types, pos, pharm_padding_mask)

    start_time = time.time()
    batch_size = config.loader.eval_batch_size
    successful_count = 0  # Counter for successful molecules
    for batch_idx in range(config.sampling.num_sample_batches):
        nodes, edges, coords, coords_true, node_mask, edge_mask = model.restore_model_and_sample(
            num_steps=config.sampling.steps,
            cond=pharm_cond
        )
        for i in range(nodes.shape[0]):
            node_onehot = nodes[i]
            edge_onehot = edges[i]

            # remove padding
            length = node_mask[i].sum()
            node_onehot, edge_onehot = indexing_utils.remove_graph_padding(
                node_onehot, edge_onehot, length
            )
            # Assert edge matrix is symmetric
            assert torch.allclose(
                edge_onehot, edge_onehot.transpose(-3, -2)
            ), "Edge matrix is not symmetric"

            coords_trimmed = indexing_utils.remove_coords_padding(coords[i], length)
            if config.spatial.normalize:
                coords_trimmed = coords_trimmed * COORDS_STD

            # convert onehot to node and edge indices
            node = indexing_utils.onehot_to_node_indices(node_onehot)
            edge = indexing_utils.onehot_to_reaction_type_and_centers(edge_onehot)

            # build molecule
            built_mol = builder_utils.build_molecule(node, edge, smiles=False)

            if built_mol:
                # Load conditional molecule as RDKit mol object with conformers
                if config.sampling.spatial.guidance != "None" or config.spatial.pharmacophore_conditioning:
                    if config.sampling.spatial.guidance != "None":
                        cond_mol_path = config.sampling.spatial.cond_mol_path
                    else:
                        cond_mol_path = config.spatial.pharm_cond_mol
                    cond_mol_rdkit = Chem.MolFromMolFile(cond_mol_path)
                    cond_mol_coords = spatial_utils.sdf_to_coordinates(cond_mol_path)
                    cond_mol_coords = cond_mol_coords - cond_mol_coords.mean(dim=0, keepdim=True)
                    conformer = cond_mol_rdkit.GetConformer()
                    for i in range(cond_mol_coords.shape[0]):
                        conformer.SetAtomPosition(i, cond_mol_coords[i].tolist())
                
                built_smiles = Chem.MolToSmiles(built_mol)
                all_smiles.append(built_smiles)
                all_nodes.append(node.cpu().numpy())
                all_edges.append(edge.cpu().numpy())
                all_lengths.append(length.cpu().numpy())

                # Save conformer and calculate energy
                Path(config.sampling.conf_out_dir).mkdir(parents=True, exist_ok=True)
                mol_with_conf = spatial_utils.coordinates_to_mol(
                    node_onehot, edge_onehot, coords_trimmed
                )
                
                # Calculate all similarity scores
                if config.sampling.spatial.guidance != "None" or config.spatial.pharmacophore_conditioning:
                    try:
                        # Gaussian overlap (existing method using coordinates)
                        # Set conformers for both molecules
                        built_mol_coords = mol_with_conf.GetConformer().GetPositions()
                        
                        # set conformer for conditional molecule
                        cond_conformer = cond_mol_rdkit.GetConformer()
                        for atom_idx in range(cond_mol_coords.shape[0]):
                            cond_conformer.SetAtomPosition(atom_idx, cond_mol_coords[atom_idx].tolist())
                        
                        # Align built molecule to conditional molecule before scoring
                        mol_with_conf_aligned = crippen_align(cond_mol_rdkit, mol_with_conf)                    
                        # Get aligned coordinates
                        aligned_conformer = mol_with_conf_aligned.GetConformer()
                        aligned_coords = torch.tensor([list(aligned_conformer.GetAtomPosition(i)) 
                                                    for i in range(mol_with_conf_aligned.GetNumAtoms())], device="cuda")
                        built_mol_coords = aligned_coords
                        cond_mol_coords = cond_mol_coords.to(built_mol_coords.device).to(built_mol_coords.dtype)
                        gaussian_overlap = get_overlap(built_mol_coords, cond_mol_coords)
                        all_gaussian_overlaps.append(gaussian_overlap.item())
                        print(f"Gaussian overlap: {gaussian_overlap:.4f}")
                        
                        # Shape-based similarity scores (using RDKit mol objects with conformers)
                        tanimoto_dist = ShapeTanimotoDist(mol_with_conf_aligned, cond_mol_rdkit)
                        protrude_dist = ShapeProtrudeDist(mol_with_conf_aligned, cond_mol_rdkit)
                        
                        all_shape_tanimoto_dists.append(tanimoto_dist)
                        all_shape_protrude_dists.append(protrude_dist)
                        
                        print(f"Scores - Gaussian: {gaussian_overlap:.4f}, "
                            f"Tanimoto Dist: {tanimoto_dist:.4f}, Protrude Dist: {protrude_dist:.4f}")
                        
                    except Exception as e:
                        print(f"Error calculating similarity scores: {e}")
                        all_gaussian_overlaps.append(float('nan'))
                        all_shape_tanimoto_dists.append(float('nan'))
                        all_shape_protrude_dists.append(float('nan'))

                elif config.spatial.pharmacophore_conditioning:
                    mol_with_conf_aligned = crippen_align(pharm_cond_mol, Chem.AddHs(mol_with_conf))
                
                else:
                    mol_with_conf_aligned = mol_with_conf

                spatial_utils.save_as_sdf(
                    mol_with_conf_aligned,
                    f"{config.sampling.conf_out_dir}/sample_{successful_count}.sdf",
                    #f"{config.sampling.conf_out_dir}/sample_{batch_idx}_{i}.sdf"
                )
                # Chem.SanitizeMol(mol_with_conf)
                # mol_with_conf = Chem.AddHs(mol_with_conf, addCoords=True)
                energy = spatial_utils.calc_energy(mol_with_conf_aligned)
                successful_count += 1  # Only increment for successful saves
                if energy:
                    all_energies.append(energy)
                else:
                    all_energies.append(float("nan"))

    # Create histograms and CSV files for all metrics
    all_scores = None
    highest_overlap_mean = None
    if config.sampling.spatial.guidance != "None" or config.spatial.pharmacophore_conditioning:
        all_scores = {
            'gaussian_overlap': all_gaussian_overlaps,
            'shape_tanimoto_dist': all_shape_tanimoto_dists,
            'shape_protrude_dist': all_shape_protrude_dists
        }

        #indices of highest gaussian overlap
        all_gaussian_overlaps = np.array(all_gaussian_overlaps)
        highest_overlap_indices = np.argsort(all_gaussian_overlaps)[-50:][::-1]
        highest_overlap_mean = all_gaussian_overlaps[highest_overlap_indices].mean()
        #make the best samples directory
        Path(config.sampling.conf_out_dir + "/best_samples").mkdir(parents=True, exist_ok=True)
        for rank_idx, idx in enumerate(highest_overlap_indices):
            shutil.copy(f"{config.sampling.conf_out_dir}/sample_{idx}.sdf", f"{config.sampling.conf_out_dir}/best_samples/sample_highest_gaussian_overlap_{rank_idx}.sdf")
        

    if all_scores:
        # Convert batch_rewards to DataFrame for easier processing
        df_scores = pd.DataFrame(all_scores)
        
        # Calculate summary statistics for each metric
        for metric in df_scores.columns:
            valid_scores = df_scores[metric].dropna()
            if len(valid_scores) > 0:
                mean_score = np.mean(valid_scores)
                median_score = np.median(valid_scores)
                std_score = np.std(valid_scores)
                min_score = np.min(valid_scores)
                max_score = np.max(valid_scores)
                
                print(f"{metric} stats - Mean: {mean_score:.4f}, Median: {median_score:.4f}, "
                      f"Std: {std_score:.4f}, Min: {min_score:.4f}, Max: {max_score:.4f}")

        # Create output directory if it doesn't exist
        Path(config.sampling.conf_out_dir).mkdir(parents=True, exist_ok=True)
        
        # Save individual scores to CSV files (one for each metric)
        for metric in df_scores.columns:
            df_metric = pd.DataFrame({metric: df_scores[metric]})
            df_metric.to_csv(f"{config.sampling.conf_out_dir}/{metric}_scores.csv", index=False)
        
        # Save all scores in one comprehensive CSV
        df_scores.to_csv(f"{config.sampling.conf_out_dir}/all_similarity_scores.csv", index=False)
        
        # Create histograms for each metric
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        axes = axes.flatten()
        
        for i, metric in enumerate(df_scores.columns):
            valid_scores = df_scores[metric].dropna()
            if len(valid_scores) > 0:
                axes[i].hist(valid_scores, bins=30, edgecolor='black', alpha=0.7)
                axes[i].set_title(f'Distribution of {metric.replace("_", " ").title()}')
                axes[i].set_xlabel(f'{metric} Value')
                axes[i].set_ylabel('Frequency')
        
        plt.tight_layout()
        plt.savefig(f"{config.sampling.conf_out_dir}/all_similarity_histograms.png", dpi=300, bbox_inches='tight')
        plt.close()
        
    else:
        print("No similarity scores to plot")
    # Save metadata
    assert (
        len(all_nodes) == len(all_edges) == len(all_smiles) == len(all_lengths) == successful_count
    )
    all_energies_to_print = [e for e in all_energies if not math.isnan(e)]
    metadata = {
        "nodes": all_nodes,
        "edges": all_edges,
        "smiles": all_smiles,
        "lengths": all_lengths,
        "similarity_scores": all_scores,  # Now contains all similarity metrics
        "energy_median": sorted(all_energies_to_print)[len(all_energies_to_print) // 2],
        "highest_overlap_mean": highest_overlap_mean
    }

    print("Overlap top 50: ", highest_overlap_mean)

    with open(f"{config.sampling.conf_out_dir}/metadata.pkl", "wb") as f:
        pickle.dump(metadata, f)

    unique_smiles = len(set(all_smiles))
    total_smiles = len(all_smiles)
    elapsed_time = time.time() - start_time
    logger.info(
        f"Generated {total_smiles} total SMILES samples ({unique_smiles} unique) in {elapsed_time:.2f} seconds"
    )
    if all_energies:
        num_nans = len(all_energies) - len(all_energies_to_print)
        mean_energy = sum(all_energies_to_print) / len(all_energies_to_print)
        median_energy = sorted(all_energies_to_print)[len(all_energies_to_print) // 2]
        min_energy = min(all_energies_to_print)
        max_energy = max(all_energies_to_print)
        logger.info(
            f"Energy stats (kcal/mol) - Mean: {mean_energy:.2f}, Median: {median_energy:.2f}, Min: {min_energy:.2f}, Max: {max_energy:.2f}, Nans: {num_nans}"
        )

    print(all_nodes[0])
    print(all_edges[0])
    print(all_smiles[0])
    print(all_lengths[0])
    print(all_energies[0])
    return all_smiles


# not tested
def _ppl_eval(config, logger):
    logger.info("Starting Zero Shot Eval.")

    model = _load_from_checkpoint(config=config)
    if config.eval.disable_ema:
        logger.info("Disabling EMA.")
        model.ema = None

    wandb_logger = None
    if config.get("wandb", None) is not None:
        wandb_logger = L.pytorch.loggers.WandbLogger(
            config=omegaconf.OmegaConf.to_object(config), **config.wandb
        )
    callbacks = []
    if "callbacks" in config:
        for _, callback in config.callbacks.items():
            callbacks.append(hydra.utils.instantiate(callback))
    trainer = hydra.utils.instantiate(
        config.trainer,
        default_root_dir=os.getcwd(),
        callbacks=callbacks,
        strategy=hydra.utils.instantiate(config.strategy),
        logger=wandb_logger,
    )
    _, valid_ds = dataloader.get_dataloaders(config, skip_train=True, valid_seed=config.seed)
    trainer.validate(model, valid_ds)


def _train(config, logger):
    logger.info("Starting Training.")
    wandb_logger = None
    if config.get("wandb", None) is not None:
        wandb_logger = L.pytorch.loggers.WandbLogger(
            config=omegaconf.OmegaConf.to_object(config), **config.wandb
        )

    if (
        config.checkpointing.resume_from_ckpt
        and config.checkpointing.resume_ckpt_path is not None
        and misc_utils.fsspec_exists(config.checkpointing.resume_ckpt_path)
    ):
        ckpt_path = config.checkpointing.resume_ckpt_path
    else:
        ckpt_path = None

    # Lightning callbacks
    callbacks = []
    if "callbacks" in config:
        for _, callback in config.callbacks.items():
            callbacks.append(hydra.utils.instantiate(callback))

    train_ds, valid_ds = dataloader.get_dataloaders(config)
    # Get a batch from the training dataset
    # train_batch = next(iter(train_ds))
    # print(train_batch)
    model = graph_diffusion.Diffusion(config)

    trainer = hydra.utils.instantiate(
        config.trainer,
        default_root_dir=os.getcwd(),
        callbacks=callbacks,
        strategy=hydra.utils.instantiate(config.strategy),
        logger=wandb_logger,
    )
    trainer.fit(model, train_ds, valid_ds, ckpt_path=ckpt_path)


@hydra.main(version_base=None, config_path="configs", config_name="config")
def main(config):
    """Main entry point for training."""
    L.seed_everything(config.seed)
    if "experiment" in config:
        # Merge experiment keys into config. The experiment config will override the base.
        config = omegaconf.OmegaConf.merge(config, config.experiment)
        # Optionally, remove the experiment branch to clean up the final config:
        del config.experiment
    _print_config(config, resolve=True, save_cfg=True)

    logger = misc_utils.get_logger(__name__)
    if config.mode == "sample_eval":
        generate_samples(config, logger)
    elif config.mode == "ppl_eval":
        _ppl_eval(config, logger)
    else:
        _train(config, logger)


if __name__ == "__main__":
    main()
