import math
import copy
import gym
import random
import numpy as np
import statistics
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from dataclasses import dataclass
from typing import Dict, List, Tuple

# Import environment
import improved_hopper
from SnapshotENV import SnapshotEnv

####################################################################
# Ablation Study Configuration
####################################################################

@dataclass
class AblationConfig:
    """Configuration for ablation study variants"""
    name: str
    use_power_mean: bool = True
    use_epsilon_net: bool = True
    use_epsilon_bonus: bool = True
    use_adaptive_epsilon: bool = True
    power: float = 2.0
    epsilon_1: float = 0.5
    beta: float = 1.0
    L: float = 1.0
    ucb_constant: float = 2.0
    description: str = ""

# Define ablation variants
ABLATION_VARIANTS = [
    # Baseline DAMCTS
    AblationConfig(
        name="DAMCTS",
        description="Full DAMCTS algorithm"
    ),

    # Standard UCT (no power mean)
    AblationConfig(
        name="UCT-Style",
        use_power_mean=False,
        description="Standard UCT with DAMCTS epsilon-nets"
    ),

    # Fixed epsilon (no adaptive epsilon)
    AblationConfig(
        name="Fixed-Epsilon",
        use_adaptive_epsilon=False,
        epsilon_1=0.1,  # Fixed epsilon
        description="Fixed epsilon-net size (no adaptation)"
    ),

    # No epsilon bonus
    AblationConfig(
        name="No-Epsilon-Bonus",
        use_epsilon_bonus=False,
        description="Power-mean + adaptive nets, no epsilon bonus"
    ),

    # Different power values
    AblationConfig(
        name="Power-1.5",
        power=1.5,
        description="Power-mean with p=1.5"
    ),

    AblationConfig(
        name="Power-3.0",
        power=3.0,
        description="Power-mean with p=3.0"
    ),

    AblationConfig(
        name="Power-4.0",
        power=4.0,
        description="Power-mean with p=4.0"
    ),

    AblationConfig(
        name="Power-5.0",
        power=5.0,
        description="Power-mean with p=5.0"
    ),

    # Different epsilon parameters
    AblationConfig(
        name="Epsilon-0.25",
        epsilon_1=0.25,
        description="Smaller initial epsilon"
    ),

    AblationConfig(
        name="Epsilon-0.75",
        epsilon_1=0.75,
        description="Larger initial epsilon"
    ),

    # Different beta values (dimension scaling)
    AblationConfig(
        name="Beta-0.5",
        beta=0.5,
        description="Faster epsilon decay (β=0.5)"
    ),

    AblationConfig(
        name="Beta-2.0",
        beta=2.0,
        description="Slower epsilon decay (β=2.0)"
    ),

    # Random actions (baseline)
    AblationConfig(
        name="Random",
        use_epsilon_net=False,
        use_power_mean=False,
        use_epsilon_bonus=False,
        use_adaptive_epsilon=False,
        description="Random action selection"
    )
]

####################################################################
# Environment Configuration
####################################################################

ENV_NAME = "ImprovedHopper-v0"
ENV_NOISE_CONFIG = {
    "action_noise_scale": 0.03,
    "dynamics_noise_scale": 0.02,
    "obs_noise_scale": 0.01
}

# Global parameters
NUM_SEEDS = 10  # Reduced for quicker ablation
TEST_ITERATIONS = 150
DISCOUNT = 0.99
MAX_DEPTH = 100
REWARD_OFFSET = 20.0
REWARD_SCALING = 1.0

# Iteration schedule (reduced for ablation)
ITERATION_SCHEDULE = [100, 500, 1000, 2000]

####################################################################
# Modified DAMCTS Implementation for Ablation
####################################################################

def build_epsilon_net(action_dim, epsilon, lo=-1.0, hi=1.0, use_epsilon_net=True):
    """Build epsilon-net with optional disable"""
    if not use_epsilon_net:
        # Return single random action for comparison
        return [np.random.uniform(lo, hi, size=action_dim).astype(np.float32)]

    if action_dim <= 4:
        per_dim = int(round((1/epsilon)**(1/action_dim)))
        per_dim = max(2, min(15, per_dim))
        axes = [np.linspace(lo, hi, per_dim) for _ in range(action_dim)]
        mesh = np.meshgrid(*axes, indexing='ij')
        points = np.stack([m.ravel() for m in mesh], axis=-1)
        return [point.astype(np.float32) for point in points]
    else:
        n_samples = min(1000, max(50, int((1/epsilon) ** 2)))
        samples = np.random.uniform(lo, hi, size=(n_samples, action_dim))
        return [sample.astype(np.float32) for sample in samples]

