"""FPVR configuration (Future-Past Visitation Redundancy).

Usage:
    from visual_minigrid_maze.config import Config
    config = Config()          # default config
    config.beta = 2.0          # override a parameter

Or create from a dict:
    config = Config.from_dict({'beta': 2.0, 'steps': 5000})
"""

from dataclasses import dataclass, field, asdict
from typing import Literal
import json


@dataclass
class Config:
    """Global configuration for FPVR (Future-Past Visitation Redundancy)."""
    
    # ========== Environment ==========
    env_size: int = 20                 # environment size
    env_seed: int = 123                # environment RNG seed
    
    # ========== Training ==========
    steps: int = 9000                  # total environment steps
    k_train: int = 5                   # number of training updates per env step
    reset_interval: int = 3000         # reset windowed coverage stats every K steps (0 disables)
    print_interval: int = 100          # print training loss every K steps (0 disables)
    
    # ========== Model ==========
    phi_dim: int = 400                 # φ feature dimension (default = env_size^2)
    psi_dim: int = None                # ψ feature dimension (None => equals phi_dim)
    lr: float = 0.001                  # learning rate
    
    # ========== FPVR ==========
    beta: float = 1.0                  # softmax temperature (larger => stronger preference for lower redundancy)
    sf_gamma: float = 0.9              # successor-feature discount factor
    lambda_c: float = 0.95             # decay factor for persistence vector c
    
    # ========== ZCA whitening ==========
    whitening_update_every: int = 100  # update whitening matrix every N training steps
    
    # ========== Replay buffer ==========
    capacity: int = 3000               # replay buffer capacity
    batch_size: int = 64               # batch size
    update_after: int = 1              # minimum buffer size before training starts
    update_every: int = 1              # train every N steps
    
    # ========== Algorithm options ==========
    sf_target: Literal["uniform_policy", "current_policy", "min_redundancy"] = "min_redundancy"  # SR target mode
    
    # ========== Visualization ==========
    visualize: bool = True             # whether to visualize in real time
    render_delay: float = 0.001        # render delay (seconds)
    
    def __post_init__(self):
        """Post-init: auto-fill dependent parameters."""
        if self.psi_dim is None:
            self.psi_dim = self.phi_dim
        # Default phi_dim/psi_dim equals the number of grid cells (env_size^2).
        if self.phi_dim == 400 and self.env_size != 20:
            self.phi_dim = self.env_size * self.env_size
            self.psi_dim = self.phi_dim
    
    @classmethod
    def from_dict(cls, config_dict: dict):
        """Create a config from a Python dict."""
        return cls(**config_dict)
    
    @classmethod
    def from_json(cls, json_path: str):
        """Load a config from a JSON file."""
        with open(json_path, 'r', encoding='utf-8') as f:
            config_dict = json.load(f)
        return cls.from_dict(config_dict)
    
    def to_dict(self):
        """Convert to a Python dict."""
        return asdict(self)
    
    def to_json(self, json_path: str):
        """Save as a JSON file."""
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
    
    def __repr__(self):
        """Pretty-print."""
        lines = ["FPVR Configuration:"]
        lines.append("=" * 50)
        for key, value in asdict(self).items():
            lines.append(f"  {key:20s} = {value}")
        lines.append("=" * 50)
        return "\n".join(lines)


# Preset configurations
class Presets:
    """Preset config factory."""
    
    @staticmethod
    def default():
        """Default configuration."""
        return Config()
    
    @staticmethod
    def fast_test():
        """Quick test preset (few steps, fast sanity check)."""
        return Config(
            steps=500,
            k_train=1,
            visualize=False,
            env_size=10
        )
    
    @staticmethod
    def high_exploration():
        """High-exploration preset (larger beta)."""
        return Config(
            beta=2.0,
            sf_gamma=0.99,
        )


if __name__ == "__main__":
    # Example: create and save a config
    config = Config()
    print(config)
    
    # Save to JSON
    config.to_json("config_default.json")
    print("\nSaved config to config_default.json")
    
    # Load config
    loaded_config = Config.from_json("config_default.json")
    print("\nLoaded config from JSON:")
    print(loaded_config)

