import torch 
import wandb 
import matplotlib.pyplot as plt 

import itertools 
import os 
import sys 
sys.path.append('.') 

# gflownet and general utils 
from experiments.experiments_utils import get_argument_parser, \
    create_gfn, create_env, create_log_reward, create_optimizers 

from gfn.gflownet import GFlowNet 
from gfn.utils import compute_fcs, train_step, MockOptimizer 

WANDB_PROJECT_NAME = 'led_gflownets' 
os.environ['WANDB_PROJECT_NAME'] = WANDB_PROJECT_NAME 

def create_gfn(config): 
    # Create policies     
    from gfn.models.reg_graphs import ForwardPolicy, ForwardPolicyLA, BackwardPolicy 
    if config.use_la_gfn: 
        pf = ForwardPolicyLA(1, config.hidden_dim, num_layers=3, device=config.device) 
    else: 
        pf = ForwardPolicy(1, config.hidden_dim, num_layers=3, device=config.device) 
    pb = BackwardPolicy(device=config.device) 
    return GFlowNet(pf, pb, criterion=config.criterion, device=config.device)    

def create_env(config, log_reward): 
    from gfn.gym.reg_graphs import RegGraph 
    return RegGraph(config.source_dir, config.idx, config.batch_size, log_reward, device=config.device)

def create_log_reward(config): 
    from gfn.gym.reg_graphs import LogReward 
    return LogReward(pi=[.1, .9], device=config.device) 

def main(config): 
    torch.set_default_dtype(torch.float64)
    torch.manual_seed(config.seed)
     
    # Instantiate the environment 
    gfn = create_gfn(config) 
    log_reward = create_log_reward(config) 
    create_env_func = lambda: create_env(config, log_reward=log_reward) 

    opt_group, sch_group = create_optimizers(gfn, config) 

    # Train and evaluate the model 
    summary_fcs = list() 
    
    # Evaluate before training 
    with gfn.off_policy(): 
        fcs = compute_fcs(gfn, create_env_func, config)
        wandb.log({'fcs': fcs}) 
        summary_fcs.append(fcs) 
        print(fcs) 

    for _ in range(config.num_steps): 
        # Train 
        gfn.train() 
        train_step(gfn, create_env_func, config, opt_group, sch_group)  

        # Evaluate 
        with gfn.off_policy(): 
            fcs = compute_fcs(gfn, create_env_func, config)
            wandb.log({'fcs': fcs}) 
            summary_fcs.append(fcs) 
            print(fcs) 
            
    wandb.run.summary['fcs'] = summary_fcs 

if __name__ == '__main__': 
    parser = get_argument_parser() 
    parser.add_argument('--idx', type=int, default=1, help='index of the dataset') 
    parser.add_argument('--source_dir', type=str, default='notebooks/graphs/', help='directory of the source')
    config = parser.parse_args() 

    wandb.init(project=WANDB_PROJECT_NAME, tags=
               ['test_reg_graphs', f'{config.seed}', f'{config.idx}'])
    wandb.config.update(config) 

    main(config) 
