import os
import random
from typing import Any, Dict, List, Optional, Tuple

import hydra
import hydra.utils as hutils
import numpy as np
import torch
import torch.nn as nn
import torchrl.envs.transforms as ttf
import yaml
from omegaconf import DictConfig, OmegaConf
from torch.optim.optimizer import Optimizer
from torchrl.collectors import DataCollectorBase
from torchrl.collectors.utils import split_trajectories
from torchrl.data import Composite, UnboundedContinuous
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.envs import ExplorationType, GymEnv, TensorDictPrimer
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.record.loggers.wandb import WandbLogger

import src.async_rl.async_mujoco  # noqa: F401
from src.async_rl.module.noise_control import NoiseControlHook
from src.async_rl.module.rnn_sac import build_rec_sac_module
from src.async_rl.module.sac import SACLossHook, build_sac_module
from src.async_rl.module.trainer import RLTrainer
from src.async_rl.module.trainer.evalrollout_hook import EvalRolloutHook
from src.async_rl.module.trainer.logweights_hook import LogWeightsHook
from src.async_rl.module.trainer.save_hook import SaveHook
from src.async_rl.module.utils import get_primers_robust
from src.async_rl.sweep_utils.sweep import (
	generate_sweep_experiment_configs,
)
from src.async_rl.utils import format_integer


def _get_most_recent_checkpoint(save_dir: str) -> str:
	ckpt_files = [
		os.path.join(save_dir, f)
		for f in os.listdir(save_dir)
		if f.startswith("checkpoint") and f.endswith(".pth")
	]

	if len(ckpt_files) == 0:
		raise FileNotFoundError("No checkpoint files found.")
	if len(ckpt_files) == 1:
		return ckpt_files[0]

	return max(ckpt_files, key=lambda f: int(f.split("_")[-1].split(".")[0]))


def _get_device(cfg: DictConfig) -> torch.device:
	return (
		torch.device(f"cuda:{cfg.device}")
		if isinstance(cfg.device, int)
		else torch.device(cfg.device)
	)


def _get_envs(
	cfg: DictConfig, device: torch.device
) -> Tuple[ttf.TransformedEnv, ttf.TransformedEnv]:
	train_make_kwargs: Dict[str, Any]
	eval_make_kwargs: Dict[str, Any]
	try:
		train_make_kwargs = dict(cfg.env.train_make_kwargs)
	except Exception:
		train_make_kwargs = {}
	try:
		eval_make_kwargs = dict(cfg.env.eval_make_kwargs)
	except Exception:
		eval_make_kwargs = {}

	if hasattr(cfg.env, "period_range") and cfg.env.period_range is not None:
		train_make_kwargs["period_range"] = tuple(cfg.env.period_range)
		eval_make_kwargs["period_range"] = tuple(cfg.env.period_range)

	if (
		hasattr(cfg.env, "gamma_beta_range")
		and cfg.env.gamma_beta_range is not None
	):
		train_make_kwargs["gamma_beta_range"] = tuple(cfg.env.gamma_beta_range)
		eval_make_kwargs["gamma_beta_range"] = tuple(cfg.env.gamma_beta_range)

	env = ttf.TransformedEnv(
		GymEnv(cfg.env.env, device=device, **train_make_kwargs),
		ttf.Compose(
			ttf.InitTracker(),
			ttf.DTypeCastTransform(torch.float64, torch.float32),
			ttf.RewardScaling(loc=0.0, scale=cfg.env.reward_scale),
		),
	)
	eval_env = ttf.TransformedEnv(
		GymEnv(cfg.env.env, device=device, **eval_make_kwargs),
		ttf.Compose(
			ttf.InitTracker(),
			ttf.DTypeCastTransform(torch.float64, torch.float32),
		),
	).eval()

	return env, eval_env


def _get_exp_name(cfg: DictConfig) -> str:
	if cfg.exp_name != "auto":
		return cfg.exp_name
	to_return = f"{cfg.env.name}-{cfg.loss.name}"
	if cfg.name_suffix != "":
		to_return += f"-{cfg.name_suffix}"
	return to_return


def _get_logger(cfg: DictConfig) -> Optional[WandbLogger]:
	if cfg.wandb.enable:
		logger = WandbLogger(
			exp_name=_get_exp_name(cfg),
			offline=cfg.wandb.offline,
			project=cfg.wandb.project,
			save_dir=cfg.wandb.save_dir,
			entity=cfg.wandb.entity,
		)
		logger.log_hparams(cfg)
		return logger
	return None


