# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import os
import sys
import math
from typing import Union, Sequence, Callable, Optional
import torch
from tensordict import (
    TensorDictBase,
)
from tensordict.nn import TensorDictModule
from torchrl.envs.common import EnvBase
from torchrl.envs.utils import (
    ExplorationType,
    set_exploration_type,
)

_TIMEOUT = 1.0
_MIN_TIMEOUT = 1e-3  # should be several orders of magnitude inferior wrt time spent collecting a trajectory
# MAX_IDLE_COUNT is the maximum number of times a Dataloader worker can timeout with his queue.
_MAX_IDLE_COUNT = int(os.environ.get("MAX_IDLE_COUNT", 1000))

DEFAULT_EXPLORATION_TYPE: ExplorationType = ExplorationType.RANDOM

_is_osx = sys.platform.startswith("darwin")

from torchrl.collectors import SyncDataCollector

class GuidedSyncDataCollector(SyncDataCollector):
    def __init__(
        self,
        create_env_fn: Union[
            EnvBase, "EnvCreator", Sequence[Callable[[], EnvBase]]  # noqa: F821
        ],  # noqa: F821
        policy: Optional[
            Union[
                TensorDictModule,
                Callable[[TensorDictBase], TensorDictBase],
            ]
        ],
        teacher_policy: Optional[
            Union[
                TensorDictModule,
                Callable[[TensorDictBase], TensorDictBase],
            ]
        ],
        *,
        frames_per_batch: int,
        teacher_kwargs: dict = {},
        advice_reset_interval: int = 5000,
        interval_advice_rate: float = 0.125,
        **kwargs
    ):
        super().__init__(
            create_env_fn=create_env_fn, policy=policy, frames_per_batch=frames_per_batch, **kwargs
        )

        self.teacher_policy = teacher_policy
        self.teacher_kwargs = teacher_kwargs
        self.init_issue_rate = teacher_kwargs['follow_prob']
        self.cur_issue_rate = self.init_issue_rate
        self.fix = teacher_kwargs["fix"]

        if self.fix:
            self.advice_count = 0
            self.advice_frames = 0
            self.advice_reset_interval = advice_reset_interval
            self.interval_advice_rate = interval_advice_rate
            self.max_advice_count = self.advice_reset_interval * self.interval_advice_rate
            print(self.interval_advice_rate)
            print(self.advice_reset_interval)

    @torch.no_grad()
    def rollout(self) -> TensorDictBase:
        """Computes a rollout in the environment using the provided policy.

        Returns:
            TensorDictBase containing the computed rollout.

        """
        if self.reset_at_each_iter:
            self._shuttle.update(self.env.reset())

        # self._shuttle.fill_(("collector", "step_count"), 0)
        self._final_rollout.fill_(("collector", "traj_ids"), -1)

        # add issue_advice placeholder
        # add advice placeholders as boolean tensors
        self._final_rollout.set(
            "issue_advice", torch.zeros(self._final_rollout.shape, dtype=torch.bool, device=self._shuttle.device)
        )

        self._final_rollout.set(
            "take_advice", torch.zeros(self._final_rollout.shape, dtype=torch.bool, device=self._shuttle.device)
        )

        num_transfer = self.teacher_kwargs['num_transfer']
        if self.teacher_kwargs['linear_decay_advice']:
            self.cur_issue_rate = (1 - self._frames / num_transfer) * self.init_issue_rate

        if self.teacher_kwargs['ex_decay_advice'] >=1:
            decay_rate = self.teacher_kwargs['ex_decay_advice']  #steepness of the decay
            self.cur_issue_rate = self.init_issue_rate * math.exp(-decay_rate * (self._frames / num_transfer))

        if self._frames > num_transfer :
            self.cur_issue_rate = 0

        tensordicts = []

        if self.fix:
            if self.advice_frames >= self.advice_reset_interval:
                self.advice_frames = 0
                self.advice_count = 0
                # print("resetting advice")

        with set_exploration_type(self.exploration_type):
            for t in range(self.frames_per_batch):
                if (
                    self.init_random_frames is not None
                    and self._frames < self.init_random_frames
                ):
                    self.env.rand_action(self._shuttle)
                else:
                    if self._cast_to_policy_device:
                        if self.policy_device is not None:
                            policy_input = self._shuttle.to(
                                self.policy_device, non_blocking=True
                            )
                        elif self.policy_device is None:
                            # we know the tensordict has a device otherwise we would not be here
                            # we can pass this, clear_device_ must have been called earlier
                            # policy_input = self._shuttle.clear_device_()
                            policy_input = self._shuttle
                    else:
                        policy_input = self._shuttle
                    # we still do the assignment for security
                    policy_input_clone = policy_input.copy()
                    policy_output = self.policy(policy_input)

                    # feed clone to avoid in-place update
                    t_policy_output = self.teacher_policy(policy_input_clone)


                    # print(t_policy_output["step_count"])

                    # follow teacher's advice with 'follow_prob'
                    issue_advice = t_policy_output['raw_energy'] > self.teacher_kwargs['threshold']
                    take_advice = issue_advice

                    if self.fix:
                        self.advice_frames += take_advice.numel()
                        # Limit advice issued over the batch to max_advice_per_batch
                        if self.advice_count < self.max_advice_count:
                            # Count advice for this frame (assuming take_advice is a boolean tensor)
                            current_advice = take_advice.sum().item()
                            if self.advice_count + current_advice >  self.max_advice_count:

                                # Calculate how many more advice tokens we can allow
                                allowed = int(self.max_advice_count - self.advice_count)
                                # Flatten the mask and select only the first 'allowed' indices
                                flat_indices = take_advice.view(-1).nonzero(as_tuple=False)
                                new_take_advice = torch.zeros_like(take_advice, dtype=torch.bool)
                                new_take_advice.view(-1)[flat_indices[:allowed]] = True
                                take_advice = new_take_advice
                                self.advice_count += allowed
                            else:
                                self.advice_count += current_advice
                        else:
                            # Budget already reached: disable advice in this frame
                            take_advice = torch.zeros_like(take_advice, dtype=torch.bool)

                    prob = take_advice.float() * torch.rand_like(take_advice.float())
                    take_advice = prob > (1 - self.cur_issue_rate)

                    policy_output['issue_advice'] = issue_advice
                    policy_output['take_advice'] = take_advice
                    policy_output['raw_energy'] = t_policy_output['raw_energy']

                    keys_to_update = ['action', 'done', 'terminated', 'truncated', 'sample_log_prob']
                    for k in keys_to_update:
                        policy_output[k][take_advice] = t_policy_output[k][take_advice]

                    if self._shuttle is not policy_output:
                        # ad-hoc update shuttle
                        self._shuttle.update(
                            policy_output,
                            keys_to_update=list(set(self._policy_output_keys + keys_to_update))
                        )

                if self._cast_to_policy_device:
                    if self.env_device is not None:
                        env_input = self._shuttle.to(self.env_device, non_blocking=True)
                    elif self.env_device is None:
                        # we know the tensordict has a device otherwise we would not be here
                        # we can pass this, clear_device_ must have been called earlier
                        # env_input = self._shuttle.clear_device_()
                        env_input = self._shuttle
                else:
                    env_input = self._shuttle
                env_output, env_next_output = self.env.step_and_maybe_reset(env_input)

                if self._shuttle is not env_output:
                    # ad-hoc update shuttle
                    next_data = env_output.get("next")
                    if self._shuttle_has_no_device:
                        # Make sure
                        next_data.clear_device_()
                    self._shuttle.set("next", next_data)

                if self.storing_device is not None:
                    tensordicts.append(
                        self._shuttle.to(self.storing_device, non_blocking=False)
                    )
                else:
                    tensordicts.append(self._shuttle)

                # carry over collector data without messing up devices
                collector_data = self._shuttle.get("collector").copy()
                self._shuttle = env_next_output
                if self._shuttle_has_no_device:
                    self._shuttle.clear_device_()
                self._shuttle.set("collector", collector_data)

                self._update_traj_ids(env_output)

                if (
                    self.interruptor is not None
                    and self.interruptor.collection_stopped()
                ):
                    try:
                        torch.stack(
                            tensordicts,
                            self._final_rollout.ndim - 1,
                            out=self._final_rollout[: t + 1],
                        )
                    except RuntimeError:
                        with self._final_rollout.unlock_():
                            torch.stack(
                                tensordicts,
                                self._final_rollout.ndim - 1,
                                out=self._final_rollout[: t + 1],
                            )
                    break
            else:
                try:
                    self._final_rollout = torch.stack(
                        tensordicts,
                        self._final_rollout.ndim - 1,
                        out=self._final_rollout,
                    )
                except RuntimeError:
                    with self._final_rollout.unlock_():
                        self._final_rollout = torch.stack(
                            tensordicts,
                            self._final_rollout.ndim - 1,
                            out=self._final_rollout,
                        )
        # print(f"Policy output keys: {policy_output.keys()}")
        # print(self.advice_count)
        return self._final_rollout





