import torch 
import wandb 
import math 
import tqdm 
import pprint 
import os 
import json 

import sys 
sys.path.extend(['..', '.']) 

from sal.experiments_utils import get_argument_parser, create_gfn, create_env, create_log_reward  
from sal.pac_utils import (
    TrajectoriesDataLoader, 
    create_bayesian_gfn, 
    train_gfn_epoch, 
    train_bayesian_gfn_epoch, 
    get_dataset 
) 

WANDB_PROJECT_NAME = 'subgraph-asyhchronous-learning' 
SAVE_LOCAL_DATA = True 

def train_base_gfn(config, trajectories, train_dataloader, **bayesian_kwargs): 
    gfn = create_bayesian_gfn(config, **bayesian_kwargs) 
    # gfn = create_gfn(config) 
    optim = torch.optim.SGD(
        gfn.parameters(), lr=1e-3, momentum=config.momentum 
    )
    # optim = torch.optim.Adam(
    #     [{'params': gfn.pf.parameters(), 'lr': 1e-3}, {'params': gfn.log_z, 'lr': 1e-1}] 
    # )

    with torch.inference_mode(): 
        gfn.eval() 
        fcs_before_training = gfn.evaluate_fcs_on_trajectories(
            trajectories, bucket_size=32 
        ).detach().cpu().item() 
        gfn.train() 
    
    pbar = tqdm.tqdm(range(config.epochs)) 
    for _ in pbar: 
        running_loss_avg = train_gfn_epoch(
            gfn, optim, train_dataloader, pbar  
        )
        if running_loss_avg < 1e-4:
            break 
        
    with torch.inference_mode(): 
        gfn.eval() 
        fcs_after_training = gfn.evaluate_fcs_on_trajectories(
            trajectories, bucket_size=32, perturb_params=False 
        ).detach().cpu().item() 
        gfn.train() 

    return fcs_before_training, fcs_after_training 

def get_partitioned_data(alpha, num_trajectories): 
    trajectories_alpha, log_reward = get_dataset(config, int(alpha * num_trajectories))

    train_loader_alpha = TrajectoriesDataLoader(
        trajectories_alpha, batch_size=config.batch_size, shuffle=True  
    )

    trajectories_non_alpha, _ = get_dataset(config, num_trajectories - int(alpha * num_trajectories), log_reward) 
    train_loader_non_alpha = TrajectoriesDataLoader(
        trajectories_non_alpha, batch_size=config.batch_size, shuffle=True 
    )

    train_loader_full = train_loader_alpha.merge(train_loader_non_alpha) 

    assert train_loader_full.num_trajs == (train_loader_alpha.num_trajs + train_loader_non_alpha.num_trajs), (
        train_loader_full.num_trajs, train_loader_alpha.num_trajs, train_loader_non_alpha.num_trajs
    )

    return (
        train_loader_full, train_loader_alpha, train_loader_non_alpha 
    )

