import torch
import torch.nn as nn
from typing import Any, Callable, Dict, Tuple, Union, Optional
from data.language_environment import Language_Environment, interact_environment
from data.rl_data import DataPoint, RL_Dataset
from models.iql_model import IQL_Policy, PerTokenIQL, TransformerMLP
from models.base import Evaluator, InputType
from transformers.modeling_utils import PreTrainedModel
from utils.sampling_utils import *
from utils.torch_utils import get_transformer_logs
import wandb
import math

class PerUtteranceIQL(PerTokenIQL):
    def __init__(self, 
                 model: PreTrainedModel, 
                 dataset: RL_Dataset, 
                 device: Union[torch.device, str] = "cuda", 
                 alpha: float = 0.005, 
                 gamma=1.0, 
                 beta=1.0, 
                 transition_weight=0.0, 
                 clip_weight: Optional[float] = None, 
                 value_max: Optional[float] = None, 
                 value_min: Optional[float] = None, 
                 detach_v: bool = False, 
                 detach_pi: bool = False, 
                 detach_q: bool = False, 
                 double_q: bool = False, 
                 tau: float = 0.9, 
                 seperate_policy: bool = False, 
                 seperate_target: bool = False, 
                 exp_weights: bool = False, 
                 advanced_mlp: bool = False, 
                ):
        super(PerUtteranceIQL, self).__init__(model, dataset, device, alpha, 
                                              gamma, beta, transition_weight, clip_weight, 
                                              value_max, value_min, detach_v, 
                                              detach_pi, detach_q, double_q, tau, 
                                              seperate_policy, seperate_target, 
                                              exp_weights, 0.0, advanced_mlp, 1.0)
        if not self.advanced_mlp:
            self.q = nn.Sequential(
                nn.Linear(self.h_dim, self.h_dim*2),
                nn.ReLU(), 
                nn.Linear(self.h_dim*2, 1),
            )
        else:
            self.q = TransformerMLP(self.h_dim, 
                                    4 * self.h_dim if self.model.config.n_inner is None else self.model.config.n_inner, 
                                    1, self.model.config.resid_pdrop)
        if self.double_q:
            if not self.advanced_mlp:
                self.q2 = nn.Sequential(
                    nn.Linear(self.h_dim, self.h_dim*2),
                    nn.ReLU(), 
                    nn.Linear(self.h_dim*2, 1),
                )
            else:
                self.q2 = TransformerMLP(self.h_dim, 
                                         4 * self.h_dim if self.model.config.n_inner is None else self.model.config.n_inner, 
                                         1, self.model.config.resid_pdrop)
        if not self.advanced_mlp:
            self.target_q = nn.Sequential(
                nn.Linear(self.h_dim, self.h_dim*2),
                nn.ReLU(), 
                nn.Linear(self.h_dim*2, 1),
            )
        else:
            self.target_q = TransformerMLP(self.h_dim, 
                                           4 * self.h_dim if self.model.config.n_inner is None else self.model.config.n_inner, 
                                           1, self.model.config.resid_pdrop)
        if self.double_q:
            if not self.advanced_mlp:
                self.target_q2 = nn.Sequential(
                    nn.Linear(self.h_dim, self.h_dim*2),
                    nn.ReLU(), 
                    nn.Linear(self.h_dim*2, 1),
                )
            else:
                self.target_q2 = TransformerMLP(self.h_dim, 
                                                4 * self.h_dim if self.model.config.n_inner is None else self.model.config.n_inner, 
                                                1, self.model.config.resid_pdrop)
        for target_param, local_param in zip(self.target_q.parameters(), self.q.parameters()):
            target_param.data.copy_(local_param.data)
        if self.double_q:
            for target_param, local_param in zip(self.target_q2.parameters(), self.q2.parameters()):
                target_param.data.copy_(local_param.data)
    
    def prepare_inputs(self, items: InputType):
        data = super().prepare_inputs(items)
        data['state_idxs'], data['action_idxs'] = data['u_state_idxs'], data['u_action_idxs']
        data['terminals'], data['rewards'] = data['u_terminals'], data['u_rewards']
        return data
    
    def get_weights(self, 
                    tokens: torch.Tensor, 
                    vs: torch.Tensor, 
                    qs: Optional[torch.Tensor], 
                    state_idxs: torch.Tensor, 
                    action_idxs: torch.Tensor, 
                    terminals: torch.Tensor):
        weights = torch.full(tokens.shape, self.transition_weight).to(self.device)
        if self.exp_weights:
            w_values = torch.exp(self.beta * (qs - vs))
        else:
            # w_values = ((qs - vs) > 0.0).float()
            adv_sign = ((qs - vs) > 0.0).float()
            w_values = self.beta * adv_sign + (1 - self.beta) * (1 - adv_sign)
        n = torch.argmax(action_idxs, dim=1)+1
        for i in range(tokens.shape[0]):
            for x in range(n[i].item()):
                weights[i] = torch.scatter(weights[i], dim=0, 
                                           index=torch.arange(state_idxs[i, x], action_idxs[i, x]).to(self.device), 
                                           src=w_values[i, x].repeat(action_idxs[i, x]-state_idxs[i, x]))
        if self.clip_weight is not None:
            weights = torch.clip(weights, max=self.clip_weight)
        # print(list(map(lambda x: list(map(lambda y: (y[0], self.dataset.tokenizer.id_to_token(y[1].item()),), zip(*x))), zip(weights.detach().cpu().tolist(), tokens))))
        return weights
    
    def get_qvs(self, items: InputType, 
                prefix_embs: Optional[torch.Tensor]=None, 
                prefix_attn_mask: Optional[torch.Tensor]=None, 
                remove_prefix_position_embs: bool=False, 
                qv_kwargs=None, policy_kwargs=None, target_kwargs=None, 
                **kwargs):
        prepared_inputs = self.prepare_inputs(items)
        tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
        s_idx, a_idx = prepared_inputs['state_idxs'], prepared_inputs['action_idxs']
        rs, terminals = prepared_inputs['rewards'], prepared_inputs['terminals']
        self_outputs = self(tokens, attn_mask, s_idx, a_idx, 
                            prefix_embs, prefix_attn_mask, 
                            remove_prefix_position_embs, qv_kwargs, 
                            policy_kwargs, target_kwargs, 
                            **kwargs)
        model_outputs, vs, qs = self_outputs['model_outputs'], self_outputs['vs'], self_outputs['qs']
        target_qs, logits = self_outputs['target_qs'], self_outputs['logits']
        vt = vs[:, :-1]
        vtp1 = vs[:, 1:]
        if self.double_q:
            q1, q2 = qs
            q1, q2 = q1.squeeze(2), q2.squeeze(2)
            qs = (q1, q2,)
        else:
            qs = qs.squeeze(2)
        target_qs = target_qs.squeeze(2)
        with torch.no_grad():
            weights = self.get_weights(tokens, vt, target_qs, s_idx, a_idx, terminals)
        return {
                    'tokens': tokens, 
                    'attn_mask': attn_mask, 
                    'model_outputs': model_outputs, 
                    'vs': vt, 
                    'qs': qs, 
                    'vns': vtp1, 
                    'target_vs': vt, 
                    'target_qs': target_qs, 
                    'target_vns': vtp1, 
                    'rs': rs, 
                    'terminals': terminals, 
                    'logits': logits, 
                    'weights': weights, 
                }
    
    def get_loss(self, 
                 items: InputType, 
                 awac_weight=0.0, 
                 v_loss_weight=0.0, 
                 q_loss_weight=0.0, 
                 mc_returns=False):
        prepared_inputs = self.prepare_inputs(items)
        a_idx = prepared_inputs['action_idxs']

        get_qvs_outputs = self.get_qvs(items, 
                                       qv_kwargs={'output_attentions': True}, 
                                       policy_kwargs={'output_attentions': True}, 
                                       target_kwargs={'output_attentions': True}, 
                                       skip_policy_on_train=(awac_weight == 0.0), 
                                      )
        tokens, attn_mask, model_outputs = get_qvs_outputs['tokens'], get_qvs_outputs['attn_mask'], get_qvs_outputs['model_outputs']
        vs, qs = get_qvs_outputs['vs'], get_qvs_outputs['qs']
        vns, target_qs, rs = get_qvs_outputs['vns'], get_qvs_outputs['target_qs'], get_qvs_outputs['rs']
        terminals, logits, weights = get_qvs_outputs['terminals'], get_qvs_outputs['logits'], get_qvs_outputs['weights']
        
        logs = {}
        transformer_logs = {}
        transformer_logs['qv_transformer_logs'] = get_transformer_logs(model_outputs['qv_model_outputs'].attentions, self.model, attn_mask)
        if self.lm_policy is not None and (not (self.training and awac_weight == 0.0)):
            transformer_logs['policy_transformer_logs'] = get_transformer_logs(model_outputs['policy_model_outputs'].attentions, self.lm_policy, attn_mask)
        if self.lm_target is not None:
            transformer_logs['target_transformer_logs'] = get_transformer_logs(model_outputs['target_model_outputs'].attentions, self.lm_target, attn_mask)
        n = (1 - terminals[:, :-1]).sum().item()
        rs_downstream = self.get_downstream_rs(rs, self.gamma)
        if mc_returns:
            v_loss = self.get_v_loss(vs, rs_downstream, terminals)
        else:
            v_loss = self.get_v_loss(vs, target_qs, terminals)
        q_loss = self.get_q_loss(vns, qs, rs, self.gamma, terminals)
        token_loss = self.awac_loss(tokens, attn_mask, logits, weights)
        logs['token_loss'] = (token_loss.item(), n)
        loss = awac_weight * token_loss + v_loss_weight * v_loss + q_loss_weight * q_loss
        logs['v_loss'] = (v_loss.item(), n)
        logs['q_loss'] = (q_loss.item(), n)
        advantages = sum([((target_qs[i] - vs[i])[:(1 - terminals[i, :-1]).sum().long().item()]).detach().cpu().tolist() for i in range(tokens.shape[0])], [])
        if self.double_q:
            q1, q2 = qs
            logs['q1_avg'] = ((q1 * (1 - terminals[:, :-1])).sum().item() / max(n, 1), n)
            logs['q1_var'] = (((((q1 - logs['q1_avg'][0]) ** 2)*(1 - terminals[:, :-1])).sum() / max(n, 1)).item(), 1)
            logs['q2_avg'] = ((q2 * (1 - terminals[:, :-1])).sum().item() / max(n, 1), n)
            logs['q2_var'] = (((((q2 - logs['q2_avg'][0]) ** 2)*(1 - terminals[:, :-1])).sum() / max(n, 1)).item(), 1)
        else:
            logs['q_avg'] = ((qs * (1 - terminals[:, :-1])).sum().item() / max(n, 1), n)
            logs['q_var'] = (((((qs - logs['q_avg'][0]) ** 2)*(1 - terminals[:, :-1])).sum() / max(n, 1)).item(), 1)
        logs['v_avg'] = ((vs * (1 - terminals[:, :-1])).sum().item() / max(n, 1), n)
        logs['v_var'] = (((((vs - logs['v_avg'][0]) ** 2)*(1 - terminals[:, :-1])).sum() / max(n, 1)).item(), 1)
        act_weights = torch.gather(weights, dim=1, index=torch.maximum(a_idx-1, torch.tensor(0).to(self.device)))
        logs['act_weight_avg'] = (((act_weights * (1 - terminals[:, :-1])).sum() / max(n, 1)).item(), n)
        logs['transformer'] = transformer_logs
        postproc_f = lambda l: l.update({'loss': awac_weight * l['token_loss'] + q_loss_weight * l['q_loss'] + v_loss_weight * l['v_loss']})
        hist_f = lambda l: l.update({'advantage_hist': wandb.Histogram(advantages)})
        return loss, logs, [postproc_f, hist_f]

