import argparse
import os
import random
from typing import List, Tuple

import networkx as nx
import torch
from tqdm import tqdm
from torch_geometric.data import Data

from rdkit import Chem
from rdkit.Chem import AllChem
from typing import List

from src.api.data_structures import MoleculeFragmentGraph
from src.constants import (
    BUILDING_BLOCKS,
    REACTIONS,
    N_BUILDING_BLOCKS,
    N_REACTIONS,
    N_CENTERS,
)
from src.utils import builder_utils, indexing_utils, xtb_utils


def load_compatibilities(bb1_path: str, bb2_path: str):
    """
    Load precomputed compatibility tensors and build a dictionary of valid partners.

    Args:
        bb1_path (str): Path to bb1_compatibilities.pt
        bb2_path (str): Path to bb2_compatibilities.pt

    Returns:
        bb1_compat (torch.Tensor)
        bb2_compat (torch.Tensor)
        VALID_PARTNERS (dict[int, list[tuple[int, int]]]):
            flat_idx -> list of (other_idx, other_cidx)
    """
    bb1_compat = torch.load(bb1_path)
    bb2_compat = torch.load(bb2_path)

    VALID_PARTNERS = {}
    for flat_idx in range(bb2_compat.shape[0]):
        other_idxs = bb2_compat[flat_idx].nonzero(as_tuple=True)[0].tolist()
        other_cidx = flat_idx % N_CENTERS
        # filter out padding column (>= N_BUILDING_BLOCKS)
        valid = [(oi, other_cidx) for oi in other_idxs if oi < N_BUILDING_BLOCKS]
        VALID_PARTNERS[flat_idx] = valid

    return bb1_compat, bb2_compat, VALID_PARTNERS


def enumerate_possible_actions(
    molecule: MoleculeFragmentGraph,
    bb1_compat: torch.Tensor,
    VALID_PARTNERS: dict,
):
    if molecule.num_fragments() == 0:
        return [[indexing_utils.smiles_to_idx(frag)] for frag in BUILDING_BLOCKS.keys()]

    results = []

    frags_and_availability = [
        (
            frag_order,
            molecule.fragment_graph.nodes[frag_order]["node"].smiles,
            molecule.fragment_graph.nodes[frag_order]["rxn_center_available"],
        )
        for frag_order in molecule.fragment_graph.nodes
    ]

    for frag_order, frag_smiles, center_availability in frags_and_availability:
        frag_idx = indexing_utils.smiles_to_idx(frag_smiles)

        for cidx, is_open in enumerate(center_availability):
            if not is_open:
                continue

            for rxn_idx in range(N_REACTIONS):
                flat_idx_base = rxn_idx * (N_CENTERS * N_CENTERS) + cidx * N_CENTERS

                # If frag is NOT compatible with this reaction/center combo, skip
                if bb1_compat[flat_idx_base, frag_idx] == 0:
                    continue

                # Directly use precomputed (other_idx, other_cidx) pairs
                for other_idx, oc in VALID_PARTNERS.get(flat_idx_base, []):
                    results.append(
                        [
                            other_idx,
                            rxn_idx,
                            frag_order,
                            cidx,
                            oc,
                        ]
                    )

    return results


