import pathlib
from tqdm import tqdm
import torch
import torch.nn as nn

import promptrl.utils as utils

class Trainer(object):
    def __init__(self, accelerator, args, logger, tokenizer, agent, optimizer, lr_scheduler, checkpoint_dir=None):
        self.accelerator = accelerator
        self.args = args
        self.logger = logger
        self.tokenizer = tokenizer
        self.agent = agent
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        if checkpoint_dir is None:
            self.checkpoint_dir = pathlib.Path.cwd() / 'checkpoints'
        else:
            self.checkpoint_dir = pathlib.Path(checkpoint_dir)
        if self.args.exp_id is not None:
            self.checkpoint_dir = self.checkpoint_dir / self.args.exp_id

    def get_checkpoint_path(self, epoch):
        name = self.args.lm_base + '_'
        if self.args.finetune_lm:
            name += 'ft_'
        name += self.args.prompt_arch + '_'
        name += self.args.obs_type + '_'
        name += f'{epoch}_{self.args.seed}.tar'
        return self.checkpoint_dir / name

    def save(self, epoch):
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        path = self.get_checkpoint_path(epoch)
        state_dict = {
            'agent': self.agent.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.lr_scheduler.state_dict(),
            'epoch': epoch,
        }
        torch.save(state_dict, str(path))

    def load(self, path):
        checkpoint = torch.load(path)
        self.agent.load_state_dict(checkpoint['agent'])
        #self.optimizer.load_state_dict(checkpoint['optimizer'])
        #self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        self.load_epoch = checkpoint.get('epoch', 0)

    def train(self, dataloader, eval_dataloaders=None, guided_eval_dataloader=None, eval_envs=None):
        max_step = self.args.train_samples // self.args.batch_size
        pbar = tqdm(total=self.args.epochs * max_step)
        global_step = 0
        for epoch in range(self.args.epochs):
            self.logger.info(f'Epoch {epoch}')

            grad_accum_count = 0
            for step, batch in enumerate(dataloader):
                self.agent.train()

                batch_loss = self.agent.step_labelled(global_step, batch)
                self.accelerator.backward(batch_loss)

                grad_accum_count += 1
                if (step == max_step - 1) or grad_accum_count >= self.args.grad_accum_steps:
                    self.optimizer.step()
                    self.lr_scheduler.step()

                    grad_accum_count = 0

                self.optimizer.zero_grad()

                # Train metrics
                if ((step == max_step - 1) and (epoch == self.args.epochs - 1)) or (step % self.args.log_every == 0):
                    self.logger.info(f'Train step {step}, loss={batch_loss.item():.6f}')

                    with torch.no_grad():
                        target, sample = self.agent.sample_labelled(global_step, batch)
                        _, train_metrics = self.agent.step_labelled(global_step, batch, log_metrics=True)
                        train_metrics = train_metrics.get_values()

                    self.logger.info(f'Train target: "{repr(target)}" | sample: "{repr(sample)}"')
                    self.logger.info(f'Train metrics: {train_metrics}')
                    log_train_metrics = {'train/' + m: val for m, val in train_metrics.items()}
                    self.logger.log(log_train_metrics, global_step)

                # Validation metrics
                if (step != 0) and ((step == max_step - 1) or step % self.args.eval_every == 0) and (eval_dataloaders is not None):
                    for eval_name, eval_loader in eval_dataloaders.items():
                        val_metrics = self.evaluate(global_step, eval_loader)
                        log_val_metrics = {eval_name + '/' + m: val for m, val in val_metrics.items()}
                        self.logger.info(f'{eval_name} step {step}: {val_metrics}')
                        self.logger.log(log_val_metrics, global_step)

                # Eval metrics
                if (step != 0) and ((epoch == self.args.epochs - 1) or epoch % self.args.eval_every_epoch == 0) and ((step == max_step - 1) or step % self.args.eval_every == 0) and (eval_dataloaders is not None or guided_eval_dataloader is not None or eval_envs is not None) and (self.args.eval_samples > 0):
                    metrics = {}
                    if self.args.task_kind == 'cooking':
                        raise NotImplementedError('Code for cooking is outdated.')
                    elif self.args.task_kind in ['alf', 'virtualhome']:
                        if eval_envs is not None:
                            for env_name, env_loader in eval_envs.items():
                                interact_metrics = self.interactive_evaluate(global_step, env_loader)
                                log_metrics = {f'{env_name}_{key}': val for key, val in interact_metrics.items()}
                                metrics.update(log_metrics)

                                self.logger.info(f'Interact {env_name} sample:')
                                self.logger.info(f'goal: {metrics[env_name + "_interact_goal"]}')
                                self.logger.info('\n'.join(f'{repr(s_action)}->{repr(s_obs)}' for s_action, s_obs in zip(metrics[f'{env_name}_interact_actions'], metrics[f'{env_name}_interact_obs'])))
                        if guided_eval_dataloader is not None:
                            raise NotImplementedError('Code for guided eval is outdated')
                            metrics.update(self.guided_evaluate(global_step, guided_eval_dataloader))
                            self.logger.info('Eval samples:')
                            self.logger.info('\n'.join(f'target: {repr(t_action)}->{repr(t_obs)} | sample: {repr(s_action)}->{repr(s_obs)}' for s_action, s_obs, t_action, t_obs in zip(metrics['sample_actions'], metrics['sample_step_obs'], metrics['target_actions'], metrics['target_step_obs'])))
                    self.logger.info(f'Eval step {step}')
                    print_metrics = dict(metrics)
                    print_mask = lambda k: not any(k.endswith(m) for m in {'targets', 'target_actions', 'target_step_obs', 'samples', 'sample_actions', 'sample_step_obs', 'loss', 'interact_actions', 'interact_obs'})
                    self.logger.info({k: v for k, v in metrics.items() if print_mask(k)})

                    log_metrics = {'eval/' + m:  val for m, val in metrics.items()}
                    self.logger.log(log_metrics, global_step)

                pbar.update(1)
                global_step += 1

            if (self.args.checkpoint_every_epoch > 0) and (epoch > 0) and ((epoch % self.args.checkpoint_every_epoch == 0) or (epoch == self.args.epochs - 1)):
                self.save(epoch)

    def eval_only(self, eval_dataloaders, eval_envs):
        global_step = 0
        for eval_name, eval_loader in eval_dataloaders.items():
            val_metrics = self.evaluate(global_step, eval_loader)
            log_val_metrics = {'final_' + eval_name + '/' + m: val for m, val in val_metrics.items()}
            self.logger.info(f'{eval_name} step {global_step}: {val_metrics}')
            self.logger.log(log_val_metrics)

        for env_name, env_loader in eval_envs.items():
            metrics = self.interactive_eval_only(env_loader, f'final_eval_{env_name}/')

            print_metrics = dict(metrics)
            print_mask = lambda k: not any(k.endswith(m) for m in {'targets', 'target_actions', 'target_step_obs', 'samples', 'sample_actions', 'sample_step_obs', 'loss', 'interact_actions', 'interact_obs'})
            self.logger.info({k: v for k, v in metrics.items() if print_mask(k)})

    def _process_first_obs(self, obs, infos):
        intro, init_obs, goal = obs[0].split('\n\n')
        prefix_obs = self.tokenizer(intro + ' ' + init_obs, return_tensors='pt')
        return [prefix_obs.input_ids.squeeze(0).to(self.accelerator.device)], goal

    def _process_obs(self, obs, info):
        return [self.tokenizer(' ' + obs[0], return_tensors='pt').input_ids.squeeze(0).to(self.accelerator.device)]

    def _log_eval_obs(self, obs, infos):
        return obs[0]

    def _decode_action(self, action):
        gen_action_s = self.tokenizer.decode(action[0], skip_special_tokens=True)
        gen_action_clean = gen_action_s.split('[', 2)[1]
        gen_action_clean = gen_action_clean.replace('SEP]', '').strip()
        return gen_action_clean

    def evaluate(self, step, dataloader):
        self.agent.eval()
        metrics = {}
        running_metrics = utils.RunningAverages()
        for i_batch, batch in zip(range(self.args.eval_demo_samples), dataloader):# limit dataloader to eval_samples
            with torch.no_grad():
                batch_loss, metrics = self.agent.step_labelled(step, batch, log_metrics=True)
                running_metrics.merge(metrics)

        return running_metrics.get_values()

    def interactive_evaluate(self, step, eval_env_loader):
        eval_env = eval_env_loader()
        self.agent.eval()
        metrics = {}
        running_metrics = utils.RunningAverages()
        device = self.accelerator.device
        with torch.no_grad():
            for i in range(self.args.eval_samples):
                obs, infos = eval_env.reset()
                self.agent.reset()
                assert obs[0] is not None
                init_obs, goal, goal_tags = self._process_first_obs(obs, infos)
                goal_ids = self.tokenizer([goal], return_tensors='pt').input_ids.to(device)

                self.agent.observe_first(init_obs, goal_ids, infos)

                finished = False
                success = 0.
                score = 0.
                sample_actions = []
                step_obs = []
                prev_infos = infos
                while not finished:
                    # generate action
                    gen_action = self.agent.sample()
                    gen_action_clean = self._decode_action(gen_action)

                    # env step
                    obs, scores, dones, infos = eval_env.step([gen_action_clean])
                    finished = dones[0]
                    sample_actions.append(gen_action_clean)
                    lang_obs = self._log_eval_obs(obs, infos)
                    step_obs.append(lang_obs)
                    score += scores[0] if scores is not None else 0.
                    success = float(infos['won'][0])
                    op_success = float(lang_obs != 'Nothing happens.')
                    admissible = float(gen_action_clean in prev_infos['admissible_commands'][0])
                    running_metrics.push('admissible', admissible, 1)
                    running_metrics.push('op_success', op_success, 1)
                    running_metrics.push('env_fail', admissible * (1 - op_success), 1)
                    prev_infos = infos

                    # update agent
                    obs = self._process_obs(obs, infos)
                    self.agent.observe(obs, goal_ids, gen_action_clean, infos)

                running_metrics.push('success', success, 1)
                running_metrics.push('score', score, 1)
                for tag in goal_tags:
                    running_metrics.push(f'{tag}_success', success, 1)
                    running_metrics.push(f'{tag}_score', score, 1)

            metrics['interact_goal'] = goal
            metrics['interact_actions'] = sample_actions
            metrics['interact_obs'] = step_obs

        metrics.update(running_metrics.get_values())
        del eval_env
        return metrics

    def interactive_eval_only(self, eval_env_loader, pref='eval/'):
        eval_env = eval_env_loader()
        self.agent.eval()
        metrics = {}
        running_metrics = utils.RunningAverages()
        device = self.accelerator.device
        pbar = tqdm(total=eval_env.num_games)
        with torch.no_grad():
            for i in range(eval_env.num_games):
                self.logger.info(f'Interact {pref} sample #{i}:')
                obs, infos = eval_env.reset()
                self.agent.reset()
                assert obs[0] is not None
                init_obs, goal, goal_tags = self._process_first_obs(obs, infos)
                self.logger.info(f'goal: {goal}, tags: {goal_tags}')
                goal_type = goal.split()[4]
                goal_ids = self.tokenizer([goal], return_tensors='pt').input_ids.to(device)

                self.agent.observe_first(init_obs, goal_ids, infos)

                finished = False
                success = 0.
                score = 0.
                sample_actions = []
                step_obs = []
                prev_infos = infos
                while not finished:
                    # generate action
                    gen_action = self.agent.sample()
                    gen_action_clean = self._decode_action(gen_action)

                    # env step
                    obs, scores, dones, infos = eval_env.step([gen_action_clean])
                    finished = dones[0]
                    sample_actions.append(gen_action_clean)
                    lang_obs = self._log_eval_obs(obs, infos)
                    step_obs.append(lang_obs)
                    score += scores[0] if scores is not None else 0.
                    self.logger.info(f'(Total reward: {score}) {repr(gen_action_clean)}->{repr(lang_obs)}')
                    success = float(infos['won'][0])
                    op_success = float(lang_obs != 'Nothing happens.')
                    admissible = float(gen_action_clean in prev_infos['admissible_commands'][0])
                    running_metrics.push('admissible', admissible, 1)
                    running_metrics.push('op_success', op_success, 1)
                    running_metrics.push('env_fail', admissible * (1 - op_success), 1)
                    prev_infos = infos

                    # update agent
                    obs = self._process_obs(obs, infos)
                    self.agent.observe(obs, goal_ids, gen_action_clean, infos)

                self.logger.info(f'{pref} sample #{i} {"success" if success else "fail"}, reward: {score}')
                running_metrics.push('success', success, 1)
                running_metrics.push('score', score, 1)
                for tag in goal_tags:
                    running_metrics.push(f'{tag}_success', success, 1)
                    running_metrics.push(f'{tag}_score', score, 1)
                metrics.update(running_metrics.get_values())
                metrics['i_success'] = success
                metrics['i_score'] = success
                log_metrics = {pref + m:  val for m, val in metrics.items()}
                self.logger.log(log_metrics)

                pbar.update(1)
        del eval_env
        return metrics


