import math
import copy
import gym
import random
import numpy as np
import statistics
import pickle

# Import your updated custom/stochastic envs
import Continuous_CartPole
import Continuous_Pendulum
import continuous_mountain_car
import continuous_acrobot
import improved_hopper


# Import the fast high-dimensional environments
try:
    import fast_high_dim_envs
except ImportError as e:
    print(f"Warning: Could not import fast_high_dim_envs: {e}")
    print("Make sure fast_high_dim_envs.py is in the same directory")

from SnapshotENV import SnapshotEnv

####################################################################
# 1) environment IDs - UPDATED WITH FAST HIGH-DIM ENVS
####################################################################
env_names = [
    "Continuous-CartPole-v0",
    "StochasticPendulum-v0",
    "StochasticMountainCarContinuous-v0",
    "StochasticContinuousAcrobot-v0",
    "ImprovedHopper-v0"
]

####################################################################
# 2) noise configs - UPDATED WITH FAST HIGH-DIM ENVS
####################################################################
ENV_NOISE_CONFIG = {
    "Continuous-CartPole-v0": {
        "action_noise_scale": 0.05,
        "dynamics_noise_scale": 0.5,
        "obs_noise_scale": 0.0
    },
    "StochasticPendulum-v0": {
        "action_noise_scale": 0.02,
        "dynamics_noise_scale": 0.1,
        "obs_noise_scale": 0.01
    },
    "StochasticMountainCarContinuous-v0": {
        "action_noise_scale": 0.05,
        "dynamics_noise_scale": 0.5,
        "obs_noise_scale": 0.0
    },
    "StochasticContinuousAcrobot-v0": {
        "action_noise_scale": 0.05,
        "dynamics_noise_scale": 0.7,
        "obs_noise_scale": 0.01
    },
    "ImprovedHopper-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01
    }
}

####################################################################
# 3) Global config - ADAPTED FOR DAMCTS
####################################################################
num_seeds = 10
TEST_ITERATIONS = 150
discount = 0.99
MAX_DAMCTS_DEPTH = 100
SCALE_UCB = 30.0
MAX_VALUE = 1e100

# DAMCTS-specific parameters
EPSILON_1 = 0.5
BETA = 1.0
L = 1.0
POWER = 2.0
# Reward scaling parameters - now environment-specific
REWARD_SCALING = 1.0

# Environment-specific reward offsets - UPDATED WITH FAST HIGH-DIM ENVS
REWARD_OFFSETS = {
    "Continuous-CartPole-v0": 5.0,              # Rewards are 0-1 per step
    "StochasticPendulum-v0": 5.0,               # Already shifted to be positive in env
    "StochasticMountainCarContinuous-v0": 5.0,  # Mixed pos/neg rewards
    "StochasticContinuousAcrobot-v0": 10.0,     # -1 per step + goal reward
    "ImprovedHopper-v0": 20.0                  # Can have negative rewards
}

####################################################################
# 4) Epsilon-net construction for DAMCTS
####################################################################
def build_epsilon_net(env_name, action_dim, epsilon, lo=-1.0, hi=1.0):
    """
    Build epsilon-net for actions - following the DAMCTS paper more closely
    """
    # For low dimensions (up to 4), use grid-based discretization
    if action_dim <= 4:
        # Calculate grid points per dimension
        # The paper suggests using (1/epsilon)^(1/d) points per dimension
        per_dim = int(round((1/epsilon)**(1/action_dim)))

        # Ensure reasonable bounds
        per_dim = max(2, min(20, per_dim))

        if action_dim == 1:
            # For 1D, simple linspace
            actions = np.linspace(lo, hi, per_dim)
            return [np.array([a], dtype=np.float32) for a in actions]
        else:
            # For multi-dimensional case
            axes = [np.linspace(lo, hi, per_dim) for _ in range(action_dim)]
            mesh = np.meshgrid(*axes, indexing='ij')
            points = np.stack([m.ravel() for m in mesh], axis=-1)
            return [point.astype(np.float32) for point in points]

    else:
        # For high dimensions, use random sampling following the paper
        # The paper suggests using approximately (1/epsilon)^d samples
        # But we need to be more practical for very high dimensions

        # Calculate desired number of samples with some practical limits
        if action_dim <= 8:
            n_samples = min(1000, max(50, int((1/epsilon) ** 2)))  # Square scaling
        else:
            n_samples = min(2000, max(100, int(500 * (1/epsilon))))  # Linear scaling

        # Generate random samples uniformly in the action space
        samples = np.random.uniform(lo, hi, size=(n_samples, action_dim))
        return [sample.astype(np.float32) for sample in samples]