class AblationDAMCTSNode:
    def __init__(self, parent, action, eps_net_func, env, config: AblationConfig):
        self.parent = parent
        self.action = action
        self.children = set()
        self.visit_count = 0
        self.value_sum = 0.0
        self.value_sum_power = 0.0
        self.config = config

        # Environment interaction
        if parent is None:
            self.snapshot = None
            self.obs = None
            self.immediate_reward = 0.0
            self.is_done = False
        else:
            snap, obs, r, done, _ = env.get_result(parent.snapshot, self.action)
            self.snapshot = snap
            self.obs = obs
            self.immediate_reward = max(0.01, (r + REWARD_OFFSET) * REWARD_SCALING)
            self.is_done = done

        self.eps_net_func = eps_net_func
        self.action_dim = 3  # Hopper has 3 action dimensions

    def get_value(self):
        """Get node value according to configuration"""
        if self.visit_count == 0:
            return 0.0

        if self.config.use_power_mean:
            return (self.value_sum_power / self.visit_count) ** (1.0 / self.config.power)
        else:
            return self.value_sum / self.visit_count

    def epsilon_level(self):
        """Compute epsilon level based on configuration"""
        if not self.config.use_adaptive_epsilon:
            return 1, self.config.epsilon_1, int(1/self.config.epsilon_1)

        n = max(self.visit_count, 1)
        k = 1
        while True:
            eps = self.config.epsilon_1 * (2 ** (-(k-1)/(self.action_dim + 2*self.config.beta)))
            size = int((1/eps) ** min(self.action_dim, 3))
            if n <= size * size:
                return k, eps, size
            k += 1

    def ucb_score(self, action, parent_visits):
        """Compute UCB score according to configuration"""
        child = self.get_child_for_action(action)

        if child is None or child.visit_count == 0:
            return float('inf')

        # Base value
        base_value = child.get_value()

        # Epsilon bonus
        eps_bonus = 0.0
        if self.config.use_epsilon_bonus:
            k, eps_k, _ = self.epsilon_level()
            eps_bonus = self.config.L * (eps_k ** self.config.beta)

        # UCB bonus
        ucb_bonus = 0.0
        if parent_visits > 0 and child.visit_count > 0:
            ucb_bonus = np.sqrt(self.config.ucb_constant * np.log(parent_visits) / child.visit_count)

        return base_value + eps_bonus + ucb_bonus

    def get_child_for_action(self, action):
        """Find child with matching action"""
        for child in self.children:
            if child.action is not None and np.allclose(child.action, action, atol=1e-6):
                return child
        return None

    def selection(self):
        """Selection phase"""
        if len(self.children) == 0 or self.is_done:
            return self

        # Get epsilon net
        k, eps_k, _ = self.epsilon_level()
        epsilon_net = self.eps_net_func(eps_k)

        # Find best action
        best_action = None
        best_score = -float('inf')

        for action in epsilon_net:
            score = self.ucb_score(action, self.visit_count)
            if score > best_score:
                best_score = score
                best_action = action

        # Get or create child
        child = self.get_child_for_action(best_action)
        if child is not None:
            return child.selection()
        else:
            return self

    def expand(self, env):
        """Expansion phase"""
        if self.is_done:
            return self

        k, eps_k, _ = self.epsilon_level()
        epsilon_net = self.eps_net_func(eps_k)

        # Add new children
        for action in epsilon_net:
            if self.get_child_for_action(action) is None:
                child = AblationDAMCTSNode(self, action, self.eps_net_func, env, self.config)
                self.children.add(child)

        return self.selection()

    def rollout(self, env, max_depth=MAX_DEPTH):
        """Random rollout"""
        if self.is_done:
            return 0.0

        env.load_snapshot(self.snapshot)
        total = 0.0
        discount_factor = 1.0

        for _ in range(max_depth):
            action = env.action_space.sample()
            obs, r, done, _ = env.step(action)
            scaled_reward = max(0.01, (r + REWARD_OFFSET) * REWARD_SCALING)
            total += scaled_reward * discount_factor
            discount_factor *= DISCOUNT
            if done:
                break

        return total

    def back_propagate(self, rollout_reward):
        """Back-propagation with power-mean updates"""
        total_return = self.immediate_reward + rollout_reward

        self.value_sum += total_return
        if self.config.use_power_mean:
            self.value_sum_power += total_return ** self.config.power
        self.visit_count += 1

        if self.parent is not None:
            self.parent.back_propagate(rollout_reward * DISCOUNT)

