import torch
import wandb
from datetime import  datetime
from Config import Config,SamplerConfig,args_process
from src.gfn.containers.replay_buffer import Replay_Traj
from src.gfn.utils import  validate_dist


from argparse import ArgumentParser
from simple_parsing.helpers.serialization import encode
from tqdm import tqdm, trange
timestamp = datetime.now().strftime('-%Y%m%d_%H%M%S')
parser = ArgumentParser(description='DAG-GFlowNet')
parser.add_argument('--project', default='Hypergrid_test')
# Environment
environment = parser.add_argument_group('Environment')
environment.add_argument('--Env',default='HyperGrid', choices=['HyperGrid'])
environment.add_argument('--height',default=128,type=int)
environment.add_argument('--ndim',default=2,type=int)
environment.add_argument('--R0',default=0.01,type=float)
#Model
optimization = parser.add_argument_group('Method')
optimization.add_argument("--PB_parameterized",default= False)
optimization.add_argument("--lamb",type=float,default=0.99)
optimization.add_argument("--lamda",type=float,default=0.9)
optimization.add_argument("--weighing",default='geometric_within')
optimization.add_argument("--epsilon_decay",type=float,default=0.99)
optimization.add_argument("--epsilon_start",type=float,default=0.0)
optimization.add_argument("--epsilon_end",type=float,default=0.0)
# Optimization
optimization = parser.add_argument_group('Optimization')
optimization.add_argument("--train_mode",default='forward_phrase', choices=['forward_phrase','backward_phrase','two_phrase','MLE_phrase','Pesi_phrase'])
optimization.add_argument('--Loss',default='RLEval', choices=['DB','TB','RL','TRPO','Sub_TB','RLEval'])
optimization.add_argument("--seed", type=int, default=1)
optimization.add_argument("--optim",default={'lr':0.001,'lr_Z':0.1})
optimization.add_argument("--evaloptim",default={'lr_V':0.005})
optimization.add_argument("--log_reward_clip_min",type=float,default=-12.)
optimization.add_argument("--trpo_delta",type=float,default=0.01)
optimization.add_argument("--no_Z",type=bool,default=False)
optimization.add_argument("--GFNModuleConfig",default={'module_name': "NeuralNet",
                                                    'n_hidden_layers': 4,
                                                    'hidden_dim': 256})
optimization.add_argument("--VModuleConfig",default={'module_name': "NeuralNet",
                                                    'n_hidden_layers': 4,
                                                    'hidden_dim': 256})
optimization.add_argument("--batch_size", type=int, default=128)
optimization.add_argument("--n_iterations", type=int, default=2000)
optimization.add_argument("--device_str",default='cpu',choices=['cpu','cuda'])
# Replay buffer
replay = parser.add_argument_group('Replay Buffer')
replay.add_argument("--replay_buffer_size", type=int, default=20*128)
# Miscellaneous
misc = parser.add_argument_group('Miscellaneous')
misc.add_argument("--use_wandb", type=bool, default=False)
misc.add_argument("--validation_interval", type=int, default=10)
misc.add_argument("--validation_samples", type=int,default=10000)
###########################################################################
args = parser.parse_args()
torch.manual_seed(args.seed)
args = args_process(args)
env,parametrization,loss_fn = Config(args)
print(loss_fn)
trajectories_sampler,B_trajectories_sampler,_= SamplerConfig(env,parametrization)
replay_traj=Replay_Traj(env,args.replay_buffer_size)
###########################################################################
if args.Loss in ['TRPO','RL','RLEval']:
    assert args.epsilon_start == 0.0, 'epsilon_start should be 0.0 for on-policy method!'
name = args.Loss + '-B' if args.PB_parameterized else args.Loss + '-U'

if args.use_wandb:
    wandb.init(project=args.project)
    arg_code = encode(args)
    arg_code['R0'] = str(arg_code['R0'])
    arg_code['lamb'] = str(arg_code['lamb'])
    arg_code['lamda'] = str(arg_code['lamda'])
    wandb.config.update(arg_code)

epsilon=args.epsilon_start
states_visited = 0
for i in trange(args.n_iterations):
    training_samples = trajectories_sampler.sample(n_trajectories=args.batch_size)
    replay_traj.add(training_samples)
    ##################################################################################################3
    states_visited += len(  training_samples)
    epsilon = args.epsilon_end + (epsilon - args.epsilon_end) * args.epsilon_decay
    trajectories_sampler.actions_sampler.epsilon = epsilon
    training_samples.to_device(args.device_str)
    to_log = {"states_visited": states_visited}

    if args.train_mode in ['forward_phrase','two_phrase','MLE_phrase','Pesi_phrase']:
        backward_update = args.train_mode == 'forward_phrase' and args.PB_parameterized
        loss = loss_fn.update_model(training_samples, backward_update=backward_update)
        to_log["loss"] = loss.item()

    if args.train_mode in ['MLE_phrase']:
        assert len(parametrization.logit_PB.parameters()), 'Backward policy is not parameterized'
        B_loss = loss_fn.B_MLE(training_samples)

    if args.train_mode in ['Pesi_phrase']:
        assert len(parametrization.logit_PB.parameters()), 'Backward policy is not parameterized'
        for _ in range(8):
            B_training_samples = replay_traj.sample(args.batch_size)
            B_training_samples.to_device(args.device_str)
            B_loss = loss_fn.B_MLE(B_training_samples)
    #########################################
    if args.train_mode in ['backward_phrase','two_phrase']:
        last_states =   training_samples.last_states
        assert len(parametrization.logit_PB.parameters()), 'Backward policy is not parameterized'
        forward_update  = args.train_mode== 'backward_phrase'
        B_training_samples = B_trajectories_sampler.sample(n_trajectories=args.batch_size, states=last_states)
        B_training_samples.to_device(args.device_str)
        B_loss = loss_fn.B_update_model(B_training_samples,forward_update=forward_update)
        to_log["B_loss"]= B_loss.item()
    #
    if args.use_wandb: wandb.log(to_log, step=i)
    if (i+1) % (args.validation_interval*2) == 0 or i==0:
        validation_info,_ = validate_dist(env, parametrization, trajectories_sampler,args.validation_samples,exact=True)
        if args.use_wandb: wandb.log(validation_info, step=i)
        to_log.update(validation_info)
        tqdm.write(f"{i}: {to_log}")
    if (i+1) % (args.validation_interval*10) == 0 and i != 0:
       parametrization.save_state_dict('./scripts', '{}_{}_'.format(name,i))
       if  args.use_wandb:
           artifact = wandb.Artifact('{}-{}'.format(name,timestamp), type='model')
           artifact.add_file('./scripts/{}_{}_logit_PF.pt'.format(name,i))
           wandb.log_artifact(artifact)
if args.use_wandb: wandb.run.name=name+timestamp