import logging
import os
import shutil
import json
import torch
import random
import numpy as np
import time
import pickle
from typing import Optional, Dict, Any
from datetime import datetime
from collections import defaultdict

def get_logger(name: str) -> logging.Logger:
    """
    Get a logger with the specified name.
    
    Args:
        name (str): Name of the logger
        
    Returns:
        logging.Logger: Configured logger instance
    """
    logger = logging.getLogger(name)
    if not logger.handlers:
        handler = logging.StreamHandler()
        formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.INFO)
    return logger


def create_timestamped_experiment_dir(base_output_dir: str = "output", 
                                      experiment_name: str = "llama_training") -> str:
    """
    Create a timestamped experiment directory for organizing all outputs.
    
    Args:
        base_output_dir (str): Base output directory name
        experiment_name (str): Name of the experiment
        
    Returns:
        str: Path to the created timestamped experiment directory
    """
    # Generate timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    exp_dir_name = f"{experiment_name}_{timestamp}"
    return os.path.join(base_output_dir, exp_dir_name)


def save_current_src(save_path: str, 
                     project_root: Optional[str] = None,) -> None:
    """
    Save the current source code to specified path for experiment recording and version management.
    
    Args:
        save_path (str): The path to save the current src
        project_root (str, optional): Root path of the project. If None, auto-detect.
        
    Returns:
        None
    """
    logger = get_logger(__name__)
    logger.info("Saving the current source code")
    
    if project_root is None:
        # Auto-detect project root (assuming this file is in src/utils/)
        current_file_dir = os.path.dirname(os.path.abspath(__file__))
        project_root = os.path.dirname(os.path.dirname(current_file_dir))
    
    try:
        # Create backup directory
        src_backup_path = os.path.join(save_path, "src_backup")
        os.makedirs(src_backup_path, exist_ok=True)
        
        # Copy src directory
        src_path = os.path.join(project_root, "src")
        if os.path.exists(src_path):
            shutil.copytree(src_path, os.path.join(src_backup_path, "src"), dirs_exist_ok=True)
            logger.info(f"Copied src directory to {src_backup_path}")
        
        # Copy scripts directory  
        scripts_path = os.path.join(project_root, "scripts")
        if os.path.exists(scripts_path):
            shutil.copytree(scripts_path, os.path.join(src_backup_path, "scripts"), dirs_exist_ok=True)
            logger.info(f"Copied scripts directory to {src_backup_path}")
            
        # Copy config directory
        config_path = os.path.join(project_root, "config")
        if os.path.exists(config_path):
            shutil.copytree(config_path, os.path.join(src_backup_path, "config"), dirs_exist_ok=True)
            logger.info(f"Copied config directory to {src_backup_path}")
            
        # Copy configurator.py if exists in root
        configurator_path = os.path.join(project_root, "configurator.py")
        if os.path.exists(configurator_path):
            shutil.copy2(configurator_path, src_backup_path)
            logger.info(f"Copied configurator.py to {src_backup_path}")
            
        logger.info(f"Successfully saved source code to {src_backup_path}")
        
    except PermissionError as e:
        logger.warning(f"Permission denied when saving source code: {e}")
        logger.warning("Continuing without source code backup...")
    except Exception as e:
        logger.error(f"Error saving source code: {e}")
        logger.warning("Continuing without source code backup...")


def save_experiment_config(config: Dict[Any, Any], 
                           save_path: str, 
                           filename: str = "experiment_config.json") -> None:
    """
    Save experiment configuration to JSON file.
    
    Args:
        config (Dict[Any, Any]): Configuration dictionary
        save_path (str): Directory path to save the config
        filename (str): Name of the config file
        
    Returns:
        None
    """
    logger = get_logger(__name__)
    
    config_dir = os.path.join(save_path, "configs")
    os.makedirs(config_dir, exist_ok=True)
    
    # Convert config to JSON-serializable format
    json_config = {}
    for key, value in config.items():
        try:
            json.dumps(value)  # Test if value is JSON serializable
            json_config[key] = value
        except (TypeError, ValueError):
            json_config[key] = str(value)  # Convert to string if not serializable
    
    try:
        with open(os.path.join(config_dir, filename), 'w') as f:
            json.dump(json_config, f, indent=2)
        logger.info(f"Saved experiment config to {config_dir}")
    except Exception as e:
        logger.error(f"Error saving experiment config: {e}")
        raise


