from typing import Any, Dict, Optional
from data.rl_data import DataPoint
from models.base import Evaluator, InputType
from wordle.wordle_dataset import WordleListDataset
from wordle.wordle_env import WordleObservation
from wordle.wordle_game import N_CHARS, WordleGame
from collections import defaultdict
from models.iql_model import PerTokenIQL
import random


class Action_Ranking_Evaluator(Evaluator):
    def __init__(self, branching_data: WordleListDataset) -> None:
        self.branching_data = branching_data
        self.expert_actions = defaultdict(list)
        self.non_expert_actions = defaultdict(list)
        for i in range(self.branching_data.size()):
            item = self.branching_data.get_item(i)
            assert item.meta is not None and "kind" in item.meta
            if item.meta["kind"] == "expert":
                for prefix in item.meta["prefixes"]:
                    self.expert_actions[self.hashable_state(prefix)].append(item.meta["self_actions"][len(prefix[1])])
            elif item.meta["kind"] == "branch_suboptimal":
                start = item.meta["start"]
                self.non_expert_actions[self.hashable_state(start)].append(item.meta["self_actions"][len(start[1])])
            else:
                raise NotImplementedError
        self.states = list(set(self.expert_actions.keys()).intersection(set(self.non_expert_actions.keys())))
        self.vocab = self.branching_data.get_item(0).meta["self"].game.vocab

    def hashable_state(self, state):
        return (
            state[0],
            tuple(state[1]),
        )

    def evaluate(self, model: PerTokenIQL, items: InputType) -> Optional[Dict[str, Any]]:
        assert not model.double_q
        tokens = model.prepare_inputs(items)["tokens"]
        total_correct = [0 for _ in range(N_CHARS + 1)]
        total_correct_target = [0 for _ in range(N_CHARS + 1)]
        total = 0
        for _ in range(tokens.shape[0]):
            s, a = random.choice(self.states)
            obs = WordleObservation(WordleGame(s, self.vocab.update_vocab(s), list(a)))
            expert_state, _, _ = obs.game.next(
                random.choice(
                    self.expert_actions[
                        self.hashable_state(
                            (
                                s,
                                a,
                            )
                        )
                    ]
                )
            )
            non_expert_state, _, _ = obs.game.next(
                random.choice(
                    self.non_expert_actions[
                        self.hashable_state(
                            (
                                s,
                                a,
                            )
                        )
                    ]
                )
            )
            expert_datapoint = DataPoint.from_obs(
                WordleObservation(expert_state), self.branching_data.token_reward, self.branching_data.tokenizer
            )
            non_expert_datapoint = DataPoint.from_obs(
                WordleObservation(non_expert_state), self.branching_data.token_reward, self.branching_data.tokenizer
            )
            iql_outputs = model.get_qvs([expert_datapoint, non_expert_datapoint])
            qs, target_qs, terminals = iql_outputs["qs"], iql_outputs["target_qs"], iql_outputs["terminals"]
            for i in range(N_CHARS + 1):
                total_correct[i] += int(
                    qs[0, (1 - terminals[0, :-1]).sum() - 1 - i] > qs[1, (1 - terminals[1, :-1]).sum() - 1 - i]
                )
                total_correct_target[i] += int(
                    target_qs[0, (1 - terminals[0, :-1]).sum() - 1 - i]
                    > target_qs[1, (1 - terminals[1, :-1]).sum() - 1 - i]
                )
            total += 1
        return {
            **{("q_rank_acc_-%d" % (i + 1)): (total_correct[i] / total, total) for i in range(N_CHARS + 1)},
            **{
                ("q_target_rank_acc_-%d" % (i + 1)): (total_correct_target[i] / total, total)
                for i in range(N_CHARS + 1)
            },
        }


