from data.language_environment import interact_environment
from utils.misc import convert_path
from typing import List, Optional, Dict, Any, Tuple
from data.rl_data import ConstantTokenReward, List_RL_Dataset, Iterable_RL_Dataset, DataPoint, TokenReward
from data.language_environment import Policy
from wordle.wordle_env import WordleEnvironment, WordleObservation
from wordle.wordle_game import Vocabulary, WordleGame, WordleState
from wordle.wordle_tokenizer import WordleTokenizer
import pickle as pkl
import json
from tqdm.auto import tqdm
import random
import numpy as np


class WordleListDataset(List_RL_Dataset):
    def __init__(
        self,
        items: List[Tuple[WordleObservation, Optional[Dict[str, Any]]]],
        max_len: Optional[int],
        token_reward: TokenReward,
    ) -> None:
        tokenizer = WordleTokenizer()
        super().__init__(tokenizer, token_reward, max_len)
        self.items = items

    def get_item(self, idx: int):
        return DataPoint.from_obs(self.items[idx][0], self.tokenizer, self.token_reward, self.items[idx][1])

    def size(self):
        return len(self.items)

    @classmethod
    def from_file(cls, file_path: str, max_len: Optional[int], vocab: Optional[Vocabulary], token_reward: TokenReward):
        with open(file_path, "rb") as f:
            d = pkl.load(f)
        if vocab is None:
            vocab = Vocabulary.from_file(convert_path(d["vocab_path"]))
            if d["vocab_cache_path"] is not None:
                vocab.cache.load(convert_path(d["vocab_cache_path"]))
        wordle_items = [
            WordleObservation(WordleGame(item["state"], vocab.update_vocab(item["state"]), item["actions"]))
            for item in tqdm(d["state_actions"])
        ]
        meta = [
            {**item["meta"], "self": wordle_items[i]} if "meta" in item else {"self": wordle_items[i]}
            for i, item in enumerate(d["state_actions"])
        ]
        return WordleListDataset(list(zip(wordle_items, meta)), max_len, token_reward)


class WordleIterableDataset(Iterable_RL_Dataset):
    def __init__(self, policy: Policy, vocab: Vocabulary, max_len: Optional[int], token_reward: TokenReward) -> None:
        tokenizer = WordleTokenizer()
        super().__init__(tokenizer, token_reward, max_len)
        self.policy = policy
        self.env = WordleEnvironment(vocab)

    def sample_item(self):
        return DataPoint.from_obs(
            interact_environment(self.env, self.policy, None)[0], self.tokenizer, self.token_reward, None
        )


class WordleHumanDataset(Iterable_RL_Dataset):
    def __init__(
        self,
        games: List[Tuple[str, List[str]]],
        transitions: Dict[str, Dict[str, List[str]]],
        use_true_word: bool,
        max_len: Optional[int],
        token_reward: TokenReward,
        game_indexes: Optional[List[int]],
        top_p: Optional[float],
    ) -> None:
        tokenizer = WordleTokenizer()
        super().__init__(tokenizer, token_reward, max_len)
        self.games = games
        if game_indexes is not None:
            self.games = [self.games[idx] for idx in game_indexes]
        if top_p is not None:
            lens = [len(game) for _, game in self.games]
            self.games = [self.games[idx] for idx in np.argsort(lens)[: int(len(lens) * top_p)]]
        self.transitions = transitions
        self.use_true_word = use_true_word

    def sample_item(self):
        true_word, game = random.choice(self.games)
        if self.use_true_word:
            while True:
                actions = []
                for transition in game:
                    if (
                        transition not in self.transitions[true_word]
                        or len(self.transitions[true_word][transition]) == 0
                    ):
                        break
                    actions.append(random.choice(self.transitions[true_word][transition]))
                if len(actions) == len(game):
                    break
                else:
                    true_word, game = random.choice(self.games)
        else:
            word_choices = list(self.transitions.keys())
            while True:
                true_word = random.choice(word_choices)
                actions = []
                for transition in game:
                    if (
                        transition not in self.transitions[true_word]
                        or len(self.transitions[true_word][transition]) == 0
                    ):
                        break
                    actions.append(random.choice(self.transitions[true_word][transition]))
                if len(actions) == len(game):
                    break
                else:
                    true_word, game = random.choice(self.games)
        state = WordleState.initial_state()
        for action in actions:
            state = state.transition_state(action, true_word)
        vocab = Vocabulary([true_word], state, cache=None, fill_cache=False)
        obs = WordleObservation(WordleGame(state, vocab, actions))
        return DataPoint.from_obs(obs, self.tokenizer, self.token_reward, {"obs": obs})

    @classmethod
    def from_file(
        cls,
        file_path: str,
        use_true_word: bool = False,
        max_len: Optional[int] = None,
        token_reward: Optional[TokenReward] = None,
        game_indexes: Optional[List[int]] = None,
        top_p: Optional[float] = None,
    ):
        if token_reward is None:
            token_reward = ConstantTokenReward(0.0)
        with open(file_path, "r") as f:
            d = json.load(f)
        return WordleHumanDataset(
            d["games"], d["transitions"], use_true_word, max_len, token_reward, game_indexes, top_p
        )