def create_experiment_readme(save_path: str, 
                             experiment_info: Dict[str, Any]) -> None:
    """
    Create a README file for the experiment with basic information.
    
    Args:
        exp_dir_path (str): Experiment directory path
        experiment_info (Dict[str, Any]): Information about the experiment
        
    Returns:
        None
    """
    logger = get_logger(__name__)
    
    os.makedirs(save_path, exist_ok=True)
    readme_path = os.path.join(save_path, "README.md")
    
    readme_content = f"""# Experiment: {experiment_info.get('name', 'Unknown')}

## Experiment Information
- **Start Time**: {experiment_info.get('start_time', 'Unknown')}
- **Model**: {experiment_info.get('model_name', 'Unknown')}
- **Dataset**: {experiment_info.get('dataset', 'Unknown')}

## Directory Structure
```
{os.path.basename(save_path)}/
├── checkpoints/          # Model checkpoints
├── train_monitor/        # Training monitor files  
├── logs/                # Training logs
├── src_backup/          # Source code backup
├── configs/             # Configuration files
├── wandb/              # Weights & Biases logs
└── README.md           # This file
```

## Training Configuration
- **Max Iterations**: {experiment_info.get('max_iters', 'Unknown')}
- **Batch Size**: {experiment_info.get('batch_size', 'Unknown')}
- **Learning Rate**: {experiment_info.get('learning_rate', 'Unknown')}
- **Model Size**: {experiment_info.get('model_name', 'Unknown')}

## Notes
Add any additional notes about this experiment here.
"""
    try:
        with open(readme_path, 'w') as f:
            f.write(readme_content)
        logger.info(f"Created experiment README at {readme_path}")
    except Exception as e:
        logger.error(f"Error creating experiment README: {e}")


def setup_experiment_environment(base_output_dir: str = "output",
                                 experiment_name: str = "llama_training",
                                 config: Optional[Dict[Any, Any]] = None,
                                 resume_path: Optional[str] = None,
                                 resume_iter: Optional[int] = None,) -> Dict[str, str]:
    """
    Set up complete experiment environment with timestamped directories.
    
    Args:
        base_output_dir (str): Base output directory
        experiment_name (str): Name of the experiment
        config (Dict[Any, Any], optional): Experiment configuration
        resume_path (str, optional): Path to resume experiment
        resume_iter (int, optional): Iteration number to resume from
        
    Returns:
        Dict[str, str]: Dictionary containing all directory paths
    """
    if resume_path:
        base_output_dir = os.path.join(resume_path, 'resume', f'resume_iter_{resume_iter}')
        return {
            'monitor_dir': os.path.join(resume_path, 'train_monitor'),
            'logs_dir': os.path.join(resume_path, 'logs'),
            'wandb_dir': os.path.join(resume_path, 'wandb'),
        }
    
    # Create timestamped experiment directory
    exp_dir_path = create_timestamped_experiment_dir(
        base_output_dir, experiment_name
    )
    
    # Save source code
    save_current_src(exp_dir_path)

    # Create experiment README
    experiment_info = {
        'name': experiment_name,
        'start_time': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    }
    
    # Save configuration if provided
    if config:
        save_experiment_config(config, exp_dir_path)
        # update experiment info
        experiment_info.update({
            'model_name': config.get('model_name', 'Unknown'),
            'dataset': config.get('dataset', 'Unknown'),
            'max_iters': config.get('max_iters', 'Unknown'),
            'batch_size': config.get('batch_size', 'Unknown'),
            'learning_rate': config.get('learning_rate', 'Unknown'),
        })
    
    # create experiment README
    create_experiment_readme(exp_dir_path, experiment_info)
    
    return {
        'monitor_dir': os.path.join(exp_dir_path, 'train_monitor'),
        'logs_dir': os.path.join(exp_dir_path, 'logs'),
        'wandb_dir': os.path.join(exp_dir_path, 'wandb'),
    }