class Action_Ranking_Evaluator_Adversarial(Evaluator):
    def __init__(self, adversarial_data: WordleListDataset) -> None:
        self.adversarial_data = adversarial_data
        self.expert_actions = defaultdict(list)
        self.adversarial_actions = defaultdict(list)
        self.suboptimal_actions = defaultdict(list)
        for i in range(self.adversarial_data.size()):
            item = self.adversarial_data.get_item(i)
            assert item.meta is not None and "kind" in item.meta
            if item.meta["kind"] == "expert":
                self.expert_actions[self.hashable_state(item.meta["s_0"])].append(item.meta["a_0"])
                if "s_2" in item.meta:
                    self.expert_actions[self.hashable_state(item.meta["s_2"])].append(item.meta["a_2"])
            elif item.meta["kind"] == "adversarial":
                self.adversarial_actions[self.hashable_state(item.meta["s_2"])].append(item.meta["a_2"])
            elif item.meta["kind"] == "suboptimal":
                self.suboptimal_actions[self.hashable_state(item.meta["s_0"])].append(item.meta["a_0"])
            else:
                raise NotImplementedError
        self.initial_states = list(set(self.expert_actions.keys()).intersection(set(self.suboptimal_actions.keys())))
        self.branch_states = list(set(self.expert_actions.keys()).intersection(set(self.adversarial_actions.keys())))
        self.vocab = self.adversarial_data.get_item(0).meta["self"].game.vocab

    def hashable_state(self, state):
        return (
            state[0],
            tuple(state[1]),
        )

    def evaluate(self, model: PerTokenIQL, items: InputType) -> Optional[Dict[str, Any]]:
        # evaluate Q-values for suboptimal verses expert at the first action
        # evaluate Q-values for expert versus adversarial at the third action
        assert not model.double_q
        tokens = model.prepare_inputs(items)["tokens"]
        initial_total_correct = [0 for _ in range(N_CHARS + 1)]
        initial_total_correct_target = [0 for _ in range(N_CHARS + 1)]
        branch_total_correct = [0 for _ in range(N_CHARS + 1)]
        branch_total_correct_target = [0 for _ in range(N_CHARS + 1)]
        total = 0
        for _ in range(tokens.shape[0]):
            initial_s, initial_a = random.choice(self.initial_states)
            initial_obs = WordleObservation(WordleGame(initial_s, self.vocab.update_vocab(initial_s), list(initial_a)))
            initial_expert_a = random.choice(
                self.expert_actions[
                    self.hashable_state(
                        (
                            initial_s,
                            initial_a,
                        )
                    )
                ]
            )
            initial_suboptimal_a = random.choice(
                self.suboptimal_actions[
                    self.hashable_state(
                        (
                            initial_s,
                            initial_a,
                        )
                    )
                ]
            )
            initial_expert_state, _, _ = initial_obs.game.next(initial_expert_a)
            initial_suboptimal_state, _, _ = initial_obs.game.next(initial_suboptimal_a)
            initial_expert_datapoint = DataPoint.from_obs(
                WordleObservation(initial_expert_state),
                self.adversarial_data.tokenizer,
                self.adversarial_data.token_reward,
            )
            initial_suboptimal_datapoint = DataPoint.from_obs(
                WordleObservation(initial_suboptimal_state),
                self.adversarial_data.tokenizer,
                self.adversarial_data.token_reward,
            )
            iql_outputs = model.get_qvs([initial_expert_datapoint, initial_suboptimal_datapoint])
            qs, target_qs, terminals = iql_outputs["qs"], iql_outputs["target_qs"], iql_outputs["terminals"]
            for i in range(N_CHARS + 1):
                initial_total_correct[i] += int(
                    qs[0, (1 - terminals[0, :-1]).sum() - 1 - i] > qs[1, (1 - terminals[1, :-1]).sum() - 1 - i]
                )
                initial_total_correct_target[i] += int(
                    target_qs[0, (1 - terminals[0, :-1]).sum() - 1 - i]
                    > target_qs[1, (1 - terminals[1, :-1]).sum() - 1 - i]
                )

            branch_s, branch_a = random.choice(self.branch_states)
            branch_obs = WordleObservation(WordleGame(branch_s, self.vocab.update_vocab(branch_s), list(branch_a)))
            branch_expert_a = random.choice(
                self.expert_actions[
                    self.hashable_state(
                        (
                            branch_s,
                            branch_a,
                        )
                    )
                ]
            )
            branch_adversarial_a = random.choice(
                self.adversarial_actions[
                    self.hashable_state(
                        (
                            branch_s,
                            branch_a,
                        )
                    )
                ]
            )
            branch_expert_state, _, _ = branch_obs.game.next(branch_expert_a)
            branch_adversarial_state, _, _ = branch_obs.game.next(branch_adversarial_a)
            branch_expert_datapoint = DataPoint.from_obs(
                WordleObservation(branch_expert_state),
                self.adversarial_data.tokenizer,
                self.adversarial_data.token_reward,
            )
            branch_adversarial_datapoint = DataPoint.from_obs(
                WordleObservation(branch_adversarial_state),
                self.adversarial_data.tokenizer,
                self.adversarial_data.token_reward,
            )
            iql_outputs = model.get_qvs([branch_expert_datapoint, branch_adversarial_datapoint])
            qs, target_qs, terminals = iql_outputs["qs"], iql_outputs["target_qs"], iql_outputs["terminals"]
            for i in range(N_CHARS + 1):
                branch_total_correct[i] += int(
                    qs[0, (1 - terminals[0, :-1]).sum() - 1 - i] > qs[1, (1 - terminals[1, :-1]).sum() - 1 - i]
                )
                branch_total_correct_target[i] += int(
                    target_qs[0, (1 - terminals[0, :-1]).sum() - 1 - i]
                    > target_qs[1, (1 - terminals[1, :-1]).sum() - 1 - i]
                )
            total += 1
        return {
            **{
                ("initial_q_rank_acc_-%d" % (i + 1)): (initial_total_correct[i] / total, total)
                for i in range(N_CHARS + 1)
            },
            **{
                ("initial_q_target_rank_acc_-%d" % (i + 1)): (initial_total_correct_target[i] / total, total)
                for i in range(N_CHARS + 1)
            },
            **{
                ("branch_q_rank_acc_-%d" % (i + 1)): (branch_total_correct[i] / total, total)
                for i in range(N_CHARS + 1)
            },
            **{
                ("branch_q_target_rank_acc_-%d" % (i + 1)): (branch_total_correct_target[i] / total, total)
                for i in range(N_CHARS + 1)
            },
        }
