"""
Weights & Biases utilities for experiment tracking.
"""

import wandb
import os
import numpy as np
import torch
from typing import Dict, Any, Optional, List
import matplotlib.pyplot as plt
import seaborn as sns


class WandbLogger:
    """Centralized Weights & Biases logging utility."""
    
    def __init__(self, project_name: str, entity: str = "anonymous"):
        self.project_name = project_name
        self.entity = entity
        self.run = None
        
    def init_run(self, config: Dict[str, Any], run_name: str, tags: Optional[List[str]] = None):
        """Initialize a new wandb run."""
        self.run = wandb.init(
            project=self.project_name,
            entity=self.entity,
            name=run_name,
            config=config,
            tags=tags or [],
            reinit=True
        )
        return self.run
    
    def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None):
        """Log metrics to wandb."""
        if self.run is not None:
            wandb.log(metrics, step=step)
    
    def log_hyperparameters(self, params: Dict[str, Any]):
        """Log hyperparameters."""
        if self.run is not None:
            wandb.config.update(params)
    
    def log_figure(self, figure, name: str, step: Optional[int] = None):
        """Log matplotlib figure to wandb."""
        if self.run is not None:
            wandb.log({name: wandb.Image(figure)}, step=step)
    
    def log_table(self, data: List[List], columns: List[str], name: str):
        """Log table data to wandb."""
        if self.run is not None:
            table = wandb.Table(data=data, columns=columns)
            wandb.log({name: table})
    
    def finish_run(self):
        """Finish the current wandb run."""
        if self.run is not None:
            wandb.finish()
            self.run = None


def init_wandb_run(project: str, config: Dict[str, Any], run_name: str, 
                   tags: Optional[List[str]] = None) -> wandb.run:
    """Initialize a wandb run with standard configuration."""
    return wandb.init(
        project=project,
        entity="anonymous",
        name=run_name,
        config=config,
        tags=tags or [],
        reinit=True
    )


def log_experiment_metadata(experiment_type: str, research_question: str, 
                          baseline_methods: List[str], environment: str):
    """Log standardized experiment metadata."""
    metadata = {
        "experiment_type": experiment_type,
        "research_question": research_question,
        "baseline_methods": baseline_methods,
        "environment": environment,
        "framework": "DATE-GFN",
        "min_steps": 2000,
        "seeds": 8
    }
    wandb.config.update(metadata)


def create_comparison_plot(results: Dict[str, List[float]], 
                          metric_name: str, 
                          title: str = None) -> plt.Figure:
    """Create comparison plot for different methods."""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    methods = list(results.keys())
    values = [np.mean(results[method]) for method in methods]
    errors = [np.std(results[method]) for method in methods]
    
    colors = plt.cm.Set3(np.linspace(0, 1, len(methods)))
    bars = ax.bar(methods, values, yerr=errors, capsize=5, color=colors, alpha=0.8)
    
    # Highlight DATE-GFN if present
    for i, method in enumerate(methods):
        if 'DATE' in method.upper():
            bars[i].set_color('#d62728')  # Red color for DATE-GFN
            bars[i].set_alpha(1.0)
    
    ax.set_ylabel(metric_name)
    ax.set_title(title or f'{metric_name} Comparison')
    ax.tick_params(axis='x', rotation=45)
    plt.tight_layout()
    
    return fig


def create_learning_curves(results: Dict[str, Dict[str, List[float]]], 
                          metric_name: str,
                          title: str = None) -> plt.Figure:
    """Create learning curves plot."""
    fig, ax = plt.subplots(figsize=(12, 8))
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
    
    for i, (method, data) in enumerate(results.items()):
        steps = data.get('steps', range(len(data['values'])))
        values = np.array(data['values'])
        
        if len(values.shape) > 1:  # Multiple seeds
            mean_values = np.mean(values, axis=0)
            std_values = np.std(values, axis=0)
            
            color = '#d62728' if 'DATE' in method.upper() else colors[i % len(colors)]
            linewidth = 3 if 'DATE' in method.upper() else 2
            
            ax.plot(steps, mean_values, label=method, color=color, linewidth=linewidth)
            ax.fill_between(steps, mean_values - std_values, mean_values + std_values, 
                           color=color, alpha=0.2)
        else:
            ax.plot(steps, values, label=method, color=colors[i % len(colors)])
    
    ax.set_xlabel('Training Steps')
    ax.set_ylabel(metric_name)
    ax.set_title(title or f'{metric_name} Learning Curves')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    
    return fig