def get_latest_experiment_dir(base_output_dir: str = "output", 
                              experiment_prefix: str = "llama_training") -> Optional[str]:
    """
    Get the path to the most recently created experiment directory.
    
    Args:
        base_output_dir (str): Base output directory
        experiment_prefix (str): Prefix of experiment directories
        
    Returns:
        Optional[str]: Path to latest experiment directory, None if not found
    """
    logger = get_logger(__name__)
    
    if not os.path.exists(base_output_dir):
        logger.warning(f"Base output directory {base_output_dir} does not exist")
        return None
    
    # Find all directories with the experiment prefix
    exp_dirs = []
    for item in os.listdir(base_output_dir):
        item_path = os.path.join(base_output_dir, item)
        if os.path.isdir(item_path) and item.startswith(experiment_prefix):
            exp_dirs.append(item_path)
    
    if not exp_dirs:
        logger.warning(f"No experiment directories found with prefix {experiment_prefix}")
        return None
    
    # Sort by modification time (most recent first)
    exp_dirs.sort(key=os.path.getmtime, reverse=True)
    latest_dir = exp_dirs[0]
    
    logger.info(f"Latest experiment directory: {latest_dir}")
    return latest_dir


def get_rng_states():
    """
    Get all random number generator states.
    """
    rng_states = {
        'torch_rng_state': torch.get_rng_state(),
        'numpy_rng_state': np.random.get_state(),
        'python_rng_state': random.getstate(),
    }
    
    if torch.cuda.is_available():
        rng_states['torch_cuda_rng_state'] = torch.cuda.get_rng_state()
        if torch.cuda.device_count() > 1:
            rng_states['torch_cuda_rng_state_all'] = torch.cuda.get_rng_state_all()
    
    return rng_states


def save_training_monitor(save_path: str,
                          config: Dict[Any, Any], 
                          model: torch.nn.Module, 
                          optimizer: torch.optim.Optimizer, 
                          lr: float, 
                          losses: Dict[str, float], 
                          token_num: int,
                          iter_num: int):
    """
    Save training monitor.
    """
    # Collect all training states
    training_monitor = {
        'config': config,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'current_lr': lr,
        'losses': losses,
        'rng_states': get_rng_states(),
        'tokens_num': token_num,
        'iter_num': iter_num,
        'timestamp': time.time(),
    }
    
    # Save to file
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, f'training_monitor_iter_{iter_num}.pkl'), 'wb') as f:
        pickle.dump(training_monitor, f)


def load_training_monitor(load_path: str) -> Dict[Any, Any]:
    """
    Load training monitor checkpoint.
    """
    with open(load_path, 'rb') as f:
        training_monitor = pickle.load(f)
    return training_monitor


def save_gradients(save_path: str,
                   model: torch.nn.Module, 
                   iter_num: int):
    """
    Save complete gradients from all model parameters, grouped by parameter type.
    
    Args:
        model: The model (unwrapped)
        monitor_dir: Directory to save gradients
        iter_num: Current iteration number
    """
    # Collect complete gradients by parameter groups
    gradients_by_group = defaultdict(dict)
    total_grad_norm_squared = 0.0
    for name, param in model.named_parameters():
        if param.grad is not None:
            # Determine which group this parameter belongs to
            if 'embed' in name:
                group = 'embed'
            elif 'lm_head' in name:
                group = 'head'
            elif 'norm' in name:
                group = 'ln'
            elif 'q_proj' in name or 'k_proj' in name:
                group = 'qk'
            elif 'v_proj' in name or 'o_proj' in name:
                group = 'vo'
            elif 'mlp' in name:
                group = 'mlp'
            else:
                raise ValueError(f"Unrecognized parameter group for parameter: {name}")
            
            # Store complete gradient (detached from computation graph and moved to CPU)
            grad_data = param.grad.detach().clone().cpu()
            gradients_by_group[group][name] = grad_data
            
            # Accumulate for total gradient norm calculation
            total_grad_norm_squared += torch.norm(grad_data).item() ** 2
    
    gradient_data = {
        'iter_num': iter_num,
        'gradients_by_group': gradients_by_group,
        'total_grad_norm': total_grad_norm_squared ** 0.5,
        'timestamp': time.time(),
    }
    
    # Create gradients directory if it doesn't exist
    os.makedirs(save_path, exist_ok=True)
    with open(os.path.join(save_path, f'gradients_iter_{iter_num}.pkl'), 'wb') as f:
        pickle.dump(gradient_data, f)
