import logging
import argparse
import json
from datetime import datetime
import torch
from simulator import sample_backward
import os

current_time = datetime.now().strftime("%Y-%m-%d--%H:%M:%S")

logger = logging.getLogger('LOG')
logger.setLevel(logging.DEBUG)

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter("%(levelname)s - %(message)s")
console_handler.setFormatter(console_formatter)

file_handler = logging.FileHandler(f'logs/main-{current_time}.log')
file_handler.setLevel(logging.DEBUG)
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
file_handler.setFormatter(file_formatter)

logger.addHandler(console_handler)
logger.addHandler(file_handler)


parser = argparse.ArgumentParser(description="Script for sampling from a diffusion model.")
parser.add_argument("--model", type=str, required=True,
                    help="Select model to sample from.")
parser.add_argument("--number-of-samples", type=int, default = 1,
                    help="Specify the number of samples.")
parser.add_argument("--number-of-noise-realizations", type=int, default = 1,
                    help="Specify the number of noise realizations.")
parser.add_argument("--root", type = str, default="./samples/",
                    help = "Root directory to save generated data. "
                    "Data will be saved as a pytorch tensor "
                    "in <root>/<model>/trajectory/forward/<output>.pt,"
                    "<root>/<model>/trajectory/backward/<output>.pt, and"
                    "<root>/<model>/lyap-exp/<output>.pt")
parser.add_argument("--output", type = str, required=True)
parser.add_argument("--T", type = float, default= 0.9,
                    help="The stopping time for the forward process.")
parser.add_argument("--n-grid", type = int, default=1000,
                    help = "Number of discritization points to sample using Euler Maruyama.")
parser.add_argument("--conf-file", default = None,
                    help= "Json file containing the configuration for model <model>.")
parser.add_argument("--perturb-size", type = float, default = 0.0,
                    help= "size of perturbation, if implemented")
parser.add_argument("--no-spectrum", action='store_true',
                    help='Specifies whether to calculate the Lyapunov spectrum after trajectories' \
                        'have been calculated. This is rather slow due to many QR factorizations.')
parser.add_argument("--noise-schedule", type = str, default = None,
                    help='Specifies the noise schedule for sampling.')

def main():
    args = parser.parse_args()

    logger.debug("=============== ARGUMENTS ===============")
    logger.debug(json.dumps(vars(args), indent = 2))

    dir = os.path.dirname(__file__)
    logger.debug('Working directory: ' + dir)
    kwargs = {}
    if args.conf_file:
        conf_file = os.path.join(dir, 'conf', args.conf_file)
        with open(conf_file, mode="r", encoding="utf-8") as file:
            kwargs = json.load(file)
    
    logger.debug(f"kwargs: {json.dumps(kwargs, indent = 2)}")

    bwd_string, le_string, lv_string = process_dir(args.root, 
                                                    args.model, 
                                                    args.output)
    calculate_spectrum = not args.no_spectrum
    trajectory, lexps, lvects = sample_backward(args.model, 
                                args.n_grid, args.T, 
                                args.number_of_samples, 
                                args.number_of_noise_realizations,
                                args.perturb_size,
                                calculate_spectrum,
                                args.noise_schedule,
                                **kwargs)
    
    logger.info(f"Saving backward trajectory in: {bwd_string}")
    torch.save(trajectory, bwd_string)
    if lexps is not None:
        logger.info(f"Saving Lyapunov Exponents in: {le_string}")
        torch.save(lexps, le_string)
    if lvects is not None:
        logger.info(f"Saving Lyapunov Vectors in: {le_string}")
        torch.save(lvects, lv_string)


def process_dir(root, model, output):
    bws = root + 'trajectory/backward/' 
    les = root + 'lyap-exp/' 
    lvs = root + 'lyap-vec/'

    post = output + '.pt'

    os.makedirs(bws, exist_ok=True)
    os.makedirs(les, exist_ok=True)
    os.makedirs(lvs, exist_ok=True)

    return bws + post, les + post, lvs + post

if __name__ == "__main__":
    main()