class PerUtteranceIQL_Policy(IQL_Policy):
    def __init__(self, iql_model: PerTokenIQL, 
                 kind: str, **generation_kwargs) -> None:
        super().__init__(iql_model, 'sample', **generation_kwargs)
        assert kind in {'rerank'}
        self.kind = kind
    
    def rerank_raw(self, 
                   tokens: torch.Tensor, attn_mask: torch.Tensor, 
                   state_idxs: torch.Tensor, action_idxs: torch.Tensor, 
                   termination_condition: Callable[[np.ndarray], bool], 
                   num_generations=1, max_generation_len=None, 
                   temp=1.0, top_k=None, top_p=None, 
                   log_prob_weight=0.0, 
                   prefix_embs: Optional[torch.Tensor]=None, 
                   prefix_attn_mask: Optional[torch.Tensor]=None, 
                   remove_prefix_position_embs: bool=False):
        tokenizer = self.iql_model.dataset.tokenizer
        max_length = self.iql_model.dataset.max_len
        if max_length is None:
            max_length = self.iql_model.model.config.n_positions
        max_length = min(max_length, self.iql_model.model.config.n_positions)
        device = self.iql_model.device
        bsize = tokens.shape[0]
        n = bsize * num_generations
        if max_generation_len is None:
            max_generation_len = max_length+1
        input_strs = [tokenizer.decode(tokens[i, :][:attn_mask[i, :].sum().long()].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
        prefix_t = 0 if prefix_embs is None else prefix_embs.shape[1]
        model_outputs = self.iql_model(tokens, attn_mask, 
                                       state_idxs, action_idxs, 
                                       prefix_embs=prefix_embs, 
                                       prefix_attn_mask=prefix_attn_mask, 
                                       remove_prefix_position_embs=remove_prefix_position_embs, 
                                       policy_kwargs={'use_cache': True})['model_outputs']['policy_model_outputs']
        dialogue_kvs = model_outputs.past_key_values
        dialogue_lens = attn_mask.sum(dim=1)
        tokens = pad_sequence(torch.repeat_interleave(tokens, num_generations, dim=0), max_length, tokenizer.pad_token_id, device, 1)
        dialogue_lens = torch.repeat_interleave(dialogue_lens, num_generations, dim=0)
        dialogue_kvs = map_all_kvs(lambda x: pad_sequence(torch.repeat_interleave(x, num_generations, dim=0), max_length, 0.0, device, 2), dialogue_kvs)
        log_probs = torch.full((dialogue_lens.shape[0],), 0.0).to(device)
        kls = torch.full((dialogue_lens.shape[0],), math.log(num_generations)-((num_generations-1)/num_generations)).to(device)
        termination_mask = torch.full((dialogue_lens.shape[0],), 1).to(device)
        state_idxs_temp, action_idxs_temp = torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device), torch.zeros((dialogue_lens.shape[0], 1,)).long().to(device)
        t = torch.min(dialogue_lens).int()
        while termination_mask.sum() > 0 and (t+prefix_t) < max_length:
            curr_token = tokens[:, t-1].unsqueeze(1)
            curr_dialogue_kvs = map_all_kvs(lambda x: x[:,:,:(t+prefix_t)-1,:], dialogue_kvs)
            iql_outputs = self.iql_model(curr_token, None, state_idxs_temp, action_idxs_temp, 
                                         policy_kwargs={'use_cache': True, 'past_key_values': curr_dialogue_kvs})
            transformer_outputs, logits = iql_outputs['model_outputs']['policy_model_outputs'], iql_outputs['logits']
            logits[:, 0, tokenizer.pad_token_id] = torch.where(termination_mask == 1, float('-inf'), 1e7)
            logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]] = logits[torch.arange(0, n).to(device), torch.full((n,), 0).to(device), tokens[:, t]].masked_fill_(t < dialogue_lens, 1e7)
            logits = process_logits(logits, temp=temp, top_k=top_k, top_p=top_p)
            cat_dist = torch.distributions.categorical.Categorical(logits=logits[:, 0])
            new_tokens = cat_dist.sample()
            log_probs += cat_dist.log_prob(new_tokens)
            tokens[:, t] = new_tokens
            dialogue_kvs = update_kvs(dialogue_kvs, transformer_outputs.past_key_values, torch.arange(0, n).to(device), (t+prefix_t)-1)
            for idx in range(n):
                if tokens[idx, t] == tokenizer.eoa_token_id and t >= dialogue_lens[idx]:
                    termination_mask[idx] *= (1 - int(termination_condition(tokenizer.decode(tokens[idx, :].tolist(), 
                                                                                             clean_up_tokenization_spaces=False))))
            t += 1
            termination_mask *= ((t-dialogue_lens) < max_generation_len).int()
        
        attn_mask = (tokens != tokenizer.pad_token_id).long()
        if prefix_embs is not None:
            prefix_embs = torch.repeat_interleave(prefix_embs, num_generations, dim=0)
        if prefix_attn_mask is not None:
            prefix_attn_mask = torch.repeat_interleave(prefix_attn_mask, num_generations, dim=0)
        action_idxs = (attn_mask.sum(dim=1)-1).unsqueeze(1)
        state_idxs = (attn_mask.sum(dim=1)-1).unsqueeze(1)
        q_outputs = self.iql_model(tokens, attn_mask, 
                                   state_idxs, action_idxs, 
                                   prefix_embs=prefix_embs, 
                                   prefix_attn_mask=prefix_attn_mask, 
                                   remove_prefix_position_embs=remove_prefix_position_embs)['target_qs']
        q_outputs = q_outputs.squeeze(2).squeeze(1)
        scores = (q_outputs + log_probs * log_prob_weight).reshape(-1, num_generations)
        order = torch.argsort(-scores, dim=1)
        output_strs = [tokenizer.decode(tokens[i, :].tolist(), clean_up_tokenization_spaces=False) for i in range(len(tokens))]
        processed_outputs = []
        for i in range(len(input_strs)):
            temp_outputs = []
            for x in range(num_generations):
                processed_str = output_strs[i*num_generations+order[i, x]][len(input_strs[i]):].strip()
                if tokenizer.id_to_token(tokenizer.pad_token_id) in processed_str:
                    processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.pad_token_id))].strip()
                if tokenizer.id_to_token(tokenizer.eoa_token_id) in processed_str:
                    processed_str = processed_str[:processed_str.find(tokenizer.id_to_token(tokenizer.eoa_token_id))].strip()
                temp_outputs.append(processed_str)
            processed_outputs.append(temp_outputs)
        scores = torch.gather(scores, dim=1, index=order)
        log_probs = torch.gather(log_probs.reshape(-1, num_generations), dim=1, index=order)
        kls = torch.gather(kls.reshape(-1, num_generations), dim=1, index=order)
        return list(zip(input_strs, processed_outputs)), log_probs, kls

    def generate(self, items: InputType, 
                 termination_condition: Callable[[np.ndarray], bool], **kwargs):
        prepared_inputs = self.iql_model.prepare_inputs(items)
        tokens, attn_mask = prepared_inputs['tokens'], prepared_inputs['attn_mask']
        state_idxs, action_idxs = prepared_inputs['state_idxs'], prepared_inputs['action_idxs']
        if self.kind == 'rerank':
            method = self.rerank_raw
        else:
            raise NotImplementedError
        generations, info, kls = method(tokens, attn_mask, 
                                             state_idxs, action_idxs, 
                                             termination_condition, 
                                             **kwargs)
        return generations, info, kls