def get_env_action_space(env_name):
    """Get action space bounds for each environment - UPDATED WITH FAST HIGH-DIM ENVS"""
    if env_name == "Continuous-CartPole-v0":
        return 1, -1.0, 1.0
    elif env_name == "StochasticPendulum-v0":
        return 1, -2.0, 2.0
    elif env_name == "StochasticMountainCarContinuous-v0":
        return 1, -1.0, 1.0
    elif env_name == "StochasticContinuousAcrobot-v0":
        return 1, -1.0, 1.0
    elif env_name == "ImprovedHopper-v0":
        return 3, -1.0, 1.0
    else:
        return 1, -1.0, 1.0  # Default

####################################################################
# Node classes for DAMCTS (unchanged)
####################################################################
class DAMCTSNode:
    def __init__(self, parent, action, eps_net_func, env, env_name):
        self.parent = parent
        self.action = action  # shape=(dim,) or None if root
        self.children = set()
        self.visit_count = 0
        self.value_sum = 0.0
        self.value_sum_power = 0.0  # For power-mean backups
        self.env_name = env_name  # Store env name for reward offset

        # Environment interaction
        if parent is None:
            # root node, no parent step
            self.snapshot = None
            self.obs = None
            self.immediate_reward = 0.0
            self.is_done = False
        else:
            snap, obs, r, done, _ = env.get_result(parent.snapshot, self.action)
            self.snapshot = snap
            self.obs = obs
            # Apply reward scaling for planning
            reward_offset = REWARD_OFFSETS.get(env_name, 10.0)  # Get env-specific offset
            self.immediate_reward = max(0.01, (r + reward_offset) * REWARD_SCALING)
            self.is_done = done

        self.eps_net_func = eps_net_func

    def __repr__(self):
        return f"DAMCTSNode(action={self.action}, visits={self.visit_count}, value={self.value_sum})"

    def safe_delete(self):
        """Recursively delete this node's children."""
        for child in self.children:
            child.safe_delete()
        del self.parent
        del self.children

    def is_root(self):
        return (self.parent is None)

    def is_leaf(self):
        return len(self.children) == 0

    def get_mean_value(self):
        return self.value_sum / self.visit_count if self.visit_count > 0 else 0.0

    def get_power_mean_value(self):
        """Power-mean value for DAMCTS"""
        if self.visit_count > 0:
            return (self.value_sum_power / self.visit_count) ** (1.0 / POWER)
        return 0.0

    def epsilon_level(self):
        """Compute epsilon level based on visits - following DAMCTS paper exactly"""
        n = max(self.visit_count, 1)
        k = 1
        d = self.get_action_dim()

        while True:
            # Epsilon level k corresponds to epsilon_k
            eps = EPSILON_1 * (2 ** (-(k-1)/(d + 2*BETA)))

            # Size of epsilon-net at level k
            if d <= 4:
                # For low dimensions, use exact formula from paper
                size = int(np.floor((1/eps) ** d))
            else:
                # For high dimensions, use approximation to avoid overflow
                # This is more practical while still following the spirit of the algorithm
                size = int(np.floor((1/eps) ** min(d, 3)))

            # Condition from the paper: we stay at level k if n <= |N_k|^2
            if n <= size * size:
                return k, eps, size
            k += 1

    def get_action_dim(self):
        """Get action dimension from environment"""
        if self.action is not None:
            return len(self.action)
        # For root, infer from environment
        if hasattr(self, '_action_dim'):
            return self._action_dim
        return 1  # Default

    def ucb_score(self, action, parent_visits):
        """Compute UCB score for action selection in DAMCTS style"""
        child = self.get_child_for_action(action)

        if child is None or child.visit_count == 0:
            return MAX_VALUE

        # Power-mean value (this is the main difference from standard UCT)
        pm_value = child.get_power_mean_value()

        # Get epsilon level for this node
        k, eps_k, _ = self.epsilon_level()

        # Epsilon-based bonus term (from DAMCTS paper)
        eps_bonus = L * (eps_k ** BETA)

        # Standard UCB exploration term
        # Use the form from the paper: sqrt(log(N_parent) / N_child)
        if parent_visits > 0 and child.visit_count > 0:
            ucb_bonus = SCALE_UCB * np.sqrt(np.sqrt(parent_visits) / child.visit_count)
        else:
            ucb_bonus = MAX_VALUE

        return pm_value + eps_bonus + ucb_bonus

    def selection(self):
        """
        DAMCTS selection using epsilon-nets
        """
        if self.is_leaf() or self.is_done:
            return self

        # Get current epsilon net
        k, eps_k, _ = self.epsilon_level()
        epsilon_net = self.eps_net_func(eps_k)

        # Find best action according to UCB
        best_action = None
        best_score = -float('inf')

        for action in epsilon_net:
            score = self.ucb_score(action, self.visit_count)
            if score > best_score:
                best_score = score
                best_action = action

        # Get or create child for best action
        child = self.get_child_for_action(best_action)
        # Note: Don't create child here, let expand() handle that
        if child is not None:
            return child.selection()
        else:
            return self  # Leaf node, needs expansion

    def get_child_for_action(self, action):
        """Find child with matching action (with tolerance)"""
        for child in self.children:
            if child.action is not None and np.allclose(child.action, action, atol=1e-6):
                return child
        return None

    def expand(self, env):
        """
        DAMCTS expansion using epsilon-nets
        """
        if self.is_done:
            return self

        k, eps_k, _ = self.epsilon_level()
        epsilon_net = self.eps_net_func(eps_k)

        # Add children for actions not yet explored
        for action in epsilon_net:
            if self.get_child_for_action(action) is None:
                child = DAMCTSNode(self, action, self.eps_net_func, env, self.env_name)
                self.children.add(child)

        return self.selection()

    def rollout(self, env, max_depth=MAX_DAMCTS_DEPTH):
        """Random rollout from current state"""
        if self.is_done:
            return 0.0

        env.load_snapshot(self.snapshot)
        total = 0.0
        discount_factor = 1.0

        # Get reward offset
        reward_offset = REWARD_OFFSETS.get(self.env_name, 10.0)

        for _ in range(max_depth):
            # Sample random action from action space
            action_dim = self.get_action_dim()
            if hasattr(env, 'action_space'):
                action = env.action_space.sample()
            else:
                # Fallback to uniform sampling
                action = np.random.uniform(-1.0, 1.0, size=action_dim).astype(np.float32)

            obs, r, done, _ = env.step(action)
            # Apply same reward scaling as in tree
            scaled_reward = max(0.01, (r + reward_offset) * REWARD_SCALING)
            total += scaled_reward * discount_factor
            discount_factor *= discount

            if done:
                break

        return total

    def back_propagate(self, rollout_reward):
        """
        DAMCTS backpropagation with power-mean updates
        """
        # Total return including immediate reward
        total_return = self.immediate_reward + rollout_reward

        # Update statistics
        self.value_sum += total_return
        self.value_sum_power += total_return ** POWER
        self.visit_count += 1

        # Continue up the tree
        if not self.is_root():
            self.parent.back_propagate(rollout_reward * discount)

