import json
import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import Sequence, List
from rdchiral.initialization import rdchiralReactants
from rdchiral.main import rdchiralRun
from rdchiral.initialization import rdchiralReaction
from rdkit import Chem
from rdkit.Chem import AllChem

#from syntheseus.reaction_prediction.inference.base import ExternalBackwardReactionModel
from syntheseus import BackwardReactionModel
from syntheseus import Bag, Molecule, SingleProductReaction


def smiles_to_fp(smiles, fp_size=2048):
    """
    Convert a SMILES string to a fingerprint.
    Args:
        smiles (str): SMILES string
    Returns:
        np.array: fingerprint of the SMILES
    """
    mol = Chem.MolFromSmiles(smiles)
    fp = AllChem.GetMorganFingerprintAsBitVect(
        mol, radius=2, nBits=fp_size, useChirality=True
    )
    fp = torch.tensor(fp, dtype=torch.uint8)
    return fp

def run_retro(product, template):
    """
    Run a reaction given the product and the template.
    Args:
        product (str): product
        template (str): template
    Returns:
        str: reactant SMILES string
    """
    reactants = template.split(">>")[0].split(".")
    if len(reactants) > 1:
        template = "(" + template.replace(">>", ")>>")
    template = rdchiralReaction(template)
    try:
        outputs = rdchiralRun(template, product)
    except Exception as e:
        print(f"Error {e} running retro reaction {template} on product {product}")
        return []
    result = []
    for output in outputs:
        result.append(output.split("."))
    return result

def get_activation(name: str) -> nn.Module:
    _activations = {
        "relu": nn.ReLU(),
        "elu": nn.ELU(),
        "gelu": nn.GELU(),
        "leakyrelu": nn.LeakyReLU(),
        "sigmoid": nn.Sigmoid(),
        "tanh": nn.Tanh(),
    }

    return _activations[name]

class Dense(nn.Module):
    def __init__(self, in_features: int, out_features: int, hidden_act: nn.Module):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=True)
        self.hidden_act = hidden_act

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.hidden_act(self.linear(x))
    
class TemplRel(nn.Module):
    def __init__(self, args):
        super().__init__()

        if isinstance(args.hidden_sizes, str):
            self.hidden_sizes = [int(size) for size in args.hidden_sizes.split(",")]

        self.args = args
        self.layers = self._build_layers(args)
        print(self.layers)
        self.output_layer = nn.Linear(
            self.hidden_sizes[-1], args.n_templates, bias=True
        )

        # we will do all the dropout here in TemplRel, for backward compatibility
        self.dropout = nn.Dropout(args.dropout)
        self.criterion = nn.CrossEntropyLoss(ignore_index=-1, reduction="mean")

    def _build_layers(self, args) -> nn.ModuleList:
        hidden_act = get_activation(args.hidden_activation)
        # input projection layer; no skip connection here
        layers = nn.ModuleList(
            [Dense(args.fp_size, self.hidden_sizes[0], hidden_act=hidden_act)]
        )

        for layer_i in range(len(self.hidden_sizes) - 1):
            in_features = self.hidden_sizes[layer_i]
            out_features = self.hidden_sizes[layer_i + 1]

            if args.skip_connection == "none":
                layer = Dense(in_features, out_features, hidden_act=hidden_act)
            else:
                raise ValueError(f"Unsupported skip_connection: {args.skip_connection}")

            layers.append(layer)

        return layers

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for layer in self.layers:
            x = layer(x)
            x = self.dropout(x)

        logits = self.output_layer(x)  # returning *unnormalized* logits

        return logits
    
class NeuralSymPredictor(BackwardReactionModel):
    """
    One-step retro predictor, which gives the highest scoring transformations
    for a given target.
    """

    def setup(self, model_path, templates_path):
        """
        Args:
            model_path (str): path to a trained model
            templates_path (str): path to the list of templates
        """
        # Load the templates
        with open(templates_path, "r") as f:
            template_dict = json.load(f)
        self.templates = {}
        for k, v in template_dict.items():
            self.templates[int(k)] = v
        # Load the model
        retro_checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
        pretrain_args = retro_checkpoint["args"]
        # Initialize both parent classes
        self.model = TemplRel(pretrain_args)
        state_dict = retro_checkpoint["state_dict"]
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        self.model.load_state_dict(state_dict)
        self.model.eval()
        print("Loaded retro model!")

    def _get_reactions_single(self, target, top_n=50):
        """
        Args:
            target (str): target molecule SMILES
            top_n (int): number of top scoring templates to return
        Returns:
            predictions (list): list of dictionaries of the format:
                {
                    "score" (float): softmax template score,
                    "template" (str): template SMILES string,
                    "reactants" (list): list of reactant SMILES strings
                }
        """
        # Convert the target SMILES to fingerprint
        target_fp = smiles_to_fp(target.smiles).float().unsqueeze(0)
        target_rd = rdchiralReactants(target.smiles)
        # Run the model
        with torch.no_grad():
            output = self.model(target_fp)
        probs = F.softmax(output, dim=1)
        top_scores, top_indices = torch.topk(probs, top_n)
        top_scores = top_scores.detach().numpy()[0]
        top_indices = top_indices.detach().numpy()[0]
        predictions = []
        for i in range(top_n):
            template = self.templates[top_indices[i]]
            try:
                pred_reactants = run_retro(target_rd, template)
                for output in pred_reactants:
                    predictions.append(
                        {
                            "rxn_smiles": ".".join(output) + ">>" + target.smiles,
                            "score": top_scores[i],
                            "template": template,
                            "reactants": output,
                        }
                    )
            except Exception as e:
                print(f"Issue applying template {template}:\n {e} \n target: {target.smiles}")
                continue

        # For each unique reactants, add their scores and templates together
        prec_to_score = {}
        prec_to_template = {}
        prec_counts = {}
        for i in range(len(predictions)):
            prec = frozenset(predictions[i]["reactants"])
            if prec in prec_to_score:
                prec_to_score[prec] += predictions[i]["score"]
                prec_to_template[prec].append(predictions[i]["template"]) # could we have the same templates duplicated for the same reactants?
                prec_counts[prec] += 1
            else:
                prec_to_score[prec] = predictions[i]["score"]
                prec_to_template[prec] = [predictions[i]["template"]]
                prec_counts[prec] = 1
        # Renormalize scores
        total_score = sum(prec_to_score.values())
        for prec in prec_to_score:
            prec_to_score[prec] /= total_score
        final_predictions = []
        for prec in prec_to_score:
            final_predictions.append(
                SingleProductReaction(
                    reactants=Bag([Molecule(smiles=r) for r in prec]),
                    product=target,
                    metadata={"probability": prec_to_score[prec]},
                )
            )
            # final_predictions.append(
            #     {   
            #         "rxn_smiles": ".".join(prec) + ">>" + target,
            #         "score": prec_to_score[prec],
            #         "template": prec_to_template[prec],
            #         "reactants": list(prec),
            #         "count": prec_counts[prec]
            #     }
            # )
        #return final_predictions
        return sorted(
            final_predictions,
            key=lambda r: r.metadata["probability"],
            reverse=True,
        )

    def _get_reactions(
        self, inputs: List[Molecule], num_results: int
    ) -> List[Sequence[SingleProductReaction]]:
        return [
            self._get_reactions_single(mol, top_n=num_results)[:num_results]
            for mol in inputs
        ]
