from collections import defaultdict
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, Tuple, Union, Optional
import wandb
from models.gpt2_optional_final_ln import GPT2LMHeadModel, GPT2Model
from data.rl_data import DataPoint, List_RL_Dataset, RL_Dataset
from utils.torch_utils import get_transformer_logs
import copy
from models.base import BaseTransformer, Evaluator, InputType
from transformers.modeling_utils import PreTrainedModel
from utils.sampling_utils import *
import numpy as np
import math
from data.language_environment import Language_Environment, Language_Observation, interact_environment, Policy
from tqdm.auto import tqdm

from models.s_iql_model import S_ILQL, TransformerMLP, IQL_Policy, IQL_Evaluator


class IS_ILQL(S_ILQL):
    def __init__(
        self,
        model: PreTrainedModel,
        dataset: RL_Dataset,
        device: Union[torch.device, str] = "cuda",
        alpha: float = 0.005,
        gamma=1.0,
        beta=1.0,
        transition_weight=0.0,
        clip_weight: Optional[float] = None,
        value_max: Optional[float] = None,
        value_min: Optional[float] = None,
        detach_v: bool = False,
        detach_pi: bool = False,
        detach_q: bool = False,
        double_q: bool = False,
        tau: float = 0.9,
        seperate_policy: bool = False,
        seperate_target: bool = False,
        exp_weights: bool = False,
        dm_margin: float = 0.0,
        advanced_mlp: bool = False,
        cql_temp: float = 1.0,
        K: int = 1,
    ):
        assert isinstance(model, GPT2Model) or isinstance(model, GPT2LMHeadModel)
        super().__init__(
            model=model,
            dataset=dataset,
            device=device,
            alpha=alpha,
            gamma=gamma,
            beta=beta,
            transition_weight=transition_weight,
            clip_weight=clip_weight,
            value_max=value_max,
            value_min=value_min,
            detach_v=detach_v,
            detach_pi=detach_pi,
            detach_q=detach_q,
            double_q=double_q,
            tau=tau,
            seperate_policy=seperate_policy,
            seperate_target=seperate_target,  # should be False
            exp_weights=exp_weights,
            dm_margin=dm_margin,
            advanced_mlp=advanced_mlp,
            cql_temp=cql_temp,
            K=K,
        )

        del self.v
        if not self.advanced_mlp:
            self.v = nn.Sequential(
                nn.Linear(self.h_dim, self.h_dim * 2),
                nn.ReLU(),
                nn.Linear(self.h_dim * 2, 1 * self.K),
            )
        else:
            self.v = TransformerMLP(
                self.h_dim,
                4 * self.h_dim if self.model.config.n_inner is None else self.model.config.n_inner,
                1 * self.K,
                self.model.config.resid_pdrop,
            )

    def forward(
        self,
        tokens: torch.Tensor,
        attn_mask: Optional[torch.Tensor],
        state_idxs: torch.Tensor,
        action_idxs: torch.Tensor,
        prefix_embs: Optional[torch.Tensor] = None,
        prefix_attn_mask: Optional[torch.Tensor] = None,
        remove_prefix_position_embs: bool = False,
        qv_kwargs=None,
        policy_kwargs=None,
        target_kwargs=None,
        skip_policy_on_train=False,
        detach_full_policy=False,
    ):
        if qv_kwargs is None:
            qv_kwargs = {}
        if target_kwargs is None:
            target_kwargs = {}
        if policy_kwargs is None:
            policy_kwargs = {}
        if self.lm_target is None:
            qv_kwargs.update(target_kwargs)
        if self.lm_policy is None:
            qv_kwargs.update(policy_kwargs)
        if attn_mask is None:
            attn_mask = torch.ones(tokens.shape, dtype=torch.long).to(self.device)
        if prefix_embs is None:
            prefix_embs = torch.empty((tokens.shape[0], 0, self.h_dim)).to(self.device)
        prefix_t = prefix_embs.shape[1]
        set_pos_ids = prefix_attn_mask is not None
        if prefix_attn_mask is None:
            prefix_attn_mask = torch.ones(prefix_embs.shape[:2]).to(self.device)
        input_attn_mask = torch.cat((prefix_attn_mask, attn_mask), dim=1)
        position_ids = torch.cumsum(input_attn_mask, dim=1) - 1 if set_pos_ids else None
        if isinstance(self.model, GPT2Model):
            transformer = self.model
            if self.lm_target is not None:
                target_transformer = self.lm_target
            if self.lm_policy is not None:
                policy_transformer = self.lm_policy
        elif isinstance(self.model, GPT2LMHeadModel):
            transformer = self.model.transformer
            if self.lm_target is not None:
                target_transformer = self.lm_target.transformer
            if self.lm_policy is not None:
                policy_transformer = self.lm_policy.transformer
        else:
            raise NotImplementedError
        if self.lm_target is not None:
            target_prefix_embs = prefix_embs.clone()
        if self.lm_policy is not None:
            policy_prefix_embs = prefix_embs.clone()
        if remove_prefix_position_embs:
            prefix_embs -= transformer.wpe(position_ids[:, : prefix_embs.shape[1]])
        input_embeddings = torch.cat((prefix_embs, transformer.wte(tokens)), dim=1)
        model_outputs = self.model(
            inputs_embeds=input_embeddings,
            attention_mask=input_attn_mask,
            position_ids=position_ids,
            output_hidden_states=True,
            **qv_kwargs
        )
        all_model_outputs = {
            "qv_model_outputs": model_outputs,
            "policy_model_outputs": model_outputs,
            "target_model_outputs": model_outputs,
        }
        if self.advanced_mlp:
            hidden_states = model_outputs.hidden_states[-2][:, prefix_t:, :]
        else:
            hidden_states = model_outputs.hidden_states[-1][:, prefix_t:, :]
        if self.lm_target is None:
            target_hidden_states = hidden_states
        else:
            if remove_prefix_position_embs:
                target_prefix_embs -= target_transformer.wpe(position_ids[:, : prefix_embs.shape[1]])
            target_input_embeddings = torch.cat((target_prefix_embs, target_transformer.wte(tokens)), dim=1)
            with torch.no_grad():
                target_outputs = self.lm_target(
                    inputs_embeds=target_input_embeddings,
                    attention_mask=input_attn_mask,
                    position_ids=position_ids,
                    output_hidden_states=True,
                    **target_kwargs
                )
            all_model_outputs["target_model_outputs"] = target_outputs
            if self.advanced_mlp:
                target_hidden_states = target_outputs.hidden_states[-2][:, prefix_t:, :]
            else:
                target_hidden_states = target_outputs.hidden_states[-1][:, prefix_t:, :]
        if self.lm_policy is None:
            if isinstance(self.model, GPT2Model):
                policy_hidden_states = hidden_states
            else:
                policy_hidden_states = model_outputs.hidden_states[-1][:, prefix_t:, :]
        else:
            if skip_policy_on_train and self.training:
                policy_hidden_states = hidden_states
            else:
                if remove_prefix_position_embs:
                    policy_prefix_embs -= policy_transformer.wpe(position_ids[:, : prefix_embs.shape[1]])
                policy_input_embeddings = torch.cat((policy_prefix_embs, policy_transformer.wte(tokens)), dim=1)
                if detach_full_policy:
                    with torch.no_grad():
                        policy_outputs = self.lm_policy(
                            inputs_embeds=policy_input_embeddings,
                            attention_mask=input_attn_mask,
                            position_ids=position_ids,
                            output_hidden_states=True,
                            **policy_kwargs
                        )
                else:
                    policy_outputs = self.lm_policy(
                        inputs_embeds=policy_input_embeddings,
                        attention_mask=input_attn_mask,
                        position_ids=position_ids,
                        output_hidden_states=True,
                        **policy_kwargs
                    )
                all_model_outputs["policy_model_outputs"] = policy_outputs
                if isinstance(self.model, GPT2Model):
                    if self.advanced_mlp:
                        policy_hidden_states = policy_outputs.hidden_states[-2][:, prefix_t:, :]
                    else:
                        policy_hidden_states = policy_outputs.hidden_states[-1][:, prefix_t:, :]
                else:
                    policy_hidden_states = policy_outputs.hidden_states[-1][:, prefix_t:, :]
        state_hidden_states = torch.gather(
            input=hidden_states, dim=1, index=state_idxs.unsqueeze(2).repeat(1, 1, self.h_dim)
        )
        action_hidden_states = torch.gather(
            input=hidden_states, dim=1, index=action_idxs.unsqueeze(2).repeat(1, 1, self.h_dim)
        )
        action_target_hidden_states = torch.gather(
            input=target_hidden_states, dim=1, index=action_idxs.unsqueeze(2).repeat(1, 1, self.h_dim)
        )
        vss = self.v(state_hidden_states.detach() if self.detach_v else state_hidden_states)
        b, t, _ = vss.shape
        vss = vss.reshape(b, t, self.K, 1)
        vs = vss
        target_vs = vss.detach().clone()

        qss = self.q(action_hidden_states.detach() if self.detach_q else action_hidden_states)
        b, t, _ = qss.shape
        qss = qss.reshape(b, t, self.K + 1, self.n_tokens)
        qs = qss[:, :, 1:]
        target_qs = qss[:, :, :-1].detach().clone()
        if self.double_q:
            qs2s = self.q2(action_hidden_states.detach() if self.detach_q else action_hidden_states)
            qs2s = qs2s.reshape(b, t, self.K + 1, self.n_tokens)
            qs2 = qs2s[:, :, 1:]
            target_qs2 = qs2s[:, :, :-1].detach().clone()

        if skip_policy_on_train and self.training and self.lm_policy is not None:
            logits = torch.zeros(
                (
                    policy_hidden_states.shape[0],
                    policy_hidden_states.shape[1],
                    self.dataset.tokenizer.num_tokens(),
                )
            ).to(self.device)
        else:
            if detach_full_policy:
                with torch.no_grad():
                    logits = self.pi(policy_hidden_states.detach() if self.detach_pi else policy_hidden_states)
            else:
                logits = self.pi(policy_hidden_states.detach() if self.detach_pi else policy_hidden_states)
        return {
            "model_outputs": all_model_outputs,
            "vs": vs,
            "target_vs": target_vs,
            "qs": (
                (
                    qs,
                    qs2,
                )
                if self.double_q
                else qs
            ),
            "target_qs": self.clip_values(torch.minimum(target_qs, target_qs2) if self.double_q else target_qs),
            "logits": logits,
        }

    def get_qvs(
        self,
        items: InputType,
        prefix_embs: Optional[torch.Tensor] = None,
        prefix_attn_mask: Optional[torch.Tensor] = None,
        remove_prefix_position_embs: bool = False,
        qv_kwargs=None,
        policy_kwargs=None,
        target_kwargs=None,
        **kwargs
    ):
        prepared_inputs = self.prepare_inputs(items)
        tokens, attn_mask = prepared_inputs["tokens"], prepared_inputs["attn_mask"]
        s_idx, a_idx = prepared_inputs["state_idxs"], prepared_inputs["action_idxs"]
        rs, terminals = prepared_inputs["rewards"], prepared_inputs["terminals"]
        self_outputs = self(
            tokens,
            attn_mask,
            s_idx,
            a_idx,
            prefix_embs,
            prefix_attn_mask,
            remove_prefix_position_embs,
            qv_kwargs,
            policy_kwargs,
            target_kwargs,
            **kwargs
        )
        model_outputs, vs, qs = self_outputs["model_outputs"], self_outputs["vs"], self_outputs["qs"]
        target_qs, logits = self_outputs["target_qs"], self_outputs["logits"]
        vt = vs[:, :-1]
        vtp1 = vs[:, 1:]
        select_tokens = torch.gather(tokens[:, 1:], dim=1, index=a_idx)
        cql_term = self.get_cql_loss(qs, select_tokens, terminals)
        full_qs = qs
        if self.double_q:
            q1, q2 = qs
            q1 = torch.gather(q1, dim=-1, index=select_tokens[:, :, None, None].repeat(1, 1, self.K, 1))
            q2 = torch.gather(q2, dim=-1, index=select_tokens[:, :, None, None].repeat(1, 1, self.K, 1))
            # tok_seq = [self.dataset.tokenizer.id_to_token(token) for token in select_tokens[0].detach().cpu().tolist()][:(1-terminals[0, :-1]).sum()]
            # max_q_seq = torch.max(q1, q2)[0, :(1-terminals[0, :-1]).sum()].detach().cpu().tolist()
            # print(self.dataset.tokenizer.decode(tokens[0, :][:attn_mask[0, :].sum().long()].tolist(), clean_up_tokenization_spaces=False))
            # print(list(zip(tok_seq, max_q_seq)))
            # print(rs)
            qs = (
                q1,
                q2,
            )
        else:
            qs = torch.gather(qs, dim=2, index=select_tokens.unsqueeze(2)).squeeze(2)
        dm_term = self.get_dm_loss(full_qs, qs, terminals, self.dm_margin)
        target_qs = torch.gather(target_qs, dim=-1, index=select_tokens[:, :, None, None].repeat(1, 1, self.K, 1))
        with torch.no_grad():
            weights = self.get_weights(tokens, vt, target_qs, s_idx, a_idx, terminals)
        return {
            "tokens": tokens,
            "attn_mask": attn_mask,
            "model_outputs": model_outputs,
            "vs": vt,
            "qs": qs,
            "vns": vtp1,
            "target_vs": vt,
            "target_qs": target_qs,
            "target_vns": vtp1,
            "rs": rs,
            "terminals": terminals,
            "logits": logits,
            "weights": weights,
            "cql_term": cql_term,
            "dm_term": dm_term,
        }

    def get_weights(
        self,
        tokens: torch.Tensor,
        vs: torch.Tensor,
        qs: Optional[torch.Tensor],
        state_idxs: torch.Tensor,
        action_idxs: torch.Tensor,
        terminals: torch.Tensor,
    ):
        weights = torch.full([self.K, *tokens.shape], self.transition_weight).to(self.device)
        if self.exp_weights:
            w_values = torch.exp(self.beta * (qs - vs))
        else:
            # w_values = ((qs - vs) > 0.0).float()
            adv_sign = ((qs - vs) > 0.0).float()
            w_values = self.beta * adv_sign + (1 - self.beta) * (1 - adv_sign)
        if action_idxs.shape[1] == 0:
            n = torch.zeros((tokens.shape[0],)).long().to(self.device)
        else:
            n = torch.argmax(action_idxs, dim=1) + 1
        for i in range(tokens.shape[0]):
            weights[:, i, :].scatter_(
                1, action_idxs[i, : n[i]].unsqueeze(0).expand(self.K, -1), w_values[i, : n[i], :, 0].permute(1, 0)
            )
        if self.clip_weight is not None:
            weights = torch.clip(weights, max=self.clip_weight)
        # print(list(map(lambda x: list(map(lambda y: (y[0], self.dataset.tokenizer.id_to_token(y[1].item()),), zip(*x))), zip(weights.detach().cpu().tolist(), tokens))))
        return weights

    def awac_loss(self, tokens, attn_mask, logits, w):
        w = w.detach()
        losses = F.cross_entropy(
            logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1), reduction="none"
        )
        losses = losses.reshape(tokens.shape[0], tokens.shape[1] - 1)
        return (losses * w[:, :, :-1] * attn_mask[:, 1:]).sum() / attn_mask[:, 1:].sum()

    def get_v_loss(self, vs, target_qs, terminals):
        target_qs = target_qs.detach()
        return (
            (
                (target_qs >= vs).int() * self.tau * (target_qs - vs) ** 2
                + (target_qs < vs).int() * (1 - self.tau) * (target_qs - vs) ** 2
            )
            * (1 - terminals[:, :-1, None, None].repeat(1, 1, self.K, 1))
        ).sum() / max((1 - terminals[:, :-1]).sum().item(), 1.0)

    def get_q_loss(self, vns, qs, rs, gamma, terminals):
        vns = vns.detach()
        if self.double_q:
            q1, q2 = qs
            repeated_terminals = terminals[:, :, None, None].repeat(1, 1, self.K, 1)
            repeated_rs = rs[:, :, None, None].repeat(1, 1, self.K, 1)
            l1 = (
                (((1 - repeated_terminals[:, 1:]) * vns * gamma + repeated_rs - q1) ** 2)
                * (1 - repeated_terminals[:, :-1])
            ).sum() / max((1 - terminals[:, :-1]).sum().item(), 1.0)
            l2 = (
                (((1 - repeated_terminals[:, 1:]) * vns * gamma + repeated_rs - q2) ** 2)
                * (1 - repeated_terminals[:, :-1])
            ).sum() / max((1 - terminals[:, :-1]).sum().item(), 1.0)
            return l1 + l2
        return ((((1 - terminals[:, 1:]) * vns * gamma + rs - qs) ** 2) * (1 - terminals[:, :-1])).sum() / max(
            (1 - terminals[:, :-1]).sum().item(), 1.0
        )

    def get_cql_loss(self, qs, action_tokens, terminals):
        repeated_terminals = terminals.unsqueeze(-1).repeat(1, 1, self.K)
        n = (1 - terminals[:, :-1]).sum()
        if self.double_q:
            q1, q2 = qs
            b, t, k, d = q1.shape
            return (
                (
                    F.cross_entropy(
                        q1.reshape(-1, d) / self.cql_temp,
                        action_tokens.unsqueeze(-1).repeat(1, 1, k).reshape(-1),
                        reduction="none",
                    )
                ).reshape(b, t, k)
                * (1 - repeated_terminals[:, :-1])
                + (
                    F.cross_entropy(
                        q2.reshape(-1, d) / self.cql_temp,
                        action_tokens.unsqueeze(-1).repeat(1, 1, k).reshape(-1),
                        reduction="none",
                    )
                ).reshape(b, t, k)
                * (1 - repeated_terminals[:, :-1])
            ).sum() / max(n.item(), 1.0)
        b, t, d = qs.shape
        return (
            F.cross_entropy(qs.reshape(-1, d) / self.cql_temp, action_tokens.reshape(-1), reduction="none").reshape(
                b, t
            )
            * (1 - terminals[:, :-1])
        ).sum() / max(n.item(), 1.0)

    def get_dm_loss(self, qs, data_qs, terminals, margin):
        repeated_terminals = terminals.unsqueeze(-1).repeat(1, 1, self.K)
        n = (1 - terminals[:, :-1]).sum()
        if self.double_q:
            q1, q2 = qs
            data_q1, data_q2 = data_qs
            return (
                (
                    (torch.max(q1 - data_q1 + margin, torch.tensor(0.0).to(self.device)) ** 2).sum(dim=-1)
                    * (1 - repeated_terminals[:, :-1])
                )
                + (
                    (torch.max(q2 - data_q2 + margin, torch.tensor(0.0).to(self.device)) ** 2).sum(dim=-1)
                    * (1 - repeated_terminals[:, :-1])
                )
            ).sum() / max(n.item(), 1.0)
        return (
            (torch.max(qs - data_qs.unsqueeze(-1) + margin, torch.tensor(0.0).to(self.device)) ** 2).sum(dim=-1)
            * (1 - terminals[:, :-1])
        ).sum() / max(n.item(), 1.0)

    def soft_update(self):
        self.q[-1].weight.data[: self.n_tokens, :].copy_(
            self.alpha * self.q[-1].weight.data[self.n_tokens : 2 * self.n_tokens, :]
            + (1 - self.alpha) * self.q[-1].weight.data[: self.n_tokens, :]
        )
        self.q[-1].bias.data[: self.n_tokens].copy_(
            self.alpha * self.q[-1].bias.data[self.n_tokens : 2 * self.n_tokens]
            + (1 - self.alpha) * self.q[-1].bias.data[: self.n_tokens]
        )
        if self.double_q:
            self.q2[-1].weight.data[: self.n_tokens, :].copy_(
                self.alpha * self.q2[-1].weight.data[self.n_tokens : 2 * self.n_tokens, :]
                + (1 - self.alpha) * self.q2[-1].weight.data[: self.n_tokens, :]
            )
            self.q2[-1].bias.data[: self.n_tokens].copy_(
                self.alpha * self.q2[-1].bias.data[self.n_tokens : 2 * self.n_tokens]
                + (1 - self.alpha) * self.q2[-1].bias.data[: self.n_tokens]
            )

        if self.K != 1:
            self.v[-1].weight.data[0, :].copy_(
                self.alpha * self.v[-1].weight.data[1, :] + (1 - self.alpha) * self.v[-1].weight.data[0, :]
            )
            self.v[-1].bias.data[0].copy_(
                self.alpha * self.v[-1].bias.data[1] + (1 - self.alpha) * self.v[-1].bias.data[0]
            )

    def hard_update(self):
        pass


