from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from data.language_environment import Language_Observation
from data.tokenizer import Tokenizer


class TokenReward(ABC):
    @abstractmethod
    def get_token_reward(self, tokens: List[int]) -> List[float]:
        pass


class ConstantTokenReward(TokenReward):
    def __init__(self, c: float = 0.0):
        self.c = c

    def get_token_reward(self, tokens: List[int]) -> List[float]:
        return [self.c] * (len(tokens) - 1)


class SepcifiedTokenReward(TokenReward):
    def __init__(self, token_data: Dict[int, float], scale: float = 1.0, shift: float = 0.0):
        self.token_data = token_data
        self.scale = scale
        self.shift = shift

    def get_token_reward(self, tokens: List[int]) -> List[float]:
        return [
            (
                (self.token_data[tok] * self.scale + self.shift)
                if tok in self.token_data
                else (0.0 * self.scale + self.shift)
            )
            for tok in tokens[1:]
        ]


@dataclass
class DataPoint:
    raw_str: str
    tokens: List[int]
    state_idxs: List[int]
    action_idxs: List[int]
    rewards: List[float]
    terminals: List[int]
    utterance_state_idxs: List[int]
    utterance_action_idxs: List[int]
    utterance_rewards: List[float]
    utterance_terminals: List[int]
    meta: Optional[Dict[str, Any]] = None

    def to_tensors(self, device, max_length: Optional[int]):
        tok = torch.tensor(self.tokens).to(device)
        s = torch.tensor(self.state_idxs).long().to(device)
        a = torch.tensor(self.action_idxs).long().to(device)
        r = torch.tensor(self.rewards).to(device)
        term = torch.tensor(self.terminals).to(device)
        u_s = torch.tensor(self.utterance_state_idxs).long().to(device)
        u_a = torch.tensor(self.utterance_action_idxs).long().to(device)
        u_r = torch.tensor(self.utterance_rewards).long().to(device)
        u_term = torch.tensor(self.utterance_terminals).long().to(device)
        if max_length is not None:
            tok = tok[:max_length]
            s = s[: (s < max_length).sum()]
            a = a[: max(min((a < (max_length - 1)).sum().item(), s.shape[0] - 1), 0)]
            r = r[: a.shape[0]]
            term = term[: s.shape[0]]
            u_s = u_s[: (s < max_length).sum()]
            u_a = u_a[: max(min((u_a < (max_length - 1)).sum().item(), u_s.shape[0] - 1), 0)]
            u_r = u_r[: u_a.shape[0]]
            u_term = u_term[: u_s.shape[0]]
        return tok, s, a, r, term, u_s, u_a, u_r, u_term

    @classmethod
    def from_obs(
        cls,
        obs: Language_Observation,
        tokenizer: Tokenizer,
        token_reward: TokenReward,
        meta: Optional[Dict[str, Any]] = None,
    ):
        sequence, terminal = obs.to_sequence()
        obs_meta = obs.metadata()
        if meta is not None and obs_meta is not None:
            meta = {**obs_meta, **meta}
        elif obs_meta is not None:
            meta = obs_meta
        if len(sequence) == 0 or sequence[0][1] is not None:
            raw_str = tokenizer.id_to_token(tokenizer.boa_token_id)
        else:
            raw_str = tokenizer.id_to_token(tokenizer.bos_token_id)
        action_rewards = []
        for s, r in sequence:
            raw_str += s
            if r is None:
                raw_str += tokenizer.id_to_token(tokenizer.eos_token_id)
            else:
                raw_str += tokenizer.id_to_token(tokenizer.eoa_token_id)
                action_rewards.append(r)
        if terminal:
            raw_str += tokenizer.id_to_token(tokenizer.eod_token_id)
        tokens = tokenizer.encode(raw_str)[0]
        token_rewards = token_reward.get_token_reward(tokens)
        state_idxs = []
        action_idxs = []
        reward = []
        utterance_state_idxs = []
        utterance_action_idxs = []
        utterance_rewards = []
        curr_idx = 0
        curr_action_idx = 0
        for i, t in enumerate(tokens):
            if t == tokenizer.eos_token_id:
                curr_idx = i
            elif t == tokenizer.eoa_token_id:
                action_idxs.extend(list(range(curr_idx, i)))
                state_idxs.extend(list(range(curr_idx, i)))
                reward.extend([token_rewards[x] for x in range(curr_idx, i)])
                reward[-1] += action_rewards[curr_action_idx]
                utterance_action_idxs.append(i)
                utterance_state_idxs.append(curr_idx)
                utterance_rewards.append(
                    action_rewards[curr_action_idx] + sum([token_rewards[x] for x in range(curr_idx, i)])
                )
                curr_idx = i
                curr_action_idx += 1
        state_idxs.append(len(tokens) - 1)
        utterance_state_idxs.append(len(tokens) - 1)
        terminals = ([0] * (len(state_idxs) - 1)) + [int(terminal)]
        utterance_terminals = ([0] * (len(utterance_state_idxs) - 1)) + [int(terminal)]
        return cls(
            raw_str,
            tokens,
            state_idxs,
            action_idxs,
            reward,
            terminals,
            utterance_state_idxs,
            utterance_action_idxs,
            utterance_rewards,
            utterance_terminals,
            meta=meta,
        )

    @staticmethod
    def get_token_reward(obs: Language_Observation, tokenizer: Tokenizer, token_reward: TokenReward):
        return DataPoint.from_obs(obs, tokenizer, token_reward).rewards


