from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import torch


@dataclass
class Episode:
    obs: torch.FloatTensor
    act: torch.LongTensor
    rew: torch.FloatTensor
    end: torch.ByteTensor
    trunc: torch.ByteTensor

    def __post_init__(self):
        assert len(self.obs) == len(self.act) == len(self.rew) == len(self.end) == len(self.trunc)
        if self.dead.sum() > 0:
            idx_end = torch.argmax(self.dead) + 1
            self.obs = self.obs[:idx_end]
            self.act = self.act[:idx_end]
            self.rew = self.rew[:idx_end]
            self.end = self.end[:idx_end]
            self.trunc = self.trunc[:idx_end]

    def __len__(self) -> int:
        return self.obs.size(0)

    def to(self, device):
        return Episode(**{k: v.to(device) for k, v in self.__dict__.items()})

    @property
    def dead(self) -> torch.ByteTensor:
        return (self.end + self.trunc).clip(max=1)

    @classmethod
    def load(cls, path: Path, map_location: Optional[torch.device] = None) -> Episode:
        return Episode(**{k: v.div(255).mul(2).sub(1) if k == 'obs' else v for k, v in torch.load(Path(path), map_location=map_location).items()})

    def save(self, path: Path) -> None:
        d = {k: v.add(1).div(2).mul(255).byte() if k == 'obs' else v for k, v in self.__dict__.items()}
        torch.save(d, Path(path))

    def merge(self, other: Episode) -> Episode:
        return Episode(
            torch.cat((self.obs, other.obs), dim=0),
            torch.cat((self.act, other.act), dim=0),
            torch.cat((self.rew, other.rew), dim=0),
            torch.cat((self.end, other.end), dim=0),
            torch.cat((self.trunc, other.trunc), dim=0),
        )

    def compute_metrics(self):
        return {'length': len(self), 'return': self.rew.sum()}