def _get_primers(
	cfg: DictConfig, loss_module: nn.Module
) -> List[ttf.Transform]:
	to_return = []
	if cfg.loss.type == "rnn-sac":
		to_return.append(get_primers_robust(loss_module.actor_network))
		to_return.append(get_primers_robust(loss_module.qvalue_network))
	if cfg.loss.signals_keys is not None:
		pred_space = Composite(
			{
				f"pred_{key}": UnboundedContinuous(dim)  # type: ignore
				for key, dim in zip(
					cfg.loss.signals_keys,
					cfg.loss.signals_dims,
					strict=True,
				)
			}
		)
		to_return.append(TensorDictPrimer(pred_space))
	return to_return


def _save_config(cfg: DictConfig, path: str) -> None:
	cfg_with_defaults = OmegaConf.create(
		{
			"defaults": [
				"_self_",
				{"override hydra/hydra_logging": "disabled"},
				{"override hydra/job_logging": "disabled"},
			],
			"hydra": {
				"output_subdir": None,
				"run": {"dir": "."},
			},
		}
	)
	cfg_with_defaults = OmegaConf.merge(cfg_with_defaults, cfg)
	with open(path, "w") as f:
		OmegaConf.save(config=cfg_with_defaults, f=f)


def _set_personal_conf(cfg: DictConfig) -> DictConfig:
	with open("config.personal.yaml", "r") as pers_cfg:
		personal_cfg = yaml.safe_load(pers_cfg)
		cfg.wandb.entity = personal_cfg["wandb"]["entity"]
	return cfg


def _set_seed(cfg: DictConfig) -> None:
	if cfg.seed is not None:
		random.seed(cfg.seed)
		np.random.seed(cfg.seed)
		torch.manual_seed(cfg.seed)

		torch.backends.cudnn.deterministic = True
		torch.backends.cudnn.benchmark = False


def _get_available_gpus() -> List[int]:
	if not torch.cuda.is_available():
		return []
	return list(range(torch.cuda.device_count()))


def _generate_seeds(
	num_seeds: int, base_seed: Optional[int] = None
) -> List[int]:
	if base_seed is not None:
		np.random.seed(base_seed)
	return [np.random.randint(0, 2**31 - 1) for _ in range(num_seeds)]


def _run_single_experiment(
	cfg: DictConfig, run_id: int, device: torch.device
) -> Dict[str, Any]:
	OmegaConf.register_new_resolver("concat", lambda x, y: f"{x}{y}")
	OmegaConf.register_new_resolver("as_tuple", lambda x, y: (x, y))

	run_cfg = OmegaConf.create(cfg)
	run_cfg.device = str(device)

	base_exp_name = _get_exp_name(run_cfg)
	if run_cfg.multi_run.enable:
		run_cfg.exp_name = f"{base_exp_name}_run_{run_id:02d}"

	if run_cfg.save_dir is not None:
		run_cfg.save_dir = f"{run_cfg.save_dir}/run_{run_id:02d}"
		os.makedirs(run_cfg.save_dir, exist_ok=True)

	return _run_single(run_cfg)


