from dataclasses import dataclass
from hmac import new

import torch

from matplotlib.transforms import Transform
from tensordict import LazyStackedTensorDict, pad_sequence, TensorDict, TensorDictBase
from tensordict.nn import TensorDictModule
from torchrl.envs.transforms import Reward2GoTransform, Transform
from torchrl.objectives.value.utils import _get_num_per_traj, _split_and_pad_sequence

# import sys
# sys.path.append("..")
# from ..data.pyaig.aig_env import AIGEnv
# reward_function = AIGEnv._reward_function

# T = len(experiences) # episode length
# for t in range(T-1):
# # for t in range(0):
#     # state s[t] sample k goals from states s[t+1]...s[T] and create goals set G
#     sub_goals_idxs = random.sample(range(t+1, T), min(num_sub_goals, T - t - 1))
#     for sub_goal_idx in sub_goals_idxs:
#         sub_goal_target = experiences[sub_goal_idx]["nodes"][-1, :].clone().unsqueeze(0)
#         new_experiences = copy.deepcopy(experiences[t:sub_goal_idx])
#         new_reward = torch.exp(torch.tensor( [-len(new_experiences[-1]["nodes"])] ))

#         for e in new_experiences:
#             e["target"] = sub_goal_target
#             e["reward"] = new_reward

#         # if len(new_experiences) == 1:
#         #     replay_buffer.add(new_experiences[0]) # type: ignore
#         # else:
#         #     # replay_buffer.extend(new_experiences) # type: ignore
#         replay_buffer.extend(LazyStackedTensorDict(*new_experiences, stack_dim=0)) #type: ignore


#         if allow_inv_experiences:
#             inv_experiences = copy.deepcopy(new_experiences) # inverse experiences
#             inv_sub_goal_target = ~sub_goal_target
#             for e in inv_experiences:
#                 e["target"] = inv_sub_goal_target

#             # if len(inv_experiences) == 1:
#             #     replay_buffer.add(inv_experiences[0]) # type: ignore
#             # else:
#             #     # replay_buffer.extend(inv_experiences) # type: ignore
#             replay_buffer.extend(LazyStackedTensorDict(*inv_experiences, stack_dim=0)) #type: ignore
# trajectories = _get_num_per_traj(sampled_td.get("terminated"))
# splitted_td = _split_and_pad_sequence(sampled_td, trajectories)
# splitted_achieved_goals = splitted_td.get(self.achieved_goal_key)


@dataclass
class HERConfig:
    num_sub_goals: int = 4
    allow_inv_experiences: bool = True


class AIGSubGoalAssigner(Transform):
    """This module assigns the subgoal to the trajectory according to a given subgoal index.

    Args:
        subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
        subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
    """

    def __init__(
        self,
        desired_goal_key: str = "target",
    ):
        self.desired_goal_key = desired_goal_key

    def forward(
        self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor
    ) -> TensorDictBase:
        batch_size, trajectory_len = trajectories.shape

        new_goals = []
        for i in range(batch_size):
            subgoal = trajectories[i][subgoals_idxs[i]].get_nestedtensor("nodes")[-1][
                -1, :
            ]
            trajectories[i, -1]["next", "done"][0] = False
            trajectories[i, subgoals_idxs.flatten(0)[i]]["next", "done"][0] = True

            desired_goal_shape = trajectories[i][self.desired_goal_key].shape

            new_goals.append(subgoal.expand(desired_goal_shape))

        stacked_goals = torch.stack(new_goals)
        trajectories["next", self.desired_goal_key] = stacked_goals
        trajectories[self.desired_goal_key] = stacked_goals

        return trajectories


