import torch
import wandb 
import matplotlib.pyplot as plt 

import sys 
sys.path.append('.') 

# gflownet and general utils 
from experiments.experiments_utils import (
                            get_argument_parser, 
                            create_gfn, 
                            create_env, 
                            create_log_reward, eval_step 
                        )
from streaming_gfn.utils import train_gfn 

WANDB_PROJECT_NAME = 'streaming_gfn_vi' 

def main(config): 
    torch.set_default_dtype(torch.float64) 
    torch.manual_seed(config.seed) 

    # instantiate the gflownet 
    gfn = create_gfn(config) 
    log_reward = create_log_reward(config, gfn) 
    
    create_env_func = lambda: create_env(config, log_reward) 
    
    # Training hyperparameters 
    optimizer = torch.optim.Adam([
        {'params': gfn.pf.parameters(), 'lr': 1e-3}, 
        {'params': gfn.log_z, 'lr': 1e-1} 
    ])
    scheduler = None 
    if config.use_scheduler: 
        total_iters = config.epochs_per_step * config.num_steps 
        print(f'Number of training iteratios: {total_iters}')
        scheduler = torch.optim.lr_scheduler.PolynomialLR(optimizer, total_iters=total_iters, power=1.)

    metric = list() 
    epochs = list() 

    metric.append(
        eval_step(config, gfn, create_env_func, plot=False) 
    ) 
    print(metric[-1]) 
    epochs.append(0) 

    for step in range(config.num_steps): 
        # Train step 

        gfn.train() 
        train_gfn(gfn, create_env_func, optimizer, scheduler, config) 

        epochs.append(config.epochs_per_step * (step + 1)) 
        # evaluate the gflownet  
        metric.append( 
            eval_step(config, gfn, create_env_func, plot=(step==config.num_steps-1))  
        ) 
        print(metric[-1]) 

    print(metric) 
    wandb.run.summary['metric'] = metric 
    wandb.run.summary['epochs'] = epochs 

if __name__ == '__main__': 
    parser = get_argument_parser() 
    config = parser.parse_args() 
    wandb.init(project=WANDB_PROJECT_NAME, tags=['learning_objectives', f'{config.env}', f'{config.criterion}', f'{config.seed}']) 
    wandb.config.update(config) 
    main(config)  
