"""
Data Storage Module

Defines data structures for storing experiment trajectories and results.
Supports serialization to Parquet and JSON formats.
"""

import numpy as np
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Tuple, Optional, Any
import json
from pathlib import Path


@dataclass
class RunTrajectory:
    """
    Complete trajectory for a single experimental run.

    Stores all time-series data and per-config statistics for analysis.
    """

    # Metadata
    run_id: str
    algorithm: str
    experiment_family: str  # 'S1', 'S2', 'Alibaba', etc.
    seed: int
    config: Dict[str, Any] = field(default_factory=dict)

    # Experiment parameters
    T: int = 0
    K: int = 0
    d: int = 0

    # Time series (length T)
    theta_t: np.ndarray = field(default_factory=lambda: np.array([]))  # Config selected
    r_t: np.ndarray = field(default_factory=lambda: np.array([]))      # Reward observed
    a_t: np.ndarray = field(default_factory=lambda: np.array([]))      # Consumption (T x d)
    x_t: np.ndarray = field(default_factory=lambda: np.array([]))      # Admission decision
    p_t: np.ndarray = field(default_factory=lambda: np.array([]))      # Price vector (T x d)
    w_t: np.ndarray = field(default_factory=lambda: np.array([]))      # Mixture vector (T x K)

    # Cumulative metrics (derived, computed at end)
    cumulative_reward: np.ndarray = field(default_factory=lambda: np.array([]))
    cumulative_consumption: np.ndarray = field(default_factory=lambda: np.array([]))
    budget_remaining: np.ndarray = field(default_factory=lambda: np.array([]))

    # Per-config statistics
    N_theta: np.ndarray = field(default_factory=lambda: np.array([]))  # Sample counts (K,)

    # Algorithm internals (optional)
    beta_t: np.ndarray = field(default_factory=lambda: np.array([]))   # Confidence radii (T x K)

    # Final metrics
    total_reward: float = 0.0
    total_consumption: np.ndarray = field(default_factory=lambda: np.array([]))
    acceptance_rate: float = 0.0
    budget_violation: bool = False

    def initialize(self, T: int, K: int, d: int):
        """Pre-allocate arrays for trajectory storage."""
        self.T = T
        self.K = K
        self.d = d

        self.theta_t = np.zeros(T, dtype=np.int32)
        self.r_t = np.zeros(T)
        self.a_t = np.zeros((T, d))
        self.x_t = np.zeros(T, dtype=np.int32)
        self.p_t = np.zeros((T, d))
        self.w_t = np.zeros((T, K))
        self.beta_t = np.zeros((T, K))
        self.N_theta = np.zeros(K, dtype=np.int32)

    def record_step(
        self,
        t: int,
        theta: int,
        r: float,
        a: np.ndarray,
        x: int,
        p: np.ndarray,
        w: np.ndarray,
        beta: Optional[np.ndarray] = None
    ):
        """Record a single timestep."""
        self.theta_t[t] = theta
        self.r_t[t] = r
        self.a_t[t] = a
        self.x_t[t] = x
        self.p_t[t] = p
        self.w_t[t] = w
        if beta is not None:
            self.beta_t[t] = beta

    def finalize(self, B: np.ndarray):
        """Compute derived metrics after run completion."""
        T = self.T

        # Cumulative reward
        self.cumulative_reward = np.cumsum(self.r_t * self.x_t)

        # Cumulative consumption per resource
        accepted = self.x_t[:, np.newaxis]  # (T, 1)
        self.cumulative_consumption = np.cumsum(self.a_t * accepted, axis=0)

        # Budget remaining
        self.budget_remaining = B[np.newaxis, :] - self.cumulative_consumption

        # Final statistics
        self.total_reward = float(self.cumulative_reward[-1]) if T > 0 else 0.0
        self.total_consumption = self.cumulative_consumption[-1] if T > 0 else np.zeros(self.d)
        self.acceptance_rate = float(np.mean(self.x_t))

        # Check budget violation
        self.budget_violation = np.any(self.budget_remaining < -1e-9)

        # Sample counts
        for theta in range(self.K):
            self.N_theta[theta] = np.sum(self.theta_t == theta)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return {
            'run_id': self.run_id,
            'algorithm': self.algorithm,
            'experiment_family': self.experiment_family,
            'seed': self.seed,
            'config': self.config,
            'T': self.T,
            'K': self.K,
            'd': self.d,
            'total_reward': self.total_reward,
            'acceptance_rate': self.acceptance_rate,
            'budget_violation': self.budget_violation,
            'N_theta': self.N_theta.tolist(),
            'total_consumption': self.total_consumption.tolist() if isinstance(self.total_consumption, np.ndarray) else self.total_consumption,
        }

    def save_npz(self, filepath: str):
        """Save full trajectory to compressed numpy format."""
        np.savez_compressed(
            filepath,
            theta_t=self.theta_t,
            r_t=self.r_t,
            a_t=self.a_t,
            x_t=self.x_t,
            p_t=self.p_t,
            w_t=self.w_t,
            beta_t=self.beta_t,
            cumulative_reward=self.cumulative_reward,
            cumulative_consumption=self.cumulative_consumption,
            budget_remaining=self.budget_remaining,
            N_theta=self.N_theta,
            metadata=np.array([self.T, self.K, self.d, self.total_reward, self.acceptance_rate])
        )

    @classmethod
    def load_npz(cls, filepath: str, run_id: str = "", algorithm: str = "",
                 experiment_family: str = "", seed: int = 0) -> 'RunTrajectory':
        """Load trajectory from numpy format."""
        data = np.load(filepath)
        traj = cls(
            run_id=run_id,
            algorithm=algorithm,
            experiment_family=experiment_family,
            seed=seed
        )
        traj.theta_t = data['theta_t']
        traj.r_t = data['r_t']
        traj.a_t = data['a_t']
        traj.x_t = data['x_t']
        traj.p_t = data['p_t']
        traj.w_t = data['w_t']
        traj.beta_t = data['beta_t']
        traj.cumulative_reward = data['cumulative_reward']
        traj.cumulative_consumption = data['cumulative_consumption']
        traj.budget_remaining = data['budget_remaining']
        traj.N_theta = data['N_theta']

        meta = data['metadata']
        traj.T = int(meta[0])
        traj.K = int(meta[1])
        traj.d = int(meta[2])
        traj.total_reward = float(meta[3])
        traj.acceptance_rate = float(meta[4])

        return traj