class UtteranceIQL_Evaluator(Evaluator):
    def __init__(self, env: Language_Environment, verbose: bool, kind: str, **generation_kwargs) -> None:
        super().__init__()
        self.env = env
        self.verbose = verbose
        self.kind = kind
        self.generation_kwargs = generation_kwargs
    
    def evaluate(self, model: PerTokenIQL, items: InputType) -> Optional[Dict[str, Any]]:
        policy = PerUtteranceIQL_Policy(model, self.kind, **self.generation_kwargs)
        tokens = model.prepare_inputs(items)['tokens']
        total_token_reward = 0
        total_env_reward = 0
        for i in range(tokens.shape[0]):
            result, sequence = interact_environment(self.env, policy, None)
            env_reward = sum(map(lambda x: x[2], sequence))
            token_reward = sum(DataPoint.get_token_reward(result, model.dataset.tokenizer, model.dataset.token_reward))
            total_env_reward += env_reward
            total_token_reward += token_reward
            if self.verbose:
                print(result)
                print('='*25)
                print('token reward:', token_reward)
                print('env reward:', env_reward)
                print('avg token reward:', total_token_reward / (i + 1))
                print('avg env reward:', total_env_reward / (i + 1))
                print('='*25)
        return {'token_reward': (total_token_reward / tokens.shape[0], tokens.shape[0]), 'env_reward': (total_env_reward / tokens.shape[0], tokens.shape[0])}
