import torch 
import wandb 
import matplotlib.pyplot as plt 
import itertools 
import copy 
import argparse 
import sys 
sys.path.append('.') 

# gflownets and general utils 
from gfn.gflownet import GFlowNet, LEDGFlowNet 
from gfn.utils import compute_marginal_dist, MockOptimizer, OptGroup, SchGroup 

def get_argument_parser():
    parser = argparse.ArgumentParser(description="GFlowNet Training Script")

    # GFlowNet architecture and training parameters
    parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension of the policy network")
    parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the policy network")
    parser.add_argument("--epochs", type=int, default=2000, help="Number of training epochs")
    parser.add_argument("--epochs_eval", type=int, default=100, help="Number of epochs for evaluation")
    parser.add_argument('--epochs_per_step', type=int, default=250, help='number of epochs per step')
    parser.add_argument('--num_steps', type=int, default=10, help='number of steps')  
    parser.add_argument("--use_scheduler", action="store_true", help="Use learning rate scheduler")
    parser.add_argument("--criterion", type=str, default="tb", help="Loss function for training", 
                        choices=['tb', 'db', 'cb', 'dbc', 'td', 'regdb', 'subtb', 'fm'])
    parser.add_argument('--off_policy', action='store_true', help='whether to pursue off-policy')
    parser.add_argument("--device", type=str, default="cpu", help="Device to use (cpu or cuda)")
    
    parser.add_argument('--env', type=str, default='sets', help='Target domain', 
                        choices=['sets', 'bags', 'sequences', 'phylogenetics', 'grids']) 

    parser.add_argument('--lamb_reg', type=float, default=1e-19, help='regularization strength') 
    parser.add_argument('--gamma_func', 
                        type=str, choices=['const', 'depth', 'inv_depth', 'learn', 'tb', 'subtb'], default='const', 
                        help='weighting function for detailed balance')  
                        
    # Environment parameters

    # Generic 
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training")
    parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 
    parser.add_argument('--disable_pbar', action='store_true', help='disable tqdm') 

    # Sets 
    parser.add_argument("--set_size", type=int, default=16, help="Number of elements in the set")
    parser.add_argument("--src_size", type=int, default=32, help="Number of source vectors")

    # Phylogenetic inference 
    parser.add_argument('--num_leaves', type=int, default=7, help='number of biological species') 
    parser.add_argument('--num_nb', type=int, default=4, help='number of nucleotides (hypothetical)') 
    parser.add_argument('--num_sites', type=int, default=25, help='number of observed sites') 
    parser.add_argument('--temperature', type=float, default=1., help='temperature of the target') 

    # Sequences
    parser.add_argument('--seq_size', type=int, default=8, help='number of elements in the sequence') 

    # Grid environment 
    parser.add_argument('--grid_size', type=int, default=12, help='grid width') 
    parser.add_argument('--ro', type=float, default=1e-3, help='baseline reward') 

    # Reward and seed
    parser.add_argument("--seed", type=int, default=42, help="Random seed for reward generation")
    parser.add_argument('--fcs_bucket_size', type=int, default=64, help='bucket size for FCS') 
    
    # Visualization parameters
    parser.add_argument("--num_back_traj", type=int, default=8, help="Number of back-trajectories for evaluation")
    parser.add_argument('--use_progress_bar', action='store_true', help='use progress bar') 
    
    # For FL- and LED-GFlowNets 
    parser.add_argument('--use_led_gfn', action='store_true', help='Use LED-GFlowNet instead of GFlowNet') 
    parser.add_argument('--learn_potential', action='store_true', help='Learn the potential function') 
    parser.add_argument('--train_phi_rounds', type=int, default=8, help='number of rounds to train phi') 

    # For LA-GFlowNet 
    parser.add_argument('--use_la_gfn', action='store_true', help='Use Look-Ahead GFlowNet') 

    # For SubTB 
    parser.add_argument('--lamb_subtb', type=float, default=.9, help='exponential weighting for subtb') 
    return parser

