import copy
import random
from typing import (
    Callable,
    Iterator,
    Optional,
    Union,
    TypeVar,
    Literal,
    List,
    Generic,
    cast,
    Tuple,
    List,
    Dict,
    Any,
)
from utils.step import Step, NotNoneStep, AllStep
from utils.transition import Transition
import numpy as np
import torch
from utils.common import AllowedStates, Action, Reward, Info, State
import math
from typing_extensions import Self
import warnings


class SingleEpisode:

    def __init__(self):
        self._steps: List[Step] = []
        self.end = False
        self.info: Info = dict()

    @property
    def steps(self) -> List[NotNoneStep]:
        if len(self._steps) >= 2:
            (_, la, lr, _) = self._steps[-2].details
            assert la is not None
            assert lr is not None

        return [NotNoneStep.from_step(s) for s in self._steps[:-1]]

    @property
    def last_state(self) -> Optional[State]:
        return self._steps[-1].state if len(self._steps) >= 2 else None

    @property
    def len(self) -> int:
        return 0 if len(self._steps) == 0 else len(self._steps) - 1

    def __len__(self) -> int:
        return self.len

    def append_step(self, step: Step) -> Self:
        assert not self.end, "cannot append step into ended episode"
        (_, a, r, info) = step.details

        if len(self._steps) > 0:
            (_, _, _, li) = self._steps[-1].details
            assert "next" not in li
            li["next"] = step

        self._steps.append(step)
        if step.is_end():
            warnings.warn(
                "an episode ends with only one step, make sure this is expected!"
            )
            assert step.action is None
            assert step.reward is None
            self.end = True

        return self

    @classmethod
    def from_list(cls, sari: Tuple[List[State], List[Action], List[Reward],
                                   List[Info]]):
        (s, a, r, info) = sari
        assert info[-1]["end"]
        assert len(s) == len(a) + 1 == len(r) + 1 == len(info)

        inst = cls()

        for i in range(len(a)):

            assert 'next' in info[i]
            inst.append_step(Step(s[i], a[i], r[i], info[i]))

        inst.append_step(Step(s[-1], None, None, info[-1]))
        inst.end = True

        return inst

    def add_info(self, add_info: Callable[[List[NotNoneStep]], Info]) -> Info:
        to_add_info = add_info(self.steps)

        for k, v in to_add_info.items():
            assert k not in self.info or self.info[k] == v
            self.info[k] = v

        return to_add_info

    def append_transition(self, transition: Transition) -> Self:
        assert not self.end, "cannot append transition into a ended episode"
        (s1, s2) = transition.as_tuple()
        (_, a, r, i1) = s1.details

        if len(self._steps) == 0:
            _s = s1.to_step()
            assert "next" not in _s.info
            _s.info["next"] = s2
            self._steps.extend([_s, s2])
            if s2.is_end():
                self.end = True
            return self

        (ls, la, lr, li) = self._steps[-1].details

        assert la is None
        assert lr is None
        assert li.items() <= i1.items()

        assert "next" not in i1
        i1["next"] = s2
        self._steps[-1] = Step(ls, a, r, i1)
        self._steps.append(s2)
        if s2.is_end():
            self.end = True

        return self

    def sample_step(self, steps: int) -> List[Tuple[int, NotNoneStep]]:
        assert len(self) > 0
        ids = random.choices(range(len(self)), k=steps)
        return [(id, self.steps[id]) for id in ids]

    def sample_continuous(self,
                          length: int,
                          strict=True) -> Tuple[List[int], List[NotNoneStep]]:
        assert len(self) > 0
        l = len(self)
        if strict:
            assert l >= length

        start = random.choice(range(l if not strict else l - (length - 1)))

        ids = list(range(start, min(l, start + length)))
        if strict:
            assert len(ids) == length
        else:
            assert len(ids) <= length

        steps = [self.steps[id] for id in ids]
        return ids, steps

    def chunks(
        self,
        chunk_size: int,
        shuffle=True,
        last: Union[Literal["drop"], Literal["keep"],
                    Literal["merge"]] = "drop",
    ) -> Iterator[List[Tuple[NotNoneStep, int]]]:
        assert self.len > 0
        assert chunk_size >= 1
        idx = (np.random.permutation(len(self.steps))
               if shuffle else np.arange(len(self.steps)))

        last_i = (len(idx) // chunk_size) * chunk_size
        for i in range(0, len(idx), chunk_size):
            if last == "drop" and i == last_i:
                return

            if last == "keep" and i == last_i:
                yield [(self.steps[j], j) for j in idx[last_i:]]
                return

            if last == "merge" and i == (last_i - chunk_size):
                yield [(self.steps[j], j) for j in idx[i:]]
                return

            yield [(self.steps[j], j) for j in idx[i:i + chunk_size]]

        # if last == "keep":
        #     yield [(self.steps[j], j) for j in idx[last_i:]]

    def compute_returns(self, gamma: float = 0.99) -> List[float]:
        assert self.end, "cannot compute returns in non-end episode"
        steps = self.steps

        rwd = None
        returns = np.empty((len(self), ))
        for i, s in enumerate(reversed(steps)):
            rwd = gamma * (0 if rwd is None else rwd) + s.reward
            returns[-(i + 1)] = rwd

        return returns.tolist()

    def compute_gae_advanatges(
        self,
        rewards: List[float],
        # done: bool,
        values: List[float],
        gamme: float = 0.99,
        _lambda: float = 0.95,
    ) -> Tuple[List[float], List[float]]:
        l = len(self.steps)
        assert l == len(rewards) == len(values) - 1

        advantages = np.zeros(l + 1)

        for t in reversed(range(l)):
            td_err = rewards[t] + gamme * values[t + 1] - values[t]

            advantages[t] = td_err + (gamme * _lambda * advantages[t + 1])

        return (
            advantages[:-1].tolist(),
            (advantages[:-1] + np.array(values)[:-1]).tolist(),
        )