def create_ablation_heatmap(results: Dict[str, Dict[str, float]], 
                           param_name: str,
                           metric_name: str,
                           title: str = None) -> plt.Figure:
    """Create heatmap for ablation studies."""
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Convert results to matrix format
    params = sorted(list(set([key.split('_')[0] for key in results.keys()])))
    metrics = sorted(list(set([key.split('_')[1] for key in results.keys() if '_' in key])))
    
    if not metrics:  # Single metric case
        metrics = [metric_name]
        data = np.array([[results.get(f"{param}", 0.0) for param in params]])
    else:
        data = np.array([[results.get(f"{param}_{metric}", 0.0) 
                         for param in params] for metric in metrics])
    
    sns.heatmap(data, xticklabels=params, yticklabels=metrics, 
                annot=True, fmt='.3f', cmap='RdYlBu_r', ax=ax)
    
    ax.set_xlabel(param_name)
    ax.set_ylabel('Metrics' if len(metrics) > 1 else metric_name)
    ax.set_title(title or f'{param_name} Ablation Study')
    plt.tight_layout()
    
    return fig


def log_research_question_summary(rq_number: int, 
                                 key_findings: Dict[str, Any],
                                 significance_tests: Dict[str, float]):
    """Log summary of research question results."""
    summary = {
        f"RQ{rq_number}_key_findings": key_findings,
        f"RQ{rq_number}_significance": significance_tests,
        f"RQ{rq_number}_status": "completed"
    }
    
    wandb.log(summary)
    
    # Create summary table
    findings_data = []
    for finding, value in key_findings.items():
        p_value = significance_tests.get(finding, "N/A")
        significant = "✓" if isinstance(p_value, float) and p_value < 0.05 else "✗"
        findings_data.append([finding, str(value), str(p_value), significant])
    
    table = wandb.Table(
        data=findings_data,
        columns=["Finding", "Value", "P-value", "Significant"]
    )
    wandb.log({f"RQ{rq_number}_summary_table": table})


class ExperimentTracker:
    """High-level experiment tracking with automatic metric aggregation."""
    
    def __init__(self, project: str, experiment_name: str):
        self.logger = WandbLogger(project)
        self.experiment_name = experiment_name
        self.metrics_history = {}
        
    def start_experiment(self, config: Dict[str, Any], tags: Optional[List[str]] = None):
        """Start tracking an experiment."""
        self.logger.init_run(config, self.experiment_name, tags)
        
    def log_step(self, metrics: Dict[str, Any], step: int):
        """Log metrics for a single step."""
        self.logger.log_metrics(metrics, step)
        
        # Store for aggregation
        for key, value in metrics.items():
            if key not in self.metrics_history:
                self.metrics_history[key] = []
            self.metrics_history[key].append(value)
    
    def log_epoch_summary(self, epoch: int, additional_metrics: Optional[Dict[str, Any]] = None):
        """Log epoch-level summary statistics."""
        summary = {}
        
        for metric_name, values in self.metrics_history.items():
            if values:  # Only if we have data
                recent_values = values[-10:]  # Last 10 steps
                summary[f"{metric_name}_mean"] = np.mean(recent_values)
                summary[f"{metric_name}_std"] = np.std(recent_values)
                summary[f"{metric_name}_trend"] = np.mean(np.diff(recent_values)) if len(recent_values) > 1 else 0
        
        if additional_metrics:
            summary.update(additional_metrics)
            
        self.logger.log_metrics(summary, step=epoch)
    
    def finish_experiment(self):
        """Finish the experiment and clean up."""
        # Log final summary
        final_summary = {}
        for metric_name, values in self.metrics_history.items():
            if values:
                final_summary[f"final_{metric_name}"] = values[-1]
                final_summary[f"best_{metric_name}"] = max(values) if 'error' not in metric_name.lower() else min(values)
                final_summary[f"convergence_step_{metric_name}"] = len(values)
        
        self.logger.log_metrics(final_summary)
        self.logger.finish_run()
