import os
import json
import re
from functools import partial
import math
import random
import re

import numpy as np
import torch
import torch.nn as nn
import copy
from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from torch.utils.data import Dataset
import torch.distributed as dist
from dataclasses import asdict
from tqdm import tqdm

from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers.modeling_utils import PreTrainedModel
from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
from transformers.data.data_collator import DataCollator
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction

from sentence_transformers import SentenceTransformer
from sentence_transformers import util as stutil


from hexa.utils.metrics import MetricLogger
from inference import BB3InferenceAgent
from hexa.utils.self_learn_utils import Case, load_previous_wrong_data, load_previous_correct_data, short_answer_validator, long_answer_validator, Symbol, Prefix
from hexa.utils.dist_utils import gather


class CustomTrainer(Trainer):
    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Callable[[], PreTrainedModel] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
        train_metric_logger: Optional[MetricLogger] = None,
        eval_metric_logger: Optional[MetricLogger] = None,
        # resume_from_checkpoint: Optional[str] = None,

    ):
        super(CustomTrainer, self).__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
            # resume_from_checkpoint=True,
        )

        self.train_metric = train_metric_logger
        self.eval_metric = eval_metric_logger

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        ret = super().training_step(model, inputs)
        with torch.no_grad():
            gnorm = 0
            for p in model.parameters():
                param_norm = p.grad.detach().data.norm(2)
                gnorm += param_norm.item() ** 2
            clipped_gnorm = torch.nn.utils.clip_grad_norm(model.parameters(), self.args.max_grad_norm)
            self.train_metric.log('gnorm', gnorm)
            self.train_metric.log('clip', clipped_gnorm)

        return ret

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        ret = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
        # log something..
        return ret

    def _log_step(self, model, inputs, outputs, task_ids, is_valid=False):
        metric = self.train_metric
        if is_valid:
            # triggered for prediction_step
            metric = self.eval_metric

        with torch.no_grad():
            metric.log_many('exs', task_ids, torch.ones_like(outputs['metric_target_tokens']).long())
            metric.log_many('loss', task_ids, outputs['metric_loss'], outputs['metric_target_tokens'])
            metric.log_many('ppl', task_ids, outputs['metric_loss'], outputs['metric_target_tokens'])
            metric.log_many('token_acc', task_ids, outputs['metric_correct'], outputs['metric_target_tokens'])
            metric.log_many('token_em', task_ids, outputs['metric_correct'] == outputs['metric_target_tokens'])
            metric.log_many('ctpb', task_ids, inputs['input_ids'].ne(self.tokenizer.pad_token_id).long().sum(dim=-1))
            metric.log_many('ltpb', task_ids, inputs['lm_labels'].ne(self.tokenizer.pad_token_id).long().sum(dim=-1))
            metric.log('exps', len(task_ids))
            unique_ids, cnts = np.unique(task_ids, return_counts=True)
            metric.log_many('expb', unique_ids, cnts)

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None

        # TODO: logging per task
        task_ids = inputs.pop('task_ids')
        inputs.pop('texts')
        inputs.pop('labels')
        inputs.pop('all_labels')

        outputs = model(**inputs)
        loss = outputs['loss']

        self._log_step(model, inputs, outputs, is_valid=return_outputs, task_ids=task_ids)

        return (loss, outputs) if return_outputs else loss

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """

        if self.state.epoch is not None:
            logs["epoch"] = round(self.state.epoch, 2)

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)

        # add custom metrics
        has_eval_loss = any(['eval' in k for k in logs.keys()])
        metric = self.train_metric
        if has_eval_loss:
            metric = self.eval_metric

        logs = metric.update_logs(logs)

        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)


class BootstrapHelper:
    def __init__(self, rank: int, world_size: int,
                 num_loop: int,
                 inc_rate: float,
                 num_bootstrap: int,
                 root_dir='./experiment/selfLearn/data',
                 name_space='',
                 base_threshold: float = 0.25,
                 use_cos_sim: bool = False,
                 skip_recorrected: bool = False,
                 bt_inc_rate: float = 0.0,
                 use_specific_thresholds: bool = True):
        self.skip_recorrected = skip_recorrected
        self.prev_good_cases = []
        self.good_cases = []
        self.hinted_good_cases = []
        self.num_skipped_samples = 0
        if num_loop > 0:
            self.bad_cases = load_previous_wrong_data(name_space, num_loop - 1, num_bootstrap)
            if self.skip_recorrected:
                self.prev_good_cases = {}
                for i in range(num_loop):
                    prev_correct_data = load_previous_correct_data(name_space, i, num_bootstrap)
                    self.prev_good_cases.update(prev_correct_data)
                for key, item in self.prev_good_cases.items():
                    new_item = {
                        'score': item['score'],
                        'history': item['history']
                    }
                    self.prev_good_cases[key] = new_item
                print('SKIP RECORRECTED DATA!!, Number of prev_good_cases:', len(self.prev_good_cases))
            if rank == 1:
                print(f'Previous wrong data in loop ({num_loop-1}) is loaded.')
        else:
            self.bad_cases = {}
        self.bad_cases_len = 0
        self.num_loop = num_loop
        self.rank = rank
        self.world_size = world_size
        self.parent_dir = root_dir
        if name_space:
            foldername = f'{name_space}_{self.rank}'
        else:
            foldername = str(self.rank)
        self.root_dir = os.path.join(root_dir, foldername)
        os.makedirs(self.root_dir, exist_ok=True)
        self.total_num_bootstrap = num_bootstrap
        self.total_target_num_bootstrap = int(num_bootstrap * (1 + bt_inc_rate * num_loop))
        self.target_num_good = self.total_target_num_bootstrap // self.world_size
        self.inc_rate = inc_rate
        self.base_threshold = base_threshold
        self.use_specific_thresholds = use_specific_thresholds

        self.use_cos_sim = use_cos_sim
        if use_cos_sim:
            model_name_or_path = 'sentence-transformers_all-MiniLM-L6-v2'
            cache_folder = None
            self.sim_evaluator = SentenceTransformer(model_name_or_path, cache_folder=cache_folder)

        self._mode = 0
        if self.rank == 0:
            print(f'Each machine would collect {self.target_num_good} of {self.total_num_bootstrap} of good samples')

        self.num_good_per_task = {}

    def sim_score(self, predicted: str, reference: str):
        e_p = self.sim_evaluator.encode(predicted, convert_to_tensor=True, show_progress_bar=False)
        e_r = self.sim_evaluator.encode(reference, convert_to_tensor=True, show_progress_bar=False)
        similarity = stutil.cos_sim(e_p, e_r)
        return similarity.max().item()

    def get_case_validator(self, task_id, baseline=None):
        task_id = task_id.lower()

        if self.use_cos_sim:
            if baseline:
                def cos_sim_validator_with_baseline(predicted, reference):
                    cos_sim = self.sim_score(predicted, reference)
                    return cos_sim - baseline > 0, cos_sim - baseline
                validator = cos_sim_validator_with_baseline
            else:
                threshold = self.base_threshold * (1 + self.inc_rate * self.num_loop)
                def cos_sim_validator(predicted, reference):
                    cos_sim = self.sim_score(predicted, reference)
                    return cos_sim > threshold, cos_sim
                validator = cos_sim_validator
        else:
            if self.use_specific_thresholds:
                if task_id.startswith('triviaqa'):
                    validator = short_answer_validator
                elif task_id.startswith('convai') or task_id.startswith('fits') or task_id.startswith('googlesgd'):
                    validator = partial(long_answer_validator, threshold=0.35)
                else:
                    # MSC, WoI, WoW, MsMarco
                    validator = partial(long_answer_validator, threshold=self.base_threshold)
            else:
                if task_id.startswith('triviaqa'):
                    validator = short_answer_validator
                else:
                    validator = partial(long_answer_validator, threshold=self.base_threshold)

        return validator

    @staticmethod
    def case2json(data: List[Case]):
        ret = {}
        for case in data:
            json_dict = asdict(case)
            ret[case.text] = json_dict
        return ret

    def save_data(self):
        # {key: text, value: case} in a form of JSON
        # save good and bad cases in different files
        good_case_fpath = os.path.join(self.root_dir, f'good_cases_{self.num_loop}_{self.total_num_bootstrap}.json')
        with open(good_case_fpath, 'w') as fout:
            json.dump(self.case2json(self.good_cases + self.hinted_good_cases), fout, indent=4)

        bad_case_fpath = os.path.join(self.root_dir, f'bad_cases_{self.num_loop}_{self.total_num_bootstrap}.json')
        with open(bad_case_fpath, 'w') as fout:
            self.prev_bad_cases_json = self.bad_cases
            json.dump(self.prev_bad_cases_json, fout, indent=4)

    def load_finetune_data(self):
        pass

    def log_msg(self):
        msg = f'rank: {self.rank}, t: {self.num_loop}, num_good: {len(self.good_cases)} + {len(self.hinted_good_cases)}, num_bad: {self.bad_cases_len}'
        print(msg)

        msg = '\t'
        for task_id, tup in self.num_good_per_task.items():
            msg += f'{task_id}: {sum(tup)}({tup[0]} + {tup[1]}) '
        print(msg)

    def count_good_cases(self):
        return len(self.good_cases) + len(self.hinted_good_cases)

    def maybe_finetune(self):
        if self.count_good_cases() >= self.target_num_good:
            self.save_data()
            self._mode = 1

    def should_bootstrap(self):
        return self._mode == 0

    def should_finetune(self):
        return self._mode == 1

    def update_case(self, case: Case, num_guidance):
        '''
        Update previously generated wrong answers of given case
        '''
        if case.text in self.bad_cases:
            latest_bad_case = Case(json_dict=self.bad_cases[case.text])
            case.update_wrong_answers(latest_bad_case, n=num_guidance)

    def is_already_bootstrapped_data(self, case: Case):
        ret = False
        if self.skip_recorrected:
            prev_score = -1
            if case.text in self.prev_good_cases:
                prev_item = self.prev_good_cases[case.text]
                if case.history == prev_item['history']:
                    prev_score = prev_item['score']
            ret = prev_score >= case.score
        if ret:
            self.num_skipped_samples += 1
            print(f"{self.rank}-th node's skipped_samples: {self.num_skipped_samples}")
        return ret

    def extend_good_cases(self, good_cases: List[Case], hinted_good_cases: List[Case]):
        self.good_cases.extend(good_cases)
        self.hinted_good_cases.extend(hinted_good_cases)

        for case in good_cases:
            if not case.task_id in self.num_good_per_task:
                self.num_good_per_task[case.task_id] = [0, 0]
            self.num_good_per_task[case.task_id][0] += 1

        for case in hinted_good_cases:
            if not case.task_id in self.num_good_per_task:
                self.num_good_per_task[case.task_id] = [0, 0]
            self.num_good_per_task[case.task_id][1] += 1


    def extend_bad_cases(self, bad_cases: List[Case]):
        for case in bad_cases:
            if case.text in self.bad_cases:
                self.bad_cases[case.text]['wrong_answers'].extend(case.wrong_answers)
                self.bad_cases[case.text]['wrong_answers'] = list(set(self.bad_cases[case.text]['wrong_answers']))
            else:
                json_dict = asdict(case)
                del json_dict['search_docs']
                self.bad_cases[case.text] = json_dict
        self.bad_cases_len += len(bad_cases)

    def clear_bad_case(self, case: Case):
        if case.text in self.bad_cases:
            del self.bad_cases[case.text]


class SelfLearner(Trainer):
    def __init__(
            self,
            agent: BB3InferenceAgent,
            rank: int,
            world_size: int,
            num_loop: int,
            num_bootstrap: int,
            is_finetune: bool = False,
            args: TrainingArguments = None,
            data_collator: Optional[DataCollator] = None,
            train_dataset: Optional[Dataset] = None,
            eval_dataset: Optional[Dataset] = None,
            tokenizer: Optional[PreTrainedTokenizerBase] = None,
            model_init: Callable[[], PreTrainedModel] = None,
            compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
            callbacks: Optional[List[TrainerCallback]] = None,
            optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
            preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
            train_metric_logger: Optional[MetricLogger] = None,
            eval_metric_logger: Optional[MetricLogger] = None,
            name_space: str = None,
            use_cos_sim: bool = False,
            inc_rate: float = 0.0,
            bt_inc_rate: float = 0.0,
            base_threshold: float = 0.25,
            save_eval_samples: bool = False
    ):
        super(SelfLearner, self).__init__(
            agent.model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
            # resume_from_checkpoint=True,
        )

        self.agent = agent
        self.train_metric = train_metric_logger
        self.eval_metric = eval_metric_logger
        self.is_finetune = is_finetune
        self.rank = rank
        self.world_size = world_size
        self.save_eval_samples = save_eval_samples
        self.name_space = name_space
        self.num_bootstrap = num_bootstrap
        self.num_loop = num_loop
        if self.save_eval_samples:
            self.eval_samples = []
        if not is_finetune:
            self.is_naive = False
            self.without_gt = False
            self.use_random_hint = False
            self.random_hint = None
            self.return_final_samples = 0
            self.sample_scheme = None
            self.add_all = False
            self.skip_recorrected = True
            self.guided_prompt_symbol = Symbol.ALPHABET
            self.guided_prompt_prefix = Prefix.Null

            print(name_space)

            if name_space.startswith('all'):
                self.hint_type = 0
            elif name_space.startswith('hexa'):
                self.hint_type = 1
                self.num_guidance = 4
                if 'lowest' in name_space:
                    self.return_final_samples = 5
                    self.sample_scheme = 'low'
                    print('sample_scheme: lowest')
                if 'highest' in name_space:
                    self.return_final_samples = 5
                    self.sample_scheme = 'high'
                    print('sample_scheme: highest')
                items = re.findall(r'est_num_\d+', name_space)
                if len(items) > 0:
                    self.return_final_samples = int(items[0].split('_')[-1])

                items = re.findall(r'n_guidance_\d+', name_space)
                if len(items) > 0:
                    self.num_guidance = int(items[0].split('_')[-1])

                if 'symbol_bullet' in name_space:
                    self.guided_prompt_symbol = Symbol.BULLET

                if 'symbol_number' in name_space:
                    self.guided_prompt_symbol = Symbol.NUMBER

                if 'symbol_random' in name_space:
                    self.guided_prompt_symbol = Symbol.RANDOM

                if 'prefix_multiple' in name_space:
                    self.guided_prompt_prefix = Prefix.Multiple

                if 'prefix_guide' in name_space:
                    self.guided_prompt_prefix = Prefix.Guide

            elif name_space.startswith('star'):
                self.hint_type = 1
                self.is_naive = True
            elif name_space.startswith('nohint'):
                self.hint_type = 2
            else:
                assert False, f"{name_space} does not support!"

            if 'wogt' in name_space:
                self.without_gt = True
            if 'random_hint' in name_space:
                self.use_random_hint = True

            use_specific_thresholds = True
            if 'fixed_thresh' in name_space:
                use_specific_thresholds = False
                print(f"USE FIXED THRESHOLDS {base_threshold}")

            self.helper = BootstrapHelper(rank=rank,
                              world_size=world_size,
                              num_loop=num_loop,
                              num_bootstrap=num_bootstrap,
                              name_space=name_space,
                              inc_rate=inc_rate,
                              base_threshold=base_threshold,
                              use_cos_sim=use_cos_sim,
                              skip_recorrected=self.skip_recorrected,
                              bt_inc_rate=bt_inc_rate,
                              use_specific_thresholds=use_specific_thresholds)

            print(f'guided_prompt_prefix: {self.guided_prompt_prefix}')
            print(f'guided_prompt_symbol: {self.guided_prompt_symbol}')
            print(f'without_gt: {self.without_gt}, use_random_hint:{self.use_random_hint}')
            print(f'return_final_samples: {self.return_final_samples}')
            print(f'num_guidance: {self.num_guidance}')

    def e2e_evaluate(self, task_name=''):
        self.agent.model.eval()
        eval_dataloader = self.get_eval_dataloader()
        print(len(eval_dataloader))
        metric = self.eval_metric
        for step, inputs in tqdm(enumerate(eval_dataloader)):
            task_ids = inputs['task_ids']
            outputs = self._e2e_evaluate(inputs)
            with torch.no_grad():
                metric.log_many('exs', task_ids, torch.ones(len(task_ids)).to(inputs['input_ids'].device).long())
                metric.log_many('score', task_ids, outputs['score'], torch.ones(len(task_ids)).to(inputs['input_ids'].device).long())
                metric.log_many('bleu', task_ids, outputs['bleus'], torch.ones(len(task_ids)).to(inputs['input_ids']).long())
                metric.log_many('f1', task_ids, outputs['f1'], torch.ones(len(task_ids)).to(inputs['input_ids']).long())
                metric.log_many('loss', task_ids, outputs['loss'], outputs['tokens'])

        if self.save_eval_samples:
            self._save_eval_samples(task_name)
        metrics = metric.update_logs({})
        new_dict = {}
        for k in metrics.keys():
            if 'loss' in k:
                new_dict[k.replace('loss', 'ppl')] = round(math.exp(metrics[k]), 4)
        metrics.update(new_dict)
        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)



    @torch.no_grad()
    def _e2e_evaluate(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        task_ids = inputs['task_ids']
        texts = inputs['texts']
        all_labels = inputs['all_labels']
        episode_done = inputs['episode_done']
        bs = len(texts)
        assert bs == 1

        outputs = {}
        scores = []
        cos_sim = []
        bleus = []
        loss = []
        tokens = []
        f1s = []
        bootstrap_mode = self.save_eval_samples
        for i in range(bs):
            task_id = task_ids[i]
            text = re.sub(r'__[\w-]+__', '', texts[i]).strip()
            labels = all_labels[i]
            history = self.agent.get_dialogue_history_str()
            bot_res = self.agent.dummy_act(text, memory_label=labels[0],
                                           return_loss=True, lm_labels=inputs['lm_labels'][i].unsqueeze(0))
            bot_res.force_set('history', history)
            case = Case(text, labels, bot_res, task_id, bootstrap_mode=bootstrap_mode)
            if self.save_eval_samples:
                self.eval_samples.append(case)
            scores.append(case.score)
            # bleus.append(sentence_bleu([label.split() for label in labels], bot_res['text'].split()))
            bleus.append(bot_res['bleu'].value())
            f1s.append(bot_res['f1'].value())
            loss.extend(bot_res['loss'])
            tokens.extend(bot_res['tokens'])

        if not episode_done[0]:
            bot_res_copy = copy.deepcopy(bot_res)
            bot_res_copy.force_set('text', all_labels[0][0])
            self.agent.update_memory(bot_res_copy)
            self.agent.update_history(texts[0], bot_res_copy)
        else:
            self.agent.reset()

        outputs['score'] = scores
        outputs['bleus'] = bleus
        outputs['f1'] = f1s
        outputs['loss'] = loss
        outputs['tokens'] = tokens
        return outputs

    def _save_eval_samples(self, task_name='', root_dir='./experiment/selfLearn/eval'):
        dir_path = os.path.join(root_dir, f'{self.name_space}_{self.num_loop}')
        os.makedirs(dir_path, exist_ok=True)
        result_json = BootstrapHelper.case2json(self.eval_samples)
        local_fpath = os.path.join(dir_path, f'{task_name}_{self.rank}.json')
        with open(local_fpath, 'w') as fout:
            json.dump(result_json, fout)

        torch.distributed.barrier()
        if self.rank == 0:
            gathered_data = {}
            gathered_data_path = os.path.join(dir_path, f'{task_name}.json')
            local_files = []
            for i in range(self.world_size):
                fpath = os.path.join(dir_path, f'{task_name}_{i}.json')
                local_files.append(fpath)
                with open(fpath, 'r') as fin:
                    data = json.load(fin)
                gathered_data.update(data)

            with open(gathered_data_path, 'w') as fout:
                json.dump(gathered_data, fout, indent=4)

            print(f"GATHERED DATA({task_name}) LEN: {len(gathered_data)}")
            for fpath in local_files:
                os.remove(fpath)

    def bootstrap(self):
        self.agent.model.eval()
        train_dataloader = self.get_train_dataloader()
        for step, inputs in enumerate(train_dataloader):
            if self.helper.should_bootstrap():
                self._bootstrap(inputs, do_print=step % 10 == 0)
                # try:
                #     self._bootstrap(inputs)
                # except Exception as e:
                #     print(e)
            else:
                break

    def __test_loader(self):
        train_dataloader = self.get_train_dataloader()
        i = 0
        print(f'{self.rank}:', len(train_dataloader))
        ids = []
        for step, inputs in enumerate(train_dataloader):
            episode_done = inputs['episode_done']
            episode_id = inputs['episode_id']
            if episode_done[0]:
                i += 1
                ids.append(torch.LongTensor([episode_id]).cuda())

        print(f'done, {self.rank}:', i)

        ids = torch.cat(ids)
        torch.distributed.barrier()
        tot_list = gather(ids)
        if self.rank == 0:
            ids = torch.cat(tot_list).cpu().numpy().tolist()
            print(sorted(ids))

    @torch.no_grad()
    def _bootstrap(self, inputs: Dict[str, Union[torch.Tensor, Any]], do_print: bool = False):
        task_ids = inputs['task_ids']
        input_texts = inputs['texts']
        all_labels = inputs['all_labels']
        episode_done = inputs['episode_done']
        bs = len(input_texts)
        assert bs == 1

        good_cases = []
        bad_cases = []

        bot_res_copy = None
        for i in range(bs):
            task_id = task_ids[i]
            input_text = re.sub(r'__[\w-]+__', '', input_texts[i]).strip()
            if not input_text:
                continue
            labels = all_labels[i]
            history = self.agent.get_dialogue_history_str()
            bot_res = self.agent.dummy_act(input_text, do_print=do_print, memory_label=labels[0],
                                           return_final_samples=self.return_final_samples)
            validator = self.helper.get_case_validator(task_id)
            lowest_similar_text, highest_similar_text = None, None

            if self.return_final_samples:
                output_texts = bot_res['text']
                if type(output_texts) == str:
                    output_texts = [output_texts]

                if self.sample_scheme:
                    scores = []
                    for output_text in output_texts:
                        _, score = validator(labels, output_text)
                        scores.append(score)

                if self.sample_scheme == 'high':
                    resp = output_texts[np.argmax(scores)]
                else:
                    resp = random.choice(output_texts)

                bot_res.force_set('text', resp)
                if self.add_all:
                    output_texts.remove(resp)

            bot_res_copy = bot_res.copy()
            bot_res.force_set('history', history)

            # if self.helper.use_cos_sim:
            #     base_bot_res = self.agent.dummy_act(text, do_print=do_print, greedy=True)
            #     base_line = self.helper.sim_score(base_bot_res['text'], all_labels[0][0])

            case = Case(input_text, labels, bot_res, task_id, validator=validator)
            do_skip = self.helper.is_already_bootstrapped_data(case)
            if case.is_correct:
                good_cases.append(case)
                self.helper.clear_bad_case(case)
            else:
                if self.sample_scheme == 'low':
                    assert len(case.wrong_answers) > 0
                    lowest_similar_text = output_texts[np.argmin(scores)]
                    case.wrong_answers[-1] = lowest_similar_text
                if self.add_all:
                    output_texts = [' '.join(output_text.split()) for output_text in output_texts]
                    case.wrong_answers.extend(np.unique(output_texts).tolist())
                bad_cases.append(case)

        if not do_skip:
            if self.hint_type == 2:
                # no guided prompt
                self.helper.extend_good_cases(good_cases, [])
                self.helper.extend_bad_cases(bad_cases)
            else:
                # with guided prompt
                hinted_good_cases = []
                hinted_bad_cases = []
                for i, case in enumerate(bad_cases):
                    self.helper.update_case(case, num_guidance=self.num_guidance)
                    input_text = case.text
                    if not input_text:
                        continue
                    if not input_text[-1] in ['.', '?', ',', '!']:
                        input_text += '.'
                    answer_pool = None
                    if self.use_random_hint:
                        if self.random_hint is not None:
                            answer_pool = [random.choice(self.random_hint)]
                        else:
                            continue

                    new_text = input_text + ' ' + case.generate_guided_prompt(is_naive=self.is_naive,
                                                                              without_gt=self.without_gt,
                                                                              answer_pool=answer_pool,
                                                                              prefix=self.guided_prompt_prefix,
                                                                              symbol=self.guided_prompt_symbol)
                    if do_print:
                        print(new_text)
                    history = self.agent.get_dialogue_history_str()
                    hinted_bot_res = self.agent.dummy_act(new_text, do_print=do_print)
                    hinted_bot_res.force_set('history', history)
                    hinted_case = Case(case.text, case.labels, hinted_bot_res, case.task_id, validator=case.validator)
                    hinted_case.update_wrong_answers(case, n=self.num_guidance)
                    good_standard = hinted_case.is_correct
                    if good_standard:
                        hinted_good_cases.append(hinted_case)
                        self.helper.clear_bad_case(hinted_case)
                    else:
                        hinted_bad_cases.append(hinted_case)

                if self.hint_type == 1:
                    self.helper.extend_good_cases(good_cases, hinted_good_cases)
                    self.helper.extend_bad_cases(hinted_bad_cases)
                else:
                    pass

        if not episode_done[0]:
            # Replace bot generated response with ground truth, only works for batch size 1!
            bot_res_copy.force_set('text', all_labels[0][0])
            self.agent.update_memory(bot_res_copy)
            # Update the history with the given text and label
            self.agent.update_history(input_texts[0], bot_res_copy)
            # if episode ended, reset the agent history and memory
        else:
            self.agent.reset()

        self.helper.log_msg()
        self.helper.maybe_finetune()

        if self.use_random_hint:
            if self.random_hint is None:
                self.random_hint = []
            self.random_hint.extend(all_labels[0])
            random.shuffle(self.random_hint)
            self.random_hint = self.random_hint[-1000:]

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        ret = super().training_step(model, inputs)
        with torch.no_grad():
            gnorm = 0
            for p in model.parameters():
                param_norm = p.grad.detach().data.norm(2)
                gnorm += param_norm.item() ** 2
            clipped_gnorm = torch.nn.utils.clip_grad_norm(model.parameters(), self.args.max_grad_norm)
            self.train_metric.log('gnorm', gnorm)
            self.train_metric.log('clip', clipped_gnorm)

        return ret

    def remove_meta_data(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
        task_ids = inputs.pop('task_ids')
        inputs.pop('texts')
        inputs.pop('labels')
        inputs.pop('all_labels')
        return task_ids

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            ignore_keys (`Lst[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
        """

        self.remove_meta_data(inputs)

        has_labels = all(inputs.get(k) is not None for k in self.label_names)
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if has_labels:
                with self.compute_loss_context_manager():
                    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                loss = loss.mean().detach()

                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                else:
                    logits = outputs[1:]
            else:
                loss = None
                with self.compute_loss_context_manager():
                    outputs = model.forward_loss(**inputs)
                if isinstance(outputs, dict):
                    logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                else:
                    logits = outputs
                # TODO: this needs to be fixed and made cleaner later.
                if self.args.past_index >= 0:
                    self._past = outputs[self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels)

    def _log_step(self, model, inputs, outputs, task_ids, is_valid=False):
        metric = self.train_metric
        if is_valid:
            # triggered for prediction_step
            metric = self.eval_metric

        with torch.no_grad():
            metric.log_many('exs', task_ids, torch.ones_like(outputs['metric_target_tokens']).long())
            metric.log_many('loss', task_ids, outputs['metric_loss'], outputs['metric_target_tokens'])
            metric.log_many('ppl', task_ids, outputs['metric_loss'], outputs['metric_target_tokens'])
            metric.log_many('token_acc', task_ids, outputs['metric_correct'], outputs['metric_target_tokens'])
            metric.log_many('token_em', task_ids, outputs['metric_correct'] == outputs['metric_target_tokens'])
            # metric.log_many('ctpb', task_ids, inputs['input_ids'].ne(self.tokenizer.pad_token_id).long().sum(dim=-1))
            # metric.log_many('ltpb', task_ids, inputs['lm_labels'].ne(self.tokenizer.pad_token_id).long().sum(dim=-1))
            metric.log('exps', len(task_ids))
            unique_ids, cnts = np.unique(task_ids, return_counts=True)
            metric.log_many('expb', unique_ids, cnts)

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        # TODO: logging per task
        task_ids = self.remove_meta_data(inputs)
        if dist.is_initialized():
            outputs = model.module.forward_loss(**inputs)
        else:
            outputs = model.forward_loss(**inputs)
        loss = outputs['loss']

        self._log_step(model, inputs, outputs, is_valid=return_outputs, task_ids=task_ids)

        return (loss, outputs) if return_outputs else loss

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """

        if self.state.epoch is not None:
            logs["epoch"] = round(self.state.epoch, 2)

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)

        # add custom metrics
        has_eval_loss = any(['eval' in k for k in logs.keys()])
        metric = self.train_metric
        if has_eval_loss:
            metric = self.eval_metric

        logs = metric.update_logs(logs)

        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