class AIGRewardTransform(Transform):
    """This module assigns the reward to the trajectory according to the new subgoal.

    Args:
        reward_name (str): The key to the reward. Defaults to "reward".
    """

    def __init__(
        self,
        gamma: float = 0.99,
        reward_key: str = "reward",
    ):
        self.reward_key = reward_key
        self.gamma = gamma

    def forward(
        self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor
    ) -> TensorDictBase:
        batch_size, trajectory_len = trajectories.shape
        for i in range(batch_size):
            # num_nodes = trajectories[i].get_nestedtensor("nodes")[-1].shape[-2]
            # new_reward = torch.clamp(
            #     torch.exp(trajectories[i]["num_inputs"][-1] * 2 - num_nodes), max=1
            # )
            new_reward = 1.0
            trajectories[i][subgoals_idxs.flatten(0)[i]]["next", "reward"][
                0
            ] = new_reward
            trajectories[i][-1]["next", "reward"][0] = 0.0
        trajectories = Reward2GoTransform(self.gamma, ("next", "reward"), "reward").inv(
            trajectories
        )
        for i in range(batch_size):
            trajectories[i][0]["reward"][0] = 0.0

        return trajectories


class AIGNegateTarget(Transform):
    def __init__(
        self,
        target_key: str = "target",
    ):
        self.target_key = target_key

    def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
        new_trajectories = trajectories.clone(True)
        negated_goals = []
        for i in range(new_trajectories.batch_size[0]):
            negated_goals.append(~new_trajectories[i][self.target_key])

        stacked_goals = torch.stack(negated_goals)
        new_trajectories["next", self.target_key] = stacked_goals
        new_trajectories[self.target_key] = stacked_goals

        return torch.cat([trajectories, new_trajectories], dim=0)  # type: ignore


class HERSubGoalSampler(Transform):
    """Returns a TensorDict with a key `subgoal_idx` of shape [batch_size, num_samples] represebting the subgoal index.
    Available strategies are: `last` and `future`. The `last` strategy assigns the last state as the subgoal. The `future` strategy samples up to `num_samples` subgoal from the future states.

    Args:
        num_samples (int): Number of subgoals to sample from each trajectory. Defaults to 4.
        out_keys (str): The key to store the subgoal index. Defaults to "subgoal_idx".
    """

    def __init__(
        self,
        num_samples: int = 4,
        subgoal_idx_key: str = "subgoal_idx",
        strategy: str = "future",
    ):
        super().__init__(
            in_keys=None,  # type: ignore
            in_keys_inv=None,
            out_keys_inv=None,
        )
        self.num_samples = num_samples
        self.subgoal_idx_key = subgoal_idx_key
        self.strategy = strategy

    def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
        if len(trajectories.shape) == 1:
            trajectories = trajectories.unsqueeze(0)

        batch_size, trajectory_len = trajectories.shape

        if self.strategy == "last":
            return TensorDict(
                {"subgoal_idx": torch.full((batch_size, 1), -1)}, batch_size=batch_size
            )

        else:
            subgoal_idxs = []
            for i in range(batch_size):
                subgoal_idxs.append(
                    TensorDict(
                        {
                            "subgoal_idx": (torch.randperm(trajectory_len - 2) + 1)[
                                : self.num_samples
                            ]
                        },
                        batch_size=torch.Size(),
                    )
                )
            return pad_sequence(subgoal_idxs, pad_dim=0, return_mask=True)


class HERSubGoalAssigner(Transform):
    """This module assigns the subgoal to the trajectory according to a given subgoal index.

    Args:
        subgoal_idx_name (str): The key to the subgoal index. Defaults to "subgoal_idx".
        subgoal_name (str): The key to assign the observation of the subgoal to the goal. Defaults to "goal".
    """

    def __init__(
        self,
        achieved_goal_key: str = "achieved_goal",
        desired_goal_key: str = "desired_goal",
    ):
        self.achieved_goal_key = achieved_goal_key
        self.desired_goal_key = desired_goal_key

    def forward(
        self, trajectories: TensorDictBase, subgoals_idxs: torch.Tensor
    ) -> TensorDictBase:
        batch_size, trajectory_len = trajectories.shape
        for i in range(batch_size):
            subgoal = trajectories[i][subgoals_idxs[i]][self.achieved_goal_key]
            desired_goal_shape = trajectories[i][self.desired_goal_key].shape
            trajectories[i][self.desired_goal_key] = subgoal.expand(desired_goal_shape)
            trajectories[i][subgoals_idxs[i]]["next", "done"][0] = True
            # trajectories[i][subgoals_idxs[i]+1:]["truncated"] = True

        return trajectories


