import numpy as np
import torch
import torch.nn as nn
from transformers import get_scheduler

import promptrl.utils as utils

class LMAgent(object):
    def __init__(self, accelerator, args, logger, tokenizer, model, prompt_model):
        self.accelerator = accelerator
        self.args = args
        self.logger = logger
        self.tokenizer = tokenizer
        self.model = model
        self.prompt_model = prompt_model

        self.cache = {}

        self.SEP = ' [SEP]'
        self.sep_input_ids = self.tokenizer(self.SEP, return_tensors='pt').input_ids.to(self.accelerator.device)
        self.sep_embeds = self.model._get_wte()(self.sep_input_ids).detach()
        self.eos_token_id = self.tokenizer(' [').input_ids[0]

        self.epoch_steps = args.train_samples // args.batch_size
        self.max_steps = args.epochs * self.epoch_steps
        self.prim_lw = utils.get_lw_scheduler(args.prim_lw)
        self.aux_lw = utils.get_lw_scheduler(args.aux_lw)

    def state_dict(self):
        states = {
            'prompt_model': self.prompt_model.state_dict(),
        }
        if self.args.finetune_lm:
            states['model'] = self.model.state_dict()
        return states

    def load_state_dict(self, states):
        self.prompt_model.load_state_dict(states['prompt_model'])
        if 'model' in states:
            assert self.args.finetune_lm
            self.model.load_state_dict(states['model'])

    def train(self):
        # set train
        if self.args.finetune_lm or not self.args.lm_eval_mode:
            self.model.train()
        else:
            self.model.eval()
        self.prompt_model.train()

    def eval(self):
        self.model.eval()
        self.prompt_model.eval()

    def _logits_labelled(self, step, batch):
        raise NotImplementedError

    def _get_lws(self, step, task_ids):
        t = step / self.max_steps
        lws = [self.prim_lw(t) if task == 0 else self.aux_lw(t) for task in task_ids]
        return torch.FloatTensor(lws)

    def step_labelled(self, step, batch, log_metrics=False):
        output_ids, logits, loss_mask = self._logits_labelled(step, batch)
        lws = self._get_lws(step, batch['tasks']).to(output_ids.device)
        batch_loss = utils.shifted_cross_ent(output_ids, logits, loss_mask=loss_mask)
        total_loss = (batch_loss * lws.unsqueeze(1)).sum() / self.args.batch_size

        if not log_metrics:
            return total_loss

        running_metrics = utils.RunningAverages()
        running_metrics.push('loss', total_loss.item(), output_ids.shape[0])
        running_metrics.push('epoch', step // self.epoch_steps, output_ids.shape[0])
        running_metrics.push('prim_lw', self.prim_lw(step / self.max_steps), output_ids.shape[0])
        running_metrics.push('aux_lw', self.aux_lw(step / self.max_steps), output_ids.shape[0])
        for i, task in enumerate(batch['tasks']):
            running_metrics.push(f'task{task}_loss', batch_loss[i].sum().item(), 1)

        pred_ids = logits.argmax(dim=-1)
        running_metrics.merge(utils.token_stats(pred_ids, output_ids, loss_mask, tasks=batch['log_tasks'], shift=True))

        return batch_loss, running_metrics

    def sample_labelled(self, step, batch):
        raise NotImplementedError

    def reset(self):
        # reset cache
        self.cache = {
            'prefix_queue': [],
            'init_prefixes': None,
            'prefixes': None,
            'past_key_values': None,
            'n_pos': 0,
            'prev_infos': None
        }

    def _prompt_first(self, obs, goal_ids, infos, task_id=0):
        # return token embeds + deep prompts + mask for first obs in sequence
        raise NotImplementedError

    def _prompt(self, obs, goal_ids, prev_action, infos, task_id=0):
        # return token embeds for obs and prev_action
        raise NotImplementedError

    def observe_first(self, obs, goal_ids, infos):
        assert self.cache['n_pos'] == 0
        self.cache['prev_infos'] = infos
        self.cache['init_prefixes'] = self._prompt_first(obs, goal_ids, infos)
        self.cache['deep_prompt_length'] = self.cache['init_prefixes']['attention_mask'].shape[1] - self.cache['init_prefixes']['inputs_embeds'].shape[1]
        self.cache['prefixes'] = dict(self.cache['init_prefixes'])

        self.cache['past_key_values'] = None

    def observe(self, obs, goal_ids, prev_action, infos):
        self.cache['prev_infos'] = infos
        self.cache['n_pos'] += 1
        action_obs_embeds = self._prompt(obs, goal_ids, prev_action, infos)
        self.cache.setdefault('prefix_queue', []).append(action_obs_embeds)
        prefix_queue_popped = False
        if self.args.dynamic_length_eval:
            while self.cache['init_prefixes']['inputs_embeds'].shape[1] + self.cache['deep_prompt_length'] + self.sep_input_ids.shape[1] + 20 + \
                    sum(ao.shape[1] for ao in self.cache['prefix_queue']) + 1 > self.args.max_context_tokens:
                self.cache['prefix_queue'].pop(0)
                prefix_queue_popped = True
        else:
            if len(self.cache['prefix_queue']) > 6:# XXX TODO make this a variable (max_obs_length)
                self.cache['prefix_queue'].pop(0)
                prefix_queue_popped = True
        if prefix_queue_popped or self.cache['past_key_values'] is None:
            self.cache['prefixes']['inputs_embeds'] = torch.cat([self.cache['init_prefixes']['inputs_embeds']] + self.cache['prefix_queue'], dim=1)
            self.cache['prefixes']['attention_mask'] = torch.ones((1, self.cache['prefixes']['inputs_embeds'].shape[1] + self.cache['deep_prompt_length']), dtype=torch.float32, device=self.accelerator.device)
            self.cache['prefixes']['past_key_values'] = self.cache['init_prefixes']['past_key_values']
        else:
            self.cache['prefixes']['inputs_embeds'] = action_obs_embeds
            self.cache['prefixes']['attention_mask'] = nn.functional.pad(self.cache['prefixes']['attention_mask'], (0, action_obs_embeds.shape[1]), 'constant', 1)
            self.cache['prefixes']['past_key_values'] = self.cache['past_key_values']

        self.cache['past_key_values'] = None

    def sample(self):
        self.cache['past_key_values'] = self.model(**self.cache['prefixes'], use_cache=True).past_key_values
        return self.model.generate(past_key_values=self.cache['past_key_values'], prefix_ids=self.sep_input_ids, max_length=20, eos_token_id=self.eos_token_id).output_ids

    def direct_sample(self, obs, goal_ids, infos, task_id=0, max_length=20):
        prefixes = self._prompt_first(obs, goal_ids, infos, task_id=task_id)
        return self.model.generate(**prefixes, prefix_ids=self.sep_input_ids, max_length=max_length, eos_token_id=self.eos_token_id).output_ids

class IterPromptAgent(LMAgent):
    def _logits_labelled(self, step, batch):
        batch_size = len(batch['obs']) if batch['obs'] is not None else len(batch['obs_precomputed'])
        inputs = self.prompt_model.fill_sequence(
            task_ids=batch['tasks'],
            obs=batch['obs'],
            obs_precomputed=batch.get('obs_precomputed', None),
            goals=batch['goals'],
            goals_precomputed=batch.get('goals_precomputed', None),
            fill_arr=batch['fill_arr'],
            fill_locs=batch['fill_locs']
        )
        if 'past_key_values' in inputs and inputs['past_key_values'] is not None:
            dp_len = inputs['past_key_values'][0][0].shape[2]
            inputs['attention_mask'] = nn.functional.pad(batch['attention_mask'], (dp_len, 0), value=1)
        else:
            inputs['attention_mask'] = batch['attention_mask']
        logits = self.model(**inputs).logits
        return batch['output_ids'], logits, batch['loss_mask']

    def sample_labelled(self, step, batch):
        check_idx = min(3, len(batch['fill_locs'][0]))
        stop_idx = batch['fill_locs'][0][check_idx - 1][1]
        end_idx = stop_idx + batch['actions'][0][check_idx - 1].shape[0]
        with torch.no_grad():
            inputs = self.prompt_model.fill_sequence(
                task_ids=[batch['tasks'][0]],
                obs=[batch['obs'][0][:check_idx]] if batch['obs'] is not None else None,
                obs_precomputed=[batch['obs_precomputed'][0][:check_idx]] if batch.get('obs_precomputed', None) is not None else None,
                goals=batch['goals'][:1],
                goals_precomputed=batch['goals_precomputed'][:1] if batch.get('goals_precomputed', None) is not None else None,
                fill_arr=batch['fill_arr'][:1, :stop_idx],
                fill_locs=[batch['fill_locs'][0][:check_idx]]
            )
            gen_outputs = self.model.generate(
                **inputs,
                prefix_ids=batch['sep'],
                max_length=20
            ).output_ids

        target = self.tokenizer.decode(batch['actions'][0][check_idx - 1], skip_special_tokens=True)
        sample = self.tokenizer.decode(gen_outputs[0], skip_special_tokens=True)
        return target, sample

    def _prompt_first(self, obs, goal_ids, infos, task_id=0):
        with torch.no_grad():
            prompts = self.prompt_model(
                task_ids=[task_id],
                obs=obs,
                obs_pos=[0],
                goals=goal_ids,
            )
        prompts['inputs_embeds'] = torch.stack(prompts['inputs_embeds'])
        prompts['attention_mask'] = torch.stack(prompts['attention_mask'])
        return prompts

    def _prompt(self, obs, goal_ids, prev_action, infos, task_id=0):
        # modify prefixes
        action_padded = self.SEP + ' ' + prev_action + self.SEP
        action_tok = self.tokenizer(action_padded, return_tensors='pt').input_ids.to(self.accelerator.device)
        with torch.no_grad():
            action_embeds = self.model._get_wte()(action_tok)
            obs_embeds = self.prompt_model(
                task_ids=[task_id],
                obs=obs,
                obs_pos=[self.cache['n_pos']],
                goals=goal_ids
            )['inputs_embeds']
        obs_embeds = torch.stack(obs_embeds)
        return torch.cat((action_embeds, obs_embeds), dim=1)

class RankPromptAgent(IterPromptAgent):
    def sample(self):
        self.cache['past_key_values'] = self.model(**self.cache['prefixes'], use_cache=True).past_key_values
        admissible_actions = self.cache['prev_infos']['admissible_commands'][0]
        admissible_actions = [' [SEP] ' + a + ' [SEP]' for a in admissible_actions]
        admissible = self.tokenizer(admissible_actions, return_tensors='pt', padding=True).to(self.accelerator.device)
        attn = nn.functional.pad(admissible.attention_mask, (self.cache['past_key_values'][0][0].shape[2], 0), 'constant', 1)
        scores = self.model.batch_score(input_ids=admissible.input_ids, attention_mask=attn, past_key_values=self.cache['past_key_values'], batch_size=self.args.batch_size)

        return [admissible.input_ids[scores.argmax().item()]]

class CaptionPromptAgent(IterPromptAgent):
    def __init__(self, accelerator, args, logger, tokenizer, model, prompt_model):
        # TODO make scripts part of library
        from scripts.load_args import CaptionArgs, VHomeCaptionArgs
        from scripts.load import load_checkpoint_agent
        super().__init__(accelerator, args, logger, tokenizer, model, prompt_model)

        if args.task_kind == 'alf':
            caption_args = CaptionArgs()
        elif args.task_kind == 'virtualhome':
            caption_args = VHomeCaptionArgs()
        self.caption_agent = load_checkpoint_agent(accelerator, logger, caption_args)
        self.caption_agent.eval()

        self.caption_goal = 'Your task is to: caption the following observation'
        self.caption_goal_ids = self.caption_agent.tokenizer(self.caption_goal, return_tensors='pt').input_ids.squeeze(0).to(accelerator.device)

    def _decode_action(self, action):
        gen_action_s = self.caption_agent.tokenizer.decode(action[0], skip_special_tokens=True)
        gen_action_clean = gen_action_s.split('[', 2)[1]
        gen_action_clean = gen_action_clean.replace('SEP]', '').strip()
        return gen_action_clean

    def _caption_obs(self, obs):
        caption_ids = self.caption_agent.direct_sample(obs, [self.caption_goal_ids], {}, task_id=1, max_length=50)
        caption = self._decode_action(caption_ids)
        caption_obs = [self.tokenizer(caption, return_tensors='pt').input_ids.squeeze(0).to(self.accelerator.device)]
        return caption_obs

    def observe_first(self, obs, goal_ids, infos):
        return super().observe_first(self._caption_obs(obs), goal_ids, infos)

    def observe(self, obs, goal_ids, prev_action, infos):
        return super().observe(self._caption_obs(obs), goal_ids, prev_action, infos)

class BeamAffordanceAgent(IterPromptAgent):
    def __init__(self, accelerator, args, logger, tokenizer, model, prompt_model):
        # TODO make scripts part of library
        from scripts.load_args import AffordanceArgs, VHomeAffordanceArgs
        from scripts.load import load_checkpoint_agent
        super().__init__(accelerator, args, logger, tokenizer, model, prompt_model)

        if args.task_kind == 'alf':
            aff_args = AffordanceArgs()
        elif args.task_kind == 'virtualhome':
            aff_args = VHomeAffordanceArgs()
        self.affordance_agent = load_checkpoint_agent(accelerator, logger, aff_args)
        self.affordance_agent.eval()

        self.affordance_goal = 'Your task is to: predict whether the following action is valid.'
        self.affordance_goal_ids = self.affordance_agent.tokenizer('Your task is to: predict whether the following action is valid.', return_tensors='pt').input_ids.squeeze(0).to(accelerator.device)
        self.valid_id = self.tokenizer(' valid').input_ids[0]
        self.invalid_id = self.tokenizer(' invalid').input_ids[0]

    def reset(self):
        super().reset()
        self.cache['past_obs'] = []

    def observe_first(self, obs, goal_ids, infos):
        assert len(self.cache['past_obs']) == 0
        self.cache['past_obs'].append(obs[0])
        return super().observe_first(obs, goal_ids, infos)

    def observe(self, obs, goal_ids, prev_action, infos):
        assert len(self.cache['past_obs']) == self.cache['n_pos'] + 1
        self.cache['past_obs'].append(obs[0])
        return super().observe(obs, goal_ids, prev_action, infos)

    def _decode_action(self, action):
        gen_action_clean = action.split('[', 2)[1]
        gen_action_clean = gen_action_clean.replace('SEP]', '').strip()
        return gen_action_clean

    def _get_affordance_scores(self, past_obs, actions, epsilon=0.1):
        actions = self.tokenizer([a + ' [SEP]' for a in actions], return_tensors='pt', padding=True).to(self.accelerator.device)
        lengths = actions.attention_mask.sum(dim=1)
        prefixes = self.affordance_agent._prompt_first(past_obs, [self.affordance_goal_ids], {}, task_id=1)
        past_key_values = self.affordance_agent.model(**prefixes, use_cache=True).past_key_values
        past_key_values = [[kv.expand(len(actions.input_ids), -1, -1, -1) for kv in layer] for layer in past_key_values]
        attn = nn.functional.pad(actions.attention_mask, (past_key_values[0][0].shape[2], 0), 'constant', 1)
        outputs = self.affordance_agent.model(input_ids=actions.input_ids, attention_mask=attn, past_key_values=past_key_values)
        scores = outputs.logits[..., (self.valid_id, self.invalid_id)]
        scores = scores[range(scores.shape[0]), lengths - 1]
        affordance = nn.functional.softmax(scores, dim=-1)[:, 0]
        affordance = epsilon + (1 - epsilon) * affordance
        return affordance

    def sample(self):
        self.cache['past_key_values'] = self.model(**self.cache['prefixes'], use_cache=True).past_key_values
        candidates = self.model.generate_beam(past_key_values=self.cache['past_key_values'], prefix_ids=self.sep_input_ids, max_length=20, eos_token_id=self.eos_token_id, n_beams=self.args.beam_search_k)
        candidates_decoded = self.tokenizer.batch_decode(candidates.output_ids, skip_special_tokens=True)
        cand_actions = [self._decode_action(action) for action in candidates_decoded]

        past_obs = [np.concatenate(self.cache['past_obs'], axis=0)]
        affordance_scores = self._get_affordance_scores(past_obs, cand_actions, self.args.affordance_epsilon)
        lm_scores = candidates.scores.exp()
        saycan_scores = affordance_scores * lm_scores

        '''
        for i in saycan_scores.topk(5)[1]:
        #for i in lm_scores.topk(5)[1]:
            idx = i.item()
            print(f'{cand_actions[idx]}: {saycan_scores[idx].item():.3f} = {affordance_scores[idx].item():.3f} x {lm_scores[idx].item():.3f}')
        '''

        return [candidates.output_ids[saycan_scores.argmax()]]