def create_optimizers(gfn, config): 
    
    opt_group = OptGroup() 
    sch_group = SchGroup() 

    # Register optimizers 
    opt_group.register(
        torch.optim.Adam([
            {'params': itertools.chain(gfn.pf.parameters(), gfn.pb.parameters()), 'lr': config.lr}, 
            {'params': gfn.log_z, 'lr': config.lr * 1e1},       
        ], weight_decay=1e-4), 'opt_gfn' 
    )

    optimizer_phi = MockOptimizer() 
    if isinstance(gfn, LEDGFlowNet) and config.learn_potential: 
        optimizer_phi = torch.optim.Adam(gfn.phi.parameters(), lr=config.lr, weight_decay=1e-4) 
    opt_group.register(optimizer_phi, 'opt_phi') 

    optimizer_gamma = MockOptimizer() 
    if config.gamma_func == 'learn': 
        optimizer_gamma = torch.optim.Adam(gfn.gamma_func.parameters(), lr=1e-2 * config.lr, weight_decay=1e-4)
    opt_group.register(optimizer_gamma, 'opt_gamma') 

    # Register schedulers 
    sch_group.register(
        torch.optim.lr_scheduler.PolynomialLR(opt_group.opt_gfn, 
            total_iters=config.epochs_per_step * config.num_steps, power=1.), 
        'sch_gfn'                                 
    )
    sch_group.register(
        torch.optim.lr_scheduler.PolynomialLR(opt_group.opt_phi, 
            total_iters=config.epochs_per_step * config.num_steps, power=1.), 
        'sch_phi'
    )
    sch_group.register(
        torch.optim.lr_scheduler.PolynomialLR(opt_group.opt_gamma, 
            total_iters=config.epochs_per_step * config.num_steps, power=1.), 
        'sch_gamma'
    )

    if not config.use_scheduler: 
        for name in sch_group.schedulers: 
            sch_group.schedulers[name] = MockOptimizer() 

    return opt_group, sch_group 

def create_pol(config): 
    pt = None 
    match config.env: 
        case 'sets' | 'bags': 
            from gfn.models.sets import ForwardPolicy, BackwardPolicy, \
                LearnablePotential, FixedPotential, ForwardPolicyLA 
            if config.use_la_gfn: 
                pf = ForwardPolicyLA(config.src_size, config.hidden_dim,  config.num_layers, device=config.device) 
            else: 
                pf = ForwardPolicy(config.src_size, config.hidden_dim, config.num_layers, device=config.device) 
            pb = BackwardPolicy(config.device) 
            if config.learn_potential:
                pt = LearnablePotential(config.src_size, config.hidden_dim, config.num_layers, config.device) 
            else: 
                pt = FixedPotential() 
        case 'sequences': 
            from gfn.models.sequences import ForwardPolicy, ForwardPolicyLA, BackwardPolicy
            if config.use_la_gfn: 
                pf = ForwardPolicyLA(config.seq_size, config.src_size, 
                        config.hidden_dim, config.num_layers, device=config.device) 
            else: 
                pf = ForwardPolicy(config.seq_size, config.src_size, config.hidden_dim, config.num_layers, device=config.device)
            pb = BackwardPolicy(config.device) 
        case 'phylogenetics': 
            from gfn.models.phylogenetics import ForwardPolicyMLP, ForwardPolicyLA, BackwardPolicy
            if config.use_la_gfn: 
                pf = ForwardPolicyLA(config.hidden_dim, config.num_leaves, device=config.device) 
            else:  
                pf = ForwardPolicyMLP(config.hidden_dim, config.num_leaves, device=config.device) 
            pb = BackwardPolicy(config.device) 
        case 'grids': 
            from gfn.models.grids import ForwardPolicy, BackwardPolicy 
            if config.use_la_gfn: 
                raise NotImplementedError('la-gfn not implemented for hypergrid') 
            pf = ForwardPolicy(config.hidden_dim, device=config.device) 
            pb = BackwardPolicy(device=config.device) 
        case _: 
            raise Exception(f'env: {config.env}')  
    
    return pf, pb, pt 