def sample_random_molecules(
    n: int,
    length: List[int],
    bb1_compat: torch.Tensor,
    VALID_PARTNERS: dict,
    seed: int = None,
) -> List[Tuple[str, nx.Graph]]:
    """
    Sample random molecules by building the fragment graph using depth-first search.
    If `length` is a list, a random length is chosen for each molecule.
    """
    if seed is not None:
        random.seed(seed)

    def dfs(
        mfg: MoleculeFragmentGraph, target_length: int, actions_taken=None
    ) -> Tuple[List, MoleculeFragmentGraph]:
        if actions_taken is None:
            actions_taken = []

        if mfg.num_fragments() == target_length:
            return actions_taken, mfg

        possible_actions = enumerate_possible_actions(mfg, bb1_compat, VALID_PARTNERS)
        random.shuffle(possible_actions)

        for action in possible_actions:
            # new_mfg = MoleculeFragmentGraph()
            # for prev_action in actions_taken:
            #     new_mfg.add_fragment(*prev_action)
            new_mfg = mfg.copy()
            new_mfg.add_fragment(*action)

            final_actions, final_mfg = dfs(new_mfg, target_length, actions_taken + [action])
            if final_mfg is not None:
                return final_actions, final_mfg

        return None, None

    graph_data = []
    labeled_molecules = []
    seen_smiles = set()
    pbar = tqdm(total=n)

    while len(graph_data) < n:
        # Pick molecule length
        if isinstance(length, list):
            target_length = random.choice(length)
        else:
            target_length = length

        frag_indices = []
        final_actions, mfg = dfs(MoleculeFragmentGraph(), target_length)

        if mfg:
            reaction_info = []
            for idx, action in enumerate(final_actions):
                frag_indices.append(action[0])
                if idx > 0:
                    reaction_info.append((action[1], action[2], idx, action[3], action[4]))

            adj_matrix = torch.tensor(nx.to_numpy_array(mfg.fragment_graph))
            adj_matrix = torch.maximum(adj_matrix, adj_matrix.T)

            X = indexing_utils.node_indices_to_onehot(torch.tensor(frag_indices))
            E = indexing_utils.reaction_type_and_centers_to_onehot(
                adj_matrix, torch.tensor(reaction_info)
            )
            edge_index = adj_matrix.nonzero().t()
            edge_attr = E[edge_index[0], edge_index[1]]

            smiles, mol = mfg.to_smiles(), mfg.to_mol()
            if smiles not in seen_smiles and builder_utils.is_valid_smiles(mol) and not "." in smiles and max(frag_indices) >= 93:
                seen_smiles.add(smiles)
                sample_graph = Data(
                    x=X,
                    edge_index=edge_index,
                    edge_attr=edge_attr,
                    data_index=len(graph_data),
                    smiles=smiles,
                )
                graph_data.append(sample_graph)
                labeled_molecules.append(mol)
                pbar.update(1)
            else:
                print(
                    f"Failed to reconstruct smiles or duplicate for {smiles}, {X.argmax(dim=-1)}"
                )

    pbar.close()
    if args.save_conformers:
        xtb_utils.process_molecules(
            labeled_molecules,
            do_xtb=args.xtb,
            out_path=args.conformer_dir,
            num_conformers=args.num_conformers,
            rmsd_threshold=args.rmsd_threshold,
            energy_cutoff=args.energy_cutoff,
        )
    return graph_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Sample random molecules by building fragment graphs."
    )
    parser.add_argument(
        "-n", "--num_samples", type=int, default=100, help="Number of molecules to sample"
    )
    parser.add_argument(
        "-l", "--length",
        type=int,
        nargs="+",
        default=[2],
        help="Length(s) of molecules in fragments (default: 2). "
            "If multiple are provided, a random one will be chosen for each molecule."
    )
    parser.add_argument("--save_graphs", action="store_true", help="Save output as numpy arrays")
    parser.add_argument(
        "--output_dir", type=str, default="data/molecule_graphs", help="Directory to save molecules"
    )
    parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility")

    # Conformer options
    parser.add_argument("--save_conformers", action="store_true", help="Save conformers")
    parser.add_argument(
        "--num_conformers", type=int, default=50, help="Number of conformers to generate"
    )
    parser.add_argument(
        "--energy_cutoff", type=float, default=10.0, help="Energy cutoff (kcal/mol)"
    )
    parser.add_argument(
        "--rmsd_threshold", type=float, default=1.5, help="RMSD threshold for clustering"
    )
    parser.add_argument(
        "--conformer_dir", type=str, default="data/conformers", help="Directory to save conformers"
    )
    parser.add_argument(
        "--xtb", action="store_true", help="Run XTB optimization on molecules and save conformers"
    )

    # Paths to compatibility matrices
    parser.add_argument(
        "--bb1_path",
        type=str,
        default="vocabulary/compatibilities_plus20/bb1_compatibilities.pt",
        help="Path to bb1_compatibilities.pt",
    )
    parser.add_argument(
        "--bb2_path",
        type=str,
        default="vocabulary/compatibilities_plus20/bb2_compatibilities.pt",
        help="Path to bb2_compatibilities.pt",
    )

    args = parser.parse_args()

    # Make output dirs if needed
    if args.save_graphs:
        os.makedirs(args.output_dir, exist_ok=True)
    if args.save_conformers:
        os.makedirs(args.conformer_dir, exist_ok=True)

    # Load compatibilities
    bb1_compat, bb2_compat, VALID_PARTNERS = load_compatibilities(args.bb1_path, args.bb2_path)

    # Generate molecules
    molecules = sample_random_molecules(
        n=args.num_samples,
        length=args.length,
        bb1_compat=bb1_compat,
        VALID_PARTNERS=VALID_PARTNERS,
        seed=args.seed,
    )

    # used_bbs = set()
    # for molecule in molecules:
    #     bbs_mol = list(torch.argmax(molecule.x, dim=-1).numpy())
    #     print(molecule.smiles, bbs_mol)
    #     used_bbs.update(bbs_mol)
    # #     print(torch.argmax(molecule.x, dim=-1))
    # print(len(used_bbs))
    # print("UNUSED: ")
    # print([bb_idx for bb_idx in range(N_BUILDING_BLOCKS) if bb_idx not in used_bbs])

    if args.save_graphs:
        out_path = os.path.join(args.output_dir, "dataset_list_full.pt")
        torch.save(molecules, out_path)
        print(f"Saved {len(molecules)} molecules to {out_path}")
    else:
        print(f"Generated {len(molecules)} molecules")
