import torch 
import wandb 
import matplotlib.pyplot as plt 
import itertools 
import copy 
import argparse 
import sys 
import numpy as np 
sys.path.append('.') 

# gflownets and general utils 
from sal.gflownet import GFlowNet, SALGFlowNet, EPGFlowNet
from sal.utils import MockOptimizer, OptGroup, SchGroup 

def get_argument_parser():
    parser = argparse.ArgumentParser(description="GFlowNet Training Script")

    # GFlowNet architecture and training parameters
    parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension of the policy network")
    parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the policy network")
    parser.add_argument("--epochs", type=int, default=2000, help="Number of training epochs")
    parser.add_argument("--epochs_eval", type=int, default=100, help="Number of epochs for evaluation")
    parser.add_argument('--epochs_per_step', type=int, default=250, help='number of epochs per step')
    parser.add_argument('--num_steps', type=int, default=10, help='number of steps')  
    parser.add_argument("--use_scheduler", action="store_true", help="Use learning rate scheduler")
    parser.add_argument("--criterion", type=str, default="tb", help="Loss function for training", 
                        choices=['tb', 'db', 'cb', 'dbc', 'td', 'regdb', 'subtb', 'fm', 'atb'])
    parser.add_argument('--off_policy', action='store_true', help='whether to pursue off-policy')
    parser.add_argument("--device", type=str, default="cpu", help="Device to use (cpu or cuda)")
    
    parser.add_argument('--env', type=str, default='sets', help='Target domain', 
                        choices=['sets', 'bags', 'sequences', 'grids']) 

    parser.add_argument('--lamb_reg', type=float, default=1e-19, help='regularization strength') 
    parser.add_argument('--gamma_func', 
                        type=str, choices=['const', 'depth', 'inv_depth', 'learn', 'tb', 'subtb'], default='const', 
                        help='weighting function for detailed balance')  
        
    # Chi-squared divergence 
    parser.add_argument('--compute_chi_squared_div', action='store_true', help='chi^2 divergences experiments')

    # Environment parameters

    # list and queue of modes
    parser.add_argument('--q', type=float, default=.9, help='quantile of the reward for mode')  
    parser.add_argument('--th', type=float, default=-np.log(2), help='threshold of the reward for being a mode')
    # Generic 
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")
    parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 
    parser.add_argument('--disable_pbar', action='store_true', help='disable tqdm') 
    parser.add_argument('--delta', type=float, default=5e-2, help='confidence for PAC bounds') 
    parser.add_argument('--topk', type=int, default=int(1e2), help='top-k rewards') 
    
    # Sets 
    parser.add_argument("--set_size", type=int, default=16, help="Number of elements in the set")
    parser.add_argument("--src_size", type=int, default=32, help="Number of source vectors")
    parser.add_argument('--force_mask_idx', type=int, default=None, help='Mask out action this action on training') 
    parser.add_argument('--hide_important_region', action='store_true', help='Hide high-probability regions') 

    # Phylogenetic inference 
    parser.add_argument('--num_leaves', type=int, default=7, help='number of biological species') 
    parser.add_argument('--num_nb', type=int, default=4, help='number of nucleotides (hypothetical)') 
    parser.add_argument('--num_sites', type=int, default=25, help='number of observed sites') 
    parser.add_argument('--temperature', type=float, default=1., help='temperature of the target') 

    # Sequences
    parser.add_argument('--seq_size', type=int, default=8, help='number of elements in the sequence') 
    parser.add_argument('--vocab_size', type=int, default=1, help='vocab size appended to the sequence') 
    parser.add_argument('--reward', type=str, default='additive', 
                        choices=['additive', 'tfbind8', 'tfbind10', 'bits'], help='reward type') 
    parser.add_argument('--reward_exp', type=float, default=3., help='reward exponent for tfbind tasks') 
    parser.add_argument('--num_modes', type=int, default=78, help='number of modes for the bit sequence generation')

    # Grid environment 
    # parser.add_argument('--grid_size', type=int, default=12, help='grid width') 
    # parser.add_argument('--ro', type=float, default=1e-3, help='baseline reward') 
    parser.add_argument('--H', type=int, default=8, help='hypergrid size') 
    parser.add_argument('--dim', type=int, default=2, help='hypergrid dimensionality') 

    # Reward and seed
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reward generation")
    parser.add_argument('--fcs_bucket_size', type=int, default=64, help='bucket size for FCS') 
    
    # Visualization parameters
    parser.add_argument("--num_back_traj", type=int, default=8, help="Number of back-trajectories for evaluation")
    parser.add_argument('--use_progress_bar', action='store_true', help='use progress bar') 
    
    # For ReplayBuffer 
    parser.add_argument('--use_replay_buffer', action='store_true', help='use replay buffer') 
    parser.add_argument('--replay_buffer_size', type=int, default=2048, help='Size of the replay buffer') 
    parser.add_argument('--replay_buffer_freq', type=float, default=-1, help='Frequency of sampling from RB') 

    # For FL- and LED-GFlowNets 
    parser.add_argument('--use_led_gfn', action='store_true', help='Use LED-GFlowNet instead of GFlowNet') 
    parser.add_argument('--learn_potential', action='store_true', help='Learn the potential function') 
    parser.add_argument('--train_phi_rounds', type=int, default=8, help='number of rounds to train phi') 

    # For LA-GFlowNet 
    parser.add_argument('--use_la_gfn', action='store_true', help='Use Look-Ahead GFlowNet') 

    # For SubTB 
    parser.add_argument('--lamb_subtb', type=float, default=.9, help='exponential weighting for subtb') 
    return parser