def main(config): 

    device = torch.device(config.device if torch.cuda.is_available() else 'cpu')  
    pp = pprint.PrettyPrinter() 
    torch.manual_seed(config.seed) 
    
    log = dict() 

    bayesian_kwargs = {
        'prior_stddev': math.sqrt(1e-6),
        'optimize_prior_mean': False,
        'optimize_prior_rho': False,
        'optimize_posterior_mean': True,
        'optimize_posterior_rho': True,
    }

    # Train the model on the full data 
    trajectories, _ = get_dataset(config, num_trajectories=config.num_trajectories) 
    train_dataloader = TrajectoriesDataLoader(trajectories, batch_size=128, shuffle=True) 
    fcs_before_training, fcs_after_training = train_base_gfn(
        config, trajectories, train_dataloader, **bayesian_kwargs
    )   

    log['fcs_before_training'] = fcs_before_training 
    log['fcs_after_training'] = fcs_after_training   

    # Evaluate Bayesian GFlowNet 
    alpha = config.alpha 
    num_trajectories = config.num_trajectories 
    train_loader_full, train_loader_alpha, train_loader_non_alpha = get_partitioned_data(
        alpha, num_trajectories 
    )
    log['alpha'] = alpha 
    log['num_trajectories'] = num_trajectories 

    bayesian_kwargs = {
        'prior_stddev': math.sqrt(1e-6),
        'optimize_prior_mean': False,
        'optimize_prior_rho': False,
        'optimize_posterior_mean': True,    # Today's posterior is tomorrow's prior; 
                                            # we will use this posterior as a prior for non-alpha data
        'optimize_posterior_rho': False, 
    }

    # Train the prior distribution's mean 
    prior_gfn = create_bayesian_gfn(config, **bayesian_kwargs) 
    # lr = 1e-3 seems to be consistent 
    optim = torch.optim.SGD(prior_gfn.parameters(), lr=1e-3, momentum=config.momentum) 
    # optim = torch.optim.Adam(
    #     [{'params': prior_gfn.pf.parameters(), 'lr': 1e-3}, 
    #     {'params': prior_gfn.log_z, 'lr': 1e-1}]  
    # )
    
    pbar = tqdm.tqdm(range(config.epochs))
    for _ in pbar: 
        train_gfn_epoch( 
            prior_gfn, optim, train_loader_alpha, pbar=pbar  
        )
    
    # Train the posterior distribution's mean and variance
    bayesian_kwargs = {
        'prior_stddev': math.sqrt(1e-6),
        'optimize_prior_mean': False,
        'optimize_prior_rho': False,
        'optimize_posterior_mean': True,    # Today's posterior is tomorrow's prior; 
                                            # we will use this posterior as a prior for non-alpha data
        'optimize_posterior_rho': True, 
    }

    posterior_gfn = create_bayesian_gfn(config, **bayesian_kwargs) 
    posterior_gfn.pf.update_prior_mean(prior_gfn.pf.mlp_logit_posterior) 

    optim = torch.optim.SGD(posterior_gfn.parameters(), lr=1e-3, momentum=config.momentum) 
    # optim = torch.optim.Adam(
    #     [{'params': posterior_gfn.pf.parameters(), 'lr': 1e-3},
    #     {'params': posterior_gfn.log_z, 'lr': 1e-1}] 
    # )
    pbar = tqdm.tqdm(range(config.epochs//2))  
    for _ in pbar: 
        train_bayesian_gfn_epoch(
            gfn=posterior_gfn, train_dataloader=train_loader_full, 
            optim=optim, config=config, pbar=pbar 
        )

    # Evaluate the (bounded) risk 
    with torch.inference_mode(): 
        posterior_gfn.eval() 
        risk = posterior_gfn.evaluate_fcs_on_trajectories(
            train_loader_non_alpha.trajectories, bucket_size=32, num_samples=32, perturb_params=True 
        )
        log['fcs_bayesian'] = risk.cpu().item() 
        risk_bound = posterior_gfn.pf.inverted_kl_bound(
            risk, posterior_gfn.pf.kl(), train_loader_non_alpha.num_trajs, delta=config.delta
        )
        log['fcs_bayesian_bound'] = risk_bound.cpu().item() 

    pp.pprint(log) 

    if wandb.run is not None: 
        wandb.run.summary['log'] = log 

    # pass 

if __name__ == '__main__': 
    parser = get_argument_parser()  
    parser.add_argument('--num_trajectories', type=int, default=int(5e3), help='number of trajectories')
    parser.add_argument('--momentum', type=float, default=.9, help='nesterov momentum for SGD') 
    parser.add_argument('--alpha', type=float, default=.6, help='data partitioning')  
    config = parser.parse_args() 
    wandb.init(project=WANDB_PROJECT_NAME, tags=['bayesian-learning', f'{config.env}', f'{config.reward}']) 
    wandb.config.update(config) 
    main(config) 
