import logging
import multiprocessing
import os
import pickle
import queue
import random
import threading
import time
import traceback
from copy import deepcopy
from pathlib import Path
from typing import Any

import numpy as np
import torch
import wandb
from seml.experiment import Experiment

from structured_llmuq.data.build import build_dataset
from structured_llmuq.model.build import build_model
from structured_llmuq.qa_orchestrator import Orchestrator
from structured_llmuq.utils.latent_encoder import build_latent_encoder

cache_lock = multiprocessing.Lock()

experiment = Experiment()


@experiment.automain
def main(
    # Define your configuration parameters here
    dataset: dict[str, Any],
    model: dict[str, Any],
    latent_encoder: dict[str, Any],
    output: dict[str, Any],
    orchestrator: dict[str, Any],
    seed: int,  # seml automatically assigns a random seed
    experiment: int,
    verbose: str = "DEBUG",
    environment_variables: dict[str, str] | None = None,
    wandb_dir: str = ".",
):
    """Run a full experiment pipeline.

    Args:
        dataset (dict): Configuration for the dataset constructor.
        model (dict): Model configuration passed to :func:`build_model`.
        latent_encoder (dict): Latent encoder configuration passed to :func:`build_latent_encoder`.
        output (dict): Output configuration containing the base ``path`` for saving results.
        seed (int): Random seed provided by seml to ensure reproducibility.
        experiment (int): Experiment identifier attached to the saved results.

    Returns:
        None: The function runs side effects only (training/evaluation, logging, and serialization).
    """
    logging.basicConfig(
        level=getattr(logging, verbose.upper(), logging.INFO), force=True
    )
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if environment_variables:
        for key, value in environment_variables.items():
            os.environ[key] = value

    os.environ["TORCHDYNAMO_DISABLE"] = "1"  # Gemma 3 caching error

    # Define your main function here
    logging.info(
        f"Running experiment with dataset={dataset}, model={model}, seed={seed}, output={output}, latent_encoder={latent_encoder} "
    )
    wandb_conf = {}
    wandb_conf["experiment"] = experiment
    wandb_conf["dataset"] = dataset
    wandb_conf["model"] = model
    wandb_conf["latent_encoder"] = latent_encoder
    wandb_conf["output"] = output
    wandb_conf["seed"] = seed

    dataset = build_dataset(
        dataset_name=dataset["dataset_name"], limit=dataset.get("limit")
    )

    model_conf = deepcopy(model)
    logging.info("Dataset loaded")
    model = build_model(**model_conf)
    logging.info("Model loaded")

    latent_encoder_conf = deepcopy(latent_encoder)
    latent_encoder_conf["config"]["token"] = model_conf["token"]
    latent_encoder = build_latent_encoder(**latent_encoder_conf)
    logging.info("Latent encoder built")

    orchestrator = Orchestrator(
        dataset=dataset,
        model=model,
        latent_encoder=latent_encoder,
        config=orchestrator,
    )

    # Create a global queue to hold run_sample results
    results_queue = queue.Queue()
    computed_metrics = {}  # Shared dictionary to store metrics

    # Maps strings to latents and does post-processing
    def metric_worker(orchestrator, results_queue, computed_metrics):
        """Consume generated samples and compute latent-space metrics.

        Args:
            orchestrator (Orchestrator): Orchestrator used to map samples to latent space.
            results_queue (queue.Queue): Queue containing tuples of ``(sample_id, run_result)``.
            computed_metrics (dict): Shared mapping from sample id to computed metrics.
        """
        try:
            while True:
                sample_id, result = results_queue.get()
                if result is None:  # Sentinel value to signal termination
                    break

                # Save the computed metric to the shared dictionary
                with cache_lock:
                    try:
                        computed_metrics[sample_id] = orchestrator.map_to_latent_space(
                            result
                        )
                        logging.info(f"Computed metric for sample {sample_id}")
                    except Exception as e:
                        logging.error(
                            f"Error computing metric for sample {sample_id}: {e}"
                        )
                        computed_metrics[sample_id] = {
                            "error": str(e),
                            "traceback": traceback.format_exc(),
                        }
                results_queue.task_done()
        except Exception as e:
            logging.error(f"Metric worker encountered an error: {e}")
            # Print the stack trace
            traceback.print_exc()
            os._exit(1)

    # Start the worker thread
    worker_thread = threading.Thread(
        target=metric_worker, args=(orchestrator, results_queue, computed_metrics)
    )
    worker_thread.start()

    results = {}
    avg_time = 0
    for i in range(len(dataset)):
        if not worker_thread.is_alive():
            if worker_thread.exitcode != 0:
                raise RuntimeError(
                    f"Worker crashed (exitcode={worker_thread.exitcode})"
                )
            exit(1)

        start_time = time.time()
        cpu_rng_state = torch.get_rng_state()  # save states
        gpu_rng_state = torch.cuda.get_rng_state_all()
        results[i] = orchestrator.generate_answers(
            i
        )  # Call the orchestrator to generate answers
        torch.set_rng_state(cpu_rng_state)  # set back states for new sample
        torch.cuda.set_rng_state_all(gpu_rng_state)
        results_queue.put((i, results[i]))
        end_time = time.time()
        avg_time += end_time - start_time

        avg_sec = avg_time / (i + 1)
        est_sec = avg_sec * (len(dataset) - i - 1)

        avg_min = avg_sec / 60
        avg_hour = avg_sec / 3600

        est_min = est_sec / 60
        est_hour = est_sec / 3600

        logging.info(
            f"Average time taken for samples so far: {avg_sec:.2f} sec, {avg_min:.2f} min, {avg_hour:.2f} hr"
        )
        logging.info(
            f"Estimated remaining time: {est_sec:.2f} sec, {est_min:.2f} min, {est_hour:.2f} hr"
        )

    # Signal the worker thread to exit
    results_queue.put((None, None))
    worker_thread.join()
    results = [
        results[i].to_dict() | {"metrics": computed_metrics[i]}
        for i in range(len(results))
    ]

    logging.info(f"Logging results to wandb")
    # logging to wandb
    try:
        wandb.init(
            project="llm_uq",
            config=wandb_conf,
            dir=wandb_dir,
        )

        run_name = wandb.run.name
        wandb.finish()
    except Exception as e:
        logging.info(f"Wandb error {e}")
        traceback.print_exc()

    path = output["path"]
    path = path + f"experiment_{experiment}/" + run_name + "/"
    if not os.path.exists(path):
        os.makedirs(path)

    file = path + "results.pkl"
    with open(file, "wb") as f:
        pickle.dump(results, f, protocol=pickle.HIGHEST_PROTOCOL)

    logging.info(f"Results saved to {file}")

    return file
