import torch
import wandb 
import matplotlib.pyplot as plt 
import json 
import os 

import sys 
sys.path.append('.') 

# gflownet and general utils 
from experiments.experiments_utils import (
                            get_argument_parser, 
                            create_gfn, 
                            create_env, 
                            create_log_reward,
                            create_opt,  
                            eval_step
                        )
from streaming_gfn.utils import (
                            train_gfn, 
                            LogRewardProduct, 
                        ) 

WANDB_PROJECT_NAME = 'streaming_gfn_vi' 

def main(config): 
    torch.set_default_dtype(torch.float64) 
    torch.manual_seed(config.seed) 
    models = list() 
    posteriors = list() 
    metrics = list() 
    learned_lst = list() 
    tgt_lst = list() 

    # instantiate the gflownet 
    gfn = create_gfn(config) 
    posterior = create_log_reward(config, gfn) 
    
    create_env_func = lambda: create_env(config, posterior) 
    
    # First training 
    opt, sch = create_opt(gfn, config) 
    train_gfn(gfn, create_env_func, opt, sch, config)
    metric, learned, tgt = eval_step(config, gfn, create_env_func, False, return_dist=True) 
    models.append(gfn) 
    posteriors.append(posterior) 
    metrics.append(metric)
    
    # plt.scatter(learned.cpu(), tgt.cpu(), rasterized=True); plt.savefig('sets_init.pdf'); plt.clf() 

    print(f'L1 [1]: {metrics[-1]}') 
    # Subsequent trainings 
    for stm_upd in range(config.num_stm_upd): 
        # Train a streaming gflownet 
        config.seed += 1 
        gfn = create_gfn(config, is_streaming=True) 
        posterior = create_log_reward(config, gfn) 
        create_env_func = lambda: create_env(config, posterior) 
        # update_logz(gfn, models[-1], create_env_func, config)
        opt, sch = create_opt(gfn, config) 
        train_gfn(gfn, create_env_func, opt, sch, config, previous_model=models[-1]) 
        models.append(gfn) 
        posteriors.append(posterior) 
        
        # Evaluate 
        log_reward_product = LogRewardProduct(posteriors) 
        create_env_full = lambda: create_env(config, log_reward_product) 
        metric, learned, tgt = eval_step(config, gfn, create_env_full, False, return_dist=True) 

        metrics.append(metric) 
        learned_lst.append(learned.cpu().tolist()) 
        tgt_lst.append(tgt.cpu().tolist()) 
        # plt.scatter(learned.cpu(), tgt.cpu(), rasterized=True); plt.savefig(f'sets_{stm_upd+1}.pdf'); plt.clf() 
        print(f'L1 [{stm_upd + 2}]: {metrics[-1]}') 
         
    print(metrics) 
    json.dump(
        {
            'metric': [metric.item() for metric in metrics],
            'learned': learned_lst, 
            'tgt': tgt_lst 
        }, 
        open(f'{os.path.basename(__file__)}_{config.criterion}.json', 'w') 
    )
    wandb.run.summary['l1'] = metrics 
    
if __name__ == '__main__': 
    parser = get_argument_parser() 
    config = parser.parse_args() 
    wandb.init(project=WANDB_PROJECT_NAME, tags=
                ['streaming_eval', f'{config.env}', f'{config.alpha}', f'{config.criterion}', f'{config.seed}']) 
    wandb.config.update(config) 
    main(config)  
