import torch
import wandb 
import matplotlib.pyplot as plt 
import tqdm 

import sys 
sys.path.append('.') 

# gflownet and general utils 
from experiments.experiments_utils import get_argument_parser, \
    create_gfn, create_env, create_log_reward, eval_step 
from var_red_gfn.utils import train_step 

WANDB_PROJECT_NAME = 'mode_seeking_gflownets' 
    
def main(config): 
    torch.set_default_dtype(torch.float64) 
    torch.manual_seed(config.seed) 

    # instantiate the gflownet 
    gfn = create_gfn(config) 
    log_reward = create_log_reward(config, gfn) 
    
    create_env_func = lambda: create_env(config, log_reward) 
    
    assert config.env in ['gmms', 'grids'] 

    # train the gflownet 
    alpha_div_criteria = ['renyi']
    # assert config.criterion in alpha_div_criteria 
    # Training hyperparameters 
    optimizer = torch.optim.Adam([
        {'params': gfn.pf.parameters(), 'lr': 1e-3}, 
        {'params': gfn.log_z, 'lr': 1e1} 
    ])
    if config.use_scheduler: 
        scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, 
            total_iters=config.epochs_per_step*config.num_steps, power=.5) 
    else: 
        scheduler = None 

    with gfn.on_policy(): 
        # Train the model 
        train_step(gfn, create_env_func, epochs=config.epochs, 
            optimizer=optimizer, scheduler=scheduler, 
            use_progress_bar=config.use_progress_bar)     

    # Compute the marginals 
    match config.env: 
        case 'grids': 
            from var_red_gfn.gym.grids import Grid2D 
            learned_dist, target_dist, coordinates = Grid2D.compute_marginal(
                gfn, create_env_func, config.num_back_traj, filename=None 
            ) 
            matrix = torch.zeros((config.width + 1, config.height + 1), device=config.device) 
            matrix[
                coordinates[:, 0].long(), coordinates[:, 1].long()  
            ] = learned_dist 
            wandb.run.summary['samples'] = matrix.cpu().tolist() 
        case 'gmms': 
            from var_red_gfn.utils import sample_massive_batch
            with gfn.on_policy(): 
                samples = list() 
                for _ in tqdm.trange(config.epochs, disable=not config.use_progress_bar): 
                    env = create_env_func() 
                    env = gfn.sample(env) 
                    samples.append(env.state) 
                samples = torch.vstack(samples).cpu() 
                wandb.run.summary['samples'] = samples.cpu().tolist() 

if __name__ == '__main__': 
    parser = get_argument_parser() 
    config = parser.parse_args() 
    wandb.init(project=WANDB_PROJECT_NAME, tags=[f'{config.env}', f'{config.criterion}', f'{config.alpha}', f'{config.seed}']) 
    wandb.config.update(config) 
    main(config)  
