import torch 
import wandb 
import matplotlib.pyplot as plt 

import itertools 
import os 
import sys 
sys.path.append('.') 

# gflownet and general utils 
from experiments.experiments_utils import get_argument_parser, \
    create_gfn, create_env, create_log_reward, create_optimizers 

from gfn.gflownet import LEDGFlowNet 
from gfn.utils import compute_fcs, train_step, MockOptimizer

WANDB_PROJECT_NAME = 'led_gflownets' 
os.environ['WANDB_PROJECT_NAME'] = WANDB_PROJECT_NAME 

def main(config): 
    torch.set_default_dtype(torch.float64)
    torch.manual_seed(config.seed)
     
    # Instantiate the environment 
    gfn = create_gfn(config) 
    log_reward = create_log_reward(config, gfn) 
    create_env_func = lambda: create_env(config, log_reward=log_reward) 

    opt_group, sch_group = create_optimizers(gfn, config) 

    # Train and evaluate the model 
    summary_fcs = list() 
    
    # Evaluate before training 
    with gfn.off_policy(): 
        fcs = compute_fcs(gfn, create_env_func, config)
        wandb.log({'fcs': fcs}) 
        summary_fcs.append(fcs) 
        print(fcs) 

    for _ in range(config.num_steps): 
        # Train 
        gfn.train() 
        train_step(gfn, create_env_func, config, opt_group, sch_group)  

        # Evaluate 
        with gfn.off_policy(): 
            fcs = compute_fcs(gfn, create_env_func, config)
            wandb.log({'fcs': fcs}) 
            summary_fcs.append(fcs) 
            print(fcs) 
            
    wandb.run.summary['fcs'] = summary_fcs 

if __name__ == '__main__': 
    parser = get_argument_parser() 
    config = parser.parse_args() 

    wandb.init(project=WANDB_PROJECT_NAME, tags=
               ['test_led_gfn', 
                   f'led_gfn_{config.use_led_gfn}', f'learn_phi_{config.learn_potential}', f'{config.seed}', f'{config.env}'])
    wandb.config.update(config) 

    main(config) 
