#!/usr/bin/env python3
"""
TensorBoard-based wandb replacement.
Usage: import tensorboard_wandb as wandb
"""

import os
import json
from typing import Dict, Any, Optional, Union
from torch.utils.tensorboard import SummaryWriter
import torch


class Config:
    """Mock wandb config object."""
    
    def __init__(self, parent=None):
        self._config = {}
        self._parent = parent
    
    def _save_config(self):
        """Save config to file and log to TensorBoard if parent and run_dir exist."""
        if self._parent and self._parent._run_dir:
            # Save to JSON file
            config_path = os.path.join(self._parent._run_dir, "config.json")
            try:
                with open(config_path, 'w') as f:
                    json.dump(self.to_dict(), f, indent=2, default=str)
            except Exception as e:
                print(f"Warning: Could not save config: {e}")
            
    def update(self, config_dict: Union[Dict[str, Any], object], allow_val_change: bool = True):
        """Update configuration with dictionary or object attributes."""
        if hasattr(config_dict, '__dict__'):
            # Handle dataclass or object with attributes
            config_dict = {k: v for k, v in config_dict.__dict__.items() 
                          if not k.startswith('_')}
        
        if allow_val_change:
            self._config.update(config_dict)
        else:
            for key, value in config_dict.items():
                if key not in self._config:
                    self._config[key] = value
        
        # Save config to file after update
        self._save_config()
    
    def __getitem__(self, key):
        return self._config[key]
    
    def __setitem__(self, key, value):
        self._config[key] = value
        # Save config to file after setting individual item
        self._save_config()
    
    def get(self, key, default=None):
        return self._config.get(key, default)
    
    def to_dict(self):
        return self._config.copy()