def _run_single(cfg: DictConfig) -> Dict[str, Any]:
	_set_seed(cfg)

	cfg = _set_personal_conf(cfg)

	device = _get_device(cfg)
	logger = _get_logger(cfg)
	env, eval_env = _get_envs(cfg, device)

	n_obs = (
		env.observation_spec["observation"].shape[0]
		if cfg.loss.n_obs == "auto"
		else cfg.loss.n_obs
	)

	rnns: Optional[Tuple[nn.Module, nn.Module]] = None
	if cfg.loss.type == "sac":
		loss_module, actor = build_sac_module(
			n_obs=n_obs,
			action_spec=env.action_spec.clone(),
			gamma=cfg.loss.gamma,
			hidden_size=cfg.loss.hidden_size,
			num_layers=cfg.loss.num_layers,
		)
	elif cfg.loss.type == "rnn-sac":
		loss_module, actor, rnns = build_rec_sac_module(
			n_obs=n_obs,
			action_spec=env.action_spec.clone(),
			gamma=cfg.loss.gamma,
			rnn_type=cfg.loss.rnn_type,
			depth=cfg.loss.depth,
			hidden_size=cfg.loss.hidden_size,
			min_alpha=cfg.loss.min_alpha,
			max_alpha=cfg.loss.max_alpha,
			pred_signals=cfg.loss.pred_signals,
			signals_keys=cfg.loss.signals_keys,
			signals_dims=cfg.loss.signals_dims,
			dropout=cfg.loss.dropout,
		)
	else:
		raise ValueError(f"Invalid loss type: {cfg.loss.type}")

	for primer in _get_primers(cfg, loss_module):
		env = env.append_transform(primer)
	for primer in _get_primers(cfg, loss_module):
		eval_env = eval_env.append_transform(primer)

	num_params = sum(p.numel() for p in loss_module.parameters())
	if cfg.just_count_params:
		print(f"Number of parameters: {format_integer(num_params)}.")
		return {"num_params": num_params}
	if logger is not None:
		logger.log_hparams(
			{
				"num_params_text": format_integer(num_params),
				"num_params": num_params,
			}
		)

	updater: TargetNetUpdater = hutils.instantiate(
		cfg.updater, loss_module=loss_module
	)

	collector: DataCollectorBase = hutils.instantiate(
		cfg.collector,
		policy=actor,
		create_env_fn=env,
		device=device,
	)
	replay_buffer: ReplayBuffer = hutils.instantiate(cfg.buffer)

	trainer: RLTrainer = RLTrainer(
		collector=collector,
		buffer=replay_buffer,
		optim_steps_per_batch=cfg.trainer.optim_steps_per_batch,
		logger=logger,
		progress_bar=cfg.trainer.progress_bar,
		device=device,
		include_additional_keys=cfg.trainer.include_additional_keys,
	)

	optimizer_qvalue: Optimizer = hutils.instantiate(
		cfg.optimizer,
		params=loss_module.qvalue_network_params.flatten_keys().values(),
	)
	optimizer_actor: Optimizer = hutils.instantiate(
		cfg.optimizer,
		params=loss_module.actor_network_params.flatten_keys().values(),
	)
	optimzer_alpha: Optimizer = hutils.instantiate(
		cfg.optimizer, params=[loss_module.log_alpha]
	)

	sac_hook = SACLossHook(
		loss=loss_module,
		opt_actor=optimizer_actor,
		opt_qvalue=optimizer_qvalue,
		opt_alpha=optimzer_alpha,
		gradient_clipping=cfg.loss.gradient_clipping,
		pred_missing_signals=cfg.loss.pred_signals,
	)
	sac_hook.register(trainer)
	if cfg.loss.type == "rnn-sac":
		trainer.register_hook(
			"sample",
			lambda b: torch.no_grad()(split_trajectories)(
				b, trajectory_key="...", done_key="end_slice"
			),
		)
		if rnns is not None and hasattr(rnns[0], "setup_for_step"):
			trainer.register_hook(
				"start_train", lambda: rnns[0].setup_for_step()
			)
			trainer.register_hook(
				"end_traj", lambda _m, _l: rnns[0].setup_for_step()
			)

	noise_control: Optional[NoiseControlHook] = hutils.instantiate(
		cfg.noise_control
	)
	if noise_control:
		noise_control.register(trainer)

	trainer.register_hook("end_optim_step", updater.step)

	recorder = EvalRolloutHook(
		env=eval_env,
		actor=actor,
		max_frames=cfg.collector.max_frames_per_traj,
		exploration_type=ExplorationType.DETERMINISTIC,  # type: ignore
		log_interval=cfg.trainer.log_interval,
		num_rollouts=cfg.trainer.num_eval_rollouts,
		metrics=cfg.trainer.eval_metrics,
	)
	recorder.register(trainer)

	if cfg.trainer.log_weights:
		log_weights_hook = LogWeightsHook(
			loss_module=loss_module,
			log_interval=cfg.trainer.log_interval,
		)
		log_weights_hook.register(trainer)

	if cfg.save_dir is not None:
		save_hook = SaveHook(
			save_dir=cfg.save_dir,
			save_interval=cfg.save_interval,
			override=False,
		)
		save_hook.register(trainer)
		_save_config(cfg, f"{cfg.save_dir}/config.yaml")

	if cfg.load_existing:
		checkpoint = torch.load(_get_most_recent_checkpoint(cfg.save_dir))
		trainer.load_state_dict(checkpoint)

	trainer.train()

	env.close()
	eval_env.close()

	return {
		"num_params": num_params,
		"seed": cfg.seed,
		"device": str(device),
		"exp_name": cfg.exp_name
		if hasattr(cfg, "exp_name")
		else _get_exp_name(cfg),
	}


