from typing import Dict, Optional

import torch

from src.gfn.containers import States, Trajectories, Transitions
from src.gfn.samplers import TrajectoriesSampler
from src.gfn.envs import Env
from src.gfn.losses import (
    EdgeDecomposableLoss,
    Loss,
    Parametrization,
    StateDecomposableLoss,
    DBParametrization,
    TBParametrization,
    SubTBParametrization,
    TrajectoryDecomposableLoss,
    RLParametrization
)
from src.gfn.distributions import Empirical_Dist

def trajectories_to_training_samples(
    trajectories: Trajectories, loss_fn: Loss
) -> tuple[States,States]    | Transitions | Trajectories:
    """Converts a Trajectories container to a States, Transitions or Trajectories container,
    depending on the loss.
    """
    if isinstance(loss_fn, StateDecomposableLoss):
        #return trajectories.to_states()
        return trajectories.intermediary_states,trajectories.last_states
    elif isinstance(loss_fn, TrajectoryDecomposableLoss):
        return trajectories
    elif isinstance(loss_fn, EdgeDecomposableLoss):
        return trajectories.to_transitions()
    else:
        raise ValueError(f"Loss {loss_fn} is not supported.")

def JSD(P, Q):
    """Computes the Jensen-Shannon divergence between two distributions P and Q"""
    M = 0.5 * (P + Q)
    P[P==0.]=Q.min()
    return 0.5 * (torch.sum(P * torch.log(P / M)) + torch.sum(Q * torch.log(Q / M)))

def validate(
    env: Env,
    parametrization: Parametrization,
    sampler:TrajectoriesSampler,
    n_validation_samples: int = 1000,
) -> Dict[str, float]:
    """Evaluates the current parametrization on the given environment.
    This is for environments with known target reward. The validation is done by computing the l1 distance between the
    learned empirical and the target distributions.

    Args:
        env: The environment to evaluate the parametrization on.
        parametrization: The parametrization to evaluate.
        n_validation_samples: The number of samples to use to evaluate the pmf.

    Returns:
        Dict[str, float]: A dictionary containing the l1 validation metric. If the parametrization is a TBParametrization,
        i.e. contains LogZ, then the (absolute) difference between the learned and the target LogZ is also returned in the
        dictionary.
    """

    true_logZ = env.log_partition
    true_dist_pmf = env.true_dist_pmf

    trajectories = sampler.sample(n_trajectories=n_validation_samples)
    final_states_dist= Empirical_Dist(env)
    final_states_dist_pmf = final_states_dist.pmf(trajectories.last_states)

    validation_info= {}
    l1_dist = 0.5*torch.abs(final_states_dist_pmf - true_dist_pmf).sum().item()/true_dist_pmf.sum().item()
    validation_info["l1_dist"]= l1_dist
    return validation_info

import networkx as nx
def check_acylic(states_tensor):
    is_directed = []
    for edges in states_tensor:
        edges = edges.reshape( int(edges.shape[-1]**0.5),
                               int(edges.shape[-1]**0.5)).numpy()
        G = nx.DiGraph(edges)
        is_directed.append(nx.is_directed_acyclic_graph(G))
    return all(is_directed)
##################################
# for param operation
##################################
def set_flat_params_to(model, flat_params):
    prev_ind = 0
    for param in model.parameters():
        flat_size = param.numel()
        param.data.copy_(flat_params[prev_ind:prev_ind + flat_size].view(param.size()))
        prev_ind += flat_size