class TensorBoardWandb:
    """TensorBoard-based wandb replacement."""
    
    def __init__(self):
        self.writer: Optional[SummaryWriter] = None
        self.config = Config(parent=self)  # Pass parent reference
        self._run_dir = None
        self._project = None
        self._name = None
        self._id = None
        self._step = 0
        self._hparams_logged = False  # Track if hyperparameters have been logged
        
    def init(self, 
             project: str = "default_project",
             name: Optional[str] = None,
             id: Optional[str] = None,
             resume: str = "allow",
             mode: str = "online",
             dir: Optional[str] = None,
             **kwargs):
        """Initialize TensorBoard logging."""
        
        self._project = project
        self._name = name or "default_run"
        self._id = id or "default_id"
        
        # Create log directory
        if dir is None:
            base_dir = os.path.join("runs", project)
        else:
            base_dir = os.path.join(dir, project)
            
        if id:
            self._run_dir = os.path.join(base_dir, f"{self._name}-{id}")
        else:
            self._run_dir = os.path.join(base_dir, self._name)
            
        os.makedirs(self._run_dir, exist_ok=True)
        
        # Initialize TensorBoard writer
        self.writer = SummaryWriter(log_dir=self._run_dir)
        
        # Save config file
        config_path = os.path.join(self._run_dir, "config.json")
        if os.path.exists(config_path) and resume == "allow":
            # Load existing config
            try:
                with open(config_path, 'r') as f:
                    existing_config = json.load(f)
                    self.config._config.update(existing_config)
            except Exception:
                pass
        
        print(f"TensorBoard logging initialized at: {self._run_dir}")
        return self
    
    def _log_hparams_once(self):
        """Log hyperparameters to TensorBoard once (called during first log)."""
        if self._hparams_logged or not self.writer:
            return
            
        try:
            config_dict = self.config.to_dict()
            
            # Log config as text instead of hparams to avoid separate event files
            text_config = []
            for key, value in config_dict.items():
                text_config.append(f"**{key}**: {value}")
            
            # Log full config as formatted text
            config_text = "## Configuration\n\n" + "\n\n".join(text_config)
            self.writer.add_text(
                "config/hyperparameters", 
                config_text, 
                0  # Log at step 0
            )
            
            # Uncomment below if you want hyperparameters in HPARAMS tab (creates separate event files)
            # hparams = {}
            # for key, value in config_dict.items():
            #     if isinstance(value, (int, float, str, bool)):
            #         hparams[key] = value
            #     elif isinstance(value, (tuple, list)) and len(value) <= 3:
            #         hparams[key] = str(value)
            # 
            # if hparams:
            #     self.writer.add_hparams(
            #         hparam_dict=hparams,
            #         metric_dict={},
            #         run_name=self._name or "default_run"
            #     )
            
            self._hparams_logged = True
            print(f"Logged configuration to TensorBoard as text")
                
        except Exception as e:
            print(f"Warning: Could not log config to TensorBoard: {e}")
    
    def log(self, data: Dict[str, Any], step: Optional[int] = None, commit: bool = True):
        """Log data to TensorBoard."""
        if self.writer is None:
            raise RuntimeError("Must call wandb.init() first")
        
        # Log hyperparameters once on first call to log
        self._log_hparams_once()
            
        if step is None:
            step = self._step
            self._step += 1
        
        for key, value in data.items():
            if isinstance(value, (int, float)):
                self.writer.add_scalar(key, value, step)
            elif isinstance(value, torch.Tensor):
                if value.numel() == 1:  # Scalar tensor
                    self.writer.add_scalar(key, value.item(), step)
                elif value.dim() == 3:  # Image tensor (C, H, W)
                    self.writer.add_image(key, value, step)
                elif value.dim() == 4:  # Batch of images (B, C, H, W)
                    self.writer.add_images(key, value, step)
                else:
                    # For other tensors, log as histogram
                    self.writer.add_histogram(key, value, step)
            elif hasattr(value, '__len__') and not isinstance(value, str):
                # Handle lists/arrays
                try:
                    # Try to convert to tensor and log as histogram
                    tensor_value = torch.tensor(value) if not isinstance(value, torch.Tensor) else value
                    if tensor_value.numel() > 1:
                        self.writer.add_histogram(key, tensor_value, step)
                    else:
                        self.writer.add_scalar(key, tensor_value.item(), step)
                except Exception:
                    # If conversion fails, just log the length
                    self.writer.add_scalar(f"{key}_length", len(value), step)
            else:
                # For other types, try to convert to string and log as text
                try:
                    self.writer.add_text(key, str(value), step)
                except Exception:
                    pass
        
        if commit:
            self.writer.flush()
    
    def watch(self, model, log_freq: int = 1000, log_graph: bool = True):
        """Watch model parameters (simplified version)."""
        # No-op for compatibility - TensorBoard model watching would require sample input
        print("TensorBoard wandb mock: watch() called (no-op)")
        return model
    
    def save(self, filename: str):
        """Save file to run directory."""
        if self._run_dir is None:
            raise RuntimeError("Must call wandb.init() first")
        
        # Simply copy the file to run directory
        import shutil
        if os.path.exists(filename):
            dest_path = os.path.join(self._run_dir, os.path.basename(filename))
            shutil.copy2(filename, dest_path)
            print(f"Saved {filename} to {dest_path}")
    
    def finish(self):
        """Finish the run and close TensorBoard writer."""
        if self.writer is not None:
            # Config is automatically saved when updated, so no need to save again here
            self.writer.close()
            self.writer = None
            print(f"TensorBoard logging finished. View with: tensorboard --logdir={self._run_dir}")
    
    def alert(self, title: str, text: str, level: str = "INFO"):
        """Log alert as text."""
        if self.writer is not None:
            self.writer.add_text(f"alert/{level}", f"{title}: {text}", self._step)
    
    @property
    def run(self):
        """Mock run object."""
        class MockRun:
            def __init__(self, parent):
                self.parent = parent
                
            @property
            def name(self):
                return self.parent._name
                
            @property
            def id(self):
                return self.parent._id
                
            @property
            def project(self):
                return self.parent._project
                
            @property
            def dir(self):
                return self.parent._run_dir
        
        return MockRun(self)


# Create global instance to mimic wandb's interface
_wandb_instance = TensorBoardWandb()

# Export the main functions to match wandb interface
init = _wandb_instance.init
log = _wandb_instance.log
watch = _wandb_instance.watch
save = _wandb_instance.save
finish = _wandb_instance.finish
alert = _wandb_instance.alert
config = _wandb_instance.config
run = _wandb_instance.run

# Additional aliases for compatibility
def login(key: Optional[str] = None, **kwargs):
    """Mock login function."""
    print("TensorBoard wandb mock: login() called (no-op)")
    return True

def sweep(sweep_config: Dict[str, Any], project: Optional[str] = None):
    """Mock sweep function."""
    print("TensorBoard wandb mock: sweep() called (not implemented)")
    return "mock_sweep_id"

def agent(sweep_id: str, function, count: Optional[int] = None):
    """Mock agent function."""
    print("TensorBoard wandb mock: agent() called (not implemented)")
    if function:
        function()

# For backwards compatibility
Table = dict  # Simple replacement for wandb.Table
Image = lambda x, caption=None: {"image": x, "caption": caption}
Audio = lambda x, caption=None, sample_rate=None: {"audio": x, "caption": caption}
Video = lambda x, caption=None: {"video": x, "caption": caption} 