import os
from rdkit import Chem
from collections import deque
import numpy as np
from scipy.stats import spearmanr, linregress
import torch

from gflownet.featurizer import GraphStateFeaturizer


def smiles_hash(state):
    m = Chem.RWMol()
    for _ in range(state.num_nodes):
        m.AddAtom(Chem.Atom('C'))
    for u, v in state.edge_list:
        m.AddBond(u, v, order=Chem.BondType.SINGLE)
    return Chem.MolToSmiles(m)


def all_states_info(env):
    is_remove_duplicates = env.remove_duplicates
    env.remove_duplicates = True
    state, st_hash = env.new(), smiles_hash(env.new())
    depth = 0
    state_info = {st_hash: {'num_trajs': 1, 'state': state, 'reward': None, 'depth': depth}}
    frontier = [(state, st_hash)]
    while frontier:
        next_frontier = []
        for state, st_hash in frontier:
            children = env.children(state)
            for child, act in zip(children.states, children.actions):
                if act.type.name == 'Stop':
                    continue
                ch_hash = smiles_hash(child)
                if ch_hash not in state_info:
                    parent_num_trajs = state_info[st_hash]['num_trajs']
                    state_info[ch_hash] = {'num_trajs': parent_num_trajs, 'state': child, 'reward': np.exp(env.log_reward(child)), 'depth': depth + 1}
                    next_frontier.append((child, ch_hash))
                else:
                    state_info[ch_hash]['num_trajs'] += state_info[st_hash]['num_trajs']
        frontier = next_frontier
        depth += 1
    env.remove_duplicates = is_remove_duplicates
    return state_info


def get_state_flows(env):
    all_states = all_states_info(env)
    states = [v['state'] for v in all_states.values()]
    smiles2idx = {k:i for i,k in enumerate(all_states.keys())}
    rewards = [v['reward'] for v in all_states.values()]
    rewards = np.array([0] + rewards[1:])
    depth = np.array([v['depth'] for v in all_states.values()])
    
    max_depth = max(depth)
    flow = np.zeros(len(states))
    leaf_nodes = np.nonzero(depth == max_depth)[0]
    while max_depth > 0:
        flow[leaf_nodes] += rewards[leaf_nodes]
        
        for u in leaf_nodes:
            parents = env.parents(states[u])
            P_B = 1 / len(parents.actions)
            smiles = [smiles_hash(s) for s in parents.states]
            pidx = [smiles2idx[smi] for smi in smiles]
            for idx in pidx:
                flow[idx] += P_B * flow[u]
        max_depth -= 1
        leaf_nodes = np.nonzero(depth == max_depth)[0]
    return flow


@torch.no_grad()
def all_states_log_probs(env, model, featurizer: GraphStateFeaturizer):
    state, st_hash = env.new(), smiles_hash(env.new())
    state_log_prob = {}
    state_log_flow = {st_hash: 0.0}
    queue = deque([(state, st_hash)])
    while queue:
        state, st_hash = queue.popleft()
        state_info = featurizer.transform([state])
        batch = featurizer.collate(state_info['graph'])
        dist, _, _ = model(batch)
        log_probs = torch.log_softmax(dist.logits, dim=0).cpu().numpy()
        next_log_probs = state_log_flow[st_hash] + log_probs
        children = state_info['children'][0]
        for child, action, log_prob in zip(children.states, children.actions, next_log_probs):
            if action.type.name == 'Stop':
                state_log_prob[st_hash] = log_prob
            else:
                ch_hash = smiles_hash(child)
                if ch_hash not in state_log_flow:
                    state_log_flow[ch_hash] = log_prob
                    queue.append((child, ch_hash))
                else:
                    state_log_flow[ch_hash] = np.logaddexp(state_log_flow[ch_hash], log_prob)
    return state_log_prob
    


def get_acc_callback(env, model, sampler, num_samples=200, eval_every=20):
    state_info = all_states_info(env)
    all_rewards = [v['reward'] for v in state_info.values()]
    reward_count = np.bincount(all_rewards)
    length = len(reward_count)
    true_dist = reward_count * np.arange(length)
    all_rewards = np.array(all_rewards)
    mean_reward = np.sum(all_rewards * all_rewards) / np.sum(all_rewards)

    def callback(trainer):
        is_training = model.training
        model.eval()
        if (trainer.iter_num + 1) % eval_every == 0:
            trajectories = sampler.sample_trajectories(model, num_samples)
            sampled_rewards = [env.reward(tr.states[-1]) for tr in trajectories]
            sampled_dist = np.bincount(sampled_rewards, minlength=length)
            stat = spearmanr(true_dist, sampled_dist)
            acc = 100 * min(1, np.mean(sampled_rewards) / mean_reward)
            trainer.logger.log_dict({
                'spearman': stat.statistic,
                'accuracy': acc,
            })
        model.train(is_training)
    return callback


def get_exact_eval_callback(env, model, featurizer, eval_every=20):
    state_info = all_states_info(env)
    all_rewards = np.array([state_info[x]['reward'] for x in sorted(state_info) if state_info[x]['reward'] is not None])
    true_probs = all_rewards / np.sum(all_rewards)

    def callback(trainer):
        is_training = model.training
        model.eval()
        if (trainer.iter_num + 1) % eval_every == 0:
            state2log_prob = all_states_log_probs(env, model, featurizer)
            model_probs = np.exp([state2log_prob[x] for x in sorted(state2log_prob)])
            try:
                spearman = spearmanr(all_rewards, model_probs).statistic
                pearson = linregress(all_rewards, model_probs).rvalue
            except:
                spearman, pearson = None, None
                pass
            l1 = np.linalg.norm(true_probs - model_probs, ord=1)
            l2 = np.linalg.norm(true_probs - model_probs, ord=2)
            linf = abs(np.max(true_probs - model_probs))

            trainer.logger.log_dict({
                'spearman': spearman,
                'pearson': pearson,
                'l1': l1,
                'l2': l2,
                'sup': linf
            })
        model.train(is_training)
    return callback


def get_save_callback(dir_path, save_every=20):
    def callback(trainer):
        it = trainer.iter_num + 1
        if it % save_every == 0:
            path = os.path.join(dir_path, f'model-{it}.pt')
            trainer.save(path)
    return callback