import argparse
import datetime
import multiprocessing
import os
import pprint
import sys
import time
from concurrent.futures import TimeoutError
import torch
import yaml

from config import config, env_info
from experiment_runners import get_runner


DEFAULT_EXPERIMENT_TIMEOUT = config['defaults']['experiment_timeout']


def _run_wrapper(func, args, kwargs, q):
    try:
        result = func(*args, **kwargs)
        q.put((True, None))
    except Exception as e:
        q.put((False, e))


def run_with_timeout(func, args=(), kwargs={}, timeout=30):
    """run `func` with timeout. Returns (success, result or exception)."""
    q = multiprocessing.Queue()
    p = multiprocessing.Process(target=_run_wrapper, args=(func, args, kwargs, q))
    p.start()
    p.join(timeout)
    if p.is_alive():
        p.terminate()
        p.join()
        return False, TimeoutError("Function timed out")
    return q.get()


def run_experiments_from_config(exp_config, device, output_root=None, job_id=None, run_id=None):
    for experiment in exp_config.get('experiments', []):
        print(f"{'#' * 80}\n{'#' * 80}")
        try:
            run_experiment(experiment, device, output_root, job_id=job_id, run_id=run_id)
        except Exception as e:
            exp_name = experiment.get('name', 'Unnamed Experiment')
            print(f"[ERROR] {datetime.datetime.now()} Unexpected error while running experiment '{exp_name}': {e}")
            continue
        print(f"{'#' * 80}\n{'#' * 80}")


def run_experiment(experiment, device, output_root=None, job_id=None, run_id=None):
    name = experiment.get('name', 'Unnamed Experiment')
    exp_type = experiment.get('type')
    timeout = experiment.get('timeout', DEFAULT_EXPERIMENT_TIMEOUT)
    dataset = experiment.get('dataset', 'mnist')
    parameters = experiment.get('parameters', {})

    if job_id is not None:
        parameters['job_id'] = job_id
    if run_id is not None:
        parameters['run_id'] = run_id

    runner_class = get_runner(exp_type)
    if not runner_class:
        raise ValueError(f"Unknown experiment type: {exp_type}")

    runner = runner_class(device)
    if output_root is not None:
        runner.base_output_dir = output_root
    # Use the base name for setting up the shared directory
    runner.setup_experiment(name, dataset, exp_type, parameters)

    job_specific_name = name
    # If this is a parallel job, make its main log file unique and update its name for logging
    if job_id is not None:
        job_specific_name = f"{name}_job{job_id}"
        runner.name = job_specific_name  # For logging inside the runner
        runner.exp_paths['main_log_file'] = os.path.join(runner.logs_dir, f'main_job_{job_id}.log')

    # redirect stdout/stderr to experiment's main log file
    exp_main_log = runner.exp_paths['main_log_file']
    os.makedirs(os.path.dirname(exp_main_log), exist_ok=True)
    run_header = f"\n\n{'*' * 80}\n{'*' * 80}\n[INFO] main.py run at {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]}\n{env_info}\n{'*' * 80}\n{'*' * 80}\n"
    log_file = open(exp_main_log, "a")
    log_file.write(run_header)
    log_file.flush()

    class Tee:
        def __init__(self, *streams):
            self.streams = streams

        def write(self, msg):
            for s in self.streams:
                s.write(msg)
                s.flush()

        def flush(self):
            for s in self.streams:
                s.flush()

    sys.stdout = Tee(sys.stdout, log_file)
    sys.stderr = Tee(sys.stderr, log_file)

    print("#" * 80)
    print(
        f"[{datetime.datetime.now():%Y-%m-%d %H:%M:%S}] [INFO] Starting experiment: {job_specific_name} | Max-Timeout: {timeout} seconds")
    print(f"[INFO] Type: {exp_type} | Dataset: {dataset}")
    if job_id is not None:
        print(f"[INFO] Job ID: {job_id} / {parameters.get('num_jobs')}")
    print(f"[INFO] Runner class: {runner_class.__name__}")
    print(f"[INFO] Parameters:")
    pprint.pprint(parameters, indent=4)
    print("#" * 80)

    # caping timeout to avoid overflow
    MAX_TIMEOUT = 604800  # 1 week in seconds
    timeout = min(timeout, MAX_TIMEOUT)

    start_time = time.time()
    success, result = run_with_timeout(runner.run, args=(job_specific_name, exp_type, dataset, parameters), timeout=timeout)
    elapsed_time = time.time() - start_time

    if not success:
        if isinstance(result, TimeoutError):
            err_msg = f"[ERROR] Experiment '{job_specific_name}' timed out after {timeout} seconds."
        else:
            err_msg = f"[ERROR] Experiment '{job_specific_name}' failed: {result}"

        print(err_msg)
        with open(runner.conf_file, "a") as log_file:
            log_file.write(err_msg)
    else:
        print(f"***[SUCCESS] Experiment '{job_specific_name}' completed successfully.***")
    print(f"[INFO] Elapsed time for experiment '{job_specific_name}': {elapsed_time:.2f} seconds")


def main():
    multiprocessing.set_start_method("spawn", force=True)
    parser = argparse.ArgumentParser()
    parser.add_argument("--expconf", type=str, default="experiments_config.yaml", help="Path to experiment config YAML")
    parser.add_argument("--output_root", type=str, default=None, help="Root directory for all experiment outputs")
    parser.add_argument("--job_id", type=int, default=None, help="Job ID for parallel execution.")
    parser.add_argument("--run_id", type=str, default=None, help="A unique ID for the run, to ensure shared output directories for parallel jobs.")
    args = parser.parse_args()

    with open(args.expconf, 'r') as f:
        exp_config = yaml.safe_load(f)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Running experiments. Using device: {device}")
    print(f"config: {config}")
    run_experiments_from_config(exp_config, device, args.output_root, args.job_id, args.run_id)
    print(f"[INFO] All experiments completed")


if __name__ == "__main__":
    main()
