import torch 
import torch.nn as nn
import torch.distributions as D

import os 
import tqdm 
import wandb 
import copy 
import numpy as np 
import seaborn as sns 
import matplotlib.pyplot as plt 

from abc import ABCMeta, abstractmethod 
from typing import Dict, Tuple, List 
from copy import deepcopy 

class MockOptimizer(torch.optim.Optimizer): 

    def __init__(self): 
        self.param_groups = list() 

    def step(self): 
        pass 
    
    def zero_grad(self): 
        pass 

class OptGroup: 

    def __init__(self): 
        self.optimizers = dict() 
        
    def register(self, opt, name): 
        self.optimizers[name] = opt 

    def __getattr__(self, name): 
        return self.optimizers[name] 
    
class SchGroup: 

    def __init__(self): 
        self.schedulers = dict() 

    def register(self, sch, name): 
        self.schedulers[name] = sch 
    
    def __getattr__(self, name): 
        return self.schedulers[name] 

def sample_from_dirichlet_multinomial(total_counts: torch.Tensor, concentration: torch.Tensor): 
    batch_size = total_counts.shape[0]  
    total_counts_sum = total_counts.sum(dim=0)  

    if total_counts_sum == 0: 
        return torch.zeros((batch_size, concentration.shape[0])) 

    p = D.Dirichlet(concentration).sample((1,)).squeeze(dim=0) 
    samples = D.OneHotCategorical(probs=p).sample((total_counts_sum,)) # (total_counts_sum, dim) 
    
    m = torch.zeros((batch_size, total_counts_sum), device=concentration.device) 
    m[
        torch.repeat_interleave(torch.arange(batch_size, device=total_counts.device), total_counts), 
        torch.arange(total_counts_sum) 
    ] = 1 
    return m @ samples  

class TopKQueue: 

    def __init__(self, k: int, dim: int, device: str = 'cpu'): 
        self.k = k 
        self.dim = dim 
        self.device = device 
        self.queue_key = torch.zeros((k, dim), device=self.device)
        self.queue_value = -torch.ones((k,), device=self.device) * torch.inf 
    
    def append(self, key: torch.Tensor, value: torch.Tensor): 
        # Remove already existing elements  
        is_in_queue = (
            self.queue_key.view(1, self.k, self.dim) == key.view(-1, 1, self.dim)
        ).all(dim=2).any(dim=1)  
        
        key = key[~is_in_queue] 
        value = value[~is_in_queue] 

        # Concatenate the key and value stacks and keep the k largest 
        cat_key = torch.vstack([self.queue_key, key]) 
        cat_value = torch.hstack([self.queue_value, value]) 

        indices_topk = torch.topk(cat_value, k=self.k).indices 

        self.queue_key = cat_key[indices_topk] 
        self.queue_value = cat_value[indices_topk] 

    @staticmethod 
    def create_topk_queue(config):
        match config.env: 
            case 'grids':
                return TopKQueue(k=config.H*2, dim=config.dim, device=config.device)
            case 'sequences': 
                return TopKQueue(k=config.topk, dim=config.seq_size, device=config.device)
            case 'sets' | 'bags': 
                return TopKQueue(k=config.topk, dim=config.src_size, device=config.device)
            case _: 
                raise ValueError 
    
    def stats(self): 
        return (
            self.queue_value.mean().cpu().detach().item(), 
            self.queue_value.std().cpu().detach().item()
        ) 

class ModesList:

    def __init__(self, q=.9, th=None): 
        self.th = th 
        self.q = q 
        self.modes = list() 

    def append(self, env, log_reward_func: nn.Module | None = None): 
        assert self.th is not None 
        if log_reward_func is None: 
            log_reward = env.log_reward() 
        else: 
            log_reward = log_reward_func(env) 
        is_mode = log_reward >= self.th
        # Check if modes are already in list 
        if len(self.modes) != 0:         
            is_in_mode_lst = (
                torch.vstack(self.modes).view(1, -1, env.unique_input.shape[1]) == env.unique_input.view(env.batch_size, 1, -1)
            ).all(dim=-1).any(dim=1) 
        else: 
            is_in_mode_lst = torch.zeros_like(is_mode, dtype=bool)
        self.modes.append(
            torch.unique(
                env.unique_input[is_mode & (~is_in_mode_lst)], dim=0
            )
        )
        
    def warmup(self, gfn, create_env_func, epochs):  
        log_reward_lst = list() 
        for _ in range(epochs): 
            env = gfn.sample(create_env_func())  
            if isinstance(env, tuple): # SAGFlowNet 
                env, _ = env 
            log_reward_lst.append(env.log_reward()) 
        log_reward_lst = torch.hstack(log_reward_lst) 
        self.th = torch.quantile(log_reward_lst, q=self.q)  
    
    def __len__(self): 
        if len(self.modes) == 0: 
            return 0 
        return len(torch.vstack(self.modes)) 

    @staticmethod 
    def create_mode_lst(config, warmup=False, gfn=None, create_env=None, epochs=None): 
        if not warmup:
            return ModesList(q=config.q, th=config.th)
        else: 
            mode_lst = ModesList(q=config.q) 
            mode_lst.warmup(gfn, create_env, epochs) 
            return mode_lst 

