from collections import defaultdict
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable, Dict, Tuple, Union, Optional
import wandb
from models.gpt2_optional_final_ln import GPT2LMHeadModel, GPT2Model
from data.rl_data import DataPoint, List_RL_Dataset, RL_Dataset
from utils.torch_utils import get_transformer_logs
import copy
from models.base import BaseTransformer, Evaluator, InputType
from transformers.modeling_utils import PreTrainedModel
from utils.sampling_utils import *
import numpy as np
import math
from data.language_environment import Language_Environment, Language_Observation, interact_environment, Policy
from tqdm.auto import tqdm

from models.iql_model import PerTokenIQL, TransformerMLP, IQL_Policy, IQL_Evaluator, TopAdvantageNGrams


class TFILQL(PerTokenIQL):
    def __init__(
        self,
        model: PreTrainedModel,
        dataset: RL_Dataset,
        device: Union[torch.device, str] = "cuda",
        alpha: float = 0.005,
        gamma=1.0,
        beta=1.0,
        transition_weight=0.0,
        clip_weight: Optional[float] = None,
        value_max: Optional[float] = None,
        value_min: Optional[float] = None,
        detach_v: bool = False,
        detach_pi: bool = False,
        detach_q: bool = False,
        double_q: bool = False,
        tau: float = 0.9,
        seperate_policy: bool = False,
        seperate_target: bool = False,
        exp_weights: bool = False,
        dm_margin: float = 0.0,
        advanced_mlp: bool = False,
        cql_temp: float = 1.0,
    ):
        assert isinstance(model, GPT2Model) or isinstance(model, GPT2LMHeadModel)
        super().__init__(
            model=model,
            dataset=dataset,
            device=device,
            alpha=alpha,
            gamma=gamma,
            beta=beta,
            transition_weight=transition_weight,
            clip_weight=clip_weight,
            value_max=value_max,
            value_min=value_min,
            detach_v=detach_v,
            detach_pi=detach_pi,
            detach_q=detach_q,
            double_q=double_q,
            tau=tau,
            seperate_policy=seperate_policy,
            seperate_target=seperate_target,  # should be False
            exp_weights=exp_weights,
            dm_margin=dm_margin,
            advanced_mlp=advanced_mlp,
            cql_temp=cql_temp,
        )

        if hasattr(self, "target_q"):
            del self.target_q
        if hasattr(self, "target_q2"):
            del self.target_q2
        if self.lm_target is not None:
            del self.lm_target
            self.lm_target = None

    def forward(
        self,
        tokens: torch.Tensor,
        attn_mask: Optional[torch.Tensor],
        state_idxs: torch.Tensor,
        action_idxs: torch.Tensor,
        prefix_embs: Optional[torch.Tensor] = None,
        prefix_attn_mask: Optional[torch.Tensor] = None,
        remove_prefix_position_embs: bool = False,
        qv_kwargs=None,
        policy_kwargs=None,
        target_kwargs=None,
        skip_policy_on_train=False,
        detach_full_policy=False,
    ):
        if qv_kwargs is None:
            qv_kwargs = {}
        if target_kwargs is None:
            target_kwargs = {}
        if policy_kwargs is None:
            policy_kwargs = {}
        if self.lm_target is None:
            qv_kwargs.update(target_kwargs)
        if self.lm_policy is None:
            qv_kwargs.update(policy_kwargs)
        if attn_mask is None:
            attn_mask = torch.ones(tokens.shape, dtype=torch.long).to(self.device)
        if prefix_embs is None:
            prefix_embs = torch.empty((tokens.shape[0], 0, self.h_dim)).to(self.device)
        prefix_t = prefix_embs.shape[1]
        set_pos_ids = prefix_attn_mask is not None
        if prefix_attn_mask is None:
            prefix_attn_mask = torch.ones(prefix_embs.shape[:2]).to(self.device)
        input_attn_mask = torch.cat((prefix_attn_mask, attn_mask), dim=1)
        position_ids = torch.cumsum(input_attn_mask, dim=1) - 1 if set_pos_ids else None
        if isinstance(self.model, GPT2Model):
            transformer = self.model
            if self.lm_target is not None:
                target_transformer = self.lm_target
            if self.lm_policy is not None:
                policy_transformer = self.lm_policy
        elif isinstance(self.model, GPT2LMHeadModel):
            transformer = self.model.transformer
            if self.lm_target is not None:
                target_transformer = self.lm_target.transformer
            if self.lm_policy is not None:
                policy_transformer = self.lm_policy.transformer
        else:
            raise NotImplementedError
        if self.lm_target is not None:
            target_prefix_embs = prefix_embs.clone()
        if self.lm_policy is not None:
            policy_prefix_embs = prefix_embs.clone()
        if remove_prefix_position_embs:
            prefix_embs -= transformer.wpe(position_ids[:, : prefix_embs.shape[1]])
        input_embeddings = torch.cat((prefix_embs, transformer.wte(tokens)), dim=1)
        model_outputs = self.model(
            inputs_embeds=input_embeddings,
            attention_mask=input_attn_mask,
            position_ids=position_ids,
            output_hidden_states=True,
            **qv_kwargs
        )
        all_model_outputs = {
            "qv_model_outputs": model_outputs,
            "policy_model_outputs": model_outputs,
            "target_model_outputs": model_outputs,
        }
        if self.advanced_mlp:
            hidden_states = model_outputs.hidden_states[-2][:, prefix_t:, :]
        else:
            hidden_states = model_outputs.hidden_states[-1][:, prefix_t:, :]
        if self.lm_target is None:
            target_hidden_states = hidden_states
        else:
            if remove_prefix_position_embs:
                target_prefix_embs -= target_transformer.wpe(position_ids[:, : prefix_embs.shape[1]])
            target_input_embeddings = torch.cat((target_prefix_embs, target_transformer.wte(tokens)), dim=1)
            with torch.no_grad():
                target_outputs = self.lm_target(
                    inputs_embeds=target_input_embeddings,
                    attention_mask=input_attn_mask,
                    position_ids=position_ids,
                    output_hidden_states=True,
                    **target_kwargs
                )
            all_model_outputs["target_model_outputs"] = target_outputs
            if self.advanced_mlp:
                target_hidden_states = target_outputs.hidden_states[-2][:, prefix_t:, :]
            else:
                target_hidden_states = target_outputs.hidden_states[-1][:, prefix_t:, :]
        if self.lm_policy is None:
            if isinstance(self.model, GPT2Model):
                policy_hidden_states = hidden_states
            else:
                policy_hidden_states = model_outputs.hidden_states[-1][:, prefix_t:, :]
        else:
            if skip_policy_on_train and self.training:
                policy_hidden_states = hidden_states
            else:
                if remove_prefix_position_embs:
                    policy_prefix_embs -= policy_transformer.wpe(position_ids[:, : prefix_embs.shape[1]])
                policy_input_embeddings = torch.cat((policy_prefix_embs, policy_transformer.wte(tokens)), dim=1)
                if detach_full_policy:
                    with torch.no_grad():
                        policy_outputs = self.lm_policy(
                            inputs_embeds=policy_input_embeddings,
                            attention_mask=input_attn_mask,
                            position_ids=position_ids,
                            output_hidden_states=True,
                            **policy_kwargs
                        )
                else:
                    policy_outputs = self.lm_policy(
                        inputs_embeds=policy_input_embeddings,
                        attention_mask=input_attn_mask,
                        position_ids=position_ids,
                        output_hidden_states=True,
                        **policy_kwargs
                    )
                all_model_outputs["policy_model_outputs"] = policy_outputs
                if isinstance(self.model, GPT2Model):
                    if self.advanced_mlp:
                        policy_hidden_states = policy_outputs.hidden_states[-2][:, prefix_t:, :]
                    else:
                        policy_hidden_states = policy_outputs.hidden_states[-1][:, prefix_t:, :]
                else:
                    policy_hidden_states = policy_outputs.hidden_states[-1][:, prefix_t:, :]
        state_hidden_states = torch.gather(
            input=hidden_states, dim=1, index=state_idxs.unsqueeze(2).repeat(1, 1, self.h_dim)
        )
        action_hidden_states = torch.gather(
            input=hidden_states, dim=1, index=action_idxs.unsqueeze(2).repeat(1, 1, self.h_dim)
        )
        action_target_hidden_states = torch.gather(
            input=target_hidden_states, dim=1, index=action_idxs.unsqueeze(2).repeat(1, 1, self.h_dim)
        )
        vs = self.v(state_hidden_states.detach() if self.detach_v else state_hidden_states).squeeze(2)
        qs = self.q(action_hidden_states.detach() if self.detach_q else action_hidden_states)
        if self.double_q:
            qs2 = self.q2(action_hidden_states.detach() if self.detach_q else action_hidden_states)
        with torch.no_grad():
            target_qs = qs.detach().clone()
            if self.double_q:
                target_qs2 = qs2.detach().clone()
        if skip_policy_on_train and self.training and self.lm_policy is not None:
            logits = torch.zeros(
                (
                    policy_hidden_states.shape[0],
                    policy_hidden_states.shape[1],
                    self.dataset.tokenizer.num_tokens(),
                )
            ).to(self.device)
        else:
            if detach_full_policy:
                with torch.no_grad():
                    logits = self.pi(policy_hidden_states.detach() if self.detach_pi else policy_hidden_states)
            else:
                logits = self.pi(policy_hidden_states.detach() if self.detach_pi else policy_hidden_states)
        return {
            "model_outputs": all_model_outputs,
            "vs": vs,
            "target_vs": vs,
            "qs": (
                (
                    qs,
                    qs2,
                )
                if self.double_q
                else qs
            ),
            "target_qs": self.clip_values(torch.minimum(target_qs, target_qs2) if self.double_q else target_qs),
            "logits": logits,
        }

    def soft_update(self):
        pass

    def hard_update(self):
        pass
