from scripts.utils import compute_loss_all_batches
from scripts.dataLoader import ParseData
from scripts.dataLoader_bball import ParseData as BBallData 
from tqdm import tqdm
import argparse
import numpy as np
from random import SystemRandom
import torch
import torch.optim as optim
import scripts.utils as utils
from torch.distributions.normal import Normal
from scripts.create_latent_ode_model import create_LatentODE_model
import os
import sys
# from scripts.dataLoader_penalize import  ParseData as PenalizeData
from scripts.dataLoader_motion import ParseData as MotionData
sys.path.append("/DIR/Projects/AttentionNet/")
sys.path.append("/DIR/Projects/AttentionNet/scripts")

BASE_DIRECTORY = "/DIR/Projects/AttentionNet/"
DATA_DIRECTORY = BASE_DIRECTORY + "data/"
COMMON_CONFIG = "/DIR/Projects/AttentionNet/configs/common/common.cfg"

# Generative model for noisy data based on ODE
parser = argparse.ArgumentParser('Latent ODE')
parser.add_argument('--config', type=str, default=None,
                    help='What simulation to generate.')
args = parser.parse_args()
args =  utils.set_args_from_config(args, args.config)
# args = utils.set_args_from_config(args, COMMON_CONFIG)
print(args)

# for section_name in cfg.sections():
#     # loop over all keys in the section
#     for key in cfg[section_name]:
#         # get the value for the key
#         value = cfg[section_name][key]
#         # check if the value can be converted to an integer
#         try:
#             value = int(value)
#         except ValueError: 
#             try :
#                 value = float(value)
#             except ValueError:
#                 try:
#                     value = str(value)
#                     if value == "":
#                         value = None
#                 except ValueError:
#                     pass
               
        
#         setattr(args, key, value)



args.total_balls = args.n_balls + args.hide_balls
args.logging_dir = BASE_DIRECTORY + "Experiments_neurips_motion/" +  \
    str(args.exp_no) + "_" + args.data + "_" + str(args.total_balls) + \
    "_" + str(args.hide_balls) + "/"
assert(int(args.rec_dims % args.n_heads) == 0)
args.save =  args.logging_dir+ args.save
if args.data == "spring":
    # args.dataset = '/DIR/Projects/AttentionNet/data_files/springs20_hide_10'
    # # args.dataset = "/home/cpslab/HemantWorkspace/2023/AttentionNet/data/Experiment1_5_4"
    args.suffix = '_springs' + str(args.total_balls)
    args.data_load_suffix = "_springs" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "charged":
    args.suffix = '_charged' + str(args.total_balls)
    args.data_load_suffix = "_charged" + str(args.total_balls)
    args.total_ode_step = 60
elif args.data == "motion":
    # args.dataset = 'data/motion'
    args.suffix = 'motion'
    args.total_ode_step = 49
    args.data_load_suffix = "motion" 
    # args.n_balls = 31

elif args.data == "bball":
    args.dataset = '/DIR/Projects/AttentionNet/data_files/basketball'
    args.suffix = '_bball'
    args.total_ode_step = 49
    args.data_load_suffix = "_bball" + str(args.total_balls)
  


############ CPU AND GPU related, Mode related, Dataset Related
if torch.cuda.is_available():
	print("Using GPU" + "-"*80)
	device = torch.device("cuda:" + str(args.gpu))
else:
	print("Using CPU" + "-" * 80)
	device = torch.device("cpu")

if args.extrap == "True":
    print("Running extrap mode" + "-"*80)
    args.mode = "extrap"
elif args.extrap == "False":
    print("Running interp mode" + "-" * 80)
    args.mode = "interp"