class JumpStartSyncDataCollector(SyncDataCollector):
    def __init__(
            self,
            create_env_fn: Union[
                "EnvBase", "EnvCreator", Sequence[Callable[[], "EnvBase"]]
            ],
            policy: Optional[
                Union[
                    "TensorDictModule",
                    Callable[[TensorDictBase], TensorDictBase],
                ]
            ],
            teacher_policy: Optional[
                Union[
                    "TensorDictModule",
                    Callable[[TensorDictBase], TensorDictBase],
                ]
            ],
            *,
            frames_per_batch: int,
            teacher_kwargs: dict,
            **kwargs
    ):
        """
        Args:
            create_env_fn: Factory function(s) that create environment(s).
            policy: The student policy that acts in the environment.
            teacher_policy: The teacher policy that can override student actions.
            frames_per_batch: Number of environment steps to collect per rollout call.
            teacher_kwargs (dict):
                - init_guided_step (int): initial number of guide-steps H_1.
                - n_stages (int): number of stages in the curriculum.
                - tolerance (float): user-defined tolerance for adjusting guide-steps
                                     or for deciding when to move to next stage
        """
        super().__init__(
            create_env_fn=create_env_fn,
            policy=policy,
            frames_per_batch=frames_per_batch,
            **kwargs
        )

        # Store teacher policy and core hyperparameters
        self.teacher_policy = teacher_policy
        self.init_guided_step = teacher_kwargs.get("init_guided_step", 300)
        self.n_stages = teacher_kwargs.get("n_stages", 20)
        self.tolerance = teacher_kwargs.get("tolerance", 0.15) # larger tolerance might hurt performance
        print(f"the init_guided_step is: {self.init_guided_step} \n n_stages is: {self.n_stages} \n tolerance: {self.tolerance}")

        # Current stage in the curriculum; can be advanced in the trainers.py where we evaluate
        self.current_stage = 0
        # We define the "current issue step" as how many steps from the start
        # of an episode or environment trajectory the teacher will issue advice.
        self.current_issue_step = self.init_guided_step

    def set_current_stage(self):
        """
        Advances the current curriculum stage by one, capping at (n_stages - 1),
        and then updates the number of guiding steps for this stage.
        If the last stage is reached, current_issue_step is set to 0.
        """
        # Increment the stage by 1, capping it to (n_stages - 1)
        self.current_stage = min(self.current_stage + 1, self.n_stages - 1)

        # If we've reached the last stage, set current_issue_step explicitly to 0.
        if self.current_stage == self.n_stages - 1:
            self.current_issue_step = 0
        else:
            # Compute the fraction of progress through the stages.
            if self.n_stages > 1:
                frac = self.current_stage / (self.n_stages - 1)
                new_issue_step = self.init_guided_step * (1 - frac)
            else:
                new_issue_step = self.init_guided_step
            self.current_issue_step = int(max(0, new_issue_step))

        print(f"[JumpStartSyncDataCollector] Stage set to {self.current_stage}, "
              f"current_issue_step={self.current_issue_step}")

    @torch.no_grad()
    def rollout(self) -> TensorDictBase:
        """
        - We only check if the environment's step_count < self.current_issue_step.
        - If yes, teacher always issues and overrides the student's action.
        - If no, the teacher does not intervene.
        """
        if self.reset_at_each_iter:
            self._shuttle.update(self.env.reset())

        # Reset the final rollout's trajectory IDs
        self._final_rollout.fill_(("collector", "traj_ids"), -1)

        # Add placeholders for teacher advice
        self._final_rollout.set(
            "issue_advice",
            torch.zeros(self._final_rollout.shape, dtype=torch.bool, device=self._shuttle.device)
        )
        self._final_rollout.set(
            "take_advice",
            torch.zeros(self._final_rollout.shape, dtype=torch.bool, device=self._shuttle.device)
        )

        tensordicts = []

        with set_exploration_type(self.exploration_type or ExplorationType.RANDOM):
            for t in range(self.frames_per_batch):
                if (
                        self.init_random_frames is not None
                        and self._frames < self.init_random_frames
                ):
                    self.env.rand_action(self._shuttle)
                else:
                    if self._cast_to_policy_device:
                        if self.policy_device is not None:
                            policy_input = self._shuttle.to(
                                self.policy_device, non_blocking=True
                            )
                        else:
                            policy_input = self._shuttle
                    else:
                        policy_input = self._shuttle

                    # Student policy
                    policy_input_clone = policy_input.clone()
                    policy_output = self.policy(policy_input)

                    # Teacher policy
                    t_policy_output = self.teacher_policy(policy_input_clone)

                    # Teacher logic:
                    # "If step_count < current_issue_step, always issue"
                    current_step_counts = t_policy_output["step_count"]

                    # Issue advice mask: step_count < current_issue_step
                    issue_advice = current_step_counts < self.current_issue_step
                    issue_advice = issue_advice.to(torch.bool)
                    issue_advice = issue_advice.squeeze(1)

                    # We take advice if it is issued (no randomness here, by design)
                    take_advice = issue_advice.clone()

                    # Populate the policy_output with these bools
                    policy_output["issue_advice"] = issue_advice
                    policy_output["take_advice"] = take_advice


                    # Where we take advice, override student’s outputs
                    keys_to_update = ["action", "done", "terminated", "truncated", "sample_log_prob"]
                    for k in keys_to_update:
                        policy_output[k][take_advice] = t_policy_output[k][take_advice]

                    if self._shuttle is not policy_output:
                        self._shuttle.update(
                            policy_output,
                            keys_to_update=list(set(self._policy_output_keys + keys_to_update))
                        )

                if self._cast_to_policy_device:
                    if self.env_device is not None:
                        env_input = self._shuttle.to(self.env_device, non_blocking=True)
                    else:
                        env_input = self._shuttle
                else:
                    env_input = self._shuttle

                env_output, env_next_output = self.env.step_and_maybe_reset(env_input)

                if self._shuttle is not env_output:
                    next_data = env_output.get("next")
                    if self._shuttle_has_no_device:
                        next_data.clear_device_()
                    self._shuttle.set("next", next_data)

                if self.storing_device is not None:
                    tensordicts.append(
                        self._shuttle.to(self.storing_device, non_blocking=False)
                    )
                else:
                    tensordicts.append(self._shuttle)

                collector_data = self._shuttle.get("collector").clone()
                self._shuttle = env_next_output
                if self._shuttle_has_no_device:
                    self._shuttle.clear_device_()
                self._shuttle.set("collector", collector_data)

                self._update_traj_ids(env_output)

                if (
                        self.interruptor is not None
                        and self.interruptor.collection_stopped()
                ):
                    try:
                        torch.stack(
                            tensordicts,
                            self._final_rollout.ndim - 1,
                            out=self._final_rollout[: t + 1],
                        )
                    except RuntimeError:
                        with self._final_rollout.unlock_():
                            torch.stack(
                                tensordicts,
                                self._final_rollout.ndim - 1,
                                out=self._final_rollout[: t + 1],
                            )
                    break
            else:
                try:
                    self._final_rollout = torch.stack(
                        tensordicts,
                        self._final_rollout.ndim - 1,
                        out=self._final_rollout,
                    )
                except RuntimeError:
                    with self._final_rollout.unlock_():
                        self._final_rollout = torch.stack(
                            tensordicts,
                            self._final_rollout.ndim - 1,
                            out=self._final_rollout,
                        )

        return self._final_rollout