class DAMCTSRoot(DAMCTSNode):
    """Root node for DAMCTS that doesn't need an action from a parent."""
    def __init__(self, snapshot, obs, eps_net_func, action_dim=1, env_name="default"):
        super().__init__(parent=None, action=None, eps_net_func=eps_net_func, env=None, env_name=env_name)
        self.snapshot = snapshot
        self.obs = obs
        self.immediate_reward = 0.0
        self.is_done = False
        self._action_dim = action_dim

    @staticmethod
    def to_root(node):
        """Convert a child node to root"""
        root = DAMCTSRoot(
            node.snapshot,
            node.obs,
            node.eps_net_func,
            action_dim=node.get_action_dim(),
            env_name=node.env_name
        )
        # Copy over the stats
        root.children = node.children
        root.value_sum = node.value_sum
        root.value_sum_power = node.value_sum_power
        root.visit_count = node.visit_count
        root.is_done = node.is_done
        return root

def plan_damcts(root, n_iter, env):
    """DAMCTS planning algorithm"""
    for _ in range(n_iter):
        leaf = root.selection()
        if leaf.is_done:
            leaf.back_propagate(0.0)
        else:
            new_leaf = leaf.expand(env)
            rollout_value = new_leaf.rollout(env, max_depth=MAX_DAMCTS_DEPTH)
            new_leaf.back_propagate(rollout_value)