def _run_sweep_experiment(cfg: DictConfig) -> None:
	"""Run a parameter sweep experiment.

	Args:
		cfg: Configuration with sweep settings
	"""
	print(
		f"Starting parameter sweep with "
		f"{cfg.sweep.num_experiments} experiments"
	)

	# Generate all experiment configurations
	experiment_configs = generate_sweep_experiment_configs(cfg, cfg.sweep)

	# Set up results directory
	if cfg.sweep.save_results:
		results_dir = cfg.sweep.results_dir
		os.makedirs(results_dir, exist_ok=True)

		# Save sweep configuration
		sweep_config_path = os.path.join(results_dir, "sweep_config.yaml")
		with open(sweep_config_path, "w") as f:
			yaml.dump(OmegaConf.to_container(cfg.sweep), f)

	all_results = []

	for i, (exp_config, metadata) in enumerate(experiment_configs):
		print(f"Running sweep experiment {i + 1}/{len(experiment_configs)}")
		print(f"  Sampled params: {metadata['sampled_params']}")

		if cfg.sweep.save_results:
			exp_save_dir = os.path.join(results_dir, f"exp_{i:03d}")
			exp_config.save_dir = exp_save_dir
			exp_config.save_interval = 10000

		try:
			result = _run_single(exp_config)
			combined_result = {**result, **metadata, "success": True}
			if cfg.sweep.get("metrics", {}).get("track_signal_density", False):
				combined_result["signal_density"] = metadata.get(
					"asynchronicity_rate", 0.0
				)

			all_results.append(combined_result)
			print(f"  Completed experiment {i + 1} successfully")

		except Exception as e:
			print(f"  Experiment {i + 1} failed with error: {e}")
			all_results.append({**metadata, "success": False, "error": str(e)})

	# Save aggregated results
	if cfg.sweep.save_results:
		results_path = os.path.join(results_dir, "sweep_results.yaml")
		with open(results_path, "w") as f:
			yaml.dump(all_results, f)

		print(f"Sweep completed. Results saved to {results_path}")

		successful_experiments = sum(
			1 for r in all_results if r.get("success", False)
		)
		print(
			f"Successfully completed {successful_experiments}/"
			f"{len(all_results)} experiments"
		)


def _cleanup_processes() -> None:
	"""Emergency cleanup function."""
	import psutil

	try:
		current_process = psutil.Process()
		children = current_process.children(recursive=True)
		print(f"Cleaning up {len(children)} child processes...")
		for child in children:
			try:
				child.terminate()
			except (psutil.NoSuchProcess, psutil.AccessDenied):
				pass
		psutil.wait_procs(children, timeout=5)
		for child in current_process.children(recursive=True):
			try:
				child.kill()
			except (psutil.NoSuchProcess, psutil.AccessDenied):
				pass
	except Exception as e:
		print(f"Error during cleanup: {e}")


def _count_child_processes() -> int:
	"""Count current child processes.

	Returns:
		Number of child processes.
	"""
	import psutil

	return len(psutil.Process().children(recursive=True))


def _run_experiments_sequentially(runs: List, results: List) -> None:
	"""Run experiments one by one."""
	for run_cfg, run_id, device in runs:
		print(f"Starting run {run_id}/{len(runs)} on {device}")
		result = _run_single_experiment(run_cfg, run_id, device)
		results.append(result)
		print(f"Completed run {run_id}/{len(runs)}")


