from arguments import parse_arguments
from functools import partial
from rdkit import Chem
import networkx as nx
from collections import defaultdict
import multiprocessing as mp
from datetime import datetime
from typing import List
from model.mydataclass import PathTool

def load_mol_graph(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    graph = nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(mol))
    for atom in mol.GetAtoms():
        graph.nodes[atom.GetIdx()]["atom_indices"] = set([atom.GetIdx()])

    return graph

def load_batch_mols(batch: list):
    mols_batch = [[] for _ in range(len(batch))]
    for i, smiles in enumerate(batch):
        mols_batch[i] = (Chem.MolFromSmiles(smiles), load_mol_graph(smiles))
    return mols_batch
    
def load_mols(data_set, num_workers):
    print(f"[{datetime.now()}] Loading molecules...")
    batch_size = len(data_set) // (num_workers) + 1
    batches = [data_set[i : i + batch_size] for i in range(0, len(data_set), batch_size)]
    mols = []
    with mp.Pool(num_workers) as pool:
        for mols_batch in pool.imap(load_batch_mols, batches):
            mols.extend(mols_batch)
    print(f"{datetime.now()} Loading molecules finished. Total: {len(mols)} molecules.\n")
    return mols

def fragment2smiles(mol: Chem.rdchem.Mol , indices: List[int]):
    smiles = Chem.MolFragmentToSmiles(mol, tuple(indices))
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles, sanitize=False))

