import torch 
import torch.nn as nn 
import tqdm 
import wandb 
import matplotlib.pyplot as plt 
import numpy as np 
import time 

from sal.utils import ModesList, TopKQueue, ReplayBuffer 
from sal.gym.sequences import Sequences 
from sal.gym.hypergrids import Hypergrid 
from sal.gym.sets import Set 
from sal.gflownet import GFlowNet, SALGFlowNet  

def eval_step(log, config, gfn, create_env, filename=None, **kwargs): 
    match config.env: 
        case 'grids': 
            return eval_step_hypergrids(log, gfn, create_env, filename, **kwargs) 
        case 'sequences': 
            return eval_step_sequences(log, gfn, create_env, filename, **kwargs) 
        case 'sets': 
            return eval_step_sets(log, gfn, create_env, filename, **kwargs) 

def train_step(config, 
               gfn: GFlowNet | SALGFlowNet, 
               create_env, 
               double_epochs: bool = False,
               multiply_epochs_by: float | None = None,  
               replay_buffer: ReplayBuffer | None = None, 
               mode_lst: ModesList | None = None, 
               topk_queue: TopKQueue | None = None, 
               topk_value: nn.Module | None = None, 
               max_time: int | None = None, 
               opt: torch.optim.Adam | None = None, 
               filename: str | None = None, 
               **kwargs
            ): 
    if multiply_epochs_by is None and double_epochs: 
        multiply_epochs_by = 2. 
    if multiply_epochs_by is None and not double_epochs: 
        multiply_epochs_by = 1. 
            
    pbar = tqdm.trange(int(config.epochs * multiply_epochs_by)) 
    if opt is None: 
        opt = torch.optim.Adam([
            {'params': gfn.pf.mlp_logit.parameters(), 'lr': config.lr},
            {'params': gfn.pf.mlp_flows.parameters(), 'lr': config.lr * 10}
        ])

    sch = torch.optim.lr_scheduler.PolynomialLR(opt, total_iters=config.epochs, power=1.) 
    
    if mode_lst is None:
        if config.env != 'grids': 
            mode_lst = ModesList.create_mode_lst(config, warmup=True, gfn=gfn, create_env=create_env, epochs=32) 
        else:
            mode_lst = ModesList.create_mode_lst(config) 
            
    if topk_queue is None:
        topk_queue = TopKQueue.create_topk_queue(config) 
    
    s = time.time() 
    for epoch in pbar: 
        if (
            (np.random.uniform() <= config.replay_buffer_freq) and 
            replay_buffer is not None and \
            len(replay_buffer.trajectories) >= 1 
        ):
            trajectories = replay_buffer.sample(config.batch_size)
            loss = gfn.evaluate_loss_on_trajectories(trajectories, **kwargs) 
            opt.zero_grad() 
            loss.backward() 
            opt.step() 
            # pass 

        env = create_env() 
        # I should include an option to return the complete trajectories 
        loss, trajectories = gfn(env, return_trajectories=True, **kwargs) 
        
        if replay_buffer is not None: 
            replay_buffer.append(trajectories, **kwargs)  
        opt.zero_grad() 
        loss.backward()
        opt.step() 
        
        if config.env != 'grids': 
            sch.step() 

        pbar.set_postfix(loss=loss.detach().item()) 

        if topk_value is not None: 
            topk_queue.append(env.unique_input, topk_value(env)) 
        else: 
            topk_queue.append(env.unique_input, env.log_reward()) 

        mode_lst.append(env, topk_value)

        if max_time is not None: 
            if (time.time() - s) >= max_time: 
                break  
        
        if filename is not None and (epoch % 200 == 0): 
            with torch.no_grad(): 
                gfn.eval() 
                eval_step(dict(), config, gfn, create_env, filename=filename, **kwargs) 
                gfn.train()  

    return topk_queue, mode_lst, loss.detach().cpu().item()  

def eval_step_sequences(log: dict, gfn, create_env, filename: str = None, use_wandb: bool = True, **kwargs): 
    target = list() 
    learned = list() 
    
    if 'gflownets' in kwargs: 
        key = 'tv_sal' 
    else: 
        key = 'tv_std'

    key_pi = 'learned_pi_sal' if 'gflownets' in kwargs else 'learned_pi'
    key_pt = 'learned_pt_sal' if 'gflownets' in kwargs else 'learned_pt' 

    for _ in range(128): 
        env = create_env() 
        _, pf = gfn.sample(env, **kwargs) 
        target.append(env.log_reward())
        learned.append(pf) 
    learned, target = torch.hstack(learned), torch.hstack(target)  

    if use_wandb and wandb.run is not None: 
        wandb.run.summary['learned_target_dist'] = (
            learned.detach().cpu().tolist(), target.detach().cpu().tolist() 
        )
        pass 

    from typing import Callable 
    normalize: Callable[[torch.Tensor], torch.Tensor] = lambda t: (t - torch.logsumexp(t, dim=0)).exp() 
    log[key] = (normalize(learned).exp() - normalize(target).exp()).abs().sum().detach().cpu().item() / 2 
    
    if filename is not None: 
        pt = normalize(learned)
        pi = normalize(target)
        log[key_pi] = pt.cpu().tolist() 
        log[key_pt] = pi.cpu().tolist() 

        plt.scatter(
            pt.cpu(), pi.cpu(), rasterized=True 
        )
        plt.title((normalize(learned).exp() - normalize(target).exp()).abs().sum())
        plt.savefig(
            f'{filename if filename is not None else "fig.pdf"}', bbox_inches='tight' 
        )
        plt.clf() 

