import logging
import os
import time
from pathlib import Path
from omegaconf import OmegaConf, DictConfig
import gc
try:
    import wandb  # type: ignore
except Exception:  # wandb may not be installed on non-zero ranks
    wandb = None  # type: ignore
from typing import Dict, List, Optional, Any
import torch
from dataclasses import dataclass, asdict

##### PHI 3 #####
PHI3_MODEL_NAMES = [
    "microsoft/Phi-3-mini-4k-instruct", 
    "microsoft/Phi-3.5-mini-instruct"
]
PHI3_PLACEHOLDER = 32002  # '<|placeholder1|>'
PHI3_RESPONSE_KEYWORD = "<|assistant|>\n"
PHI3_RESPONSE_IDX = [32001]

##### LLAMA 3 #####
LLAMA3_MODEL_NAMES = [
    "meta-llama/Llama-3.2-3B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
]
LLAMA3_PLACEHOLDER = 128255
LLAMA3_RESPONSE_KEYWORD = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
LLAMA3_RESPONSE_IDX = [128009, 128006, 78191, 128007, 271]
LLAMA3_USER_KEYWORD = "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
LLAMA3_USER_IDX = [128009, 128006, 882, 128007, 271]

##### MISTRAL V3 #####
MISTRALV3_MODEL_NAMES = [
    "mistralai/Mistral-7B-Instruct-v0.3",
]
MISTRALV3_PLACEHOLDER = 34
MISTRALV3_RESPONSE_KEYWORD = "[/INST]"
MISTRALV3_RESPONSE_IDX = [4]

##### GEMMA 3 #####
GEMMA3_MODEL_NAMES = [
    "google/gemma-3-4b-it",
]
GEMMA3_PLACEHOLDER = 6  # <unused0>
GEMMA3_RESPONSE_KEYWORD = "<end_of_turn>\n<start_of_turn>model\n"
GEMMA3_RESPONSE_IDX = [106, 107, 105, 4368, 107]

ALL_MODEL_NAMES = PHI3_MODEL_NAMES + LLAMA3_MODEL_NAMES + MISTRALV3_MODEL_NAMES + GEMMA3_MODEL_NAMES


def empty_cache():
    for _ in range(3):
        gc.collect()
        torch.cuda.empty_cache()
        time.sleep(1)


@dataclass
class TokenConf:
    rf_token_id: int
    response_keyword: str
    response_token_ids: List[int]
    user_keyword: str = None
    user_token_ids: List[int] = None
    
    def __getitem__(self, key):
        return getattr(self, key)
    
    def __setitem__(self, key, value):
        setattr(self, key, value)
    
    def get(self, key, default=None):
        return getattr(self, key, default)


def get_token_conf(model_name: str) -> TokenConf:
    if model_name in PHI3_MODEL_NAMES:
        token_conf = TokenConf(
            rf_token_id=PHI3_PLACEHOLDER,
            response_keyword=PHI3_RESPONSE_KEYWORD,
            response_token_ids=PHI3_RESPONSE_IDX,
        )

    elif model_name in LLAMA3_MODEL_NAMES:
        token_conf = TokenConf(
            rf_token_id=LLAMA3_PLACEHOLDER,
            response_keyword=LLAMA3_RESPONSE_KEYWORD,
            response_token_ids=LLAMA3_RESPONSE_IDX,
            user_keyword=LLAMA3_USER_KEYWORD,
            user_token_ids=LLAMA3_USER_IDX,
        )

    elif model_name in MISTRALV3_MODEL_NAMES: 
        token_conf = TokenConf(
            rf_token_id=MISTRALV3_PLACEHOLDER,
            response_keyword=MISTRALV3_RESPONSE_KEYWORD,
            response_token_ids=MISTRALV3_RESPONSE_IDX,
        )

    elif model_name in GEMMA3_MODEL_NAMES:
        token_conf = TokenConf(
            rf_token_id=GEMMA3_PLACEHOLDER,
            response_keyword=GEMMA3_RESPONSE_KEYWORD,
            response_token_ids=GEMMA3_RESPONSE_IDX,
        )

    else:
        raise NotImplementedError(f"Model {model_name} not implemented; must be one of\n:{ALL_MODEL_NAMES}")
    return token_conf


def get_global_rank() -> int:
    try:
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            return torch.distributed.get_rank()
    except Exception:
        pass
    return int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", "0")))