def create_optimizers(gfn, config): 
    
    opt_group = OptGroup() 
    sch_group = SchGroup() 

    # Register optimizers 
    opt_group.register(
        torch.optim.Adam([
            {'params': itertools.chain(gfn.pf.parameters(), gfn.pb.parameters()), 'lr': config.lr}, 
            {'params': gfn.log_z, 'lr': config.lr * 1e2},       
        ], weight_decay=1e-4), 'opt_gfn' 
    )

    optimizer_phi = MockOptimizer() 
    opt_group.register(optimizer_phi, 'opt_phi') 

    optimizer_gamma = MockOptimizer() 
    if config.gamma_func == 'learn': 
        optimizer_gamma = torch.optim.Adam(gfn.gamma_func.parameters(), lr=1e-2 * config.lr, weight_decay=1e-4)
    opt_group.register(optimizer_gamma, 'opt_gamma') 

    # Register schedulers 
    sch_group.register(
        torch.optim.lr_scheduler.PolynomialLR(opt_group.opt_gfn, 
            total_iters=config.epochs_per_step * config.num_steps, power=1.), 
        'sch_gfn'                                 
    )
    sch_group.register(
        torch.optim.lr_scheduler.PolynomialLR(opt_group.opt_phi, 
            total_iters=config.epochs_per_step * config.num_steps, power=1.), 
        'sch_phi'
    )
    sch_group.register(
        torch.optim.lr_scheduler.PolynomialLR(opt_group.opt_gamma, 
            total_iters=config.epochs_per_step * config.num_steps, power=1.), 
        'sch_gamma'
    )

    if not config.use_scheduler: 
        for name in sch_group.schedulers: 
            sch_group.schedulers[name] = MockOptimizer() 

    return opt_group, sch_group 

