import torch
from typing import Callable, Dict, Generator, Tuple, Optional

from src.hp_student.storage.replay_storage import ReplayStorage
from src.hp_student.storage.storage import Dataset, Storage, Transition
from src.physical_design import ENVELOPE_P
from src.utils.utils import ActionMode, energy_value


class DualReplayStorage(Storage):
    def __init__(self, environment_count: int, hp_max_size: int, ha_max_size: int, device: str = "cpu",
                 hp_initial_size: int = 0, ha_initial_size: int = 0) -> None:
        self._env_count = environment_count
        self.hp_max_size = hp_max_size
        self.ha_max_size = ha_max_size
        self.hp_initial_size = hp_initial_size
        self.ha_initial_size = ha_initial_size
        self.device = device

        # HP-Student Buffer
        self.hp_storage = ReplayStorage(environment_count=environment_count, max_size=hp_max_size,
                                        initial_size=hp_initial_size, device=device)

        # HA-Teacher Buffer
        self.ha_storage = ReplayStorage(environment_count=environment_count, max_size=ha_max_size,
                                        initial_size=ha_initial_size, device=device)

        self.last_data_type = None

    @property
    def initialized(self) -> bool:
        """Returns whether the storage is initialized."""
        return self.hp_storage.initialized and self.ha_storage.initialized

    @property
    def hp_transition_count(self) -> int:
        return self.hp_storage.sample_count

    @property
    def ha_transition_count(self) -> int:
        return self.ha_storage.sample_count

    def append(self, dataset: Dataset, action_type=None) -> None:
        """Appends a dataset of transitions to the storage.

        Args:
            dataset (Dataset): The dataset of transitions.
        """

        if torch.any(action_type == 1):
            self.ha_storage.append(dataset=dataset, action_type=action_type)
            self.last_data_type = ActionMode.TEACHER

        elif torch.any(action_type == 0):
            self.hp_storage.append(dataset=dataset, action_type=action_type)
            self.last_data_type = ActionMode.STUDENT

        else:
            raise RuntimeError(f"Unrecognized action type: {action_type} for dataset")

    def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Transition, None, None]:
        """Returns a generator that yields batches of transitions.

        Args:
            batch_size (int): The size of the batches.
            batch_count (int): The number of batches to yield.
        Returns:
            A generator that yields batches of transitions.
        """
        L = batch_size
        # import pdb
        # pdb.set_trace()
        #######################   Replay Buffer Batch Sample   #######################

        # Get last row of HP-Student Buffer for computing safety status
        if self.last_data_type == ActionMode.STUDENT:
            idx = self.hp_storage.sample_count - 1
            boundary_state = self.hp_storage._data['actor_observations'][idx]

        # Get last row of HA-Teacher Buffer for computing safety status
        elif self.last_data_type == ActionMode.TEACHER:
            idx = self.ha_storage.sample_count - 1
            boundary_state = self.ha_storage._data['actor_observations'][idx]
        else:
            raise RuntimeError(f"Unknown action type pointer: {self.last_data_type}")

        # Calculate the safety status indicator by sT*P*s
        Vs = energy_value(boundary_state[2:].cpu().numpy(), ENVELOPE_P) * 0.01
        # print(f"Vs: {Vs}")

        # Batch size for HA-Buffer
        ha_batch_size = max(min(L - 1, int(L * Vs)), 1)

        # Batch size for HP-Buffer
        # hp_batch_size = max(L - min(L, int(L * Vs)), 1)
        hp_batch_size = L - ha_batch_size

        hp_gen = self.hp_storage.batch_generator(hp_batch_size, batch_count) if self.hp_transition_count > 0 else None
        ha_gen = self.ha_storage.batch_generator(ha_batch_size, batch_count) if self.ha_transition_count > 0 else None

        for _ in range(batch_count):
            batches = []
            if hp_gen:
                batches.append(next(hp_gen))
            if ha_gen:
                batches.append(next(ha_gen))
            # hp_batch = next(hp_gen)
            # ha_batch = next(ha_gen)

            # Merge two batches
            merged_batch = {
                k: torch.cat([b[k] for b in batches if k in b], dim=0)
                for k in (batches[0].keys() if batches else [])
            }

            yield merged_batch

    @property
    def sample_count(self) -> int:
        """Returns the number of individual transitions stored in the storage."""
        return self.hp_transition_count + self.ha_transition_count