def create_gamma_func(config): 
    from gfn.utils import GammaFuncConst 

    match config.env: 
        case 'sets' | 'bags': 
            from gfn.models.sets import GammaFuncDepth, GammaFuncInvDepth, LearnableGamma 
        case 'sequences': 
            from gfn.models.sequences import GammaFuncDepth, GammaFuncInvDepth, LearnableGamma 
        case 'phylogenetics': 
            from gfn.models.phylogenetics import GammaFuncDepth, GammaFuncInvDepth, LearnableGamma 
        case 'grids': 
            from gfn.models.grids import GammaFuncDepth, GammaFuncInvDepth, LearnableGamma 

    total_iters = config.epochs_per_step * config.num_steps 
    match config.gamma_func: 
        case 'const': 
            gamma_func = GammaFuncConst(total_iters=total_iters) 
        case 'depth': 
            gamma_func = GammaFuncDepth(total_iters=total_iters) 
        case 'inv_depth': 
            gamma_func = GammaFuncInvDepth(total_iters=total_iters) 
        case 'learn': 
            match config.env:  
                case 'sets' | 'bags': 
                    gamma_func = LearnableGamma(2*config.src_size, config.hidden_dim, total_iters, config.device) 
                case 'sequences': 
                    gamma_func = LearnableGamma(2*(config.seq_size+1), config.hidden_dim, total_iters, config.device) 
                case 'phylogenetics': 
                    gamma_func = LearnableGamma(config.num_leaves, config.hidden_dim, total_iters, config.device) 
                case 'grids': 
                    gamma_func = LearnableGamma(2, config.hidden_dim, total_iters, config.device)
        case _: 
            gamma_func = None  
    
    return gamma_func 

def create_gfn(config): 
    pf, pb, pt = create_pol(config) 
    gamma_func = create_gamma_func(config) 

    if config.use_led_gfn: 
        return LEDGFlowNet(pf, pb, pt, .1, criterion=config.criterion, device=config.device)
    else: 
        return GFlowNet(pf, pb, gamma_func=gamma_func, lamb_reg=config.lamb_reg, 
                        lamb_subtb=config.lamb_subtb, criterion=config.criterion, device=config.device)  

def create_env(config, log_reward=None): 
    match config.env: 
        case 'sets': 
            from gfn.gym.sets import Set 
            return Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
        case 'bags': 
            from gfn.gym.sets import Bag  
            return Bag(config.src_size, config.set_size, config.batch_size, log_reward, config.device) 
        case 'sequences': 
            from gfn.gym.sequences import Sequences
            return Sequences(config.seq_size, config.src_size, config.batch_size, log_reward, device=config.device) 
        case 'phylogenetics': 
            from gfn.gym.phylogenetics import Trees 
            return Trees(config.num_leaves, config.batch_size, log_reward, device=config.device) 
        case 'grids': 
            from gfn.gym.grids import Grid2D 
            return Grid2D(config.grid_size, config.grid_size, config.batch_size, log_reward, config.device) 
        
def create_log_reward(config, gflownet): 
    match config.env: 
        case 'sets' | 'bags': 
            from gfn.gym.sets import Set, LogReward 
            log_reward = LogReward(config.src_size, config.seed, device=config.device) 
            sets = Set(config.src_size, config.set_size, config.batch_size, log_reward, device=config.device) 
            sets = gflownet.sample(sets) 
            log_reward.shift = sets.log_reward().max()  
            return log_reward 
        case 'sequences': 
            from gfn.gym.sequences import LogReward 
            return LogReward(config.seq_size, config.seq_size, config.seed, device=config.device) 
        case 'phylogenetics': 
            from gfn.gym.phylogenetics import Trees, LogReward 
            tree = Trees(config.num_leaves, batch_size=1, log_reward=None, device=config.device) 
            with gflownet.off_policy(): 
                tree = gflownet.sample(tree, seed=42) 
            # Simulate JC69 
            Q = 5e-1 * torch.ones((config.num_nb, config.num_nb), device=config.device) 
            Q[torch.arange(config.num_nb), torch.arange(config.num_nb)] -= Q.sum(dim=-1) 
            pi = torch.ones((config.num_nb,), device=config.device) / config.num_nb   
            sites = Trees.sample_from_phylogeny(tree, Q, config.num_sites, pi, device=config.device)
            sites = sites[:, :config.num_leaves] 
            # Tree's likelihood using Felsenstein's algorithm 
            log_reward = LogReward(pi, sites, Q, config.temperature) 
            # Shift the reward for enhanced numerical stability 
            env = Trees(config.num_leaves, batch_size=config.batch_size, log_reward=log_reward, device=config.device) 
            values = gflownet.sample(env).log_reward() 
            log_reward.shift = values.max() 
            return log_reward 
        case 'grids': 
            from gfn.gym.grids import LogRewardSparse 
            return LogRewardSparse(device=config.device)  