def train_step(gfn, create_env, config, opt_group, sch_group, use_wandb=True, mode_lst: ModesList = None): 

    loss_ls = torch.nan 
    for epoch in (pbar := tqdm.trange(config.epochs_per_step, disable=config.disable_pbar)): 
    
        opt_group.opt_gfn.zero_grad() 
        opt_group.opt_gamma.zero_grad() 

        env = create_env() 
        loss = gfn(env) 

        loss.backward() 
        opt_group.opt_gamma.step() 
        opt_group.opt_gfn.step()

        sch_group.sch_gfn.step() 
        sch_group.sch_gamma.step() 
        
        log = {'loss_gfn': loss, 'loss_ls': loss_ls}
        if use_wandb: 
            wandb.log(log) 
        pbar.set_postfix(**log)  
        
        if config.criterion == 'td': 
            gfn.gamma_func.step() 

        if mode_lst is not None: 
            mode_lst.append(env) 
            
    return gfn 
        
class Environment(metaclass=ABCMeta):

    def __init__(self, batch_size: int, max_trajectory_length: int, log_reward: nn.Module, device: torch.device):
        self.batch_size = batch_size
        self.device = device 
        self.max_trajectory_length = max_trajectory_length
        self.batch_ids = torch.arange(self.batch_size, device=self.device) 
        self.traj_size = torch.ones((self.batch_size,), device=self.device)         
        self.stopped = torch.zeros((self.batch_size), device=self.device)
        self.is_initial = torch.ones((self.batch_size,), device=self.device)
        self._log_reward = log_reward

        self.node_indices = None 

    @abstractmethod
    def apply(self, actions: torch.Tensor):
        pass

    @abstractmethod
    def backward(self, actions: torch.Tensor):
        pass

    @torch.no_grad()
    def log_reward(self, **reward_params):
        return self._log_reward(self, **reward_params)
    
    @torch.no_grad() 
    def merge(self, batch_state): 
        self.batch_ids = torch.hstack([self.batch_ids, self.batch_size + batch_state.batch_ids]) 
        self.batch_size += batch_state.batch_size 
        self.stopped = torch.hstack([self.stopped, batch_state.stopped]) 
        self.is_initial = torch.hstack([self.is_initial, batch_state.is_initial]) 
        self.traj_size = torch.hstack([self.traj_size, batch_state.traj_size])  
        if isinstance(self.node_indices, torch.Tensor):
            self.node_indices = torch.hstack([self.node_indices, batch_state.node_indices]) 
 
    def get(self, indices): 
        copy_self = deepcopy(self)
        copy_self.batch_ids = torch.arange(len(indices), device=copy_self.device) 
        copy_self.stopped = self.stopped[indices]
        copy_self.batch_size = len(indices) 
        copy_self.is_initial = self.is_initial[indices]
        copy_self.traj_size = self.traj_size[indices] 
        if isinstance(self.node_indices, torch.Tensor): 
            copy_self.node_indices = self.node_indices[indices] 
        return copy_self 

    @property 
    def unique_input(self): 
        raise NotImplementedError 

def lexsort(tensor, descending: bool = False): 
    # First sort by column [0], then by column [1] 
    indices = torch.argsort(tensor[:, 0], stable=True, descending=descending) 
    return torch.argsort(tensor[indices, 1], stable=True, descending=descending) 
    pass 