def create_pol(config, is_sal): 
    match config.env: 
        case 'sets' | 'bags': 
            from sal.models.sets import ForwardPolicy, BackwardPolicy, ForwardPolicyLA 
            if config.use_la_gfn: 
                pf = ForwardPolicyLA(config.src_size, config.hidden_dim,  config.num_layers, device=config.device) 
            else: 
                pf = ForwardPolicy(config.src_size, config.hidden_dim, 
                                   config.num_layers, force_mask_idx=config.force_mask_idx, device=config.device,
                                   compute_chi_squared_div=config.compute_chi_squared_div) 
            pb = BackwardPolicy(config.device) 
        case 'sequences': 
            from sal.models.sequences import ForwardPolicy, ForwardPolicyLA, BackwardPolicy # , ForwardPolicySAL 
            if config.use_la_gfn: 
                assert config.vocab_size == 1 
                pf = ForwardPolicyLA(config.seq_size, config.src_size, 
                        config.hidden_dim, config.num_layers, device=config.device) 
            else: 
                pf = ForwardPolicy(config.seq_size, config.src_size ** config.vocab_size, config.hidden_dim, config.num_layers, device=config.device)

            # if is_sal: 
            #     pf = ForwardPolicySAL(config.seq_size, config.src_size, config.hidden_dim, config.num_layers, device=config.device) 
            pb = BackwardPolicy(config.device) 
        case 'grids': 
            from sal.models.hypergrids import ForwardPolicy, BackwardPolicy 
            if config.use_la_gfn: 
                raise NotImplementedError('la-gfn not implemented for hypergrid') 
            pf = ForwardPolicy(config.dim, config.hidden_dim, config.num_layers, config.device) 
            pb = BackwardPolicy(config.device) 
        case _: 
            raise Exception(f'env: {config.env}')  
    
    return pf, pb

def create_gfn(config, is_sal: bool = False, is_ep: bool = False): 
    log_pf, log_pb = create_pol(config, is_sal) 
    if is_sal:
        return SALGFlowNet(log_pf, log_pb, criterion=config.criterion, device=config.device) 
    if is_ep: 
        return EPGFlowNet(log_pf, log_pb, device=config.device)
    else:
        return GFlowNet(log_pf, log_pb, lamb_subtb=config.lamb_subtb, 
                    criterion=config.criterion, device=config.device)  

def create_env(config, log_reward=None): 
    match config.env: 
        case 'sets': 
            from sal.gym.sets import Set 
            return Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
        case 'bags': 
            from sal.gym.sets import Bag  
            return Bag(config.src_size, config.set_size, config.batch_size, log_reward, config.device) 
        case 'sequences': 
            from sal.gym.sequences import Sequences, SequencesVocabSize 
            if config.vocab_size <= 1: 
                return Sequences(config.seq_size, config.src_size, config.batch_size, log_reward, device=config.device) 
            else: 
                return SequencesVocabSize(config.seq_size, config.src_size, config.vocab_size,
                                          config.batch_size, log_reward, device=config.device) 
        case 'grids': 
            from sal.gym.hypergrids import Hypergrid 
            return Hypergrid(config.dim, config.H, config.batch_size, log_reward, device=config.device) 
        
def create_log_reward(config, gflownet): 
    match config.env: 
        case 'sets' | 'bags': 
            from sal.gym.sets import Set, LogReward 
            log_reward = LogReward(config.src_size, config.seed, config, 
                                   temperature=config.temperature, device=config.device) 
            sets = Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
            sets = gflownet.sample(sets) 
            log_reward.shift = sets.log_reward().max()  
            return log_reward 
        case 'sequences': 
            from sal.gym.sequences import LogReward, LogRewardBits, LogRewardTFN 
            match config.reward: 
                case 'additive': 
                    return LogReward(config.seq_size, config.seq_size, config.seed, device=config.device) 
                case 'tfbind8':
                    assert config.seq_size == 8 and config.src_size == 4  
                    return LogRewardTFN(8, exp=config.reward_exp, device=config.device)
                case 'tfbind10': 
                    assert config.seq_size == 10 and config.src_size == 4 
                    return LogRewardTFN(10, exp=config.reward_exp, device=config.device) 
                case 'bits': 
                    assert config.src_size == 2 
                    return LogRewardBits(config.src_size, config.seq_size, config.num_modes, device=config.device)
        case 'grids': 
            from sal.gym.hypergrids import LogReward 
            return LogReward()  
