from typing import List, Callable, Union, Dict, Any, TypeVar, Tuple, overload

import torch as th
import numpy as np


class Buffer():
    def __init__(self, trajectories: th.Tensor, device: th.device | str = 'cuda', batch_size: int = 256) -> None:
        self.device = th.device(device)
        self.batch_size = batch_size
        self.capacity = trajectories['obs'].shape[0]

        self.obs = trajectories['obs'].to(self.device).float()
        self.action = trajectories['action'].to(self.device).float()
        self.reward = trajectories['reward'].to(self.device).float()
        self.next_obs = trajectories['next_obs'].to(self.device).float()
        self.truncated = trajectories['truncated'].to(self.device).float()
        self.terminated = trajectories['terminated'].to(self.device).float()

    def sample(self, batch_size: int = -1) -> List[th.Tensor]:
        if batch_size == -1:
            batch_size = self.batch_size
        batch_inds = np.random.randint(0, self.capacity, size=batch_size)

        return self.obs[batch_inds], self.action[batch_inds], self.next_obs[batch_inds], self.reward[batch_inds], self.terminated[batch_inds], {}


class ExpertBuffer():
    def __init__(self, expert_files: Dict[str, str], device: th.device | str = 'cuda', batch_size: int = 256) -> None:
        self.expert_buffers = {}
        self.device = th.device(device)
        self.batch_size = batch_size

        for env_name, file in expert_files.items():
            trajectories = th.load(file)
            self.expert_buffers[env_name] = Buffer(trajectories, device=self.device, batch_size=batch_size)
        
    def sample(self, env_name: str, batch_size: int = -1) -> List[th.Tensor]:
        if env_name not in self.expert_buffers:
            raise ValueError(f"Environment {env_name} not found in expert buffers.")
        
        if batch_size == -1:
            batch_size = self.batch_size

        return self.expert_buffers[env_name].sample(batch_size)


class LearnerBuffer():
    def __init__(self, capacity: int, obs_dim: int, action_dim: int, batch_size: int = 256) -> None:
        self.capacity = capacity
        self.batch_size = batch_size
        self.obs = np.zeros((capacity, obs_dim), dtype=np.float32)
        self.action = np.zeros((capacity, action_dim), dtype=np.float32)
        self.reward = np.zeros((capacity, 1), dtype=np.float32) 
        self.next_obs = np.zeros((capacity, obs_dim), dtype=np.float32)
        self.truncated = np.zeros((capacity, 1), dtype=np.float32)
        self.terminated = np.zeros((capacity, 1), dtype=np.float32)
        self.ptr = 0
        self.size = 0

    def add(self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray, next_obs: np.ndarray, terminated: np.ndarray, truncated: np.ndarray) -> None:
        self.obs[self.ptr] = obs
        self.action[self.ptr] = action
        self.reward[self.ptr] = reward
        self.next_obs[self.ptr] = next_obs
        self.truncated[self.ptr] = truncated
        self.terminated[self.ptr] = terminated

        self.ptr += 1
        if self.size < self.capacity:
            self.size += 1
        if self.ptr == self.capacity:
            self.ptr = 0
    
    def sample(self, batch_size: int = -1, output_device: str = 'cuda') -> List[th.Tensor]:
        if batch_size == -1:
            batch_size = self.batch_size
        batch_inds = np.random.randint(0, self.size, size=batch_size)

        return (th.from_numpy(self.obs[batch_inds]).to(output_device).float(), 
                th.from_numpy(self.action[batch_inds]).to(output_device).float(), 
                th.from_numpy(self.next_obs[batch_inds]).to(output_device).float(), 
                th.from_numpy(self.reward[batch_inds]).to(output_device).float(), 
                th.from_numpy(self.terminated[batch_inds]).to(output_device).float(), 
                th.from_numpy(self.truncated[batch_inds]).to(output_device).float())
            