"""Default Worker class."""
from collections import defaultdict

import gym
import numpy as np
import os

from garage import TrajectoryBatch
from garage.experiment import deterministic
from garage.sampler.env_update import EnvUpdate
from garage.sampler.worker import Worker
from iod.disentanglement import linear_disentanglement


class DefaultWorker(Worker):
    """Initialize a worker.

    Args:
        seed(int): The seed to use to intialize random number generators.
        max_path_length(int or float): The maximum length paths which will
            be sampled. Can be (floating point) infinity.
        worker_number(int): The number of the worker where this update is
            occurring. This argument is used to set a different seed for each
            worker.

    Attributes:
        agent(Policy or None): The worker's agent.
        env(gym.Env or None): The worker's environment.

    """

    def __init__(
            self,
            *,  # Require passing by keyword, since everything's an int.
            seed,
            max_path_length,
            worker_number):
        super().__init__(seed=seed,
                         max_path_length=max_path_length,
                         worker_number=worker_number)
        self.agent = None
        self.encoder = None
        self.env = None
        self._observations = []
        self._last_observations = []
        self._actions = []
        self._rewards = []
        self._terminals = []
        self._lengths = []
        self._agent_infos = defaultdict(list)
        self._env_infos = defaultdict(list)
        self._prev_obs = None
        self._path_length = 0
        self.worker_init()

    def worker_init(self):
        """Initialize a worker."""
        if self._seed is not None:
            deterministic.set_seed(self._seed + self._worker_number)

    def update_agent(self, agent_update):
        """Update an agent, assuming it implements garage.Policy.

        Args:
            agent_update (np.ndarray or dict or garage.Policy): If a
                tuple, dict, or np.ndarray, these should be parameters to
                agent, which should have been generated by calling
                `policy.get_param_values`. Alternatively, a policy itself. Note
                that other implementations of `Worker` may take different types
                for this parameter.

        """
        if isinstance(agent_update, (dict, tuple, np.ndarray)):
            self.agent.set_param_values(agent_update)
        elif agent_update is not None:
            self.agent = agent_update

    def update_encoder(self, encoder_update):
        if isinstance(encoder_update, (dict, tuple, np.ndarray)):
            self.encoder.load_state_dict(encoder_update)
        elif encoder_update is not None:
            self.encoder = encoder_update.cpu()

    def update_env(self, env_update):
        """Use any non-None env_update as a new environment.

        A simple env update function. If env_update is not None, it should be
        the complete new environment.

        This allows changing environments by passing the new environment as
        `env_update` into `obtain_samples`.

        Args:
            env_update(gym.Env or EnvUpdate or None): The environment to
                replace the existing env with. Note that other implementations
                of `Worker` may take different types for this parameter.

        Raises:
            TypeError: If env_update is not one of the documented types.

        """
        if env_update is not None:
            if isinstance(env_update, EnvUpdate):
                self.env = env_update(self.env)
            elif isinstance(env_update, gym.Env):
                if self.env is not None:
                    self.env.close()
                self.env = env_update
            else:
                raise TypeError('Uknown environment update type.')

    def start_rollout(self):
        """Begin a new rollout."""
        self._path_length = 0
        self._prev_obs = self.env.reset()
        self.agent.reset()

    def step_rollout(self):
        """Take a single time-step in the current rollout.

        Returns:
            bool: True iff the path is done, either due to the environment
            indicating termination of due to reaching `max_path_length`.

        """
        if self._path_length < self._max_path_length:
            a, agent_info = self.agent.get_action(self._prev_obs)
            next_o, r, d, env_info = self.env.step(a)
            self._observations.append(self._prev_obs)
            self._rewards.append(r)
            self._actions.append(a)
            for k, v in agent_info.items():
                self._agent_infos[k].append(v)
            for k, v in env_info.items():
                self._env_infos[k].append(v)
            self._path_length += 1
            self._terminals.append(d)
            if not d:
                self._prev_obs = next_o
                return False
        self._lengths.append(self._path_length)
        self._last_observations.append(self._prev_obs)
        return True

    def collect_rollout(self):
        """Collect the current rollout, clearing the internal buffer.

        Returns:
            garage.TrajectoryBatch: A batch of the trajectories completed since
                the last call to collect_rollout().

        """
        observations = self._observations
        self._observations = []
        last_observations = self._last_observations
        self._last_observations = []
        actions = self._actions
        self._actions = []
        rewards = self._rewards
        self._rewards = []
        terminals = self._terminals
        self._terminals = []
        env_infos = self._env_infos
        self._env_infos = defaultdict(list)
        agent_infos = self._agent_infos
        self._agent_infos = defaultdict(list)

        # Calculate disentanglement score if we have both ground-truth states
        # and encoder outputs
        if hasattr(self, "_deterministic_policy") and self._deterministic_policy and len(self._ground_truth_states) > 0 and len(self._encoder_outputs) > 0:
            # Convert lists to numpy arrays
            ground_truth_matrix = np.stack(self._ground_truth_states, axis=0).astype(np.float64)
            encoder_matrix = np.stack(self._encoder_outputs, axis=0).astype(np.float64)

            # Filter out constant dimensions before computing disentanglement
            variance_threshold = 1e-8

            # For absolute states
            state_variance = np.var(ground_truth_matrix, axis=0)
            active_dims = state_variance > variance_threshold

            # Only compute if there are active dimensions
            if np.sum(active_dims) > 0:
                ground_truth_filtered = ground_truth_matrix[:, active_dims]
                # Note: encoder matrix is not filtered - we want to see which encoder dims predict active GT dims
                r_square = linear_disentanglement(ground_truth_filtered, encoder_matrix, mode="r2")
                pearson = linear_disentanglement(ground_truth_filtered, encoder_matrix, mode="pearson")
            else:
                # No active dimensions - everything is constant
                r_square = float('nan')
                pearson = float('nan')

            # Calculate multi-step differences for dynamics assessment
            log_dict = {
                "r_square": float(r_square),
                "pearson": float(pearson)
            }

            # Test dynamics at multiple time scales
            # Fine-grained for short-term dynamics, coarser for long-term
            for step_size in [1, 2, 3, 4, 5, 10, 20]:
                if len(ground_truth_matrix) > step_size:
                    # Compute multi-step differences
                    gt_diff = ground_truth_matrix[step_size:] - ground_truth_matrix[:-step_size]
                    enc_diff = encoder_matrix[step_size:] - encoder_matrix[:-step_size]

                    # Filter out constant dimensions
                    diff_variance = np.var(gt_diff, axis=0)
                    active_diff_dims = diff_variance > variance_threshold

                    if np.sum(active_diff_dims) > 0:
                        gt_diff_filtered = gt_diff[:, active_diff_dims]
                        r2_diff = linear_disentanglement(gt_diff_filtered, enc_diff, mode="r2")
                        pearson_diff = linear_disentanglement(gt_diff_filtered, enc_diff, mode="pearson")
                    else:
                        r2_diff = float('nan')
                        pearson_diff = float('nan')

                    # Add to log dict with consistent naming
                    log_dict[f"r_square_diff_{step_size}_step"] = float(r2_diff)
                    log_dict[f"pearson_diff_{step_size}_step"] = float(pearson_diff)
        else:
            log_dict = {}

        # Clear the lists
        self._ground_truth_states = []
        self._encoder_outputs = []

        for k, v in agent_infos.items():
            agent_infos[k] = np.asarray(v)
        for k, v in env_infos.items():
            env_infos[k] = np.asarray(v)
        lengths = self._lengths
        self._lengths = []
        return TrajectoryBatch(self.env.spec, np.asarray(observations),
                               np.asarray(last_observations),
                               np.asarray(actions), np.asarray(rewards),
                               np.asarray(terminals), dict(env_infos),
                               dict(agent_infos), np.asarray(lengths,
                                                             dtype='i')), log_dict

    def rollout(self):
        """Sample a single rollout of the agent in the environment.

        Returns:
            garage.TrajectoryBatch: The collected trajectory.

        """
        self.start_rollout()
        while not self.step_rollout():
            pass
        return self.collect_rollout()

    def shutdown(self):
        """Close the worker's environment."""
        self.env.close()