class ReplayBuffer:  
    
    def __init__(self, size, max_states_freq=2, device='cpu'): 
        self.size = size 
        self.trajectories: List[Tuple[Environment, torch.Tensor, Environment]] = list() 
        self.device = device 
        self.max_states_freq = max_states_freq  
        self.log_rewards = None 

    def append(self, trajectories: List[Tuple[Environment, torch.Tensor, Environment]], **kwargs):
        # Ideally, the replay buffer would be as diverse as possible 
        # To ensure this, we may add a constraint that a given 
        # terminal state can appear at most `K` times within the RB, for a given K 
        traj_length = trajectories[-1][-1].max_trajectory_length  
        if len(self.trajectories) < 1: 
            self.trajectories = trajectories 

            # pad the trajectories  
            last_seen_state = self.trajectories[-1][-1]  
            for _ in range(traj_length - len(self.trajectories)): 
                self.trajectories.append(
                    (
                        last_seen_state, 
                        -torch.ones((last_seen_state.batch_size,), dtype=int, device=self.device), 
                        last_seen_state
                    )  
                ) 

            states_tp1_curr = trajectories[-1][-1]  
            self.log_rewards = states_tp1_curr.log_reward(**kwargs) 
            trajectory_length = torch.ones((states_tp1_curr.batch_size,), device=states_tp1_curr.device) 
        else: 
            # Check which states within `trajectories` are already in self.trajectories 
            
            # pad the novel trajectories  
            last_seen_state = trajectories[-1][-1] 
            for _ in range(traj_length - len(trajectories)): 
                trajectories.append(
                    (
                        last_seen_state, 
                        -torch.ones((last_seen_state.batch_size,), dtype=int, device=self.device), 
                        last_seen_state
                    ) 
                )

            terminal_states = trajectories[-1][-1].unique_input 
            curr_terminal_states = self.trajectories[-1][-1].unique_input 
            dim = terminal_states.shape[1] 
            state_should_be_added = (
                terminal_states.view(-1, 1, dim) == curr_terminal_states.view(1, -1, dim)  
            ).all(dim=2).sum(dim=1) <= self.max_states_freq  
            if not state_should_be_added.any(): 
                # Keep the replay buffer as it is 
                return 
            state_should_be_added, = torch.where(state_should_be_added)
            trajectory_length = torch.zeros(
                (state_should_be_added.shape[0] + curr_terminal_states.shape[0],), device=terminal_states.device  
            )
            for idx, (states_t, actions, states_tp1) in enumerate(trajectories): 
                batch = self.trajectories[idx]
                states_t_curr, actions_curr, states_tp1_curr = batch 
                
                states_t_curr.merge(states_t.get(state_should_be_added)) 
                actions_curr = torch.hstack([actions_curr, actions[state_should_be_added]]) 
                states_tp1_curr.merge(states_tp1.get(state_should_be_added)) 
                self.trajectories[idx] = (
                    (states_t_curr, actions_curr, states_tp1_curr) 
                )
                trajectory_length += (states_tp1_curr.stopped != 1)  
            self.log_rewards = torch.hstack([self.log_rewards, states_tp1.log_reward(**kwargs)[state_should_be_added]]) 
        # Filter the states with highest rewards (according to `size`)  
        # Preferentially store larger trajectories 
        indices = lexsort(
            torch.cat([trajectory_length.view(-1, 1), self.log_rewards.view(-1, 1)], dim=1), descending=True
        )[:min(self.size, states_tp1_curr.batch_size)] 
        for idx, (states_t, actions, states_tp1) in enumerate(self.trajectories): 
            self.trajectories[idx] = (
                states_t.get(indices), actions[indices], states_tp1.get(indices) 
            )
        self.log_rewards = self.log_rewards[indices]  
        # pass 
    
    @torch.no_grad() 
    def sample(self, batch_size): 
        # Sample proportionally to the reward 
        probs = (
            self.log_rewards - torch.logsumexp(self.log_rewards, dim=0) 
        ).exp() 
        indices = np.random.choice(
            self.trajectories[-1][-1].batch_size, size=(batch_size,), p=probs.cpu() 
        )
        batch_traj = list() 
        for (states_t, actions, states_tp1) in self.trajectories: 
            batch_traj.append(
                (states_t.get(indices), actions[indices], states_tp1.get(indices))   
            )
        return batch_traj 
        # pass 

class Swish(nn.Module): 
    def forward(self, x: torch.Tensor): 
        return x * x.sigmoid()  

