# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
A minimal training script for DiT using PyTorch DDP.
"""
import torch
import wandb
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import logging
import os

from dataset.h5_to_pdb import convert
from utility.parse_args import eval_parse_args
from dataset.md_dataset import get_datasets
from model.train_epoch import simulate
from model.get_model import get_model, get_optim

from utility import functions as uf

#################################################################################
#                             inference Helper Functions                         #
#################################################################################

def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    logging.basicConfig(
        level=logging.INFO,
        format='[\033[34m%(asctime)s\033[0m] %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
    )
    logger = logging.getLogger(__name__)
    return logger

#################################################################################
#                                  inference Loop                               #
#################################################################################


def save_trajectory(args, ds_path, pos_predict):
    total_frames = 100
    struct = args.target_struct
    datasetMD = "data/MD/h5_files/MD.hdf5"
    mapdir = 'dataset/Maps/'
    outputdir = f'{args.target_struct}/'
    os.makedirs(outputdir, exist_ok=True)
    for frame in range(total_frames):
        convert(struct, datasetMD, mapdir, ds_path, frame, pos_predict[frame].cpu().numpy())
        convert(struct, datasetMD, mapdir, outputdir, frame)

def main(eval_args):
    ds_path = os.path.join(eval_args.output_path, eval_args.target_struct)
    os.makedirs(ds_path, exist_ok=True)  # Make results folder (holds all experiment subfolders)
    device = 'cuda'
    dtype = torch.float32

    ckpt_path = eval_args.model_path
    assert os.path.isfile(ckpt_path), f'Could not find SMG checkpoint at {ckpt_path}'
    checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    args = checkpoint["args"]
    model_check_point = checkpoint["model_ema"] if args.ema_decay > 0 else checkpoint["model"]

    logger = create_logger(ds_path)

    datasets = get_datasets(args)
    model = get_model(args, device)
    model.load_state_dict(model_check_point)
    model.to(device)
    pos, x = datasets['train'].get_protein(eval_args.target_struct)
    pos_predict, mse, elapsed_time = simulate(args, model, pos, x, device, dtype)
    print(f'Simulate 99 frames, 100 ps per frame, total 10 ns, took {elapsed_time}\nLoss on {eval_args.target_struct}: {mse}')
    logger.info(
        f'Simulate 99 frames, 100 ps per frame, total 10 ns, took {elapsed_time}\nLoss on {eval_args.target_struct}: {mse}')
    save_trajectory(eval_args, ds_path, pos_predict)
    logger.info("Done!")


if __name__ == "__main__":
    args = eval_parse_args()
    main(args)