def configure_rank_zero_logging():
    """Configure logging so only global rank 0 emits logs by default."""
    rank = get_global_rank()
    is_rank_zero = (rank == 0)

    class _RankZeroFilter(logging.Filter):
        def filter(self, record: logging.LogRecord) -> bool:
            return is_rank_zero

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO if is_rank_zero else logging.ERROR)

    if not root_logger.handlers:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
        root_logger.addHandler(handler)

    for handler in root_logger.handlers:
        has_filter = any(isinstance(f, _RankZeroFilter) for f in getattr(handler, 'filters', []))
        if not has_filter:
            handler.addFilter(_RankZeroFilter())


def setup_distributed_environment():
    """Set up environment variables for distributed training from SLURM."""
    try:
        if 'SLURM_CPUS_PER_TASK' in os.environ:
            os.environ['OMP_NUM_THREADS'] = os.environ['SLURM_CPUS_PER_TASK']

        import submitit
        env = submitit.JobEnvironment()
        logging.info("Submitit Environment:")
        logging.info(env)
        
        # These environment variables are all the Trainer needs
        if env.num_tasks > 1:
            os.environ['RANK'] = str(env.global_rank)
            os.environ['LOCAL_RANK'] = str(env.local_rank) 
            os.environ['WORLD_SIZE'] = str(env.num_tasks)
            os.environ['MASTER_ADDR'] = env.hostname
            os.environ['MASTER_PORT'] = '29500'
            
            # NCCL configuration for better reliability and debugging
            os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '1'

            print("-" * 80)
            print(f"Distributed environment set: rank={env.global_rank}/{env.num_tasks-1}")
            print(f"Visible CUDA device: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}")
            print(f"RANK:\t\t{os.environ.get('RANK', None)}")
            print(f"LOCAL_RANK:\t{os.environ.get('LOCAL_RANK', None)}")
            print(f"WORLD_SIZE:\t{os.environ.get('WORLD_SIZE', None)}")
            print(f"MASTER_ADDR:\t{os.environ.get('MASTER_ADDR', None)}")
            print(f"MASTER_PORT:\t{os.environ.get('MASTER_PORT', None)}")
            print("-" * 80)
        return True
        
    except (ImportError, RuntimeError):
        print("Not in submitit environment")
        return False


def to_dict(obj: Any) -> Dict[str, Any]:
    """
    Convert DictConfig, dataclass, or dict to plain dict.
    
    Args:
        obj: Either a DictConfig, dataclass instance, or dict
        
    Returns:
        Plain dict representation
    """
    if isinstance(obj, DictConfig):
        return OmegaConf.to_container(obj, resolve=True)
    elif hasattr(obj, '__dataclass_fields__'):
        return asdict(obj)
    else:
        return obj


def init_wandb(config: DictConfig, force_no_resume: bool, output_dir: Path, group: str = None):
    wandb_run_id_file = output_dir / "wandb_run_id.txt"
    slurm_job_id = os.getenv("SLURM_JOB_ID", None)
    job_name = f"slurm-{slurm_job_id}-{output_dir.name}" if slurm_job_id else output_dir.name 
    logging.info(f"Job name: {job_name}")
    config_dict = to_dict(config)

    if "wandb" in config_dict['training_args']['report_to']:
        if force_no_resume:
            # Force a NEW W&B run
            logging.info("`resume_checkpoint`=False => starting a NEW W&B run, ignoring old run_id.")
            os.environ["WANDB_DISABLED"] = "false"  # Ensure W&B is actually enabled
            wandb_run_id_file.unlink(missing_ok=True)

            wandb.init(name=job_name, group=group, config=config_dict)
            output_dir.mkdir(parents=True, exist_ok=True)
            wandb_run_id_file.write_text(wandb.run.id)

        else:
            # Attempt to resume old run if it exists
            if wandb_run_id_file.exists():
                existing_run_id = wandb_run_id_file.read_text().strip()
                logging.info(f"Resuming W&B run ID: {existing_run_id}")
                wandb.init(
                    name=job_name,
                    group=group,
                    id=existing_run_id,
                    resume="allow",
                    config=config_dict
                )
            else:
                logging.info("No existing W&B run ID => creating a NEW run.")
                wandb.init(name=job_name, group=group, config=config_dict)
                output_dir.mkdir(parents=True, exist_ok=True)
                wandb_run_id_file.write_text(wandb.run.id)
    else:
        logging.info("W&B is disabled (`report_to != 'wandb'`).")
        os.environ["WANDB_DISABLED"] = "true"
        