class RL_Dataset(ABC):
    def __init__(self, tokenizer: Tokenizer, token_reward: TokenReward, max_len: Optional[int]) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.token_reward = token_reward
        self.max_len = max_len

    def collate(self, items: List[DataPoint], device):
        tokens, state_idxs, action_idxs, rewards, terminals, u_state_idxs, u_action_idxs, u_rewards, u_terminals = zip(
            *map(lambda x: x.to_tensors(device, self.max_len), items)
        )
        tokens = torch.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        attn_mask = (tokens != self.tokenizer.pad_token_id).float()
        state_idxs = torch.nn.utils.rnn.pad_sequence(state_idxs, batch_first=True, padding_value=0)
        action_idxs = torch.nn.utils.rnn.pad_sequence(action_idxs, batch_first=True, padding_value=0)
        terminals = torch.nn.utils.rnn.pad_sequence(terminals, batch_first=True, padding_value=1)
        rewards = torch.nn.utils.rnn.pad_sequence(rewards, batch_first=True, padding_value=0.0)
        u_state_idxs = torch.nn.utils.rnn.pad_sequence(u_state_idxs, batch_first=True, padding_value=0)
        u_action_idxs = torch.nn.utils.rnn.pad_sequence(u_action_idxs, batch_first=True, padding_value=0)
        u_terminals = torch.nn.utils.rnn.pad_sequence(u_terminals, batch_first=True, padding_value=1)
        u_rewards = torch.nn.utils.rnn.pad_sequence(u_rewards, batch_first=True, padding_value=0.0)
        return {
            "tokens": tokens,
            "attn_mask": attn_mask,
            "state_idxs": state_idxs,
            "action_idxs": action_idxs,
            "rewards": rewards,
            "terminals": terminals,
            "u_state_idxs": u_state_idxs,
            "u_action_idxs": u_action_idxs,
            "u_rewards": u_rewards,
            "u_terminals": u_terminals,
        }


class List_RL_Dataset(RL_Dataset):
    @abstractmethod
    def get_item(self, idx: int) -> DataPoint:
        pass

    @abstractmethod
    def size(self) -> int:
        pass


class Iterable_RL_Dataset(RL_Dataset):
    @abstractmethod
    def sample_item(self) -> DataPoint:
        pass
