import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import time
from machiavelli.utils import clean

from machiavelli.agent.base_agent import BaseAgent
from machiavelli.agent.memory import ReplayMemory, PrioritizedReplayMemory, Transition, State
from machiavelli.agent.drrn import DRRN


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def build_state(tokenize_fn, obs, infos, prev_obs=None, prev_acts=None):
    """
    Return a state representation built from various info sources.
    obs, prev_obs, prev_acts: list of strs.
    """
    def get_desc(info):
        return f"""You are {info["player_character"]} in the text-based adventure game {info["title"]}. {info["short_summary"]}"""

    # def get_stats(info):
    #     stats = ""
    #     for var, val in info["variables"].items():
    #         if isinstance(val, int) or isinstance(val, float):
    #             stats += f"{var}: {val}, "
    #     return stats[:-2]

    def get_achievements(info):
        text = "\n".join(info['game_state']['remaining_achievements'])
        return "You are trying to obtain the following achievements:\n" + text

    if prev_obs is None:
        return [State(tokenize_fn(ob), tokenize_fn(get_desc(info)), torch.tensor(list(info['game_state']['stats'].values()), dtype=torch.float32, device=device), tokenize_fn(ob))
                for ob, info in zip(obs, infos)]
    else:
        states = []
        for prev_ob, ob, info, act in zip(prev_obs, obs, infos, prev_acts):
            sent = "[CLS] %s [SEP] %s [SEP] %s [SEP] %s [SEP] %s [SEP]" % (
                ob, prev_ob, act, get_achievements(info), get_desc(info))
            # sent = "[CLS] %s [SEP]" % (ob + info['inv'] + info['look'])
            states.append(State(tokenize_fn(ob), tokenize_fn(
                get_desc(info)), torch.tensor(list(info['game_state']['stats'].values()), dtype=torch.float32, device=device), tokenize_fn(sent)))
        return states


class DRRN_Agent(BaseAgent):
    def __init__(self, args):
        self.gamma = args.gamma
        self.batch_size = args.batch_size
        self.network = DRRN(args.lm_name, args.stats_dim,
                            args.hidden_dim).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(args.lm_name, model_max_length=512)
        # self.memory = ReplayMemory(args.memory_size)
        self.memory = PrioritizedReplayMemory(
            args.memory_size, args.priority_fraction)
        self.save_path = args.output_dir
        self.clip = args.clip
        self.optimizer = torch.optim.Adam(
            self.network.parameters(), lr=args.learning_rate)
        self.transitions = None
        self.prev_obs = None
        self.prev_acts = None
        self.configure = False

    def _tokenize(self, text):
        #text = [clean(t) for t in text]
        encoding = self.tokenizer(
            clean(text), padding=True, truncation=True, return_tensors="pt").to(device)
        encoding['length'] = torch.tensor(
            [encoding['input_ids'].size(1)]).to(device)
        return encoding

    def _configure(self, lm, alpha, beta):
        self.lm = lm
        self.alpha = alpha
        self.beta = beta
        self.configure = True
    
    def train(self):
        self.prev_obs = None
        self.prev_acts = None

    def test(self):
        self.prev_obs = None
        self.prev_acts = None

    def observe(self, transition, is_prior=False):
        self.memory.push(transition, is_prior)

    def act(self, states, poss_acts, lm=None, eps=None, alpha='', beta=1, k=-1):
        """ Returns a string action from poss_acts. """
        idxs = self.network.act(
            states, poss_acts, lm, eps, alpha, beta, k)
        act_ids = [poss_acts[batch][idx] for batch, idx in enumerate(idxs)]
        return act_ids, idxs, None

    def update(self):
        if len(self.memory) < self.batch_size:
            return

        self.transitions = transitions = self.memory.sample(self.batch_size)
        batch = Transition(*zip(*transitions))

        # Compute Q(s', a') for all a'
        # TODO: Use a target network???
        next_qvals = self.network(batch.next_state, batch.next_acts)
        # Take the max over next q-values
        next_qvals = torch.tensor([vals.max()
                                  for vals in next_qvals], device=device)
        # Zero all the next_qvals that are done
        next_qvals = next_qvals * \
            (1 - torch.tensor(batch.done, dtype=torch.float, device=device))
        targets = torch.tensor(
            batch.reward, dtype=torch.float, device=device) + self.gamma * next_qvals
        # Next compute Q(s, a)
        # Nest each action in a list - so that it becomes the only admissible cmd
        nested_acts = tuple([[a] for a in batch.act])
        qvals = self.network(batch.state, nested_acts)
        # Combine the qvals: Maybe just do a greedy max for generality
        qvals = torch.cat(qvals)

        # Compute TD loss (Huber loss)
        loss = F.smooth_l1_loss(qvals, targets.detach())

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.network.parameters(), self.clip)
        self.optimizer.step()
        loss_val = loss.item()
        del next_qvals, qvals, targets, loss
        return loss_val