class BaseNN(nn.Module):

    def __init__(self, input_dim, hidden_dim, num_layers, output_dim=None, act=nn.LeakyReLU(), clip: bool = False): 
        super(BaseNN, self).__init__() 
        self.input_dim = input_dim 
        self.hidden_dim = hidden_dim 
        self.num_layers = num_layers 
        self.output_dim = output_dim if output_dim is not None else hidden_dim 
        self.clip = clip 

        self.model = nn.Sequential(nn.Linear(input_dim, hidden_dim)) 
        for layer in range(num_layers): 
            self.model.append(act) 
            if layer == num_layers - 1: 
                self.model.append(nn.Linear(self.hidden_dim, self.output_dim)) 
            else: 
                self.model.append(nn.Linear(self.hidden_dim, self.hidden_dim)) 
                # self.model.append(nn.LayerNorm((self.hidden_dim,)))

    def forward(self, x): 
        if not self.clip: 
            return self.model(x) 
        else: 
            return torch.clip(
                self.model(x), -5e2, 5e2 
            ) 

class BaseTransformer(nn.Module): 
    # Self-attention layer 

    def __init__(self, input_dim, hidden_dim, num_layers, output_dim, num_heads=4, device='cpu'):
        super(BaseTransformer, self).__init__() 
        self.device = device 
        self.qkv = BaseNN(input_dim, 3 * hidden_dim * num_heads, num_layers=num_layers).to(self.device)  
        self.attn = torch.nn.MultiheadAttention(
            hidden_dim * num_heads, 
            num_heads=num_heads, 
            batch_first=True
        ).to(self.device) 
        self.out = nn.Linear(num_heads * hidden_dim, output_dim).to(self.device) 
    
    def forward(self, x):
        q, k, v = torch.chunk(self.qkv(x), chunks=3, dim=1) 
        attn_out, _ = self.attn(q, k, v, need_weights=False) 
        out = self.out(attn_out) 
        return out 

class ForwardPolicyMeta(nn.Module, metaclass=ABCMeta): 
    
    masked_value = -1e5 

    def __init__(self, eps=.05, device='cpu'): 
        super(ForwardPolicyMeta, self).__init__() 
        self.eps = eps 
        self.seed = None
        self.device = device 

        self.force_mask_idx = None 

    @abstractmethod 
    def get_latent_emb(self): 
        pass 

    @abstractmethod 
    def get_pol(self): 
        pass 

    def set_seed(self, seed): 
        self.seed = seed 
    
    def unset_seed(self): 
        self.seed = None 

    def get_actions(self, pol, mask=None): 
        if mask is None: 
            uniform_pol = torch.ones_like(pol) 
        else: 
            uniform_pol = torch.where(mask==1., 1., 0.)
            uniform_pol = uniform_pol / uniform_pol.sum(dim=1, keepdims=True)  

        eps = 0. if not self.training else self.eps 
        exp_pol = pol * (1 - eps) + eps * uniform_pol 
        
        if self.force_mask_idx is not None and self.training: 
            exp_pol[:, self.force_mask_idx] = 0 
            exp_pol /= exp_pol.sum(dim=1, keepdims=True) 

        if self.seed is not None: 
            g = torch.Generator(device=self.device) 
            g.manual_seed(self.seed) 
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True, generator=g) 
        else: 
            actions = torch.multinomial(exp_pol, num_samples=1, replacement=True) 
        actions = actions.squeeze(dim=-1) 
        return actions, exp_pol  

    def forward(self, batch_state, actions=None, return_ps=False, perturb_params=False, **kwargs):
        del perturb_params # compatibility with `BayesianPolicy` 
        if not hasattr(batch_state, 'forward_mask'): 
            batch_state.forward_mask = None  
        latent_emb = self.get_latent_emb(batch_state, **kwargs) 
        pol, gflows = self.get_pol(latent_emb, batch_state.forward_mask) 
        spol = pol # sampling policy  
        if actions is None: actions, spol = self.get_actions(pol, batch_state.forward_mask) 
        if return_ps: 
            return actions, \
                torch.log(pol[batch_state.batch_ids, actions]), \
                torch.log(pol[batch_state.batch_ids, -1]), \
                    torch.log(spol[batch_state.batch_ids, actions]) 
        else: 
            return actions, \
                torch.log(pol[batch_state.batch_ids, actions]), \
                gflows, \
                torch.log(spol[batch_state.batch_ids, actions]) 

def log_artifact_tensor(tensor, artifact_name): 
    torch.save(tensor, artifact_name)  
    artifact = wandb.Artifact(artifact_name, type='tensor') 
    artifact.add_file(artifact_name) 
    wandb.run.log_artifact(artifact) 

def load_artifact_tensor(artifact_name): 
    artifact = wandb.run.use_artifact(f'weekday/{os.environ["WANDB_PROJECT_NAME"]}/{artifact_name}:latest', 
                                      type='tensor') 
    artifact_dir = artifact.download() 
    return torch.load(os.path.join(artifact_dir, artifact_name)) 

