"""
Modified from the official ICVF codebase: https://github.com/dibyaghosh/icvf_release/blob/main/src/gc_dataset.py

Original lincense information:

MIT License

Copyright (c) 2023 Dibya Ghosh

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import dataclasses
import numpy as np
import jax
import ml_collections
from flax.core import frozen_dict


@dataclasses.dataclass
class GCDataset:
    dataset: dict
    p_randomgoal: float
    p_trajgoal: float
    p_currgoal: float
    terminal_key: str = "dones"
    reward_scale: float = 1.0
    reward_shift: float = -1.0
    terminal: bool = True
    max_distance: int = None
    curr_goal_shift: int = 0

    @staticmethod
    def get_default_config():
        return ml_collections.ConfigDict(
            {
                "p_randomgoal": 0.3,
                "p_trajgoal": 0.5,
                "p_currgoal": 0.2,
                "reward_scale": 1.0,
                "reward_shift": -1.0,
                "terminal": True,
                "max_distance": ml_collections.config_dict.placeholder(int),
                "curr_goal_shift": 0,
            }
        )

    def __post_init__(self):
        (self.terminal_locs,) = np.nonzero(
            self.dataset.dataset_dict[self.terminal_key] > 0
        )
        self.terminal_locs = np.concatenate(
            [self.terminal_locs, [len(self.dataset) - 1]], axis=0
        )
        print(f"Number of terminal states: {len(self.terminal_locs)}")
        if len(self.terminal_locs) == 0:
            print("No terminal states found in dataset")
            self.terminal_locs = np.arange(100, len(self.dataset) + 100, 100)
            print("Manually setting terminal states to every 100th state")
        assert np.isclose(self.p_randomgoal + self.p_trajgoal + self.p_currgoal, 1.0)

    def insert(self, data_dict):
        self.dataset.insert(data_dict)
        (self.terminal_locs,) = np.nonzero(
            self.dataset.dataset_dict[self.terminal_key] > 0
        )
        self.terminal_locs = np.concatenate(
            [self.terminal_locs, [len(self.dataset) - 1]], axis=0
        )
        
    def sample_goals(self, indx, p_randomgoal=None, p_trajgoal=None, p_currgoal=None):
        if p_randomgoal is None:
            p_randomgoal = self.p_randomgoal
        if p_trajgoal is None:
            p_trajgoal = self.p_trajgoal
        if p_currgoal is None:
            p_currgoal = self.p_currgoal

        batch_size = len(indx)
        # Random goals
        goal_indx = np.random.randint(
            len(self.dataset) - self.curr_goal_shift, size=batch_size
        )

        # Goals from the same trajectory
        final_state_indx = self.terminal_locs[np.searchsorted(self.terminal_locs, indx)]
        if self.max_distance is not None:
            final_state_indx = np.clip(final_state_indx, 0, indx + self.max_distance)

        distance = np.random.rand(batch_size)
        middle_goal_indx = np.round(
            ((indx) * distance + final_state_indx * (1 - distance))
        ).astype(int)

        goal_indx = np.where(
            np.random.rand(batch_size) < p_trajgoal / (1.0 - p_currgoal),
            middle_goal_indx,
            goal_indx,
        )

        # Goals at the current state
        goal_indx = np.where(np.random.rand(batch_size) < p_currgoal, indx, goal_indx)
        return goal_indx

    def sample(self, batch_size: int, indx=None):
        if indx is None:
            indx = np.random.randint(len(self.dataset) - 1, size=batch_size)

        batch = self.dataset.sample(batch_size, indx)
        goal_indx = self.sample_goals(indx)

        success = indx == goal_indx
        batch["rewards"] = success.astype(float) * self.reward_scale + self.reward_shift
        if self.terminal:
            batch["masks"] = 1.0 - success.astype(float)
        else:
            batch["masks"] = np.ones(batch_size)
        batch["goals"] = jax.tree_map(
            lambda arr: arr[goal_indx + self.curr_goal_shift],
            self.dataset["observations"],
        )

        return batch


@dataclasses.dataclass
class GCSDataset(GCDataset):
    p_samegoal: float = 0.5
    intent_sametraj: bool = False

    @staticmethod
    def get_default_config():
        return ml_collections.ConfigDict(
            {
                "p_randomgoal": 0.3,
                "p_trajgoal": 0.5,
                "p_currgoal": 0.2,
                "reward_scale": 1.0,
                "reward_shift": -1.0,
                "terminal": True,
                "p_samegoal": 0.5,
                "intent_sametraj": False,
                "max_distance": ml_collections.config_dict.placeholder(int),
                "curr_goal_shift": 0,
            }
        )

    def sample(self, batch_size: int, indx=None):
        if indx is None:
            indx = np.random.randint(len(self.dataset) - 100, size=batch_size)

        batch = frozen_dict.unfreeze(self.dataset.sample(batch_size, indx=indx))

        if self.intent_sametraj:
            desired_goal_indx = self.sample_goals(
                indx,
                p_randomgoal=0.0,
                p_trajgoal=1.0 - self.p_currgoal,
                p_currgoal=self.p_currgoal,
            )
        else:
            desired_goal_indx = self.sample_goals(indx)

        goal_indx = self.sample_goals(indx)
        goal_indx = np.where(
            np.random.rand(batch_size) < self.p_samegoal, desired_goal_indx, goal_indx
        )

        success = indx == goal_indx
        desired_success = indx == desired_goal_indx

        batch["rewards"] = success.astype(float) * self.reward_scale + self.reward_shift
        batch["desired_rewards"] = (
            desired_success.astype(float) * self.reward_scale + self.reward_shift
        )

        if self.terminal:
            batch["masks"] = 1.0 - success.astype(float)
            batch["desired_masks"] = 1.0 - desired_success.astype(float)

        else:
            batch["masks"] = np.ones(batch_size)
            batch["desired_masks"] = np.ones(batch_size)

        goal_indx = np.clip(goal_indx + self.curr_goal_shift, 0, len(self.dataset) - 1)
        desired_goal_indx = np.clip(
            desired_goal_indx + self.curr_goal_shift, 0, len(self.dataset) - 1
        )
        batch["goals"] = self.dataset.sample(batch_size, indx=goal_indx)["observations"]
        batch["desired_goals"] = self.dataset.sample(
            batch_size, indx=desired_goal_indx
        )["observations"]

        return batch