import torch
import wandb
from datetime import  datetime
from Config import Config,SamplerConfig
from src.gfn.containers.replay_buffer import ReplayBuffer
from src.gfn.utils import trajectories_to_training_samples, validate


from argparse import ArgumentParser,BooleanOptionalAction
from simple_parsing.helpers.serialization import encode
from tqdm import tqdm, trange
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
parser = ArgumentParser(description='DAG-GFlowNet')
parser.add_argument('--feature', default=True, action=BooleanOptionalAction)
# Environment
environment = parser.add_argument_group('Environment')
environment.add_argument('--Env',default='HyperGrid', choices=['HyperGrid','DiscreteEBM','BayesianNetwork'])
environment.add_argument('--height',default=256,type=int)
environment.add_argument('--ndim',default=2,type=int)
environment.add_argument('--R0',default=0.01,type=float)
environment.add_argument('--alpha',default=1.0,type=float)
# Optimization
optimization = parser.add_argument_group('Optimization')
optimization.add_argument('--Q_estimation',type=bool,default=False)
optimization.add_argument('--TD_estimation',type=bool,default=False)
optimization.add_argument('--Loss',default='RL', choices=['DB','TB','RL','TRPO'])
optimization.add_argument("--seed", type=int, default=1)
optimization.add_argument("--optim",default={'lr':0.01,'lr_Z':0.1,'lr_V':0.005})
optimization.add_argument("--GFNModuleConfig",default={'module_name': "NeuralNet",
                                                    'n_hidden_layers': 4,
                                                    'hidden_dim': 256})
optimization.add_argument("--PB_parameterized",default=False)
optimization.add_argument("--PG_used",default=False)
optimization.add_argument("--batch_size", type=int, default=128)
optimization.add_argument("--n_iterations", type=int, default=2000)
optimization.add_argument("--device_str",default='cpu',choices=['cpu','cuda'])
optimization.add_argument("--cuda",type=bool,default=False)
# Replay buffer
replay = parser.add_argument_group('Replay Buffer')
replay.add_argument("--replay_buffer_size", type=int, default=0)
# Miscellaneous
misc = parser.add_argument_group('Miscellaneous')
misc.add_argument("--use_wandb", type=bool, default=False)
misc.add_argument("--validation_interval", type=int, default=10)
misc.add_argument("--validation_samples", type=int,default=1000)
args = parser.parse_args()
#torch.manual_seed(args.seed)
args.device_str="cuda" if torch.cuda.is_available() and args.cuda else "cpu"
args.PB_parameterized=True if args.PG_used else  args.PB_parameterized
env,parametrization,loss_fn=Config(args)
trajectories_sampler,B_trajectories_sampler=SamplerConfig(env,parametrization)
if args.replay_buffer_size > 0 :
    replay_buffer = ReplayBuffer(env, loss_fn, capacity=args.replay_buffer_size)
else:
    replay_buffer=  None

if args.use_wandb:
    wandb.init(project='xxxx')
    wandb.config.update(encode(args))

epsilon_start=1.
epsilon_end=0.1
epsilon_decay=0.95
epsilon=epsilon_start
states_visited = 0
for i in trange(args.n_iterations):
    trajectories = trajectories_sampler.sample(n_trajectories=args.batch_size)
    training_samples = trajectories_to_training_samples(trajectories, loss_fn)
    training_last_states=training_samples.last_states
    states_visited += len(trajectories)
    if replay_buffer is not None:
        replay_buffer.add(training_samples)
        training_samples = replay_buffer.sample(n_trajectories=args.batch_size)
    #
    training_samples.to_device(args.device_str)
    if args.Loss=='TRPO':
        loss=loss_fn.trpo_update_model(training_samples)
    else:
        loss=loss_fn.update_model(training_samples)
    to_log = {"loss": loss.item(), "states_visited": states_visited}

    if args.use_wandb: wandb.log(to_log, step=i)
    if len(parametrization.logit_PB.parameters()) and args.Loss in ['RL','TRPO']:
        B_trajectories = B_trajectories_sampler.sample(n_trajectories=128, states=training_last_states)
        B_training_samples = trajectories_to_training_samples(B_trajectories,loss_fn)
     #   B_training_samples.to_device(args.device_str)
        B_loss = loss_fn.B_update_model(B_training_samples)
        to_log = {"loss": B_loss.item()}
        if args.use_wandb: wandb.log(to_log, step=i)
    if (i+1) % args.validation_interval == 0 and i!=0:
        validation_info = validate(env, parametrization, trajectories_sampler,args.validation_samples)
        if args.use_wandb:
            wandb.log(validation_info, step=i)
        to_log.update(validation_info)
        tqdm.write(f"{i}: {to_log}")
    if (i+1) % (args.validation_interval*10) == 0 and i != 0:
        #parametrization.save_state_dict('./scripts','{}_{}_{}_{}_'.format(args.Loss,env.height,env.R0,i))
        parametrization.save_state_dict('./scripts', '{}_{}_{}_'.format(args.Loss, env.alpha, i))
        if  args.use_wandb:
            #artifact = wandb.Artifact('{}_{}_{}_{}'.format(args.Loss,env.height,env.R0,timestamp), type='model')
            #artifact.add_file('./scripts/{}_{}_{}_{}_logit_PF.pt'.format(args.Loss,env.height,env.R0,i))
            #artifact.add_file('./scripts/{}_{}_{}_{}_logit_PB.pt'.format(args.Loss,env.height,env.R0,i))
            #artifact.add_file('./scripts/{}_{}_{}_{}_logZ.pt'.format(args.Loss,env.height,env.R0,i))
            artifact = wandb.Artifact('{}_{}_{}'.format(args.Loss,env.alpha,timestamp), type='model')
            artifact.add_file('./scripts/{}_{}_{}_logit_PF.pt'.format(args.Loss,env.alpha,i))
            artifact.add_file('./scripts/{}_{}_{}_logit_PB.pt'.format(args.Loss,env.alpha,i))
            artifact.add_file('./scripts/{}_{}_{}_logZ.pt'.format(args.Loss,env.alpha,i))
            wandb.log_artifact(artifact)
