import os
import re
from datetime import datetime

import hydra
from hydra.utils import instantiate
import numpy as np
from omegaconf import DictConfig, OmegaConf, ListConfig
from torch.utils.tensorboard import SummaryWriter

from benchrl.environments.registry import get_env_builder
from benchrl.trainers import Trainer
from benchrl.utils._functions import build_agent

def run_experiment(agent, env, config):
    """Run training using the new Trainer architecture."""
    timestamp = datetime.now().strftime("%Y:%m:%d_%H:%M:%S")

    if config.experiment.load_path is not None:
        path_model = config.experiment.load_path
        print(f"Loading model from {path_model}")
        agent.load(path_model)
        parts = path_model.split(os.sep)
        run_name = os.path.join(parts[0], parts[1], "retraining_checkpoint")
        base_dir = run_name
        print(f"Using base directory for logging {base_dir}")
    else:
        env_id = re.sub("/", '-', config.environment.env_id)
        run_name = f"{env_id}_{config.algorithm.name}_seed{config.experiment.seed}_{timestamp}"
        base_dir = os.path.join("runs", run_name)

    os.makedirs(base_dir, exist_ok=True)

    # Initialize wandb if tracking is enabled
    if config.experiment.track:
        dict_config = OmegaConf.to_container(config)
        import wandb
        from dotenv import load_dotenv
        load_dotenv()
        WANDB_API_KEY = os.getenv('WANDB_API_KEY')
        WANDB_PROJECT = os.getenv('WANDB_PROJECT')
        WANDB_ENTITY = os.getenv('WANDB_ENTITY')
        wandb.login(key=WANDB_API_KEY)
        wandb.init(
            project=WANDB_PROJECT,
            entity=WANDB_ENTITY,
            group=f"{config.algorithm.name}_{env_id}",
            sync_tensorboard=True,
            config=dict_config,
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )

    # Create tensorboard writer
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(config.algorithm).items()])),
    )
    
    # Update agent's writer
    agent.writer = writer
    
    # Create trainer configuration
    trainer_config = {
        'total_timesteps': config.algorithm.total_timesteps,
        'checkpoint_interval': config.experiment.checkpoint_interval,
        'max_checkpoints': config.experiment.max_checkpoints,
        'checkpoint_dir': os.path.join(base_dir, "checkpoints"),
        'save_final': True,
        'eval_episodes': getattr(config.experiment, 'num_eval_episodes', 100),
        'eval_deterministic': True,
        'eval_interval': getattr(config.experiment, 'eval_interval', None),
    }
    
    # Create and run trainer
    trainer = Trainer(
        algorithm=agent,
        env=env,
        config=trainer_config,
        writer=writer
    )
    
    trainer.train(seed=config.experiment.seed)


@hydra.main(config_path="../configs", config_name="config", version_base=None)
def main(cfg: DictConfig) -> None:
    """Main training function using BenchRL v2 architecture."""
    print(f"======= BenchRL v2 Experiment Parameters =======\n{OmegaConf.to_yaml(cfg)}")

    # Build environment
    env_builder = get_env_builder()
    env = env_builder.build_env(
        seed=cfg.experiment.seed,
        capture_video=cfg.experiment.capture_video,
        video_folder=cfg.experiment.video_folder,
        **cfg.environment
    )

    # Build agent
    agent = build_agent(env, cfg)
    
    if cfg.experiment.evaluate:
        if cfg.experiment.load_path is None:
            raise ValueError("Load path must be specified for evaluation.")
        agent.load(cfg.experiment.load_path)
        eval_metrics = agent.evaluate(env, num_episodes=cfg.experiment.num_eval_episodes)
        print(f"Evaluation results: {eval_metrics}")
    else:
        run_experiment(
            agent=agent,
            env=env,
            config=cfg
        )


if __name__ == "__main__":
    main()