####################################################################
# 5) Main experiment - UPDATED WITH FAST HIGH-DIM ENVS
####################################################################
if __name__ == "__main__":
    results_filename = "damcts_results_fast_high_dims.txt"
    f_out = open(results_filename, "a")

    # Same iteration schedule as UCT
    base = 1000 ** (1.0 / 15.0)
    samples = [int(3 * (base ** i)) for i in range(16)]
    samples_to_use = samples[0:6]

    # Same num_seeds as UCT
    num_seeds = 10

    for envname in env_names:
        print(f"\n{'='*60}")
        print(f"Starting experiments for {envname}")
        print(f"{'='*60}")

        # A) Build environment with noise - UPDATED FOR FAST HIGH-DIM ENVS
        stoch_kwargs = ENV_NOISE_CONFIG.get(envname, {})

        # For fast high-dim environments, pass noise parameters directly to the constructor
        if envname.startswith("Fast"):
            base_env = gym.make(envname, **stoch_kwargs).env
        else:
            # For other environments, handle as before
            base_env = gym.make(envname, **stoch_kwargs).env

        # B) Get environment specifics for DAMCTS
        action_dim, lo, hi = get_env_action_space(envname)

        # Set max depth based on environment
        if envname in ["ImprovedHopper-v0"]:
            max_depth = 100  # Medium episodes for locomotion environments
        else:
            max_depth = 50   # Shorter episodes for simpler environments

        # Create epsilon-net function for this environment
        def make_eps_net_func(action_dim, lo, hi):
            def eps_net_func(epsilon):
                return build_epsilon_net(envname, action_dim, epsilon, lo, hi)
            return eps_net_func

        eps_net_func = make_eps_net_func(action_dim, lo, hi)

        # Wrap in SnapshotEnv
        if envname.startswith("Fast"):
            planning_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)
        else:
            planning_env = SnapshotEnv(gym.make(envname, **stoch_kwargs).env)

        # Store environment name for reward offset lookup
        planning_env._env_name = envname
        root_obs_ori = planning_env.reset()
        root_snapshot_ori = planning_env.get_snapshot()

        for ITERATIONS in samples_to_use:
            print(f"\nRunning {ITERATIONS} iterations for {envname}")
            seed_returns = []

            for seed_i in range(num_seeds):
                if seed_i % 5 == 0:
                    print(f"  Seed {seed_i}/{num_seeds}")

                random.seed(seed_i)
                np.random.seed(seed_i)

                # copy snapshot
                root_obs = copy.copy(root_obs_ori)
                root_snapshot = copy.copy(root_snapshot_ori)

                # build root node for DAMCTS
                root = DAMCTSRoot(
                    snapshot=root_snapshot,
                    obs=root_obs,
                    eps_net_func=eps_net_func,
                    action_dim=action_dim,
                    env_name=envname
                )

                # plan with DAMCTS
                plan_damcts(root, n_iter=ITERATIONS, env=planning_env)

                # test phase
                test_env = pickle.loads(root_snapshot)
                total_reward = 0.0
                discount = 0.99
                current_discount = 1.0
                done = False
                TEST_ITERATIONS = 150

                for step_i in range(TEST_ITERATIONS):
                    # pick child with best power-mean value
                    if len(root.children) == 0:
                        # no children => random
                        best_child = None
                        best_action = np.zeros(action_dim, dtype=np.float32)
                    else:
                        best_child = max(root.children, key=lambda c: c.get_power_mean_value())
                        best_action = best_child.action

                    s, r, done, _ = test_env.step(best_action)
                    # Use original reward for evaluation (not scaled)
                    total_reward += r * current_discount
                    current_discount *= discount

                    if done:
                        test_env.close()
                        break

                    # re-root
                    for child in list(root.children):
                        if child is not best_child:
                            child.safe_delete()
                            root.children.remove(child)

                    if best_child is None:
                        # No action has been explored yet – start a fresh tree from here
                        planning_env.load_snapshot(pickle.dumps(test_env))  # sync planner
                        root = DAMCTSRoot(
                            snapshot=planning_env.get_snapshot(),
                            obs=s,                       # current observation from test_env
                            eps_net_func=eps_net_func,
                            action_dim=action_dim,
                            env_name=envname
                        )
                    else:
                        root = DAMCTSRoot.to_root(best_child)

                    # re-plan
                    plan_damcts(root, n_iter=ITERATIONS, env=planning_env)

                if not done:
                    test_env.close()

                seed_returns.append(total_reward)

            # Statistics
            mean_return = statistics.mean(seed_returns)
            std_return = statistics.pstdev(seed_returns)
            interval = 2.0 * std_return

            msg = (f"Env={envname}, ITER={ITERATIONS}: "
                   f"Mean={mean_return:.3f} ± {interval:.3f} "
                   f"(over {num_seeds} seeds)")
            print(msg)
            f_out.write(msg + "\n")
            f_out.flush()  # Ensure data is written immediately

    f_out.close()
    print("Done! Results saved to", results_filename)