class HERRewardTransform(Transform):
    """This module assigns the reward to the trajectory according to the new subgoal.

    Args:
        reward_name (str): The key to the reward. Defaults to "reward".
    """

    def __init__(self):
        pass

    def forward(self, trajectories: TensorDictBase) -> TensorDictBase:
        return trajectories


class HindsightExperienceReplayTransform(Transform):
    """Hindsight Experience Replay (HER) is a technique that allows to learn from failure by creating new experiences from the failed ones.
    This module is a wrapper that includes the following modules:
    - SubGoalSampler: Creates new trajectories by sampling future subgoals from the same trajectory.
    - SubGoalAssigner: Assigns the subgoal to the trajectory according to a given subgoal index.
    - RewardTransform: Assigns the reward to the trajectory according to the new subgoal.

    Args:
        SubGoalSampler (Transform):
        SubGoalAssigner (Transform):
        RewardTransform (Transform):
    """

    def __init__(
        self,
        SubGoalSampler: Transform = HERSubGoalSampler(),
        SubGoalAssigner: Transform = HERSubGoalAssigner(),
        RewardTransform: Transform = HERRewardTransform(),
        PostTransaform: Transform | None = None,
        assign_subgoal_idxs: bool = False,
    ):
        super().__init__(
            in_keys=None,  # type: ignore
            in_keys_inv=None,
            out_keys_inv=None,
        )
        self.SubGoalSampler = SubGoalSampler
        self.SubGoalAssigner = SubGoalAssigner
        self.RewardTransform = RewardTransform
        self.PostTransaform = PostTransaform
        self.assign_subgoal_idxs = assign_subgoal_idxs

    def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
        augmentation_td = self.her_augmentation(tensordict)
        return torch.cat([tensordict, augmentation_td], dim=0)  # type: ignore

    def _inv_apply_transform(self, tensordict: TensorDictBase) -> torch.Tensor:
        return self.her_augmentation(tensordict)  # type: ignore

    def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
        return tensordict

    def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
        raise ValueError(self.ENV_ERR)

    def her_augmentation(self, trajectories: TensorDictBase):
        if len(trajectories.shape) == 1:
            trajectories = trajectories.unsqueeze(0)
        batch_size, trajectory_length = trajectories.shape

        new_trajectories = trajectories.clone(True)

        # Sample subgoal indices
        subgoal_idxs = self.SubGoalSampler(new_trajectories)

        # Create new trajectories
        augmented_trajectories = []
        list_idxs = []
        for i in range(batch_size):
            idxs = subgoal_idxs[i][self.SubGoalSampler.subgoal_idx_key]

            if "masks" in subgoal_idxs.keys():
                idxs = idxs[
                    subgoal_idxs[i]["masks", self.SubGoalSampler.subgoal_idx_key]
                ]

            list_idxs.append(idxs.unsqueeze(-1))
            new_traj = (
                new_trajectories[i]
                .expand((idxs.numel(), trajectory_length))
                .clone(True)
            )

            if self.assign_subgoal_idxs:
                new_traj[self.SubGoalSampler.subgoal_idx_key] = idxs.unsqueeze(
                    -1
                ).repeat(1, trajectory_length)

            augmented_trajectories.append(new_traj)
        augmented_trajectories = torch.cat(augmented_trajectories, dim=0)
        associated_idxs = torch.cat(list_idxs, dim=0)

        # Assign subgoals to the new trajectories
        augmented_trajectories = self.SubGoalAssigner.forward(augmented_trajectories, associated_idxs)  # type: ignore

        # Adjust the rewards based on the new subgoals
        augmented_trajectories = self.RewardTransform.forward(augmented_trajectories, associated_idxs)  # type: ignore

        # Apply post transform
        if self.PostTransaform is not None:
            augmented_trajectories = self.PostTransaform.forward(augmented_trajectories)

        stacked_trajectories = torch.cat(
            [
                traj[: associated_idxs.flatten(0)[i % associated_idxs.numel()] + 1]
                for i, traj in enumerate(augmented_trajectories)
            ],
            dim=0,
        )

        return stacked_trajectories

        print(augmented_trajectories)

        # Sample and create new trajectories
        list_trajectories = []
        for i in range(batch_size):
            new_traj = self.sample_and_create_trajectories(new_trajectories[i])
            list_trajectories.extend(new_traj)
            # for traj in new_traj:
            #     list_trajectories.append(traj)
        print(list_trajectories)

        #
        for traj in list_trajectories:
            self.assign_subgoal(traj)

        for traj in list_trajectories:
            self.assign_reward(traj)

        steps = []
        for traj in list_trajectories:
            steps.extend(traj.unbind(0))

        # list_trajectories = [traj.unbind(0) for traj in list_trajectories]
        return LazyStackedTensorDict(*steps, stack_dim=0)

        return torch.stack(list_trajectories)

        print(new_trajectories)

        # # get indices for each trajectory
        # idxs = self.generate_sample_idxs(trajectories)

        # # create new goals based idxs
        # new_goals = []
        # for i, ids in enumerate(idxs):
        #     new_goals.append(splitted_achieved_goals[i][ids])

        # # calculate rewards given new desired goals and old achieved goals
        # vmap_rewards = torch.vmap(distance_reward_function)

    def sample_and_create_trajectories(self, trajectory: TensorDict):
        trajectory_len = trajectory.batch_size[0]
        print("trajectory_len:", trajectory_len)
        idxs = (torch.randperm(trajectory_len - 2) + 1)[: self.samples]
        print("idxs shape:", idxs.shape)
        print(idxs.size() + trajectory.batch_size)
        trajectory = trajectory.expand((idxs.numel(), trajectory_len))
        trajectory["subgoal_idx"] = idxs.unsqueeze(-1).repeat(1, trajectory_len)

        if False:
            print("trajectory batch size:", trajectory.batch_size)
            idxs = idxs.unsqueeze(-1).repeat(1, trajectory_len)
            print("idxs new shape:", idxs.shape)
            print("trajectories shape:", trajectory.batch_size)

            vm_assign_subgoals = torch.vmap(self.assign_subgoal)
            new_trajectories = vm_assign_subgoals(trajectory)

        new_trajectories = [
            traj[: idxs[i]].clone(True) for i, traj in enumerate(trajectory)
        ]
        return new_trajectories

        print("new_trajectories shape:", new_trajectories[0].shape)
        print(new_trajectories)
        # for traj in
        # return new_trajectories

        # set the rest of the indices past the end to trancated = True

        # set subgoals with a module

        # reassign rewards

        # return trajectories

    def assign_subgoal(
        self,
        trajectory,
    ):
        subgoal_idx = trajectory["subgoal_idx"][0]
        trajectory_len = trajectory.batch_size[0]
        if False:
            print("trajectory_len:", trajectory_len)
            print("target shape", trajectory.get_nestedtensor("nodes")[-1][-1, :].shape)
            print(
                "new shape",
                trajectory.get_nestedtensor("nodes")[-1][-1, :]
                .repeat(trajectory_len, 1)
                .shape,
            )
        trajectory["target"] = trajectory.get_nestedtensor("nodes")[-1][-1, :].repeat(
            trajectory_len, 1
        )
        # new_trajectory["target"] = new_trajectory.get_nestedtensor("nodes")[-1].repeat(new_trajectory.shape[-1], 1)

        # print("subgoal_idx", subgoal_idx)
        # print(type(new_trajectory))
        # new_trajectory = pad_sequence([*new_trajectory], pad_dim=1)
        # print(new_trajectory.get("nodes"))
        # new_trajectory = trajectory.clone(True)

    def assign_reward(self, trajectory):
        num_nodes = trajectory.get_nestedtensor("nodes")[-1].shape[-2]
        new_reward = torch.clamp(
            torch.exp(trajectory["num_inputs"][-1] * 2 - num_nodes), max=1
        )
        trajectory["reward"][-1] = new_reward
        trajectory["done"][-1] = True
        trajectory = Reward2GoTransform(0.99, "reward", "reward", "done").inv(
            trajectory
        )
        # return trajectory
