from scripts_no_temporal_encoding.utils import compute_loss_all_batches
from scripts_no_temporal_encoding.dataLoader import ParseData
# from scripts.dataLoader_bball import ParseData as BBallData 
from tqdm import tqdm
import argparse
import numpy as np
from random import SystemRandom
import torch
import torch.optim as optim
import scripts_no_temporal_encoding.utils as utils
from torch.distributions.normal import Normal
from scripts_no_temporal_encoding.create_latent_ode_model import create_LatentODE_model
import os
import sys
# from scripts.dataLoader_penalize import  ParseData as PenalizeData
# from scripts.dataLoader_motion import ParseData as MotionData
# from scripts.dataLoader_kuramoto import ParseData as KuramotoData
sys.path.append("/DIR/Projects/AttentionNet/")
sys.path.append("/DIR/Projects/AttentionNet/scripts_no_temporal_encoding")

BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
DATA_DIRECTORY = BASE_DIRECTORY + "data/"
COMMON_CONFIG = "/DIR/Projects/AttentionNet/configs/common/common.cfg"

# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('Latent ODE')
parser.add_argument('--config', type=str, default=None,
                    help='What simulation to generate.')
args = parser.parse_args()
args =  utils.set_args_from_config(args, args.config)
# args = utils.set_args_from_config(args, COMMON_CONFIG)
print(args)

# for section_name in cfg.sections():
#     # loop over all keys in the section
#     for key in cfg[section_name]:
#         # get the value for the key
#         value = cfg[section_name][key]
#         # check if the value can be converted to an integer
#         try:
#             value = int(value)
#         except ValueError: 
#             try :
#                 value = float(value)
#             except ValueError:
#                 try:
#                     value = str(value)
#                     if value == "":
#                         value = None
#                 except ValueError:
#                     pass
               
        
#         setattr(args, key, value)



args.total_balls = args.n_balls + args.hide_balls
args.logging_dir = BASE_DIRECTORY + "neurips_no_temporal_encoding_experiments/" +  \
    str(args.exp_no) + "_" + args.data + "_" + str(args.total_balls) + \
    "_" + str(args.hide_balls) + "/"
assert(int(args.rec_dims % args.n_heads) == 0)
args.save =  args.logging_dir+ args.save
if args.data == "spring":
    # args.dataset = '/DIR/Projects/AttentionNet/data_files/springs20_hide_10'
    # # args.dataset = "/home/cpslab/HemantWorkspace/2023/AttentionNet/data/Experiment1_5_4"
    args.suffix = '_springs' + str(args.total_balls)
    args.data_load_suffix = "_springs" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "charged":
    args.suffix = '_charged' + str(args.total_balls)
    args.data_load_suffix = "_charged" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "motion":
    # args.dataset = 'data/motion'
    args.suffix = 'motion'
    args.total_ode_step = 49
    args.data_load_suffix = "motion" 
elif args.data == "kuramoto":
    args.suffix = '_kuramoto5'
    # args.n_balls = 31

elif args.data == "bball":
    args.dataset = '/DIR/Projects/AttentionNet/data_files/basketball'
    args.suffix = '_bball'
    args.total_ode_step = 49
    args.data_load_suffix = "_bball" + str(args.total_balls)
  


############ CPU AND GPU related, Mode related, Dataset Related
if torch.cuda.is_available():
	print("Using GPU" + "-"*80)
	device = torch.device("cuda:" + str(args.gpu))
else:
	print("Using CPU" + "-" * 80)
	device = torch.device("cpu")

if args.extrap == "True":
    print("Running extrap mode" + "-"*80)
    args.mode = "extrap"
elif args.extrap == "False":
    print("Running interp mode" + "-" * 80)
    args.mode = "interp"


