import numpy as np
from typing import Dict, Tuple, Optional, Any, Union, List, TypeVar, cast, Literal, OrderedDict
import numpy as np
import torch
from termcolor import colored
from collections import OrderedDict as ordered_dict
from uuid import uuid4, UUID


def get_device() -> Union[Literal["cuda"], Literal["cpu"]]:
    DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
    print(f"Using {colored(DEVICE, 'yellow' if DEVICE == 'cpu' else 'green')} device")
    return DEVICE


DEVICE = get_device()


class LazyFrames:
    def __init__(self, frames: List[np.ndarray], stack_num: int):
        assert len(frames) == stack_num
        self.stack_num = stack_num
        self.frame_shape = frames[0].shape
        self.mem = [(np.byte_bounds(f), f) for f in frames]

        self.shape = (self.stack_num,) + self.frame_shape
        self.dtype = frames[0].dtype

    # def resolve(self):
    #     return self.mem

    def details(self):
        return self.mem

class LazyFramesGPU:

    def __init__(
        self, frames: List[Tuple[Tuple[int, int], np.ndarray]], buffer_dict: OrderedDict[Tuple[int, int], torch.Tensor], stack_num: int
    ):

        self.stack_num = stack_num
        self.frame_shape = frames[0][1].shape
        self.mem: List[torch.Tensor] = cast(
            List[torch.Tensor], [None for _ in range(self.stack_num)]
        )
        self.shape = (self.stack_num,) + self.frame_shape
        self.dtype = frames[0][1].dtype
        self.buffer_dict = buffer_dict

        assert len(self.buffer_dict) <= self.stack_num
        for i, (se, f) in enumerate(frames):

            if se in self.buffer_dict:
                self.mem[i] = self.buffer_dict[se]
            else:
                if len(self.buffer_dict) == self.stack_num:
                    self.buffer_dict.popitem(last=False)

                self.buffer_dict[se] = torch.tensor(
                    f, dtype=torch.float32, device=DEVICE
                )
                self.mem[i] = self.buffer_dict[se]


    def resolve(self):
        return self.mem
    




Action = torch.Tensor
Info = Dict[str, Any]
ActionInfo = Tuple[Action, List[Info]]

Reward = float


SARSAI = Tuple[
    torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[Info]
]

AllowedStates = Union[torch.Tensor, LazyFrames]
EmptyStates = Union[Optional[torch.Tensor], Optional[LazyFrames]]
AllAllowedStates = Union[AllowedStates, EmptyStates]

S = State = AllowedStates


def resolve_lazy_frames(lazy_frames: LazyFrames) -> torch.Tensor:

    rlt = torch.stack(lazy_frames.resolve())

    assert rlt.shape == lazy_frames.shape
    return rlt