class VizTrainer(Trainer):
    def _log_eval_obs(self, obs, infos):
        return f'{obs[0].shape[0]} Frames: {infos["feedback"][0]}'

    def _process_first_obs(self, obs, infos):
        if 'img' in self.args.obs_type and self.args.limit_frames is not None:
            obs = [o[-self.args.limit_frames:] for o in obs]
        _, _, goal = infos['feedback'][0].split('\n\n')
        task_id = infos['extra.gamefile'][0].split('/')[-2]
        task_type, obj, _, rec, scene_id = task_id.split('-')
        return obs, goal, [f'task_{task_type}', f'obj_{obj}', f'rec_{rec}', f'scene_{scene_id}']

    def _process_obs(self, obs, info):
        if 'img' in self.args.obs_type and self.args.limit_frames is not None:
            obs = [o[-self.args.limit_frames:] for o in obs]
        return obs

class VHomeTrainer(Trainer):
    def _log_eval_obs(self, obs, infos):
        #visible_interactable = infos['interactable_objects'][0] & set(infos['vis_objects_raw'][0][1].values())
        visible_interactable = infos['interactable_objects'][0]
        visible_interactable.discard('character')
        objs_list = ['a ' + o for o in sorted(visible_interactable)]
        caption = 'In front of you, you see '
        if len(objs_list) == 0:
            caption += 'nothing.'
        elif len(objs_list) == 1:
            caption += objs_list[0] + '.'
        else:
            caption += ', '.join(objs_list[:-1]) + ', and ' + objs_list[-1] + '.'
        return caption

    def _process_first_obs(self, obs, infos):
        if 'img' in self.args.obs_type and self.args.limit_frames is not None:
            obs = [o[-self.args.limit_frames:] for o in obs]
        raw_goal = infos['goal'][0]
        assert isinstance(raw_goal, list)
        #assert len(raw_goal) > 0
        if len(raw_goal) == 1:
            goal = f'Your task is to: {raw_goal[0]}.'
        else:
            goal = 'Your task is to: ' + ', '.join(raw_goal[:-1]) + ', and ' + raw_goal[-1] + '.'
        return obs, goal, []

    def _process_obs(self, obs, info):
        if 'img' in self.args.obs_type and self.args.limit_frames is not None:
            obs = [o[-self.args.limit_frames:] for o in obs]
        return obs