class LossAccumulator:
    def __init__(self, keywords, accumulation_steps, manual_keywords=None):
        """
        Initialize the LossAccumulator with given keywords and number of steps for gradient accumulation.

        Args:
            keywords (list of str): A list of keywords representing the different loss terms to track.
            accumulation_steps (int): Number of steps for gradient accumulation before computing the final loss.
        """
        self.keywords = keywords
        self.accumulation_steps = accumulation_steps
        self.reset()

    def reset(self):
        """Reset the accumulated losses and step counter."""
        self.loss_sums = {key: 0.0 for key in self.keywords}
        self.current_step = 0

    def update(self, **kwargs):
        """
        Update the accumulated losses with the values provided in kwargs.

        Args:
            kwargs: Key-value pairs where keys are the loss keywords and values are the loss terms.
        """
        for key, value in kwargs.items():
            if key in self.loss_sums:
                self.loss_sums[key] += value
            else:
                raise ValueError(f"Unknown keyword '{key}'. Expected one of: {', '.join(self.keywords)}")

        self.current_step += 1

    def compute_loss(self, prefix: str = ""):
        """
        Compute the averaged loss values for all keywords after accumulation.

        Returns:
            dict: A dictionary with the average loss value for each keyword.
        """
        if self.current_step == 0:
            raise ValueError("No steps have been accumulated yet. Cannot compute loss.")

        avg_losses = {f"{prefix}{key}": value / self.accumulation_steps for key, value in self.loss_sums.items()}
        return avg_losses

    def ready_to_step(self):
        """
        Check if the accumulator is ready for a gradient update step.

        Returns:
            bool: True if the accumulated steps match the required number of accumulation steps.
        """
        return self.current_step >= self.accumulation_steps

    def finalize_and_reset(self):
        """
        Finalize the accumulated loss, reset the accumulator, and return the averaged loss.

        Returns:
            dict: A dictionary with the average loss value for each keyword.
        """
        avg_losses = self.compute_loss()
        self.reset()
        return avg_losses


class ManualLossAccumulator:
    def __init__(self, keywords):
        self.keywords = keywords
        self.reset()

    def reset(self):
        """Reset the accumulated losses and step counter."""
        self.loss_sums = {key: 0.0 for key in self.keywords}
        self.loss_steps = {key: 0 for key in self.keywords}

    def update(self, **kwargs):
        """
        Expects inputs of dicts with the form {keyword: {value: value, steps: n}}
        """
        for key, value in kwargs.items():
            if key in self.loss_sums:
                self.loss_sums[key] += value['value']
                self.loss_steps[key] += value['steps']
            else:
                raise ValueError(f"Unknown keyword '{key}'. Expected one of: {', '.join(self.keywords)}")

    def compute_loss(self, prefix: str = ""):
        """
        Compute the averaged loss values for all keywords after accumulation.

        Returns:
            dict: A dictionary with the average loss value for each keyword.
        """
        avg_losses = {}
        for key, value in self.loss_sums.items():
            step = self.loss_steps[key]
            if step == 0:
                avg_losses[f"{prefix}{key}"] = -1
            else:
                # avg_losses[f"{prefix}{key}"] = value  # / step  # TODO: TEMPORARY SANITY CHECK!!!!!!!
                avg_losses[f"{prefix}{key}"] = value / step
        return avg_losses 

    def finalize_and_reset(self):
        """
        Finalize the accumulated loss, reset the accumulator, and return the averaged loss.

        Returns:
            dict: A dictionary with the average loss value for each keyword.
        """
        avg_losses = self.compute_loss()
        self.reset()
        return avg_losses

# =============================================================================
# Hydra Override Utilities
# =============================================================================

FILTERED_OVERRIDE_KEYS = [
    'training_args.output_dir',
    'training_args.report_to',
    'hydra.launcher',
]

SPLIT_ON_KW = [
    'training_args.',
    'model_config.',
    'script_args.',
    '@script_args=',
    'script_args/',
    'training_args/',
    'insert_sampler/',
    'adv_attack.',
]

def _filter_overrides(overrides: List[str]) -> List[str]:
    """
    Hacky way to filter out overrides that are not relevant to the output directory.
    """
    filtered_overrides = []
    for override in overrides:
        if not any(key in override for key in FILTERED_OVERRIDE_KEYS):
            filtered_overrides.append(override)
    if len(filtered_overrides) == 0:
        return 'default'
    # clean up strings from parent configs
    split_filtered_overrides = []
    for x in filtered_overrides:
        for kw in SPLIT_ON_KW:
            if kw in x:
                x = x.split(kw)[-1]
                split_filtered_overrides.append(x)
                break
        else:
            split_filtered_overrides.append(x)
    # merge into final string
    split_filtered_overrides = ','.join(split_filtered_overrides)
    return split_filtered_overrides


