import os
from typing import Optional, List, Union, Dict, Any
import time
from utils.logger import Logger

class Config:
    """
    Configuration class for counterfactual data augmentation experiments.
    Generates a unique name for each experiment run and stores experiment parameters.
    """
    def __init__(
        self,
        # Dataset parameters
        dataset_path: str = "./data/train.jsonl",

        # DQN parameters
        learning_rate: float = 5e-4,
        learning_starts: int = 50,
        gamma: float = 0.99,
        batch_size: int = 16,
        target_update_interval: int = 50,
        exploration_fraction: float = 0.5,
        exploration_initial_eps: float = 1.0,
        exploration_final_eps: float = 0.1,
        buffer_size: int = 500,
        verbose: int = 1,

        # Experiment parameters
        random_seed: int = 42,  # Random seed for reproducibility
        text_embedding_dim: int = 1536,
        max_steps: int = 6, # we index starting from 0 so this is 7 steps
        idx_to_start: int = 0,
        global_step_counter: int = 0,
        log_interval: int = 20,
        model_usage_counts = {},
        model_usage_type_counts = {},
        model_usage_type_score_sums = {},
        model_score_sums = {},
        step_times = [],
        eval_mode=False,

        # Output parameters
        results_dir: str = "/datasets/uig/results/",
        
        # Additional parameters that can be passed as a dictionary
        **kwargs: Any
    ):
        # Generate a unique name for this experiment run
        self.timestamp = time.strftime("%Y%m%d-%H%M%S")
        self.experiment_name = f"run_{self.timestamp}"
        
        # Dataset parameters
        self.dataset_path = dataset_path
        
        # Experiment parameters
        self.random_seed = random_seed
        self.text_embedding_dim = text_embedding_dim
        self.max_steps = max_steps
        self.idx_to_start = idx_to_start
        self.global_step_counter = global_step_counter
        self.log_interval = log_interval
        self.model_usage_counts = model_usage_counts
        self.model_usage_type_counts = model_usage_type_counts
        self.model_usage_type_score_sums = model_usage_type_score_sums
        self.model_score_sums = model_score_sums
        self.step_times = step_times
        self.eval_mode = eval_mode

        # Output parameters
        user = os.getenv("USER")
        self.results_dir = results_dir + user
        os.makedirs(self.results_dir, exist_ok=True)

        # Create an experiment directory
        self.experiment_dir = self.get_experiment_dir()
        self.logger = Logger(log_to_console=False, log_to_file=True, log_file=os.path.join(self.experiment_dir, f"{self.experiment_name}.log"))

        # DQN parameters
        self.learning_rate = learning_rate
        self.learning_starts = learning_starts
        self.gamma = gamma
        self.batch_size = batch_size
        self.target_update_interval = target_update_interval
        self.exploration_fraction = exploration_fraction
        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_final_eps = exploration_final_eps
        self.buffer_size = buffer_size
        self.verbose = verbose
        self.tensorboard_log = self.experiment_dir + "/tensorboard"

        # WandB parameters
        self.wandb_project = "maestro"
        self.wandb_name = self.experiment_name
        self.wandb_dir = self.experiment_dir + "/wandb"
        self.checkpoint_dir = self.experiment_dir + "/checkpoints"

        # Stats directory
        self.stats_dir = self.experiment_dir + "/stats"

        # Store any additional parameters
        for key, value in kwargs.items():
            setattr(self, key, value)
    
    def get_experiment_dir(self) -> str:
        """
        Returns the directory path for this experiment's results.
        """
        exp_dir = os.path.join(self.results_dir, f"{self.experiment_name}")
        os.makedirs(exp_dir, exist_ok=True)
        return exp_dir
    
    def to_dict(self) -> Dict[str, Any]:
        """
        Convert the config to a dictionary for serialization.
        """
        return {k: v for k, v in self.__dict__.items()}
    
    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> 'Config':
        """
        Create a Config instance from a dictionary.
        """
        config = cls(**config_dict)
        return config
    
    def __str__(self) -> str:
        """
        Returns a string representation of the config.
        """
        return f"Experiment Config (Name: {self.experiment_name})\n" + "\n".join(
            f"  {k}: {v}" for k, v in self.__dict__.items() if k != 'experiment_name'
        )
