from config import get_config
from train_eval import train, eval

import torch
import json
import argparse
import ml_collections
from loguru import logger
import os
import numpy as np
import random

torch.set_printoptions(profile="full")

config_dir = "./configs"

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    logger.info(f"Random seed set as {seed}")

def get_config(config_json):
    if config_json is not None:
        logger.info(f"Reading config from JSON: {config_json}")
        with open(config_json, 'r') as f:
            config = ml_collections.ConfigDict(json.loads(f.read()))
    else:
        config = get_config()
    return config

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_json", type=str, default=None, help="Path to config json file.")
    parser.add_argument("--ckpt_dir", type=str, default='./checkpoint', help="Path to folder to save checkpoints.")
    args = parser.parse_args()
    config = get_config(args.config_json)

    with open(args.config_json, 'w') as file:
        file.write(config.to_json_best_effort(sort_keys=True, indent=4) + '\n')
    logger.info(f"Config JSON file saved.")

    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    logger.info(f"Using device: {device}")

    set_seed(config.seed)

    if not config.eval_only:
        trained_model = train(config, args.ckpt_dir)

    eval(config, args.ckpt_dir)