@dataclass
class ExperimentResults:
    """
    Aggregated results across multiple runs.

    Stores oracle values and per-algorithm statistics.
    """

    # Experiment metadata
    experiment_id: str
    family: str  # 'S1', 'S2', 'Alibaba', etc.
    n_runs: int

    # Experiment parameters
    T: int = 0
    K: int = 0
    d: int = 0
    budget_factor: float = 1.0

    # Oracle values (computed once from data)
    V_mix: float = 0.0
    V_star: float = 0.0
    w_star: np.ndarray = field(default_factory=lambda: np.array([]))
    p_star: np.ndarray = field(default_factory=lambda: np.array([]))

    # Algorithm names
    algorithms: List[str] = field(default_factory=list)

    # Per-algorithm results: Dict[algorithm_name, list of values across runs]
    total_rewards: Dict[str, np.ndarray] = field(default_factory=dict)
    competitive_ratios: Dict[str, np.ndarray] = field(default_factory=dict)
    fixed_config_ratios: Dict[str, np.ndarray] = field(default_factory=dict)
    acceptance_rates: Dict[str, np.ndarray] = field(default_factory=dict)
    budget_violations: Dict[str, np.ndarray] = field(default_factory=dict)

    # Regret metrics
    regrets: Dict[str, np.ndarray] = field(default_factory=dict)

    def add_run(self, algorithm: str, trajectory: RunTrajectory):
        """Add results from a single run."""
        if algorithm not in self.algorithms:
            self.algorithms.append(algorithm)
            self.total_rewards[algorithm] = []
            self.competitive_ratios[algorithm] = []
            self.fixed_config_ratios[algorithm] = []
            self.acceptance_rates[algorithm] = []
            self.budget_violations[algorithm] = []
            self.regrets[algorithm] = []

        reward = trajectory.total_reward

        self.total_rewards[algorithm].append(reward)
        self.acceptance_rates[algorithm].append(trajectory.acceptance_rate)
        self.budget_violations[algorithm].append(int(trajectory.budget_violation))

        # Compute ratios
        if self.V_mix > 0:
            self.competitive_ratios[algorithm].append(reward / (self.T * self.V_mix))
            self.regrets[algorithm].append(self.T * self.V_mix - reward)
        if self.V_star > 0:
            self.fixed_config_ratios[algorithm].append(reward / (self.T * self.V_star))

    def finalize(self):
        """Convert lists to arrays and compute summary statistics."""
        for alg in self.algorithms:
            self.total_rewards[alg] = np.array(self.total_rewards[alg])
            self.competitive_ratios[alg] = np.array(self.competitive_ratios[alg])
            self.fixed_config_ratios[alg] = np.array(self.fixed_config_ratios[alg])
            self.acceptance_rates[alg] = np.array(self.acceptance_rates[alg])
            self.budget_violations[alg] = np.array(self.budget_violations[alg])
            self.regrets[alg] = np.array(self.regrets[alg])

    def get_summary(self, algorithm: str) -> Dict[str, float]:
        """Get summary statistics for an algorithm."""
        return {
            'mean_reward': float(np.mean(self.total_rewards[algorithm])),
            'std_reward': float(np.std(self.total_rewards[algorithm])),
            'mean_competitive_ratio': float(np.mean(self.competitive_ratios[algorithm])),
            'std_competitive_ratio': float(np.std(self.competitive_ratios[algorithm])),
            'mean_fixed_ratio': float(np.mean(self.fixed_config_ratios[algorithm])),
            'mean_regret': float(np.mean(self.regrets[algorithm])),
            'std_regret': float(np.std(self.regrets[algorithm])),
            'mean_acceptance_rate': float(np.mean(self.acceptance_rates[algorithm])),
            'violation_rate': float(np.mean(self.budget_violations[algorithm])),
        }

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization."""
        result = {
            'experiment_id': self.experiment_id,
            'family': self.family,
            'n_runs': self.n_runs,
            'T': self.T,
            'K': self.K,
            'd': self.d,
            'budget_factor': self.budget_factor,
            'V_mix': self.V_mix,
            'V_star': self.V_star,
            'w_star': self.w_star.tolist() if isinstance(self.w_star, np.ndarray) else self.w_star,
            'p_star': self.p_star.tolist() if isinstance(self.p_star, np.ndarray) else self.p_star,
            'algorithms': self.algorithms,
            'summaries': {}
        }

        for alg in self.algorithms:
            result['summaries'][alg] = self.get_summary(alg)

        return result

    def save_json(self, filepath: str):
        """Save results to JSON file."""
        with open(filepath, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)

    @classmethod
    def load_json(cls, filepath: str) -> 'ExperimentResults':
        """Load results from JSON file."""
        with open(filepath, 'r') as f:
            data = json.load(f)

        results = cls(
            experiment_id=data['experiment_id'],
            family=data['family'],
            n_runs=data['n_runs']
        )
        results.T = data['T']
        results.K = data['K']
        results.d = data['d']
        results.budget_factor = data.get('budget_factor', 1.0)
        results.V_mix = data['V_mix']
        results.V_star = data['V_star']
        results.w_star = np.array(data['w_star'])
        results.p_star = np.array(data['p_star'])
        results.algorithms = data['algorithms']

        return results


def create_results_directory(base_path: str, experiment_family: str) -> Path:
    """Create directory structure for results."""
    path = Path(base_path) / experiment_family
    (path / 'trajectories').mkdir(parents=True, exist_ok=True)
    return path