class KickStartSyncDataCollector(SyncDataCollector):
    def __init__(
            self,
            create_env_fn: Union[
                "EnvBase", "EnvCreator", Sequence[Callable[[], "EnvBase"]]
            ],
            policy: Optional[
                Union[
                    "TensorDictModule",
                    Callable[[TensorDictBase], TensorDictBase],
                ]
            ],
            teacher_policy: Optional[
                Union[
                    "TensorDictModule",
                    Callable[[TensorDictBase], TensorDictBase],
                ]
            ],
            *,
            frames_per_batch: int,
            teacher_kwargs: dict,
            **kwargs
    ):
        """
        Kickstarting-style collector with linear decay of lambda_k.

        teacher_kwargs should include:
            - init_lambda (float): starting value for distillation weight
            - imitation_ends (int): frame count at which lambda_k reaches zero
        """
        super().__init__(
            create_env_fn=create_env_fn,
            policy=policy,
            frames_per_batch=frames_per_batch,
            **kwargs
        )

        self.teacher_policy = teacher_policy
        self.init_lambda = teacher_kwargs.get("init_lambda", 1.0)
        self.imitation_ends = teacher_kwargs.get("imitation_ends", 1_000_000)

    def compute_lambda_k(self) -> float:
        """Linearly decays from init_lambda to 0 over imitation_ends frames."""
        if self._frames >= self.imitation_ends:
            return 0.0
        decay_ratio = 1.0 - (self._frames / self.imitation_ends)
        # print(f"current lambda: {self.init_lambda * decay_ratio}")
        return self.init_lambda * decay_ratio

    @torch.no_grad()
    def rollout(self) -> TensorDictBase:
        if self.reset_at_each_iter:
            self._shuttle.update(self.env.reset())

        self._final_rollout.fill_(("collector", "traj_ids"), -1)

        tensordicts = []

        with set_exploration_type(self.exploration_type or ExplorationType.RANDOM):
            for t in range(self.frames_per_batch):
                lambda_k = self.compute_lambda_k()

                if (
                        self.init_random_frames is not None
                        and self._frames < self.init_random_frames
                ):
                    self.env.rand_action(self._shuttle)
                else:
                    policy_input = self._shuttle
                    policy_input_clone = policy_input.clone()

                    # Get student and teacher outputs
                    policy_output = self.policy(policy_input)
                    t_policy_output = self.teacher_policy(policy_input_clone)

                    # Compute cross-entropy loss
                    teacher_logits = t_policy_output["logits"]
                    student_logits = policy_output["logits"]
                    teacher_probs = torch.softmax(teacher_logits, dim=-1)
                    log_student_probs = torch.log_softmax(student_logits, dim=-1)
                    cross_entropy = -(teacher_probs * log_student_probs).sum(-1)

                    # Save scaled distillation loss
                    policy_output["distill_loss"] = lambda_k * cross_entropy.detach()
                    # print(f"distill_loss: {lambda_k * cross_entropy.detach()}")
                    policy_output["lambda_k"] = torch.full_like(cross_entropy, lambda_k)

                    if self._shuttle is not policy_output:
                        self._shuttle.update(
                            policy_output,
                            keys_to_update=list(set(self._policy_output_keys + ["distill_loss"]))
                        )

                env_input = self._shuttle
                env_output, env_next_output = self.env.step_and_maybe_reset(env_input)

                if self._shuttle is not env_output:
                    next_data = env_output.get("next")
                    if self._shuttle_has_no_device:
                        next_data.clear_device_()
                    self._shuttle.set("next", next_data)

                if self.storing_device is not None:
                    tensordicts.append(self._shuttle.to(self.storing_device))
                else:
                    tensordicts.append(self._shuttle)

                collector_data = self._shuttle.get("collector").clone()
                self._shuttle = env_next_output
                if self._shuttle_has_no_device:
                    self._shuttle.clear_device_()
                self._shuttle.set("collector", collector_data)

                self._update_traj_ids(env_output)

        self._final_rollout = torch.stack(tensordicts, self._final_rollout.ndim - 1)
        return self._final_rollout