if __name__ == '__main__':
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)

    ############ Saving Path and Preload.
    file_name = os.path.basename(__file__)[:-3]  # run_models
    utils.makedirs(args.save)
    utils.makedirs(args.save_graph)
    print(args.load)
    # print(args.load_experimentID)
    if args.exp_id is not None:
        experimentID = int(args.exp_id)
    else:
        experimentID = None
    if experimentID is None:
        # Make a new experiment ID
        experimentID = int(SystemRandom().random() * 100000)

    ############ Data Loader
    print("Loading dataset: " + args.dataset)

    if args.data == "bball":
        dataloader = BBallData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    if args.penalize:
        dataloader = PenalizeData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    elif args.data == "motion":
        dataloader = MotionData(args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    else:
        dataloader = ParseData(
            args.dataset, suffix=args.suffix, mode=args.mode, args=args)
    test_encoder, test_decoder, test_graph, test_batch, _, _ = dataloader.load_data(sample_percent=args.sample_percent_test,
                                                                                    batch_size=args.batch_size,
                                                                                    data_type="train")
    train_encoder, train_decoder, train_graph, train_batch, _, _ = dataloader.load_data(
        sample_percent=args.sample_percent_train, batch_size=args.batch_size, data_type="test")

    input_dim = dataloader.feature  # TODO: feature dimension

    input_command = sys.argv
    ind = [i for i in range(len(input_command))
           if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind + 2):]
    input_command = " ".join(input_command)

    obsrv_std = args.obsrv_std
    print("obsrv_std: ", obsrv_std)
    obsrv_std = torch.Tensor([obsrv_std]).to(device)
    z0_prior = Normal(torch.Tensor([0.0]).to(
        device), torch.Tensor([1.]).to(device))

    model = create_LatentODE_model(
        args, input_dim, z0_prior, obsrv_std, device)

    # Load checkpoint and evaluate the model
    if args.load is not None:
        ckpt_path = os.path.join(args.save, args.load)
        utils.get_ckpt_model(ckpt_path, model, device)
        #exit()

    ##################################################################
    # Training
    log_dir = args.logging_dir + "logs/"
    log_path = log_dir + args.alias + "_" + args.z0_encoder + "_" + args.data + "_" + \
        str(args.sample_percent_train) + "_" + \
        args.mode + "_" + str(experimentID) + ".log"
    if not os.path.exists(log_dir):
        utils.makedirs(log_dir)
    logger = utils.get_logger(
        logpath=log_path, filepath=os.path.abspath(__file__))
    logger.info(input_command)
    logger.info(str(args))
    logger.info(args.alias)

    # Optimizer
    if args.optimizer == "AdamW":
        optimizer = optim.AdamW(
            model.parameters(), lr=args.lr, weight_decay=args.l2)
    elif args.optimizer == "Adam":
        optimizer = optim.Adam(
            model.parameters(), lr=args.lr, weight_decay=args.l2)

    if args.scheduler == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, args.niters, eta_min=1e-9)
    elif args.scheduler == "step":
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.step_lr_size, gamma=args.step_lr_gamma)



    wait_until_kl_inc = 10
    best_test_mse = np.inf
    n_iters_to_viz = 1

    def train_single_batch(model, batch_dict_encoder, batch_dict_decoder, batch_dict_graph, kl_coef, weights_type = None):
        optimizer.zero_grad()
        train_res = model.compute_all_losses(
            batch_dict_encoder, batch_dict_decoder, batch_dict_graph, n_traj_samples=3, kl_coef=kl_coef, weights_type=weights_type)
        # print("we did it boys!! ")
        loss = train_res["loss"]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        optimizer.step()
        

        loss_value = loss.data.item()

        del loss
        torch.cuda.empty_cache()
        # train_res, loss
        return loss_value, train_res["mse"], train_res["likelihood"], train_res["kl_first_p"], train_res["std_first_p"]

    def train_epoch(epo, args):
        model.train()
        loss_list = []
        mse_list = []
        likelihood_list = []
        kl_first_p_list = []
        std_first_p_list = []

        torch.cuda.empty_cache()
        for itr in tqdm(range(train_batch)):

            #utils.update_learning_rate(optimizer, decay_rate=0.999, lowest=args.lr / 10)
            wait_until_kl_inc = 10

            if itr < wait_until_kl_inc:
                kl_coef = 0.
            else:
                kl_coef = (1 - 0.99 ** (itr - wait_until_kl_inc))

            batch_dict_encoder = utils.get_next_batch_new(
                train_encoder, device)

            batch_dict_graph = utils.get_next_batch_new(train_graph, device)

            batch_dict_decoder = utils.get_next_batch(train_decoder, device)

            loss, mse, likelihood, kl_first_p, std_first_p = train_single_batch(
                model, batch_dict_encoder, batch_dict_decoder, batch_dict_graph, kl_coef, weights_type = args.weights_type)

            #saving results
            loss_list.append(loss), mse_list.append(mse), likelihood_list.append(
                likelihood)
            kl_first_p_list.append(
                kl_first_p), std_first_p_list.append(std_first_p)

            del batch_dict_encoder, batch_dict_graph, batch_dict_decoder
            #train_res, loss
            torch.cuda.empty_cache()

        scheduler.step()
        for param_group in optimizer.param_groups:
            print(param_group['lr'])

        message_train = 'Epoch {:04d} [Train seq (cond on sampled tp)] | Loss {:.6f} | MSE {:.6F} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
            epo,
            np.mean(loss_list), np.mean(mse_list), np.mean(likelihood_list),
            np.mean(kl_first_p_list), np.mean(std_first_p_list))

        return message_train, kl_coef

    for epo in range(1, args.niters + 1):

        message_train, kl_coef = train_epoch(epo, args)
        ckpt_path = os.path.join(args.save, "experiment_TRAIN_" + str(
                    experimentID) + "_" + args.z0_encoder + "_" + args.data + "_" + str(
                    args.sample_percent_train) + "_" + args.mode + "_epoch_" + str(epo) + "_mse_" + str(
                    best_test_mse) + '.ckpt')
        torch.save({
            'args': args,
            'state_dict': model.state_dict(),
        }, ckpt_path)

        if epo % n_iters_to_viz == 0:
            model.eval()
            test_res = compute_loss_all_batches(model, test_encoder, test_graph, test_decoder,
                                                n_batches=test_batch, device=device,
                                                n_traj_samples=3, kl_coef=kl_coef, weights_type = args.weights_type)

            message_test = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | MSE {:.6F} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(
                epo,
                test_res["loss"], test_res["mse"], test_res["likelihood"],
                test_res["kl_first_p"], test_res["std_first_p"])

            logger.info("Experiment " + str(experimentID))
            logger.info(message_train)
            logger.info(message_test)
            logger.info("KL coef: {}".format(kl_coef))
            print("data: %s, encoder: %s, sample: %s, mode:%s" % (
                args.data, args.z0_encoder, str(args.sample_percent_train), args.mode))

            if test_res["mse"] < best_test_mse:
                best_test_mse = test_res["mse"]
                message_best = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Best mse {:.6f}|'.format(epo,
                                                                                                        best_test_mse)
                logger.info(message_best)
                ckpt_path = os.path.join(args.save, "experiment_" + str(
                    experimentID) + "_" + args.z0_encoder + "_" + args.data + "_" + str(
                    args.sample_percent_train) + "_" + args.mode + "_epoch_" + str(epo) + "_mse_" + str(
                    best_test_mse) + '.ckpt')
                torch.save({
                    'args': args,
                    'state_dict': model.state_dict(),
                }, ckpt_path)

            torch.cuda.empty_cache()
