"""
Wandb logging utilities for GRPO training with system dict and reward components.

This module provides a clean, modular way to handle wandb logging including:
- System dict visualization
- Reward component curves
- Graph visualizations and network plots
- Formatted reward strings
"""

import json
import copy
from typing import Dict, List, Optional, Any, Tuple
import pandas as pd
import wandb
from transformers import TrainerCallback
from utils.logging_utils import create_wandb_logs


class WandbLogger:
    """Clean, Pythonic wandb logger for GRPO training with system dict and reward components."""
    
    def __init__(self):
        self.last_component_rewards = {}
        
    def log_training_step(
        self,
        step: int,
        prompts: List[Dict],
        completions: List[Dict],
        rewards: List[float],
        reflections: Optional[List[str]] = None,
        gen_globals_list: Optional[List[Dict]] = None,
        reward_funcs: Optional[List] = None,
        is_eval_step: bool = False,
        answers: Optional[List[str]] = None
    ) -> None:
        """
        Log a training step to wandb with system dict and reward components.
        
        Args:
            step: Training step number
            prompts: List of prompt dictionaries
            completions: List of completion dictionaries  
            rewards: List of total rewards
            reflections: Optional reflection strings
            gen_globals_list: List of generated globals containing system_dict
            reward_funcs: List of reward functions to extract data from
            is_eval_step: Whether this is an evaluation step
        """
        if not (wandb.run and reward_funcs):
            return
            
        # Convert to string format for logging
        prompts_strs, completions_strs, reflections_strs = create_wandb_logs(
            prompts, completions, reflections
        )
        
        # Extract data from reward functions
        (graph_viz, network_plots, component_rewards, 
         formatted_rewards, system_dicts, power_system_weights) = self._extract_reward_data(reward_funcs, gen_globals_list)
        
        # Create main table
        table_data = self._create_table_data(
            step, prompts_strs, completions_strs, reflections_strs,
            system_dicts, formatted_rewards, graph_viz, network_plots, answers
        )
        
        # Log component curves
        self._log_component_curves(component_rewards, power_system_weights, step, is_eval_step)
        
        # Log main table - skip at step 0 to avoid meaningless initial points
        if step > 0:
            df = pd.DataFrame(table_data)
            wandb.log({"completions": wandb.Table(dataframe=df), "step": step})
        
    def _extract_reward_data(self, reward_funcs: List, gen_globals_list: Optional[List[Dict]] = None) -> Tuple[List, List, Dict, List, List, Dict]:
        """Extract all necessary data from reward functions."""
        graph_visualizations = []
        network_plots = []
        component_rewards = {}
        formatted_reward_strings = []
        system_dicts = []
        power_system_weights = {}
        
        for reward_func in reward_funcs:
            # Extract graph visualizations
            viz = self._get_method_result(reward_func, 'get_last_graph_visualizations')
            if viz:
                graph_visualizations = viz
                
            # Extract network plots  
            plots = self._get_method_result(reward_func, 'get_last_network_plots')
            if plots:
                network_plots = plots
                
            # Extract component rewards
            comp_rewards = self._get_method_result(reward_func, 'get_component_rewards_summary')
            if comp_rewards:
                component_rewards.update(comp_rewards)
                
            # Extract formatted reward strings
            reward_strings = self._get_method_result(reward_func, 'get_formatted_reward_strings')
            if reward_strings:
                formatted_reward_strings = reward_strings
                
            # Extract system dicts from last reward components
            system_dicts = self._extract_system_dicts(reward_func, gen_globals_list)
            
            # Extract power system weights
            weights = self._get_method_result(reward_func, 'get_power_system_weights')
            if weights:
                power_system_weights = weights
                
        return (graph_visualizations, network_plots, component_rewards, 
                formatted_reward_strings, system_dicts, power_system_weights)
    
    def _get_method_result(self, reward_func, method_name: str):
        """Safely get result from reward function method."""
        # Check direct method
        if hasattr(reward_func, method_name):
            result = getattr(reward_func, method_name)()
            if result and (not isinstance(result, list) or any(r is not None for r in result)):
                return result
                
        # Check bound method instance
        if hasattr(reward_func, '__self__') and hasattr(reward_func.__self__, method_name):
            result = getattr(reward_func.__self__, method_name)()
            if result and (not isinstance(result, list) or any(r is not None for r in result)):
                return result
                
        return None
    
    def _extract_system_dicts(self, reward_func, gen_globals_list: Optional[List[Dict]] = None) -> List[str]:
        """Extract system_dict from gen_globals_list or reward function's last computation."""
        system_dicts = []
        
        # First try to get from gen_globals_list if available
        if gen_globals_list:
            for gen_globals in gen_globals_list:
                system_dict = gen_globals.get('system_dict', {})
                if not system_dict:
                    system_dict = {"status": "no_system_dict_found"}
                    
                # Format as pretty JSON string
                try:
                    system_dict_str = json.dumps(system_dict, indent=2, sort_keys=True)
                except (TypeError, ValueError):
                    system_dict_str = str(system_dict)
                    
                system_dicts.append(system_dict_str)
            return system_dicts
            
        # Fallback: try to get from reward components
        components = self._get_method_result(reward_func, 'get_last_reward_components')
        if not components:
            return []
            
        for component in components:
            # Check if component has system_dict in details
            if hasattr(component, 'details') and 'system_dict' in component.details:
                system_dict = component.details['system_dict']
            else:
                # Fallback: create a placeholder
                system_dict = {"status": "not_available"}
                
            # Format as pretty JSON string
            try:
                system_dict_str = json.dumps(system_dict, indent=2, sort_keys=True)
            except (TypeError, ValueError):
                system_dict_str = str(system_dict)
                
            system_dicts.append(system_dict_str)
            
        return system_dicts
    
    def _create_table_data(
        self, 
        step: int,
        prompts_strs: List[str],
        completions_strs: List[str], 
        reflections_strs: Optional[List[str]],
        system_dicts: List[str],
        formatted_rewards: List[str],
        graph_viz: List,
        network_plots: List,
        answers: Optional[List[str]] = None
    ) -> Dict[str, List]:
        """Create table data for wandb logging."""
        table_data = {
            "step": [str(step)] * len(prompts_strs),
            "prompt": prompts_strs,
            "completion": completions_strs,
            "system_dict": system_dicts or ["N/A"] * len(prompts_strs),
            "reward": formatted_rewards or ["N/A"] * len(prompts_strs),
        }
        
        if answers:
            table_data["answer"] = answers
        
        if reflections_strs:
            table_data["reflections"] = reflections_strs
            
        # Add visualizations
        if graph_viz:
            table_data["graph_visualization"] = self._process_images(graph_viz, len(prompts_strs))
            
        if network_plots:
            table_data["network_plot"] = self._process_images(network_plots, len(prompts_strs))
            
        return table_data
    
    def _process_images(self, images: List, expected_length: int) -> List:
        """Process base64 images to wandb.Image objects."""
        wandb_images = []
        
        for img in images:
            if img:
                try:
                    import io
                    import base64
                    from PIL import Image
                    
                    image_data = base64.b64decode(img)
                    pil_image = Image.open(io.BytesIO(image_data))
                    wandb_images.append(wandb.Image(pil_image))
                except Exception as e:
                    print(f"Error processing image: {e}")
                    wandb_images.append(None)
            else:
                wandb_images.append(None)
                
        # Ensure length matches expected
        while len(wandb_images) < expected_length:
            wandb_images.append(None)
            
        return wandb_images[:expected_length]
    
    def _log_component_curves(
        self,
        component_rewards: Dict[str, List[float]],
        power_system_weights: Dict[str, float],
        step: int,
        is_eval_step: bool
    ) -> None:
        """Log individual reward component curves to wandb."""
        if not component_rewards:
            return
            
        # Skip logging at step 0 to avoid meaningless (0,0) points
        if step == 0:
            return
            
        prefix = "eval_" if is_eval_step else "train_"
        metrics = {}
        
        for component_name, values in component_rewards.items():
            if not values:
                continue
                
            # Only log components with non-zero weights (or total which is always logged)
            weight = power_system_weights.get(component_name, 0.0)
            if component_name == 'total' or weight > 0:
                avg_value = sum(values) / len(values)
                metrics[f"{prefix}reward_components/{component_name}"] = avg_value
                
        if metrics:
            wandb.log(metrics)


class SimpleCallback(TrainerCallback):
    """Simple callback that does nothing but satisfies the callback interface."""
    pass


def create_script_upload_callback(script_path=None, additional_files=None):
    """
    Convenience function to create a simple callback for script uploading.
    Returns a simple no-op callback for backward compatibility.
    """
    return SimpleCallback()

 