def _run_experiments_parallel(
	runs: List, results: List, max_concurrent: int, experiment_timeout: int
) -> None:
	"""Run experiments in parallel with proper cleanup.

	Raises:
		KeyboardInterrupt: If user interrupts execution.
	"""
	import atexit
	from concurrent.futures import (
		ProcessPoolExecutor,
		TimeoutError,
		as_completed,
	)

	active_futures = set()
	executor = None
	atexit.register(_cleanup_processes)

	try:
		executor = ProcessPoolExecutor(max_workers=max_concurrent)
		print(f"Created ProcessPoolExecutor with {max_concurrent} workers")
		print(f"Current child processes: {_count_child_processes()}")

		try:
			future_to_run = {}
			for run_cfg, run_id, device in runs:
				future = executor.submit(
					_run_single_experiment, run_cfg, run_id, device
				)
				future_to_run[future] = (run_id, device)
				active_futures.add(future)

			print(f"Submitted {len(future_to_run)} experiments")
			print(f"Current child processes: {_count_child_processes()}")

			total_timeout = experiment_timeout * len(runs)
			completed_futures = as_completed(
				future_to_run, timeout=total_timeout
			)

			for future in completed_futures:
				run_id, device = future_to_run[future]
				active_futures.discard(future)
				try:
					result = future.result(timeout=300)
					results.append(result)
					print(f"Completed run {run_id}/{len(runs)} on {device}")
					child_count = _count_child_processes()
					print(f"Current child processes: {child_count}")
				except TimeoutError:
					timeout_msg = f"Timeout after {experiment_timeout} seconds"
					timeout_print = (
						f"Run {run_id} timed out after "
						f"{experiment_timeout} seconds"
					)
					print(timeout_print)
					future.cancel()
					results.append({"error": timeout_msg})
				except Exception as exc:
					print(f"Run {run_id} generated an exception: {exc}")
					results.append({"error": str(exc)})

		except TimeoutError:
			batch_timeout_msg = (
				"Overall experiment batch timed out. "
				"Cancelling remaining runs..."
			)
			print(batch_timeout_msg)
			for future in active_futures:
				future.cancel()
		except KeyboardInterrupt:
			interrupt_msg = (
				"Received keyboard interrupt. "
				"Cancelling all running experiments..."
			)
			print(interrupt_msg)
			for future in active_futures:
				future.cancel()
			raise
		finally:
			if executor:
				print("Shutting down executor...")
				executor.shutdown(wait=True, cancel_futures=True)
				executor = None
				child_count = _count_child_processes()
				shutdown_msg = (
					f"Executor shut down. "
					f"Current child processes: {child_count}"
				)
				print(shutdown_msg)

	finally:
		try:
			atexit.unregister(_cleanup_processes)
		except ValueError:
			pass


def _run_multi_experiment(cfg: DictConfig) -> None:
	print(f"Starting multi-run experiment with {cfg.multi_run.num_runs} runs")

	initial_process_count = _count_child_processes()
	print(f"Initial child processes: {initial_process_count}")

	if cfg.multi_run.seeds is not None:
		seeds = list(cfg.multi_run.seeds)
		if len(seeds) != cfg.multi_run.num_runs:
			raise ValueError(
				f"Number of provided seeds ({len(seeds)}) doesn't match "
				f"num_runs ({cfg.multi_run.num_runs})"
			)
	else:
		seeds = _generate_seeds(cfg.multi_run.num_runs, cfg.seed)

	available_gpus = _get_available_gpus()
	if cfg.multi_run.gpus is not None:
		devices = [torch.device(f"cuda:{gpu}") for gpu in cfg.multi_run.gpus]
	elif available_gpus:
		devices = [torch.device(f"cuda:{gpu}") for gpu in available_gpus]
	else:
		devices = [torch.device("cpu")]

	max_concurrent = cfg.multi_run.max_concurrent
	if max_concurrent is None:
		max_concurrent = len(devices)

	print(f"Using devices: {[str(d) for d in devices]}")
	print(f"Max concurrent runs: {max_concurrent}")
	print(f"Seeds: {seeds}")

	runs = []
	for i, seed in enumerate(seeds):
		run_cfg = OmegaConf.create(cfg)
		run_cfg.seed = seed
		device = devices[i % len(devices)]
		runs.append((run_cfg, i + 1, device))

	results = []

	try:
		if max_concurrent == 1:
			_run_experiments_sequentially(runs, results)
		else:
			experiment_timeout = cfg.multi_run.get(
				"experiment_timeout", 259200
			)
			_run_experiments_parallel(
				runs, results, max_concurrent, experiment_timeout
			)

	finally:
		final_process_count = _count_child_processes()
		if final_process_count > initial_process_count:
			leaked_count = final_process_count - initial_process_count
			print(f"Warning: {leaked_count} processes may have leaked")
			_cleanup_processes()

	print(f"Multi-run experiment completed. Results: {results}")
	final_count = _count_child_processes()
	print(f"Final child processes: {final_count}")


@hydra.main(version_base="1.3", config_path="conf", config_name="base")
def _run(cfg: DictConfig) -> None:
	if hasattr(cfg, "sweep") and cfg.sweep.enable:
		_run_sweep_experiment(cfg)
	elif cfg.multi_run.enable:
		_run_multi_experiment(cfg)
	else:
		_run_single(cfg)


if __name__ == "__main__":
	OmegaConf.register_new_resolver("concat", lambda x, y: f"{x}{y}")
	OmegaConf.register_new_resolver("as_tuple", lambda x, y: (x, y))
	_run()
