"""Results aggregation and analysis for parameter sweeps."""

import json
import os
from pathlib import Path
from typing import Any, Dict, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml


class SweepResultsAnalyzer:
	"""Analyze and visualize results from parameter sweeps."""

	def __init__(self, results_dir: str) -> None:
		"""Initialize the analyzer with results directory.

		Args:
			results_dir (str): Directory containing sweep results.
		"""
		self.results_dir = Path(results_dir)
		self.results_df: Optional[pd.DataFrame] = None

	def load_results(
		self, results_file: str = "sweep_results.yaml"
	) -> pd.DataFrame:
		"""Load sweep results from file.

		Args:
			results_file (str): Name of the results file.

		Returns:
			pd.DataFrame: DataFrame with all experiment results.

		Raises:
			FileNotFoundError: If the results file does not exist.
		"""
		results_path = self.results_dir / results_file

		if not results_path.exists():
			raise FileNotFoundError(f"Results file not found: {results_path}")

		with open(results_path, "r") as f:
			results_data = yaml.safe_load(f)

		flattened_results = []
		for result in results_data:
			flat_result = {"experiment_id": result.get("experiment_id")}

			flat_result["success"] = result.get("success", False)

			for key in [
				"mean_reward",
				"std_reward",
				"final_reward",
				"num_episodes",
			]:
				flat_result[key] = result.get(key, 0.0)

			flat_result["asynchronicity_rate"] = result.get(
				"asynchronicity_rate", 0.0
			)
			flat_result["signal_density"] = result.get("signal_density", 0.0)

			if "sampled_params" in result:
				sampled = result["sampled_params"]

				if "env" in sampled:
					env_params = sampled["env"]
					for param, value in env_params.items():
						if isinstance(value, list) and len(value) == 2:
							flat_result[f"env_{param}_avg"] = sum(value) / 2.0
							flat_result[f"env_{param}_range"] = (
								value[1] - value[0]
							)
						else:
							flat_result[f"env_{param}"] = value

				for category, params in sampled.items():
					if category != "env":
						for param, value in params.items():
							flat_result[f"{category}_{param}"] = value

			flattened_results.append(flat_result)

		self.results_df = pd.DataFrame(flattened_results)
		return self.results_df

	def filter_successful_experiments(self) -> pd.DataFrame:
		"""Filter to only successful experiments.

		Returns:
			pd.DataFrame: DataFrame with only successful experiments.

		Raises:
			ValueError: If results are not loaded.
		"""
		if self.results_df is None:
			raise ValueError("Results not loaded. Call load_results() first.")

		return self.results_df[self.results_df["success"]]

	def plot_asynchronicity_vs_performance(
		self,
		performance_metric: str = "mean_reward",
		asynchronicity_metric: str = "asynchronicity_rate",
		save_path: Optional[str] = None,
	) -> None:
		"""Plot asynchronicity rate vs performance.

		Args:
			performance_metric (str): Which performance metric to use.
			asynchronicity_metric (str): Which asynchronicity metric to
				use.
			save_path (Optional[str]): Optional path to save the plot.
		"""
		successful_df = self.filter_successful_experiments()

		if successful_df.empty:
			print("No successful experiments to plot")
			return

		plt.figure(figsize=(10, 6))

		plt.scatter(
			successful_df[asynchronicity_metric],
			successful_df[performance_metric],
			alpha=0.6,
			s=50,
		)

		if len(successful_df) > 1:
			z = np.polyfit(
				successful_df[asynchronicity_metric],
				successful_df[performance_metric],
				1,
			)
			p = np.poly1d(z)
			x_trend = np.linspace(
				successful_df[asynchronicity_metric].min(),
				successful_df[asynchronicity_metric].max(),
				100,
			)
			plt.plot(x_trend, p(x_trend), "r--", alpha=0.8, linewidth=2)

		plt.xlabel(f"Asynchronicity Rate ({asynchronicity_metric})")
		plt.ylabel(f"Performance ({performance_metric})")
		plt.title("Asynchronicity vs Performance")
		plt.grid(True, alpha=0.3)

		if save_path:
			plt.savefig(save_path, dpi=300, bbox_inches="tight")

		plt.show()

	def plot_parameter_correlation_matrix(
		self, save_path: Optional[str] = None
	) -> None:
		"""Plot correlation matrix of parameters and performance.

		Args:
			save_path (Optional[str]): Optional path to save the plot.
		"""
		successful_df = self.filter_successful_experiments()

		if successful_df.empty:
			print("No successful experiments to analyze")
			return

		numeric_cols = successful_df.select_dtypes(include=[np.number]).columns
		correlation_matrix = successful_df[numeric_cols].corr()

		plt.figure(figsize=(12, 10))
		plt.imshow(
			correlation_matrix, cmap="coolwarm", aspect="auto", vmin=-1, vmax=1
		)
		plt.colorbar(label="Correlation Coefficient")
		plt.xticks(
			range(len(numeric_cols)), numeric_cols, rotation=45, ha="right"
		)
		plt.yticks(range(len(numeric_cols)), numeric_cols)

		for i in range(len(numeric_cols)):
			for j in range(len(numeric_cols)):
				plt.text(
					j,
					i,
					f"{correlation_matrix.iloc[i, j]:.2f}",
					ha="center",
					va="center",
					fontsize=8,
				)

		plt.title("Parameter Correlation Matrix")
		plt.tight_layout()

		if save_path:
			plt.savefig(save_path, dpi=300, bbox_inches="tight")

		plt.show()

	def generate_summary_report(self) -> Dict[str, Any]:
		"""Generate a summary report of the sweep results.

		Returns:
			Dict[str, Any]: Dictionary with summary statistics.

		Raises:
			ValueError: If results are not loaded.
		"""
		if self.results_df is None:
			raise ValueError("Results not loaded. Call load_results() first.")

		total_experiments = len(self.results_df)
		successful_experiments = len(self.filter_successful_experiments())
		success_rate = (
			successful_experiments / total_experiments
			if total_experiments > 0
			else 0.0
		)

		summary = {
			"total_experiments": total_experiments,
			"successful_experiments": successful_experiments,
			"success_rate": success_rate,
		}

		if successful_experiments > 0:
			successful_df = self.filter_successful_experiments()
			performance_stats = {}
			for metric in ["mean_reward", "final_reward"]:
				if metric in successful_df.columns:
					performance_stats[f"{metric}_mean"] = successful_df[
						metric
					].mean()
					performance_stats[f"{metric}_std"] = successful_df[
						metric
					].std()
					performance_stats[f"{metric}_min"] = successful_df[
						metric
					].min()
					performance_stats[f"{metric}_max"] = successful_df[
						metric
					].max()

			summary["performance_stats"] = performance_stats

			if "asynchronicity_rate" in successful_df.columns:
				async_stats = {
					"asynchronicity_rate_mean": successful_df[
						"asynchronicity_rate"
					].mean(),
					"asynchronicity_rate_std": successful_df[
						"asynchronicity_rate"
					].std(),
					"asynchronicity_rate_min": successful_df[
						"asynchronicity_rate"
					].min(),
					"asynchronicity_rate_max": successful_df[
						"asynchronicity_rate"
					].max(),
				}
				summary["asynchronicity_stats"] = async_stats

			if "mean_reward" in successful_df.columns:
				best_idx = successful_df["mean_reward"].idxmax()
				best_experiment = successful_df.loc[best_idx].to_dict()
				summary["best_experiment"] = best_experiment

		return summary

	def save_summary_report(self, save_path: Optional[str] = None) -> None:
		"""Save summary report to file.

		Args:
			save_path (Optional[str]): Optional path to save the report.
		"""
		summary = self.generate_summary_report()

		if save_path is None:
			save_path = self.results_dir / "summary_report.json"  # type: ignore

		with open(save_path, "w") as f:  # type: ignore
			json.dump(summary, f, indent=2, default=str)

		print(f"Summary report saved to {save_path}")