def log_artifact_module(module, artifact_name, artifact_filename): 
    torch.save(module.state_dict(), artifact_filename) 
    artifact = wandb.Artifact(artifact_name, type='model') 
    artifact.add_file(artifact_filename) 
    wandb.run.log_artifact(artifact) 

def load_artifact_module(module, artifact_name, artifact_filename): 
    artifact = wandb.run.use_artifact(f'weekday/{os.environ["WANDB_PROJECT_NAME"]}/{artifact_name}:latest', 
                                        type='model') 
    artifact_dir = artifact.download() 
    module.load_state_dict(torch.load(os.path.join(artifact_dir, 
                                                   os.path.basename(artifact_filename)))) 
    return module 

@torch.no_grad() 
def sample_massive_batch(gflownet, create_env, num_batches, num_back_traj=1, use_progress_bar=True, **kwargs): 
    env = create_env() 
    env = gflownet.sample(env, **kwargs) 
    if isinstance(env, tuple): 
        env, _ = env 
    marginal_log = gflownet.sample_many_backward(env, num_trajectories=num_back_traj, **kwargs) 
    log_rewards = env.log_reward(**kwargs) 

    for _ in tqdm.trange(num_batches - 1, disable=not use_progress_bar): 
        env_i = create_env()
        env_i = gflownet.sample(env_i, **kwargs) 
        env.merge(env_i) 
        marginal_log = torch.vstack([marginal_log, gflownet.sample_many_backward(
                            env_i, num_trajectories=num_back_traj)], **kwargs)
        log_rewards = torch.hstack([log_rewards, env_i.log_reward(**kwargs)]) 

    return env, marginal_log, log_rewards   

def unique(x, dim=-1):
    values, inverse, counts = torch.unique(x, return_inverse=True, return_counts=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=values.device)
    # inverse, perm = inverse.flip([dim]), perm.flip([dim])
    return values, inverse, counts, inverse.new_empty(values.size(dim)).scatter_(dim, inverse, perm)

def marginal_dist(env, marginal_log, log_rewards, dim=-1, normalize=True):
    values, inverse, counts, indices = unique(env.unique_input, dim=dim)
    # Compute learned distribution 
    marginal_log_batch = torch.zeros((values.size(0), marginal_log.shape[1]), device=values.device)
    marginal_log_batch.scatter_add_(dim=0, index=inverse.view(-1, 1), src=marginal_log.exp())
    marginal_log_batch = marginal_log_batch.sum(dim=-1)
    marginal_log_batch /= (counts * marginal_log.shape[1]) 

    if normalize: 
        learned_dist = marginal_log_batch / marginal_log_batch.sum()  
        # Compute the target distribution 
        target_dist = (log_rewards[indices] - torch.logsumexp(log_rewards[indices], dim=0)).exp() 
    else: 
        learned_dist = marginal_log_batch 
        target_dist = log_rewards[indices] 
    
    return learned_dist, target_dist 

def compute_marginal_dist(gflownet, create_env, num_batches, num_back_traj, use_progress_bar=False): 
    # Sample from the learned distribution 
    samples, marginal_log, log_rewards = sample_massive_batch(gflownet, create_env, num_batches, 
                                                             num_back_traj, use_progress_bar) 
    return marginal_dist(samples, marginal_log, log_rewards, dim=0) 

class LogRewardProduct(nn.Module): 

    def __init__(self, device): 
        super(LogRewardProduct, self).__init__() 
        self.device = device 

    @torch.no_grad() 
    def forward(self, batch_state, log_rewards: List[nn.Module]): 
        log_reward = torch.zeros((batch_state.batch_size,), device=self.device)
        for log_reward_func in log_rewards: 
            log_reward += log_reward_func(batch_state)  
        return log_reward 