def get_batch_pair_stats(batch):
    batch_stats = []
    batch_indices = []
    for i, (mol, graph) in batch:
        for (node1, node2) in graph.edges:
            atom_indices = graph.nodes[node1]["atom_indices"].union(graph.nodes[node2]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            batch_stats.append(motif_smiles)
            batch_indices.append((motif_smiles,i))
    return batch_stats, batch_indices

def get_pair_stats(mols, num_workers):
    stats = defaultdict(int)
    indices = defaultdict(lambda: defaultdict(int))
    print(f"[{datetime.now()}] Begin getting statistics.")
    if num_workers == 1:
        for i, (mol, graph) in enumerate(mols):
            for (node1, node2) in graph.edges:
                atom_indices = graph.nodes[node1]["atom_indices"].union(graph.nodes[node2]["atom_indices"])
                motif_smiles = fragment2smiles(mol, atom_indices)
                stats[motif_smiles] += 1
                indices[motif_smiles][i] += 1
    else:
        batch_size = len(mols) // (mp.cpu_count()) + 1
        mols = [(i, mol) for (i, mol) in enumerate(mols)]
        batches = [mols[i : i + batch_size] for i in range(0, len(mols), batch_size)]
        with mp.Pool(num_workers) as pool:
            for batch_stats, batch_indices in pool.imap(get_batch_pair_stats, batches):
                for motif in batch_stats:
                    stats[motif] += 1
                for (motif, i) in batch_indices:
                    indices[motif][i] += 1
    return stats, indices

def merge_nodes(graph: nx.Graph, node1: int, node2: int) -> None:
    neighbors = [n for n in graph.neighbors(node2)]
    atom_indices = graph.nodes[node1]["atom_indices"].union(graph.nodes[node2]["atom_indices"])
    for n in neighbors:
        if node1 != n and not graph.has_edge(node1, n):
            graph.add_edge(node1, n)
        graph.remove_edge(node2, n)
    graph.remove_node(node2)
    graph.nodes[node1]["atom_indices"] = atom_indices

def update_stats(mol, graph, new_graph, node1, node2, stats, indices, i):
    neighbors1 = [n for n in graph.neighbors(node1)]
    for n in neighbors1:
        if n != node2:
            atom_indices = graph.nodes[node1]["atom_indices"].union(graph.nodes[n]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            stats[motif_smiles] -= 1
            indices[motif_smiles][i] -= 1
    neighbors2 = [n for n in graph.neighbors(node2)]
    for n in neighbors2:
        if n != node1:
            atom_indices = graph.nodes[node2]["atom_indices"].union(graph.nodes[n]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            stats[motif_smiles] -= 1
            indices[motif_smiles][i] -= 1
    neighbors = [n for n in new_graph.neighbors(node1)]
    for n in neighbors:
        atom_indices = new_graph.nodes[node1]["atom_indices"].union(new_graph.nodes[n]["atom_indices"])
        motif_smiles = fragment2smiles(mol, atom_indices)
        stats[motif_smiles] += 1
        indices[motif_smiles][i] += 1

def mp_put_change_info(mol, graph, new_graph, node1, node2, stats_change, indices_change, i):
    neighbors1 = [n for n in graph.neighbors(node1)]
    for n in neighbors1:
        if n != node2:
            atom_indices = graph.nodes[node1]["atom_indices"].union(graph.nodes[n]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            stats_change.append((motif_smiles, -1))
            indices_change.append((motif_smiles, i, -1))
    neighbors2 = [n for n in graph.neighbors(node2)]
    for n in neighbors2:
        if n != node1:
            atom_indices = graph.nodes[node2]["atom_indices"].union(graph.nodes[n]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            stats_change.append((motif_smiles, -1))
            indices_change.append((motif_smiles, i, -1))
    neighbors = [n for n in new_graph.neighbors(node1)]
    for n in neighbors:
        atom_indices = new_graph.nodes[node1]["atom_indices"].union(new_graph.nodes[n]["atom_indices"])
        motif_smiles = fragment2smiles(mol, atom_indices)
        stats_change.append((motif_smiles, 1))
        indices_change.append((motif_smiles, i, 1))

def mp_process(motif: str, mols: list, batch: list):
    pos = mp.current_process()._identity[0] - 1
    mols_change = []
    stats_change = []
    indices_change = []
    for (i, freq) in batch:
        if freq < 1:
            continue
        (mol, graph) = mols[i]
        if graph.number_of_nodes == 1:
            continue
        new_graph = graph.copy()
        for (node1, node2) in graph.edges:
            if not new_graph.has_edge(node1, node2):
                continue
            atom_indices = new_graph.nodes[node1]["atom_indices"].union(new_graph.nodes[node2]["atom_indices"])
            motif_smiles = fragment2smiles(mol, atom_indices)
            if motif_smiles == motif:
                # merge node1 and node2
                graph_before_merge = new_graph.copy()
                merge_nodes(new_graph, node1, node2)
                mp_put_change_info(mol, graph_before_merge, new_graph, node1, node2, stats_change, indices_change, i)
        mols_change.append((i, mol, new_graph))
        indices_change.append((motif, i, "set_0"))
    return mols_change, stats_change, indices_change

def mp_update_stats(mols: list, stats: list, indices: list, mols_change: list, stats_change: list, indices_change: list):
    for (i, mol, new_graph) in mols_change:
        mols[i] = (mol, new_graph)
    for (motif, change) in stats_change:
        stats[motif] += change
    for (motif, i, change) in indices_change:
        if change == "set_0":
            indices[motif][i] = 0
        else:
            indices[motif][i] += change

def replace_pair(motif: str, mols: list, stats: list, indices: list, num_workers: int):
    mols_freq = [(i, freq) for i, freq in indices[motif].items()]
    if num_workers == 1:
        for (i, freq) in mols_freq:
            if freq < 1:
                continue
            (mol, graph) = mols[i]
            if graph.number_of_nodes == 1:
                continue
            new_graph = graph.copy()
            for (node1, node2) in graph.edges:
                if not new_graph.has_edge(node1, node2):
                    continue
                atom_indices = new_graph.nodes[node1]["atom_indices"].union(new_graph.nodes[node2]["atom_indices"])
                motif_smiles = fragment2smiles(mol, atom_indices)
                if motif_smiles == motif:
                    graph_before_merge = new_graph.copy()
                    merge_nodes(new_graph, node1, node2)
                    update_stats(mol, graph_before_merge, new_graph, node1, node2, stats, indices, i)
            mols[i] = (mol, new_graph)
            indices[motif][i] = 0
    else:
        batch_size = len(mols_freq) // (num_workers) + 1
        batches = [mols_freq[i : i + batch_size] for i in range(0, len(mols_freq), batch_size)]
        func = partial(mp_process, motif, mols)
        with mp.Pool(num_workers) as pool:
            for mols_change, stats_change, indices_change in pool.imap(func, batches):
                mp_update_stats(mols, stats, indices, mols_change, stats_change, indices_change)
    stats[motif] = 0

def merging_operation_learning(
    train_path: str,
    operation_path: str,
    num_iters: int,
    min_frequency: int,
    num_workers: int,
    mp_threshold: int,
    log_path: str,
):

    print(f"[{datetime.now()}] Learning merging operations from {train_path}.")
    print(f"Number of workers: {num_workers}. Total number of CPUs: {mp.cpu_count()}.\n")

    data_set = [smi.strip("\n") for smi in open(train_path)]
    mols = load_mols(data_set, num_workers)

    output = open(operation_path, "w")
    stats, indices = get_pair_stats(mols, num_workers)

    trace = []
    for i in range(num_iters):
        print(f"[{datetime.now()}] Iteration {i}.")
        motif= max(stats, key=lambda x: (stats[x], x))
        if stats[motif] < min_frequency:
            print(f"No motif has frequency >= {min_frequency}. Stopping.\n")
            break
        print(f"[Iteration {i}] Most frequent motif: {motif}, frequency: {stats[motif]}.\n")
        trace.append((motif, stats[motif]))
        if stats[motif] >= mp_threshold:
            replace_pair(motif, mols, stats, indices, num_workers)
        else:
            replace_pair(motif, mols, stats, indices, 1)
        output.write(f"{motif}\n")
    output.close()
    print(f"[{datetime.now()}] Merging operation learning finished.")
    print(f"The merging operations are in {operation_path}.\n\n")
    
    with open(log_path, "w") as f:
        for (motif, num) in trace:
            f.write(f"{motif} {num}\n")

    return trace

if __name__ == "__main__":

    args = parse_arguments()
    path_tool = PathTool.from_arguments(args)
    
    learning_trace = merging_operation_learning(
        train_path = path_tool.train_file,
        operation_path = path_tool.operation_path,
        num_iters = args.num_iters,
        min_frequency = args.min_frequency,
        num_workers = args.num_workers,
        mp_threshold = args.mp_thd,
        log_path = path_tool.operation_learning_log_path,
    )