class IS_IQL_Policy(IQL_Policy):
    def __init__(self, iql_model: IS_ILQL, kind: str, **generation_kwargs) -> None:
        super().__init__(iql_model, kind, **generation_kwargs)

    def beam_raw(
        self,
        tokens: torch.Tensor,
        attn_mask: torch.Tensor,
        state_idxs: torch.Tensor,
        action_idxs: torch.Tensor,
        termination_condition: Callable[[np.ndarray], bool],
        max_generation_len: Optional[int] = None,
        beam_width=1,
        temp=1.0,
        top_k=None,
        top_p=None,
        exp_adv=False,
        adv_weight=0.0,
        adv_clip=None,
        include_logits=True,
        include_adv=True,
        prefix_embs: Optional[torch.Tensor] = None,
        prefix_attn_mask: Optional[torch.Tensor] = None,
        remove_prefix_position_embs: bool = False,
    ):
        # swap out models so that only the relevent model is executed for speed purposes.
        # temp_target = self.iql_model.lm_target
        # temp_policy = self.iql_model.lm_policy
        # temp_model = self.iql_model.model

        # self.iql_model.lm_target = temp_target
        # self.iql_model.lm_policy = None
        # self.iql_model.model = temp_policy

        tokenizer = self.iql_model.dataset.tokenizer
        max_length = self.iql_model.dataset.max_len
        if max_length is None:
            max_length = self.iql_model.model.config.n_positions
        max_length = min(max_length, self.iql_model.model.config.n_positions)
        device = self.iql_model.device
        bsize, vocab_size = tokens.shape[0], tokenizer.num_tokens()
        n = bsize * beam_width
        if max_generation_len is None:
            max_generation_len = max_length + 1
        input_strs = [
            tokenizer.decode(tokens[i, :][: attn_mask[i, :].sum().long()].tolist(), clean_up_tokenization_spaces=False)
            for i in range(len(tokens))
        ]
        prefix_t = 0 if prefix_embs is None else prefix_embs.shape[1]
        model_outputs = self.iql_model(
            tokens,
            attn_mask,
            state_idxs,
            action_idxs,
            prefix_embs=prefix_embs,
            prefix_attn_mask=prefix_attn_mask,
            remove_prefix_position_embs=remove_prefix_position_embs,
            qv_kwargs={"use_cache": True},
            policy_kwargs={"use_cache": True},
            target_kwargs={"use_cache": True},
        )["model_outputs"]
        kvs = {"qv": model_outputs["qv_model_outputs"].past_key_values}
        if self.iql_model.lm_target is not None:
            kvs["target"] = model_outputs["target_model_outputs"].past_key_values
        if self.iql_model.lm_policy is not None:
            kvs["policy"] = model_outputs["policy_model_outputs"].past_key_values
        original_dialogue_lens = attn_mask.sum(dim=1)
        batch_indicator = torch.stack(beam_width * [torch.arange(0, bsize).to(device)], dim=1)

        tokens = pad_sequence(
            torch.repeat_interleave(tokens, beam_width, dim=0), max_length, tokenizer.pad_token_id, device, 1
        )
        dialogue_lens = torch.repeat_interleave(original_dialogue_lens, beam_width, dim=0)
        kvs["qv"] = map_all_kvs(
            lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2), kvs["qv"]
        )
        if "target" in kvs:
            kvs["target"] = map_all_kvs(
                lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2),
                kvs["target"],
            )
        if "policy" in kvs:
            kvs["policy"] = map_all_kvs(
                lambda x: pad_sequence(torch.repeat_interleave(x, beam_width, dim=0), max_length, 0.0, device, 2),
                kvs["policy"],
            )
        curr_scores = torch.zeros(bsize, beam_width).to(device)  # (batch, k)
        logit_scores = torch.zeros(bsize, beam_width).to(device)  # (batch, k)
        termination_mask = torch.full((n,), 1).to(device)
        state_idxs_temp, action_idxs_temp = torch.zeros(
            (
                dialogue_lens.shape[0],
                1,
            )
        ).long().to(device), torch.zeros(
            (
                dialogue_lens.shape[0],
                1,
            )
        ).long().to(device)
        t = torch.min(dialogue_lens).int()
        base_logits = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
        while termination_mask.sum() > 0 and (t + prefix_t) < max_length:
            curr_token = tokens[:, t - 1].unsqueeze(1)
            curr_kvs = map_all_kvs(lambda x: x[:, :, : (t + prefix_t) - 1, :], kvs["qv"])
            curr_target_kvs, curr_policy_kvs = curr_kvs, curr_kvs
            if "target" in kvs:
                curr_target_kvs = map_all_kvs(lambda x: x[:, :, : (t + prefix_t) - 1, :], kvs["target"])
            if "policy" in kvs:
                curr_policy_kvs = map_all_kvs(lambda x: x[:, :, : (t + prefix_t) - 1, :], kvs["policy"])
            iql_outputs = self.iql_model(
                curr_token,
                None,
                state_idxs_temp,
                action_idxs_temp,
                qv_kwargs={"use_cache": True, "past_key_values": curr_kvs},
                policy_kwargs={"use_cache": True, "past_key_values": curr_policy_kvs},
                target_kwargs={"use_cache": True, "past_key_values": curr_target_kvs},
            )
            model_outputs, logits = iql_outputs["model_outputs"], iql_outputs["logits"]

            logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float("-inf"), 1e7)
            logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = logits[
                torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]
            ].masked_fill_(t < dialogue_lens, 1e7)
            edited_logits = process_logits(logits.clone(), temp=temp, top_k=top_k, top_p=top_p)

            vs, qs = iql_outputs["target_vs"], iql_outputs["target_qs"]
            if exp_adv:
                adv_logits = adv_weight * (qs - vs).mean(2)
            else:
                adv_sign = ((qs - vs.unsqueeze(2)) > 0.0).float()
                adv_logits = adv_weight * adv_sign + (1 - adv_weight) * (1 - adv_sign)
                adv_logits = torch.log(adv_logits)
            if adv_clip is not None:
                adv_logits = torch.clip(adv_logits, max=adv_clip)
            adv_logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float("-inf"), 1e7)
            adv_logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = adv_logits[
                torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]
            ].masked_fill_(t < dialogue_lens, 1e7)

            full_logits = (
                (edited_logits if include_logits else 0.0)
                + (adv_logits if include_adv else 0.0)
                + base_logits.unsqueeze(1).unsqueeze(2)
            )

            scores = (
                (
                    torch.log(F.softmax(full_logits, dim=-1)).reshape(1, bsize, beam_width, -1).permute(3, 0, 1, 2)
                    + curr_scores
                )
                .permute(1, 2, 3, 0)
                .reshape(1, bsize, -1)
            )  # (time, batch, k*vocab)
            scores[0, :, vocab_size:] = scores[0, :, vocab_size:].masked_fill_(
                (t == original_dialogue_lens).unsqueeze(1).repeat(1, scores.shape[2] - vocab_size), float("-inf")
            )
            curr_scores, top_k_ = torch.topk(scores[0, :, :], k=beam_width, dim=1)  # (batch, k), (batch, k)
            tokens = tokens[(batch_indicator * beam_width + (top_k_ // vocab_size)).reshape(-1), :]
            logits = logits[(batch_indicator * beam_width + (top_k_ // vocab_size)).reshape(-1), :, :]
            logit_scores += (
                torch.gather(
                    torch.log(F.softmax(logits, dim=-1)).squeeze(1),
                    dim=1,
                    index=(top_k_.reshape(-1) % vocab_size).unsqueeze(1),
                )
                .squeeze(1)
                .reshape(-1, beam_width)
            )
            tokens[:, t] = top_k_.reshape(-1) % vocab_size  # (batch*k,)
            fixed_kvs = map_all_kvs(
                lambda x: x[
                    (batch_indicator * beam_width + torch.div(top_k_, vocab_size, rounding_mode="trunc")).reshape(-1),
                    :,
                    :,
                    :,
                ],
                model_outputs["qv_model_outputs"].past_key_values,
            )
            kvs["qv"] = map_all_kvs(
                lambda x: x[
                    (batch_indicator * beam_width + torch.div(top_k_, vocab_size, rounding_mode="trunc")).reshape(-1),
                    :,
                    :,
                    :,
                ],
                kvs["qv"],
            )
            kvs["qv"] = update_kvs(kvs["qv"], fixed_kvs, torch.arange(0, n).to(device), (t + prefix_t) - 1)
            if "target" in kvs:
                fixed_target_kvs = map_all_kvs(
                    lambda x: x[
                        (batch_indicator * beam_width + torch.div(top_k_, vocab_size, rounding_mode="trunc")).reshape(
                            -1
                        ),
                        :,
                        :,
                        :,
                    ],
                    model_outputs["target_model_outputs"].past_key_values,
                )
                kvs["target"] = map_all_kvs(
                    lambda x: x[
                        (batch_indicator * beam_width + torch.div(top_k_, vocab_size, rounding_mode="trunc")).reshape(
                            -1
                        ),
                        :,
                        :,
                        :,
                    ],
                    kvs["target"],
                )
                kvs["target"] = update_kvs(
                    kvs["target"], fixed_target_kvs, torch.arange(0, n).to(device), (t + prefix_t) - 1
                )
            if "policy" in kvs:
                fixed_policy_kvs = map_all_kvs(
                    lambda x: x[
                        (batch_indicator * beam_width + torch.div(top_k_, vocab_size, rounding_mode="trunc")).reshape(
                            -1
                        ),
                        :,
                        :,
                        :,
                    ],
                    model_outputs["policy_model_outputs"].past_key_values,
                )
                kvs["policy"] = map_all_kvs(
                    lambda x: x[
                        (batch_indicator * beam_width + torch.div(top_k_, vocab_size, rounding_mode="trunc")).reshape(
                            -1
                        ),
                        :,
                        :,
                        :,
                    ],
                    kvs["policy"],
                )
                kvs["policy"] = update_kvs(
                    kvs["policy"], fixed_policy_kvs, torch.arange(0, n).to(device), (t + prefix_t) - 1
                )
            termination_mask = termination_mask[(batch_indicator * beam_width + (top_k_ // vocab_size)).reshape(-1)]
            for idx in range(n):
                if tokens[idx, t] == tokenizer.eoa_token_id and t >= dialogue_lens[idx]:
                    termination_mask[idx] *= 1 - int(
                        termination_condition(
                            tokenizer.decode(tokens[idx, :].tolist(), clean_up_tokenization_spaces=False)
                        )
                    )
            t += 1
            termination_mask *= ((t - dialogue_lens) < max_generation_len).int()

        # self.iql_model.lm_target = temp_target
        # self.iql_model.lm_policy = temp_policy
        # self.iql_model.model = temp_model

        output_strs = [tokenizer.decode(tokens[i, :].tolist(), clean_up_tokenization_spaces=False) for i in range(n)]
        processed_outputs = []
        for i in range(len(input_strs)):
            temp_outputs = []
            for x in range(beam_width):
                processed_str = output_strs[i * beam_width + x][len(input_strs[i]) :].strip()
                if tokenizer.id_to_token(tokenizer.pad_token_id) in processed_str:
                    processed_str = processed_str[
                        : processed_str.find(tokenizer.id_to_token(tokenizer.pad_token_id))
                    ].strip()
                if tokenizer.id_to_token(tokenizer.eoa_token_id) in processed_str:
                    processed_str = processed_str[
                        : processed_str.find(tokenizer.id_to_token(tokenizer.eoa_token_id))
                    ].strip()
                temp_outputs.append(processed_str)
            processed_outputs.append(temp_outputs)
        return list(zip(input_strs, processed_outputs)), curr_scores, -logit_scores


class IS_IQL_Evaluator(IQL_Evaluator):
    def __init__(self, env: Language_Environment, verbose: bool, kind: str, **generation_kwargs) -> None:
        super().__init__(env, verbose, kind, **generation_kwargs)

    def evaluate(self, model: IS_ILQL, items: InputType) -> Optional[Dict[str, Any]]:
        policy = IS_IQL_Policy(model, self.kind, **self.generation_kwargs)
        tokens = model.prepare_inputs(items)["tokens"]
        total_token_reward = 0
        total_env_reward = 0
        for i in range(tokens.shape[0]):
            result, sequence = interact_environment(self.env, policy, None)
            self.all_results.append(
                (
                    result,
                    sequence,
                )
            )
            env_reward = sum(map(lambda x: x[2], sequence))
            token_reward = sum(DataPoint.get_token_reward(result, model.dataset.tokenizer, model.dataset.token_reward))
            total_env_reward += env_reward
            total_token_reward += token_reward
            if self.verbose:
                print(result)
                print("=" * 25)
                print("token reward:", token_reward)
                print("env reward:", env_reward)
                print("avg token reward:", total_token_reward / (i + 1))
                print("avg env reward:", total_env_reward / (i + 1))
                print("=" * 25)
        kl_total = sum(policy.kls_all)
        entropy_total = -sum(policy.logprobs_all)
        self.all_entropy.extend(policy.logprobs_all)
        return {
            "token_reward": (total_token_reward / tokens.shape[0], tokens.shape[0]),
            "env_reward": (total_env_reward / tokens.shape[0], tokens.shape[0]),
        }
        # return {'token_reward': (total_token_reward / tokens.shape[0], tokens.shape[0]), 'env_reward': (total_env_reward / tokens.shape[0], tokens.shape[0]), 'kl': (kl_total / len(policy.kls_all), len(policy.kls_all)),
        # 'entropy': (entropy_total / len(policy.logprobs_all), len(policy.logprobs_all))}