@torch.no_grad() 
def compute_marginal_grid(gfn, create_env_func, num_back_traj): 
    env = create_env_func() 
    # Ensure the correct batch size 
    env.batch_size = (env.width + 1) * (env.height + 1)  
    env.batch_ids = torch.arange(env.batch_size, device=env.device) 
    env.forward_mask = torch.zeros((env.batch_size, 3), device=env.device) 
    env.backward_mask = torch.zeros((env.batch_size, 2), device=env.device) 
    env.stopped = torch.ones((env.batch_size,), device=env.device) 
    env.is_initial = torch.zeros((env.batch_size,), device=env.device) 

    env.pos = torch.meshgrid(
        torch.arange(env.width + 1, device=env.device), torch.arange(env.height + 1, device=env.device) 
    )  
    env.pos = torch.cat(
        (env.pos[0].unsqueeze(0), env.pos[1].unsqueeze(0)), dim=0   
    ).flatten(start_dim=1).t().type(env.forward_mask.dtype) 

    # env.update_forward_mask() 
    env.update_backward_mask() 

    env.forward_mask[:, -1] = 1 

    marginal_log = gfn.sample_many_backward(env, num_back_traj)
    learned_log_prob = torch.logsumexp(marginal_log, dim=1) - np.log(num_back_traj) 
    targetd_log_prob = env.log_reward() 

    learned_log_prob = learned_log_prob - torch.logsumexp(learned_log_prob, dim=0) 
    targetd_log_prob = targetd_log_prob - torch.logsumexp(targetd_log_prob, dim=0)  
    return learned_log_prob.exp(), targetd_log_prob.exp(), env.pos 

@torch.no_grad() 
def compute_fcs(gflownet, create_env, config, **kwargs): 
    if config.env == 'grids': 
        learned_prob, target_prob, _ = compute_marginal_grid(gflownet, create_env, num_back_traj=config.num_back_traj) 
        return (learned_prob - target_prob).abs().sum() / 2. 

    l1_lst = list() 
    bs = copy.copy(config.batch_size) 

    config.batch_size = min(config.fcs_bucket_size, config.batch_size)  
    
    for idx in tqdm.trange(config.epochs_eval, disable=config.disable_pbar): 
        samples, marginal_log, log_rewards = sample_massive_batch(
                    gflownet, 
                    create_env, 
                    # Note that this is num_iterations - 1 
                    num_batches=(config.fcs_bucket_size // config.batch_size) - 1, 
                    num_back_traj=config.num_back_traj,  
                    use_progress_bar=False, **kwargs) 
        assert samples.batch_size == config.fcs_bucket_size 
        # Compute the L1 distance for the sampled subset 
        learned_dist, target_dist = marginal_dist(samples, marginal_log, log_rewards, dim=0) 
        l1 = (learned_dist - target_dist).abs().sum() 
        l1_lst.append(l1) 
    # Re-assign the initial batch size 
    config.batch_size = bs 
    l1_lst = .5 * torch.tensor(l1_lst) 
    return l1_lst.mean() 

@torch.no_grad() 
def compute_tv(gflownet, create_env, config, verbose=True, return_dist=False): 
    env = create_env() 
    
    learned_dist_lst = list()
    target_dist_lst = list() 
    states = list() 
    
    if verbose: print('Enumerating the state space') 
    for batch_state in env.list_all_states(max_batch_size=config.batch_size):  
        assert batch_state.batch_size <= config.batch_size 
        log_rewards = batch_state.log_reward()  
        marginal_log = gflownet.sample_many_backward(batch_state, num_trajectories=config.num_back_traj)
        learned_dist, target_dist = marginal_dist(batch_state, 
                    marginal_log, log_rewards, dim=0, normalize=False) 
        learned_dist_lst.append(learned_dist) 
        target_dist_lst.append(target_dist) 
    
    if verbose: print('Computing the L1 norm') 
    learned_dist = torch.hstack(learned_dist_lst) 
    target_dist = torch.hstack(target_dist_lst) 

    learned_dist = (torch.log(learned_dist) - torch.logsumexp(torch.log(learned_dist), dim=0)).exp()   
    target_dist = (target_dist - torch.logsumexp(target_dist, dim=0)).exp() 

    tv = .5 * (learned_dist - target_dist).abs().sum() 

    if return_dist: 
        return tv, learned_dist, target_dist 
    else: 
        return tv 

def plot_topk_num_modes_histogram(data: Dict[str, Tuple | float], filename: str | None = None):      
    number_of_modes = data['num_modes'] 
    keys = ['std', 'sal']
    plt.subplot(1, 2, 1) 
    sns.histplot(
        x=keys, weights=[number_of_modes[k] for k in keys], discrete=True 
    )

    topk_avg = data['topk'] 
    plt.subplot(1, 2, 2)  
    sns.histplot(
        x=keys, weights=[topk_avg[k][0] for k in keys], discrete=True 
    )

    plt.tight_layout() 

    if filename is not None: 
        plt.savefig(
            filename, bbox_inches='tight'  
        )
        plt.clf()
