import torch 
import wandb 
import time 
import pprint  

import sys 
sys.path.append('.') 

from sal.experiments_utils import create_log_reward, create_gfn, create_env, get_argument_parser 
from sal.utils import TopKQueue, ModesList, ReplayBuffer, plot_topk_num_modes_histogram 
from sal.sal_utils import (
    train_step, 
    eval_step_sequences, 
    eval_step_hypergrids, 
    eval_step_sets, 
    get_env, 
    create_log_reward_model  
)

WANDB_PROJECT_NAME = 'subgraph-asynchronous-learning'  

def eval_step(log, config, gfn, create_env, filename=None, **kwargs): 
    match config.env: 
        case 'grids': 
            return eval_step_hypergrids(log, gfn, create_env, filename, **kwargs) 
        case 'sequences': 
            return eval_step_sequences(log, gfn, create_env, filename, **kwargs) 
        case 'sets': 
            return eval_step_sets(log, gfn, create_env, filename, **kwargs) 
      
def main(config): 
    pp = pprint.PrettyPrinter() 
    torch.set_default_dtype(torch.float64)
    torch.manual_seed(config.seed)
    #  device = torch.device(config.device) 
    
    log = {
        'env': config.env, 
        'reward': config.reward, 
        'num_models': config.num_models, 
        'max_depth': config.max_depth,
        'mode_th': config.q, 
        'seed': config.seed, 
        'epochs_eval': config.epochs_eval, 
        'elapsed_time_clients': list(),
        'elapsed_time_sal': None, 
        'elapsed_time_std': None, 
        'topk_std': None, 
        'topk_sal': None, 
        'num_modes_std': None, 
        'num_modes_sal': None,
        'loss_std': None,  
        'loss_sal': None, 
        'loss_clients': list(),   
        # environment-specific 
        'tv_std': None, 
        'tv_sal': None, 
        'grid_learned_std': None, 
        'grid_learned_sal': None, 
        'grid_target': None, 
    }

    log_reward = create_log_reward(config, create_gfn(config))  
    gfn = create_gfn(config, is_sal=True) 
    replay_buffer = None 

    if config.use_replay_buffer: 
        replay_buffer = ReplayBuffer(config.replay_buffer_size, device=config.device) 

    create_env_func = lambda: create_env(config, log_reward) 
    
    t = time.time()   
    # multiply_epochs_by = 2 for most, but 4 for hypergrid; but use 1.5 for figure 4 (8 x 8 hypergrid)  
    topk_queue, mode_lst, log['loss_std'] = train_step(config, gfn, create_env_func, 
                                                       multiply_epochs_by=2, replay_buffer=replay_buffer,
                                                       max_time=2*config.max_time if config.max_time is not None else None) 
    log['elapsed_time_std'] = time.time() - t 

    log['num_modes_std'] = len(mode_lst)
    log['topk_std'] = topk_queue.stats() 
    
    eval_step(log, config, gfn, create_env_func, 
              filename='eval_std.pdf' if config.save_figures else None)  
    
    models = list() 

    config.th = mode_lst.th 
    mode_lst = ModesList.create_mode_lst(config, warmup=False) 
    topk_queue = TopKQueue.create_topk_queue(config) 
    
    for model_idx in range(config.num_models): 
        gfn = create_gfn(config, is_sal=True) 
        create_env_func = lambda: get_env(config).create_env_on_depth(config, log_reward, model_idx)

        if config.use_replay_buffer: 
            replay_buffer = ReplayBuffer(config.replay_buffer_size, device=config.device) 

        t = time.time() 
        log['loss_clients'].append(
            train_step(config, gfn, create_env_func, mode_lst=mode_lst, 
                       replay_buffer=replay_buffer, 
                       topk_queue=topk_queue, topk_value=log_reward, 
                       max_time=config.max_time if config.max_time is not None else None)[-1] 
        ) 
        log['elapsed_time_clients'].append(time.time() - t) 
    
        models.append(gfn) 

    log_reward_model = create_log_reward_model(config, log_reward) 
    gfn = create_gfn(config, is_sal=True) 

    create_env_func = lambda: get_env(config).create_env_maximum_depth(config, log_reward_model)

    if config.use_replay_buffer: 
        replay_buffer = ReplayBuffer(config.replay_buffer_size, device=config.device)
    t = time.time() 
    log['loss_sal'] = train_step(config, gfn, create_env_func, gflownets=models, 
                                 replay_buffer=replay_buffer,  
                                 mode_lst=mode_lst, topk_queue=topk_queue, topk_value=log_reward,
                                 max_time=config.max_time if config.max_time is not None else None)[-1] 
    log['elapsed_time_sal'] = time.time() - t 

    log['num_modes_sal'] = len(mode_lst)
    log['topk_sal'] = topk_queue.stats() 
        
    create_env_func = lambda: get_env(config).create_env_for_sal(config, log_reward) 

    eval_step(log, config, gfn, create_env_func, 
              filename='eval_sal.pdf' if config.save_figures else None, gflownets=models) 

    # Upload data to wandb 
    if wandb.run is not None and not config.save_figures: 
        wandb.run.summary['log'] = log 

    if config.save_figures:  
        import json 
        json.dump(
            log, open(f'log_{config.env}_{config.reward}.json', 'w')
        )
    else: 
        pp.pprint(log)

if __name__ == '__main__': 
    parser = get_argument_parser() 
    parser.add_argument(
        '--max_depth', type=int, default=6
    )
    parser.add_argument(
        '--num_models', type=int, default=10 
    )
    parser.add_argument(
        '--save_figures', action='store_true' 
    )
    parser.add_argument(
        '--perform_time_analysis', action='store_true' 
    )
    parser.add_argument(
        '--max_time', type=float, default=None 
    )
    config = parser.parse_args() 
    if config.env == 'grids': 
        assert config.num_models < config.H 
    wandb.init(project=WANDB_PROJECT_NAME, tags=[
        'asynchronous-learning', f'{config.env}', f'{config.reward}', f'{config.H}',
        f'time_analysis={config.perform_time_analysis}', f'{config.seed}'
    ]) 
    wandb.config.update(config) 
    main(config) 