def get_hydra_overrides() -> Dict[str, Any]:
    """
    Get comprehensive override information from Hydra configuration.
    
    Returns:
        Dict containing:
            - task_overrides: List of task-level overrides
            - hydra_overrides: List of hydra-specific overrides  
            - override_dirname: String representation of overrides for directory naming
            - job_name: Name of the current job
            - output_dir: Hydra's output directory
    """
    try:
        from hydra.core.hydra_config import HydraConfig
        hydra_cfg = HydraConfig.get()
        return {
            'task_overrides': hydra_cfg.overrides.task,
            'hydra_overrides': hydra_cfg.overrides.hydra,
            'override_dirname': hydra_cfg.job.override_dirname,
            'job_name': hydra_cfg.job.name,
            'output_dir': hydra_cfg.runtime.output_dir,
            'filtered_override_dirname': _filter_overrides(hydra_cfg.overrides.task)
        }
    except Exception as e:
        logging.warning(f"Could not access Hydra config: {e}")
        return {
            'task_overrides': [],
            'hydra_overrides': [],
            'override_dirname': '',
            'job_name': 'unknown',
            'output_dir': 'outputs',
            'filtered_output_dir': []
        }


def create_output_dir_with_overrides(base_output_dir: str, override_dirname: Optional[str] = None) -> str:
    """
    Create output directory path incorporating override information.
    
    Args:
        base_output_dir: Base output directory path
        override_dirname: Override dirname (if None, will fetch from Hydra)
        
    Returns:
        Updated output directory path with override information
    """
    if override_dirname is None:
        try:
            from hydra.core.hydra_config import HydraConfig
            hydra_cfg = HydraConfig.get()
            override_dirname = hydra_cfg.job.override_dirname
        except Exception:
            override_dirname = ""
    
    if override_dirname:
        # Clean up the override dirname for filesystem compatibility
        clean_dirname = (
            override_dirname
            # .replace('=', '_')
            # .replace(',', '_')
            .replace('/', '_')
            .replace(' ', '_')
            # .replace('.', '-')
        )
        return f"{base_output_dir}/{clean_dirname}"
    else:
        return base_output_dir


def log_hydra_override_information(logger_instance: Optional[logging.Logger] = None) -> Dict[str, Any]:
    """
    Log detailed override information and return it.
    
    Args:
        logger_instance: Logger to use (defaults to root logger)
        
    Returns:
        Dictionary containing override information
    """
    if logger_instance is None:
        logger_instance = logging.getLogger(__name__)
    
    overrides_info = get_hydra_overrides()
    
    logger_instance.info("="*60)
    logger_instance.info("HYDRA OVERRIDE INFORMATION")
    logger_instance.info("="*60)
    
    task_overrides = overrides_info['task_overrides']
    logger_instance.info(f"Task Overrides ({len(task_overrides)} items):")
    for i, override in enumerate(task_overrides, 1):
        logger_instance.info(f"  {i:2d}. {override}")
    
    hydra_overrides = overrides_info['hydra_overrides']
    if hydra_overrides:
        logger_instance.info(f"Hydra Overrides ({len(hydra_overrides)} items):")
        for i, override in enumerate(hydra_overrides, 1):
            logger_instance.info(f"  {i:2d}. {override}")
    
    logger_instance.info(f"Override Dirname: '{overrides_info['override_dirname']}'")
    logger_instance.info(f"Job Name: {overrides_info['job_name']}")
    
    return overrides_info


def save_hydra_overrides_to_file(output_dir: str, overrides_info: Optional[Dict[str, Any]] = None) -> Path:
    """
    Save override information to a file for reproducibility.
    
    Args:
        output_dir: Directory to save the overrides file
        overrides_info: Override information dict (if None, will fetch from Hydra)
        
    Returns:
        Path to the saved overrides file
    """
    if overrides_info is None:
        overrides_info = get_hydra_overrides()
    
    override_file = Path(output_dir) / "hydra_overrides.txt"
    override_file.parent.mkdir(parents=True, exist_ok=True)
    
    with open(override_file, 'w') as f:
        f.write("# Hydra Overrides Used for Reproducibility\n")
        f.write(f"# Job: {overrides_info['job_name']}\n")
        f.write(f"# Output Directory: {overrides_info['output_dir']}\n")
        f.write(f"# Override Dirname: {overrides_info['override_dirname']}\n\n")
        
        f.write("## Task Overrides:\n")
        if overrides_info['task_overrides']:
            for override in overrides_info['task_overrides']:
                f.write(f"{override}\n")
        else:
            f.write("# No task overrides\n")
        
        if overrides_info['hydra_overrides']:
            f.write("\n## Hydra Overrides:\n")
            for override in overrides_info['hydra_overrides']:
                f.write(f"{override}\n")
    
    logging.info(f"Saved overrides to: {override_file}")
    return override_file