from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from pathlib import Path


@dataclass
class ModelConfig:
    llm_model_name: str = "microsoft/DialoGPT-medium"
    hidden_dim: int = 256
    max_motifs: int = 50
    max_sites_per_motif: int = 10
    max_length: int = 1024


@dataclass
class TrainingConfig:
    learning_rate: float = 3e-4
    batch_size: int = 32
    num_episodes_per_iteration: int = 10
    num_iterations: int = 1000
    max_steps_per_episode: int = 100

    # PPO hyperparameters
    clip_epsilon: float = 0.2
    entropy_coef: float = 0.01
    value_loss_coef: float = 0.5
    max_grad_norm: float = 0.5
    ppo_epochs: int = 4
    mini_batch_size: int = 32
    gae_lambda: float = 0.95
    gamma: float = 0.99

    # Set-BC parameters
    set_bc_weight: float = 1.0
    kl_coef: float = 0.01

    # Curriculum learning
    use_curriculum: bool = True
    curriculum_stages: int = 4


@dataclass
class RewardConfig:
    chemical_weight: float = 0.4
    topological_weight: float = 0.6

    # Chemical reward weights
    chemical_weights: Dict[str, float] = field(default_factory=lambda: {
        'validity': 1.0,
        'stability': 0.5,
        'functional_groups': 0.3,
        'property_alignment': 0.8,
        'synthetic_accessibility': 0.2,
        'novelty': 0.1
    })

    # Topological reward weights
    topological_weights: Dict[str, float] = field(default_factory=lambda: {
        'connectivity': 1.0,
        'edge_progress': 0.8,
        'topology_similarity': 0.6,
        'over_connection_penalty': 0.5
    })


@dataclass
class EnvironmentConfig:
    max_steps: int = 100
    mode: str = "reconstruction"  # "reconstruction" or "generation"
    chemical_validation: bool = True
    topology_validation: bool = True


@dataclass
class InferenceConfig:
    temperature: float = 0.1
    beam_width: int = 1
    max_attempts: int = 5
    diversity_bonus: float = 0.1


@dataclass
class DataConfig:
    train_data_path: Optional[str] = None
    val_data_path: Optional[str] = None
    test_data_path: Optional[str] = None
    motif_library_path: Optional[str] = None

    smiles_column: str = "smiles"
    properties_columns: List[str] = field(default_factory=list)

    train_ratio: float = 0.8
    val_ratio: float = 0.1
    test_ratio: float = 0.1

    max_molecules_per_dataset: Optional[int] = None


@dataclass
class LoggingConfig:
    log_level: str = "INFO"
    log_dir: str = "logs"
    tensorboard_dir: str = "tensorboard"
    save_frequency: int = 100
    eval_frequency: int = 50


@dataclass
class SystemConfig:
    device: str = "auto"  # "auto", "cpu", "cuda"
    num_workers: int = 4
    seed: int = 42
    deterministic: bool = False


@dataclass
class MotifAgentConfig:
    model: ModelConfig = field(default_factory=ModelConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    rewards: RewardConfig = field(default_factory=RewardConfig)
    environment: EnvironmentConfig = field(default_factory=EnvironmentConfig)
    inference: InferenceConfig = field(default_factory=InferenceConfig)
    data: DataConfig = field(default_factory=DataConfig)
    logging: LoggingConfig = field(default_factory=LoggingConfig)
    system: SystemConfig = field(default_factory=SystemConfig)

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> 'MotifAgentConfig':
        # Create nested configs
        model_config = ModelConfig(**config_dict.get('model', {}))
        training_config = TrainingConfig(**config_dict.get('training', {}))
        rewards_config = RewardConfig(**config_dict.get('rewards', {}))
        environment_config = EnvironmentConfig(**config_dict.get('environment', {}))
        inference_config = InferenceConfig(**config_dict.get('inference', {}))
        data_config = DataConfig(**config_dict.get('data', {}))
        logging_config = LoggingConfig(**config_dict.get('logging', {}))
        system_config = SystemConfig(**config_dict.get('system', {}))

        return cls(
            model=model_config,
            training=training_config,
            rewards=rewards_config,
            environment=environment_config,
            inference=inference_config,
            data=data_config,
            logging=logging_config,
            system=system_config
        )

    def to_dict(self) -> Dict[str, Any]:
        return {
            'model': self.model.__dict__,
            'training': self.training.__dict__,
            'rewards': {
                'chemical_weight': self.rewards.chemical_weight,
                'topological_weight': self.rewards.topological_weight,
                'chemical_weights': self.rewards.chemical_weights,
                'topological_weights': self.rewards.topological_weights
            },
            'environment': self.environment.__dict__,
            'inference': self.inference.__dict__,
            'data': self.data.__dict__,
            'logging': self.logging.__dict__,
            'system': self.system.__dict__
        }

    def save(self, file_path: str):
        import json
        with open(file_path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def load(cls, file_path: str) -> 'MotifAgentConfig':
        import json
        with open(file_path, 'r') as f:
            config_dict = json.load(f)
        return cls.from_dict(config_dict)

    def update_from_dict(self, updates: Dict[str, Any]):
        # Update configuration with new values
        for section, section_updates in updates.items():
            if hasattr(self, section):
                section_obj = getattr(self, section)
                for key, value in section_updates.items():
                    if hasattr(section_obj, key):
                        setattr(section_obj, key, value)

    def validate(self) -> List[str]:
        # Validate configuration settings
        errors = []

        # Model validation
        if self.model.max_motifs <= 0:
            errors.append("max_motifs must be positive")

        if self.model.hidden_dim <= 0:
            errors.append("hidden_dim must be positive")

        # Training validation
        if self.training.learning_rate <= 0:
            errors.append("learning_rate must be positive")

        if self.training.clip_epsilon <= 0 or self.training.clip_epsilon > 1:
            errors.append("clip_epsilon must be in (0, 1]")

        if self.training.gamma < 0 or self.training.gamma > 1:
            errors.append("gamma must be in [0, 1]")

        # Reward validation
        if abs(self.rewards.chemical_weight + self.rewards.topological_weight - 1.0) > 1e-6:
            errors.append("chemical_weight + topological_weight should sum to 1.0")

        # Data validation
        if abs(self.data.train_ratio + self.data.val_ratio + self.data.test_ratio - 1.0) > 1e-6:
            errors.append("Data split ratios must sum to 1.0")

        # Environment validation
        if self.environment.mode not in ["reconstruction", "generation"]:
            errors.append("mode must be 'reconstruction' or 'generation'")

        return errors

    def get_device(self):
        import torch
        if self.system.device == "auto":
            return torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            return torch.device(self.system.device)


def get_default_config() -> MotifAgentConfig:
    return MotifAgentConfig()


def create_reconstruction_config() -> MotifAgentConfig:
    config = get_default_config()
    config.environment.mode = "reconstruction"
    config.training.set_bc_weight = 1.0
    config.rewards.topological_weight = 0.7
    config.rewards.chemical_weight = 0.3
    return config


def create_generation_config() -> MotifAgentConfig:
    config = get_default_config()
    config.environment.mode = "generation"
    config.training.set_bc_weight = 0.2
    config.training.entropy_coef = 0.05  # Higher exploration
    config.rewards.topological_weight = 0.4
    config.rewards.chemical_weight = 0.6
    config.inference.temperature = 0.5
    config.inference.diversity_bonus = 0.2
    return config


def create_fast_training_config() -> MotifAgentConfig:
    config = get_default_config()
    config.training.num_iterations = 500
    config.training.num_episodes_per_iteration = 5
    config.training.ppo_epochs = 2
    config.model.max_length = 512
    return config