def create_sweep_analysis_script(results_dir: str) -> str:
	"""Create a standalone analysis script for the sweep results.

	Args:
		results_dir (str): Directory containing sweep results.

	Returns:
		str: Path to the created analysis script.
	"""
	script_content = f'''#!/usr/bin/env python3
"""
Analyze parameter sweep results.

This script analyzes the results from a parameter sweep experiment,
generating plots and summary statistics.
"""

import sys
from pathlib import Path

# Add the project root to the path
project_root = Path(__file__).parent.parent
sys.path.append(str(project_root))

from src.async_rl.utils.analysis import SweepResultsAnalyzer

def main():
    # Initialize analyzer
    analyzer = SweepResultsAnalyzer("{results_dir}")

    # Load results
    print("Loading sweep results...")
    results_df = analyzer.load_results()
    print(f"Loaded {{len(results_df)}} experiments")

    # Generate summary report
    print("Generating summary report...")
    analyzer.save_summary_report()

    # Create plots
    print("Creating plots...")

    # Asynchronicity vs performance plot
    analyzer.plot_asynchronicity_vs_performance(
        save_path="{results_dir}/asynchronicity_vs_performance.png"
    )

    # Parameter correlation matrix
    analyzer.plot_parameter_correlation_matrix(
        save_path="{results_dir}/parameter_correlations.png"
    )

    print("Analysis complete!")

if __name__ == "__main__":
    main()
'''

	script_path = os.path.join(results_dir, "analyze_results.py")
	with open(script_path, "w") as f:
		f.write(script_content)

	os.chmod(script_path, 0o755)

	return script_path
