import os
import time
from typing import Dict, Any, Optional
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import tqdm

from benchrl.algorithms.base import BaseAlgorithm


class Trainer:
    """Main training controller for BenchRL v2.
    
    Orchestrates the entire training process including algorithm training,
    evaluation, checkpointing, and logging. Follows the class diagram
    structure with clean separation of concerns.
    """
    
    def __init__(
        self,
        algorithm: BaseAlgorithm,
        env,
        config: Dict[str, Any],
        writer: Optional[SummaryWriter] = None,
        eval_env=None
    ):
        """Initialize trainer.
        
        Args:
            algorithm: RL algorithm instance
            env: Training environment
            config: Configuration dictionary
            writer: Tensorboard writer
            eval_env: Optional evaluation environment
        """
        self.algorithm = algorithm
        self.env = env
        self.config = config
        self.writer = writer
        self.eval_env = eval_env
        
        # Training configuration
        self.total_timesteps = config.get('total_timesteps', 50000)
        self.checkpoint_interval = config.get('checkpoint_interval', 10000)
        self.eval_interval = config.get('eval_interval', None)
        self.max_checkpoints = config.get('max_checkpoints', 5)
        self.checkpoint_dir = config.get('checkpoint_dir', 'checkpoints')
        self.save_final_checkpoint = config.get('save_final', True)
        
        # Evaluation configuration
        self.eval_episodes = config.get('eval_episodes', 100)
        self.eval_deterministic = config.get('eval_deterministic', True)
        
        # Create checkpoint directory
        os.makedirs(self.checkpoint_dir, exist_ok=True)
        
        # Training state
        self.start_time = time.time()
        
    def train(self, seed: Optional[int] = None) -> None:
        """Main training loop.
        
        Args:
            seed: Random seed for training
        """
        print(f"Starting training for {self.total_timesteps} timesteps")
        print(f"Algorithm: {self.algorithm.__class__.__name__}")
        print(f"Environment: {self.env.envs[0].spec.id if hasattr(self.env.envs[0], 'spec') else 'Unknown'}")
        print(f"Device: {self.algorithm.device}")
        print("-" * 50)
        
        # Training loop
        with tqdm.tqdm(total=self.total_timesteps, desc="Training") as pbar:
            while self.algorithm.global_step < self.total_timesteps:
                # Execute one training step
                metrics = self.algorithm.train_step()
                
                # Log metrics
                self._log_metrics(metrics)
                
                # Update progress bar
                pbar.update(self.algorithm.global_step - pbar.n)
                
                # Check metrics if not empty
                if metrics.get('rollout/episodes', 0) != 0:
                    mean_return = np.mean(metrics.get('rollout/episodic_return', 0)).item()
                    pbar.set_postfix({
                        'Episodes': self.algorithm.episode_count,
                        'SPS': self.algorithm.get_sps(),
                        'Mean Return': f"{mean_return:.2f}",
                        'critic_loss': f"{metrics.get('sac/critic_loss', 0):.2f}",
                    })
                
                # Periodic evaluation
                if (self.eval_interval is not None and 
                    self.algorithm.global_step % self.eval_interval == 0 and
                    self.algorithm.global_step > 0):
                    self._evaluate_policy()
                
                # Periodic checkpointing
                if (self.algorithm.global_step - self.algorithm.last_checkpoint_step 
                    >= self.checkpoint_interval and self.max_checkpoints > 0):
                    self.algorithm.save_checkpoint(self.checkpoint_dir, self.max_checkpoints)
        
        # Final evaluation and checkpoint
        if self.eval_env is not None:
            print("\nRunning final evaluation...")
            self._evaluate_policy()
        
        if self.save_final_checkpoint:
            print("Saving final checkpoint...")
            self.algorithm.save_checkpoint(self.checkpoint_dir, max(1, self.max_checkpoints))
        
        # Training summary
        self._print_training_summary()
        
        if self.writer is not None:
            self.writer.close()
    
    def _log_metrics(self, metrics: Dict[str, float]) -> None:
        """Log training metrics.
        
        Args:
            metrics: Dictionary of metrics to log
        """
        if self.writer is None:
            return
        
        step = self.algorithm.global_step
        for key, value in metrics.items():
            if isinstance(value, list):
                for v in value:
                    self.writer.add_scalar(key, v, step)
            else:
                self.writer.add_scalar(key, value, step)
        
        # Log additional system metrics
        self.writer.add_scalar("system/sps", self.algorithm.get_sps(), step)
        self.writer.add_scalar("system/episodes", self.algorithm.episode_count, step)
    
    def _evaluate_policy(self) -> Dict[str, float]:
        """Evaluate the current policy.
        
        Returns:
            Dictionary of evaluation metrics
        """
        eval_env = self.eval_env or self.env
        
        print(f"\nEvaluating policy at step {self.algorithm.global_step}...")
        eval_metrics = self.algorithm.evaluate(
            eval_env=eval_env,
            num_episodes=self.eval_episodes,
            deterministic=self.eval_deterministic
        )
        
        # Log evaluation metrics
        if self.writer is not None:
            step = self.algorithm.global_step
            for key, value in eval_metrics.items():
                self.writer.add_scalar(f"eval/{key}", value, step)
        
        print(f"Evaluation complete - Mean return: {eval_metrics['mean_return']:.2f}")
        return eval_metrics
    
    def _print_training_summary(self) -> None:
        """Print training summary."""
        total_time = time.time() - self.start_time
        
        print("\n" + "="*50)
        print("TRAINING SUMMARY")
        print("="*50)
        print(f"Total timesteps: {self.algorithm.global_step}")
        print(f"Total episodes: {self.algorithm.episode_count}")
        print(f"Training time: {total_time:.2f}s")
        print(f"Average SPS: {self.algorithm.get_sps()}")
        
        if self.algorithm.episode_returns:
            recent_returns = self.algorithm.episode_returns[-100:]
            print(f"Final mean return (last 100 episodes): {sum(recent_returns) / len(recent_returns):.2f}")
        
        print("="*50)
    
    def collect_rollouts(self) -> Dict[str, float]:
        """Collect rollouts for the algorithm.
        
        Delegates to the algorithm's rollout collection method.
        This method exists for compatibility with the class diagram.
        
        Returns:
            Dictionary of rollout metrics
        """
        return self.algorithm.collect_rollouts()
    
    def checkpoint(self, path: Optional[str] = None) -> None:
        """Save a checkpoint.
        
        Args:
            path: Optional custom checkpoint path
        """
        if path is None:
            self.algorithm.save_checkpoint(self.checkpoint_dir, self.max_checkpoints)
        else:
            self.algorithm.save(path)
    
    def evaluate_policy(
        self, 
        num_episodes: Optional[int] = None,
        deterministic: Optional[bool] = None
    ) -> Dict[str, float]:
        """Evaluate the current policy.
        
        Args:
            num_episodes: Number of episodes to evaluate (uses config default if None)
            deterministic: Whether to use deterministic policy (uses config default if None)
            
        Returns:
            Dictionary of evaluation metrics
        """
        num_episodes = num_episodes or self.eval_episodes
        deterministic = deterministic if deterministic is not None else self.eval_deterministic
        
        return self.algorithm.evaluate(
            eval_env=self.eval_env or self.env,
            num_episodes=num_episodes,
            deterministic=deterministic
        )