if __name__ == '__main__':
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    ############ Saving Path and Preload.
    file_name = os.path.basename(__file__)[:-3]  # run_models
    utils.makedirs(args.save)
    utils.makedirs(args.save_graph)
    print(args.load)
    # print(args.load_experimentID)
    if args.exp_id is not None:
        experimentID = int(args.exp_id)
    else:
        experimentID = None
    if experimentID is None:
        # Make a new experiment ID
        experimentID = int(SystemRandom().random() * 100000)

    ############ Data Loader
    print("Loading dataset: " + args.dataset)
    print(args.data)
    if args.data == "bball":
        print("Loading bball dataset")
        dataloader = BBallData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.penalize:
        dataloader = PenalizeData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.data == "motion":
        dataloader = MotionData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.data == "kuramoto":
        dataloader = KuramotoData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)

    else:
        dataloader = ParseData(
            args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    test_encoder, test_decoder, test_graph, test_batch, _, _ = dataloader.load_data(sample_percent=args.sample_percent_test,
                                                                                    batch_size=args.batch_size,
                                                                                    data_type="test")
    train_encoder, train_decoder, train_graph, train_batch, _, _ = dataloader.load_data(
        sample_percent=args.sample_percent_train, batch_size=args.batch_size, data_type="train")

    input_dim = dataloader.feature  # TODO: feature dimension

    input_command = sys.argv
    ind = [i for i in range(len(input_command))
           if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind + 2):]
    input_command = " ".join(input_command)

    obsrv_std = args.obsrv_std
    print("obsrv_std: ", obsrv_std)
    obsrv_std = torch.Tensor([obsrv_std]).to(device)
    z0_prior = Normal(torch.Tensor([0.0]).to(
        device), torch.Tensor([1.]).to(device))

    model = create_LatentODE_model(
        args, input_dim, z0_prior, obsrv_std, device)

    # Load checkpoint and evaluate the model
    if args.load is not None:
        ckpt_path = os.path.join(args.save, args.load)
        utils.get_ckpt_model(ckpt_path, model, device)
        #exit()

    ##################################################################
    # Training
    log_dir = args.logging_dir + "logs/"
    log_path = log_dir + args.alias + "_" + args.z0_encoder + "_" + args.data + "_" + \
        str(args.sample_percent_train) + "_" + \
        args.mode + "_" + str(experimentID) + ".log"
    if not os.path.exists(log_dir):
        utils.makedirs(log_dir)
    logger = utils.get_logger(
        logpath=log_path, filepath=os.path.abspath(__file__))
    logger.info(input_command)
    logger.info(str(args))
    logger.info(args.alias)

    # Optimizer
    if args.optimizer == "AdamW":
        optimizer = optim.AdamW(
            model.parameters(), lr=args.lr, weight_decay=args.l2)
    elif args.optimizer == "Adam":
        optimizer = optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.l2)

    if args.scheduler == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.niters, eta_min=1e-9)
    elif args.scheduler == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.step_lr_size, gamma=args.step_lr_gamma)



    wait_until_kl_inc = 10
    best_test_mse = np.inf
    n_iters_to_viz = 1

    def train_single_batch(model, batch_dict_encoder, batch_dict_decoder, batch_dict_graph, kl_coef, weights_type = None):
        optimizer.zero_grad()
        train_res = model.compute_all_losses(
            batch_dict_encoder, batch_dict_decoder, batch_dict_graph, n_traj_samples=3, kl_coef=kl_coef, weights_type=weights_type)
        # print("we did it boys!! ")
        loss = train_res["loss"]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        optimizer.step()
        

        loss_value = loss.data.item()

        del loss
        torch.cuda.empty_cache()
        # train_res, loss
        return loss_value, train_res["mse"], train_res["likelihood"], train_res["kl_first_p"], train_res["std_first_p"]

    def train_epoch(epo, args):
        model.train()
        loss_list = []
        mse_list = []
        likelihood_list = []
        kl_first_p_list = []
        std_first_p_list = []

        torch.cuda.empty_cache()
        for itr in tqdm(range(train_batch)):

            #utils.update_learning_rate(optimizer, decay_rate=0.999, lowest=args.lr / 10)
            wait_until_kl_inc = 10

            if itr < wait_until_kl_inc:
                kl_coef = 0.
            else:
                kl_coef = (1 - 0.99 ** (itr - wait_until_kl_inc))

            batch_dict_encoder = utils.get_next_batch_new(
                train_encoder, device)

            batch_dict_graph = utils.get_next_batch_new(train_graph, device)

            batch_dict_decoder = utils.get_next_batch(train_decoder, device)

            loss, mse, likelihood, kl_first_p, std_first_p = train_single_batch(
                model, batch_dict_encoder, batch_dict_decoder, batch_dict_graph, kl_coef, weights_type = args.weights_type)

            #saving results
            loss_list.append(loss), mse_list.append(mse), likelihood_list.append(
                likelihood)
            kl_first_p_list.append(
                kl_first_p), std_first_p_list.append(std_first_p)

            del batch_dict_encoder, batch_dict_graph, batch_dict_decoder
            #train_res, loss
            torch.cuda.empty_cache()

        scheduler.step()
        for param_group in optimizer.param_groups:
            print(param_group['lr'])

        message_train = 'Epoch {:04d} [Train seq (cond on sampled tp)] | Loss {:.6f} | MSE {:.6F} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
            epo,
            np.mean(loss_list), np.mean(mse_list), np.mean(likelihood_list),
            np.mean(kl_first_p_list), np.mean(std_first_p_list))

        return message_train, kl_coef

    for epo in range(1, args.niters + 1):

        message_train, kl_coef = train_epoch(epo, args)
        ckpt_path = os.path.join(args.save, "experiment_TRAIN_" + str(
                    experimentID) + "_" + args.z0_encoder + "_" + args.data + "_" + str(
                    args.sample_percent_train) + "_" + args.mode + "_epoch_" + str(epo) + "_mse_" + str(
                    best_test_mse) + '.ckpt')
        torch.save({
            'args': args,
            'state_dict': model.state_dict(),
        }, ckpt_path)

        if epo % n_iters_to_viz == 0:
            model.eval()
            test_res = compute_loss_all_batches(model, test_encoder, test_graph, test_decoder,
                                                n_batches=test_batch, device=device,
                                                n_traj_samples=3, kl_coef=kl_coef, weights_type = args.weights_type)

            message_test = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | MSE {:.6F} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
                epo,
                test_res["loss"], test_res["mse"], test_res["likelihood"],
                test_res["kl_first_p"], test_res["std_first_p"])

            logger.info("Experiment " + str(experimentID))
            logger.info(message_train)
            logger.info(message_test)
            logger.info("KL coef: {}".format(kl_coef))
            print("data: %s, encoder: %s, sample: %s, mode:%s" % (
                args.data, args.z0_encoder, str(args.sample_percent_train), args.mode))

            if test_res["mse"] < best_test_mse:
                best_test_mse = test_res["mse"]
                message_best = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Best mse {:.6f}|'.format(epo,
                                                                                                        best_test_mse)
                logger.info(message_best)
                ckpt_path = os.path.join(args.save, "experiment_" + str(
                    experimentID) + "_" + args.z0_encoder + "_" + args.data + "_" + str(
                    args.sample_percent_train) + "_" + args.mode + "_epoch_" + str(epo) + "_mse_" + str(
                    best_test_mse) + '.ckpt')
                torch.save({
                    'args': args,
                    'state_dict': model.state_dict(),
                }, ckpt_path)

            torch.cuda.empty_cache()