from sal.gflownet import GFlowNet 
def eval_step_hypergrids(log: dict, gfn: GFlowNet, create_env, filename: str = None, use_wandb: bool = True, **kwargs): 
    gfn.eval() 
    env_base = create_env() # prototype 
    # Flow-based marginal 
    states = torch.cartesian_prod(
        *[torch.arange(env_base.H, device=env_base.device) for _ in range(env_base.dim)]  
    )
    env = Hypergrid(env_base.dim, env_base.H, batch_size=states.shape[0], 
                    log_reward=env_base._log_reward, device=env_base.device)
    env.state = states
    env.stopped[:] = 1.  
    env.is_initial[:] = 0. 
    env.update_mask()   
    log_rewards = env.log_reward() 

    if (gflownets := kwargs.get('gflownets')): 
        env.max_depth = env_base.max_depth 
        marginal_prob = gfn.evaluate_marginal_on_backward_traj(env, num_trajectories=32, gflownets=gflownets) 
        key = 'grid_learned_sal' 
    else: 
        # marginal_prob = gfn.sample_many_backward(env, 32) 
        # log_pt = torch.logsumexp(marginal_prob, dim=1) - np.log(32) 
        # marginal_prob = (log_pt - torch.logsumexp(log_pt, dim=0)).exp() 
        marginal_prob = gfn.evaluate_marginal_on_backward_traj(env, num_trajectories=32)  
        key = 'grid_learned_std'   
    assert torch.isclose(marginal_prob.sum(), torch.ones_like(marginal_prob.sum()))

    counts = torch.zeros((env.H, env.H), device=env.device) 
    target = torch.zeros((env.H, env.H), device=env.device) 
    for idx, state in enumerate(states): 
        counts[
            state[0], state[1]
        ] = marginal_prob[idx] 
        target[
            state[0], state[1] 
        ] = log_rewards[idx] 

    target = (
        target - torch.logsumexp(target.flatten(), dim=0).cpu().item() 
    ).exp() 
    # print(marginal_prob) 

    # Count-based marginal 
    # states = list() 
    # for _ in tqdm.trange(512):
    #     env = create_env() 
    #     states.append(
    #         gfn.sample(env, **kwargs)[0].state   
    #     ) 
    # states = torch.vstack(states) 
    # counts = torch.zeros((env.H, env.H)) 
    # for p in states:
    #     p = p.long()  
    #     counts[p[0], p[1]] += 1 
    
    counts = counts.detach().cpu()  
    if use_wandb and wandb.run is not None: 
        wandb.run.summary['learned_target_dist'] = counts.tolist()  
    if filename is not None: 
        plt.subplot(1, 2, 1) 
        plt.imshow(counts) 
        plt.subplot(1, 2, 2) 
        plt.imshow(target.cpu()) 
        plt.savefig(filename)
        plt.clf() 

    log[key] = counts.tolist()
    log['grid_target'] = target.detach().cpu().tolist() 
 

from sal.gflownet import SALGFlowNet 
def eval_step_sets(log: dict, gfn: SALGFlowNet, create_env, filename: str = None, use_wandb: bool = True, **kwargs):
    log_pt_lst = list() 
    log_pi_lst = list() 
    key = 'tv_sal' if 'gflownets' in kwargs else 'tv_std' 
    key_pi = 'learned_pi_sal' if 'gflownets' in kwargs else 'learned_pi'
    key_pt = 'learned_pt_sal' if 'gflownets' in kwargs else 'learned_pt' 

    for _ in tqdm.tqdm(range(log['epochs_eval'])): 
        states = create_env() 
        with gfn.off_policy(): 
            states, _ = gfn.sample(states)
        log_pt = gfn.evaluate_marginal_on_backward_traj(states, num_trajectories=8, normalize=False, **kwargs)   

        log_pt_lst.append(log_pt) 
        log_pi_lst.append(states.log_reward())  
    
    from typing import Callable 
    normalize: Callable[[torch.Tensor], torch.Tensor] = lambda logits: (logits - torch.logsumexp(logits, dim=0)).exp() 

    pt = normalize(torch.hstack(log_pt_lst)) 
    pi = normalize(torch.hstack(log_pi_lst)) 

    if wandb.run is not None and use_wandb: 
        wandb.run.summary['learned_target_dist'] = (pt.cpu().tolist(), pi.cpu().tolist()) 

    if filename is not None: 
        log[key_pi] = pt.cpu().tolist() 
        log[key_pt] = pi.cpu().tolist() 
        plt.scatter(
            pt.cpu(), pi.cpu(), rasterized=True  
        )
        plt.savefig(filename) 
        plt.clf() 

    tv = (pt - pi).abs().sum().cpu().item() / 2 
    log[key] = tv 
    return tv 

def get_env(config): 
    match config.env: 
        case 'sequences': 
            return Sequences 
        case 'grids': 
            return Hypergrid 
        case 'sets': 
            return Set 
        case _: 
            raise ValueError 

def create_log_reward_model(config, log_reward): 
    match config.env: 
        case 'sequences': 
            from sal.gym.sequences import LogRewardModel 
            return LogRewardModel(log_reward, config.device) 
        case 'grids': 
            from sal.gym.hypergrids import LogRewardModel 
            return LogRewardModel(log_reward, config.device) 
        case 'sets': 
            from sal.gym.sets import LogRewardModel
            return LogRewardModel(log_reward, config.device) 
        case _: 
            raise ValueError 