class TrainerBuilder:
    """Builder class for creating Trainer instances.
    
    Provides a clean interface for constructing trainers with different
    configurations and components.
    """
    
    def __init__(self):
        self.algorithm = None
        self.env = None
        self.config = {}
        self.writer = None
        self.eval_env = None
    
    def with_algorithm(self, algorithm: BaseAlgorithm) -> 'TrainerBuilder':
        """Set the algorithm.
        
        Args:
            algorithm: RL algorithm instance
            
        Returns:
            Self for method chaining
        """
        self.algorithm = algorithm
        return self
    
    def with_env(self, env) -> 'TrainerBuilder':
        """Set the training environment.
        
        Args:
            env: Training environment
            
        Returns:
            Self for method chaining
        """
        self.env = env
        return self
    
    def with_config(self, config: Dict[str, Any]) -> 'TrainerBuilder':
        """Set the configuration.
        
        Args:
            config: Configuration dictionary
            
        Returns:
            Self for method chaining
        """
        self.config = config
        return self
    
    def with_writer(self, writer: SummaryWriter) -> 'TrainerBuilder':
        """Set the tensorboard writer.
        
        Args:
            writer: Tensorboard writer
            
        Returns:
            Self for method chaining
        """
        self.writer = writer
        return self
    
    def with_eval_env(self, eval_env) -> 'TrainerBuilder':
        """Set the evaluation environment.
        
        Args:
            eval_env: Evaluation environment
            
        Returns:
            Self for method chaining
        """
        self.eval_env = eval_env
        return self
    
    def build(self) -> Trainer:
        """Build the trainer instance.
        
        Returns:
            Configured Trainer instance
            
        Raises:
            ValueError: If required components are missing
        """
        if self.algorithm is None:
            raise ValueError("Algorithm is required")
        if self.env is None:
            raise ValueError("Environment is required")
        
        return Trainer(
            algorithm=self.algorithm,
            env=self.env,
            config=self.config,
            writer=self.writer,
            eval_env=self.eval_env
        )