import torch
import numpy as np
from pfrl.wrappers.atari_wrappers import LazyFrames


def batch_experiences(experiences, device, gamma):
    img_env = isinstance(experiences[0][0]["state"], LazyFrames)
    batch_exp = {
        "state": torch.as_tensor(
            np.asarray([np.asarray(elem[0]["state"]) for elem in experiences], dtype=np.float32), 
            device=device
        ) / 255 if img_env else torch.as_tensor(
            np.asarray([elem[0]["state"] for elem in experiences], dtype=np.float32), 
            device=device
        ),
        "action": torch.as_tensor(
            np.asarray([elem[0]["action"] for elem in experiences], dtype=np.int64), device=device
        ),
        "reward": torch.as_tensor(
            np.asarray([
                sum((gamma ** i) * exp[i]["reward"] for i in range(len(exp)))
                for exp in experiences
            ], dtype=np.float32),
            device=device,
        ),
        "next_state": torch.as_tensor(
            np.asarray([np.asarray(elem[-1]["next_state"]) for elem in experiences], dtype=np.float32), 
            device=device
        ) / 255 if img_env else torch.as_tensor(
            np.asarray([elem[0]["next_state"] for elem in experiences], dtype=np.float32), 
            device=device
        ),
        "is_state_terminal": torch.as_tensor(np.asarray(
            [
                any(transition["is_state_terminal"] for transition in exp)
                for exp in experiences
            ], dtype=np.float32),
            device=device,
        ),
        "discount": torch.as_tensor(np.asarray(
            [(gamma ** len(elem)) for elem in experiences], dtype=np.float32),
            device=device,
        ),
    }
    
    return batch_exp


def torch_log(x):
    return torch.log(torch.clamp(x, min=1e-10))