class AblationDAMCTSRoot(AblationDAMCTSNode):
    def __init__(self, snapshot, obs, eps_net_func, config: AblationConfig):
        super().__init__(None, None, eps_net_func, None, config)
        self.snapshot = snapshot
        self.obs = obs
        self.immediate_reward = 0.0
        self.is_done = False

def plan_ablation_damcts(root, n_iter, env):
    """Planning with ablation configuration"""
    for _ in range(n_iter):
        leaf = root.selection()
        if leaf.is_done:
            leaf.back_propagate(0.0)
        else:
            new_leaf = leaf.expand(env)
            rollout_value = new_leaf.rollout(env)
            new_leaf.back_propagate(rollout_value)

####################################################################
# Ablation Study Execution
####################################################################

def run_ablation_experiment(config: AblationConfig, n_iterations: int, seed: int) -> float:
    """Run single experiment with given configuration"""
    # Set seeds
    random.seed(seed)
    np.random.seed(seed)

    # Create environment
    env = gym.make(ENV_NAME, **ENV_NOISE_CONFIG).env
    planning_env = SnapshotEnv(env)

    # Reset environment
    obs = planning_env.reset()
    snapshot = planning_env.get_snapshot()

    # Create epsilon-net function
    def eps_net_func(epsilon):
        return build_epsilon_net(3, epsilon, -1.0, 1.0, config.use_epsilon_net)

    # Create root node
    root = AblationDAMCTSRoot(snapshot, obs, eps_net_func, config)

    # Planning phase
    plan_ablation_damcts(root, n_iterations, planning_env)

    # Evaluation phase
    test_env = pickle.loads(snapshot)
    total_reward = 0.0
    discount_factor = 1.0

    for step in range(TEST_ITERATIONS):
        # Select best action
        if len(root.children) == 0:
            action = np.zeros(3, dtype=np.float32)
        else:
            if config.name == "Random":
                action = test_env.action_space.sample()
            else:
                best_child = max(root.children, key=lambda c: c.get_value())
                action = best_child.action

        # Take step
        obs, reward, done, _ = test_env.step(action)
        total_reward += reward * discount_factor
        discount_factor *= DISCOUNT

        if done:
            break

        # Re-root tree (simplified for ablation)
        if config.name != "Random" and len(root.children) > 0:
            best_child = max(root.children, key=lambda c: c.get_value())
            # Create new root from best child
            root = AblationDAMCTSRoot(best_child.snapshot, obs, eps_net_func, config)
            # Quick re-planning
            plan_ablation_damcts(root, n_iterations // 4, planning_env)

    test_env.close()
    return total_reward

def run_ablation_study():
    """Run complete ablation study"""
    results = {}

    print("Running DAMCTS Ablation Study on Hopper")
    print("=" * 50)

    for config in ABLATION_VARIANTS:
        print(f"\nTesting {config.name}: {config.description}")
        results[config.name] = {}

        for n_iter in ITERATION_SCHEDULE:
            print(f"  Iterations: {n_iter}")
            seed_results = []

            for seed in range(NUM_SEEDS):
                try:
                    reward = run_ablation_experiment(config, n_iter, seed)
                    seed_results.append(reward)
                    print(f"    Seed {seed}: {reward:.2f}")
                except Exception as e:
                    print(f"    Seed {seed}: Failed ({e})")
                    continue

            if seed_results:
                mean_reward = np.mean(seed_results)
                std_reward = np.std(seed_results)
                results[config.name][n_iter] = {
                    'mean': mean_reward,
                    'std': std_reward,
                    'raw': seed_results
                }
                print(f"    Mean ± Std: {mean_reward:.2f} ± {std_reward:.2f}")
            else:
                results[config.name][n_iter] = {'mean': 0, 'std': 0, 'raw': []}

    return results

def plot_ablation_results(results: Dict):
    """Plot ablation study results"""
    plt.figure(figsize=(15, 10))

    # Plot 1: Performance vs Iterations
    plt.subplot(2, 2, 1)
    for config_name, config_results in results.items():
        iterations = sorted(config_results.keys())
        means = [config_results[it]['mean'] for it in iterations]
        stds = [config_results[it]['std'] for it in iterations]

        plt.errorbar(iterations, means, yerr=stds, label=config_name, marker='o')

    plt.xlabel('Planning Iterations')
    plt.ylabel('Average Return')
    plt.title('Performance vs Planning Iterations')
    plt.legend()
    plt.grid(True, alpha=0.3)

    # Plot 2: Final Performance Comparison
    plt.subplot(2, 2, 2)
    final_iteration = max(ITERATION_SCHEDULE)
    final_means = []
    final_stds = []
    config_names = []

    for config_name, config_results in results.items():
        if final_iteration in config_results:
            final_means.append(config_results[final_iteration]['mean'])
            final_stds.append(config_results[final_iteration]['std'])
            config_names.append(config_name)

    plt.bar(range(len(config_names)), final_means, yerr=final_stds)
    plt.xticks(range(len(config_names)), config_names, rotation=45, ha='right')
    plt.ylabel('Average Return')
    plt.title(f'Final Performance Comparison ({final_iteration} iterations)')
    plt.grid(True, alpha=0.3)

    # Plot 3: Component Analysis
    plt.subplot(2, 2, 3)
    component_analysis = {
        'Power Mean': [],
        'Epsilon Net': [],
        'Epsilon Bonus': [],
        'Adaptive Epsilon': []
    }

    # Analyze impact of each component
    baseline_performance = results['DAMCTS'][final_iteration]['mean']

    for config in ABLATION_VARIANTS:
        if final_iteration in results[config.name]:
            perf = results[config.name][final_iteration]['mean']

            if config.name == 'UCT-Style':
                component_analysis['Power Mean'].append(baseline_performance - perf)
            elif config.name == 'Fixed-Epsilon':
                component_analysis['Adaptive Epsilon'].append(baseline_performance - perf)
            elif config.name == 'No-Epsilon-Bonus':
                component_analysis['Epsilon Bonus'].append(baseline_performance - perf)

    components = list(component_analysis.keys())
    impacts = [np.mean(component_analysis[comp]) if component_analysis[comp] else 0
               for comp in components]

    plt.bar(components, impacts)
    plt.ylabel('Performance Impact')
    plt.title('Component Contribution Analysis')
    plt.xticks(rotation=45, ha='right')
    plt.grid(True, alpha=0.3)

    # Plot 4: Parameter Sensitivity
    plt.subplot(2, 2, 4)
    parameter_variants = {
        'Power': ['Power-1.5', 'DAMCTS', 'Power-3.0'],
        'Epsilon': ['Epsilon-0.25', 'DAMCTS', 'Epsilon-0.75'],
        'Beta': ['Beta-0.5', 'DAMCTS', 'Beta-2.0']
    }

    for param_name, variants in parameter_variants.items():
        values = []
        for variant in variants:
            if variant in results and final_iteration in results[variant]:
                values.append(results[variant][final_iteration]['mean'])

        if len(values) == 3:
            plt.plot([0, 1, 2], values, marker='o', label=param_name)

    plt.xlabel('Parameter Value (Low, Default, High)')
    plt.ylabel('Average Return')
    plt.title('Parameter Sensitivity Analysis')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('damcts_ablation_hopper.png', dpi=300, bbox_inches='tight')
    plt.show()

def save_ablation_results(results: Dict, filename: str = 'damcts_ablation_results.txt'):
    """Save detailed results to file"""
    with open(filename, 'w') as f:
        f.write("DAMCTS Ablation Study Results - Hopper Environment\n")
        f.write("=" * 50 + "\n\n")

        for config_name, config_results in results.items():
            f.write(f"{config_name}:\n")
            for iteration, result in sorted(config_results.items()):
                f.write(f"  {iteration} iterations: {result['mean']:.3f} ± {result['std']:.3f}\n")
            f.write("\n")

# Run the ablation study
if __name__ == "__main__":
    # Run the study
    results = run_ablation_study()

    # Plot results
    plot_ablation_results(results)

    # Save results
    save_ablation_results(results)

    print("\nAblation study complete!")
    print("Results saved to damcts_ablation_results.txt")
    print("Plots saved to damcts_ablation_hopper.png")