import os
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

from benchrl.utils._functions import set_device


class BaseAlgorithm(ABC):
    """Base class for all reinforcement learning algorithms.
    
    Compatible with CleanRL patterns while providing modular structure for BenchRL v2.
    Designed for simplicity, performance, and research-friendly extensibility.
    """
    
    def __init__(
        self,
        env,
        algo_config: Dict[str, Any],
        device: str = "auto",
        writer: Optional[SummaryWriter] = None
    ):
        """Initialize base algorithm.
        
        Args:
            env: Vectorized environment
            algo_config: Algorithm configuration dictionary
            device: Device for computation ('auto', 'cpu', 'cuda')
            writer: Optional tensorboard writer
        """
        self.env = env
        self.algo_config = algo_config
        self.device = set_device(device)
        self.writer = writer
        
        # Training state
        self.global_step = 0
        self.episode_count = 0
        self.episode_returns = []
        self.start_time = time.time()
        
        # Environment info
        self.num_envs = env.num_envs
        self.obs_shape = env.single_observation_space.shape
        self.action_space = env.single_action_space
        
        # Performance tracking
        self.checkpoint_performances = []
        self.last_checkpoint_step = 0
        
        # Initialize observation and action dimensions
        self.in_dim = env.single_observation_space.shape
        if len(self.in_dim) == 1 : 
            self.in_dim = self.in_dim[0]
        self.flat_in_dim = int(np.array(self.in_dim).prod())
        
        if hasattr(self.action_space, 'n'):
            self.action_dim = self.action_space.n
        else:
            self.action_dim = int(np.prod(self.action_space.shape))
        
        self.out_dim = self.action_dim
    
    @abstractmethod
    def train_step(self) -> Dict[str, float]:
        """Execute one training step.
        
        Returns:
            Dictionary of training metrics
        """
        pass
    
    @abstractmethod
    def save(self, path: str) -> None:
        """Save algorithm state.
        
        Args:
            path: Path to save checkpoint
        """
        pass
    
    @abstractmethod
    def load(self, path: str) -> None:
        """Load algorithm state.
        
        Args:
            path: Path to load checkpoint from
        """
        pass
    
    def evaluate(
        self,
        eval_env=None,
        num_episodes: int = 100,
        deterministic: bool = True,
        render: bool = False,
        seed: Optional[int] = None
    ) -> Dict[str, float]:
        """Evaluate the current policy.
        
        Args:
            eval_env: Environment for evaluation (uses training env if None)
            num_episodes: Number of episodes to evaluate
            deterministic: Whether to use deterministic policy
            render: Whether to render episodes
            seed: Random seed for evaluation
            
        Returns:
            Dictionary of evaluation metrics
        """
        if eval_env is None:
            eval_env = self.env
            
        episode_returns = []
        episode_lengths = []
        episodes_completed = 0
        
        obs, _ = eval_env.reset(seed=seed)
        obs = torch.tensor(obs, dtype=torch.float32, device=self.device)
        
        current_returns = np.zeros(eval_env.num_envs)
        current_lengths = np.zeros(eval_env.num_envs)
        
        with torch.no_grad():
            while episodes_completed < num_episodes:
                action = self.get_action(obs, deterministic=deterministic)
                obs, reward, terminations, truncations, infos = eval_env.step(action.cpu().numpy())
                obs = torch.tensor(obs, dtype=torch.float32, device=self.device)
                
                current_returns += reward
                current_lengths += 1
                
                dones = np.logical_or(terminations, truncations)
                
                if "final_info" in infos:
                    final_info = infos["final_info"]
                    if "episode" in final_info:
                        eps_info = final_info["episode"]
                        for i in range(len(eps_info['r'])):
                            if episodes_completed < num_episodes:
                                episode_returns.append(eps_info['r'][i])
                                episode_lengths.append(eps_info['l'][i])
                                episodes_completed += 1
                
                if render:
                    try:
                        eval_env.render()
                    except:
                        pass
        
        episode_returns = np.array(episode_returns[:num_episodes])
        episode_lengths = np.array(episode_lengths[:num_episodes])
        
        metrics = {
            'mean_return': np.mean(episode_returns),
            'std_return': np.std(episode_returns),
            'median_return': np.median(episode_returns),
            'min_return': np.min(episode_returns),
            'max_return': np.max(episode_returns),
            'mean_length': np.mean(episode_lengths),
            'num_episodes': len(episode_returns)
        }
        
        return metrics
    
    def save_checkpoint(self, checkpoint_dir: str, max_checkpoints: int = 5) -> None:
        """Save checkpoint based on performance.
        
        Args:
            checkpoint_dir: Directory to save checkpoints
            max_checkpoints: Maximum number of checkpoints to keep
        """
        if len(self.episode_returns) == 0:
            return
            
        current_performance = np.mean(self.episode_returns[-100:])
        timestamp = time.strftime("%Y%m%d_%H%M%S")
        checkpoint_name = f"checkpoint_step_{self.global_step}_{timestamp}.pth"
        checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name)
        
        os.makedirs(checkpoint_dir, exist_ok=True)
        self.save(checkpoint_path)
        
        self.checkpoint_performances.append((current_performance, checkpoint_path))
        
        if len(self.checkpoint_performances) > max_checkpoints:
            self.checkpoint_performances.sort(key=lambda x: x[0], reverse=True)
            _, path_to_remove = self.checkpoint_performances.pop()
            if os.path.exists(path_to_remove):
                os.remove(path_to_remove)
        
        print(f"Checkpoint saved: {checkpoint_path} (performance: {current_performance:.2f})")
        self.last_checkpoint_step = self.global_step
    
    def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
        """Log metrics to tensorboard.
        
        Args:
            metrics: Dictionary of metrics to log
            step: Step number (uses global_step if None)
        """
        if self.writer is None:
            return
            
        step = step or self.global_step
        for key, value in metrics.items():
            self.writer.add_scalar(key, value, step)
    
    def get_sps(self) -> float:
        """Get steps per second."""
        return int(self.global_step / (time.time() - self.start_time))
    
    def print_progress(self) -> None:
        """Print training progress."""
        mean_return = np.mean(self.episode_returns[-100:]) if self.episode_returns else 0.0
        print(f"Step: {self.global_step} | "
              f"Episodes: {self.episode_count} | "
              f"SPS: {self.get_sps()} | "
              f"Mean Return: {mean_return:.2f}")