import json
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import NamedTuple

import numpy as np
import torch
from transformers import Seq2SeqAdapterTrainer
from packaging import version
from torch import nn
from torch.utils.data import Dataset
from transformers.deepspeed import is_deepspeed_zero3_enabled

from transformers.trainer_utils import PredictionOutput, speed_metrics
import os
from .training_arguments import WrappedSeq2SeqTrainingArguments

_is_torch_generator_available = False
if version.parse(torch.__version__) >= version.parse("1.6"):
    _is_torch_generator_available = True
    _is_native_amp_available = True
    from torch.cuda.amp import autocast


class EvalPrediction(NamedTuple):
    predictions: List[str]
    items: List[dict]


class CascadeSeq2SeqTrainerAdapter(Seq2SeqAdapterTrainer):
    def __init__(
            self,
            evaluator,
            args: WrappedSeq2SeqTrainingArguments,
            eval_examples: Optional[Dataset] = None,
            test_dataset: Optional[Dataset] = None,
            test_examples: Optional[Dataset] = None,
            ignore_pad_token_for_loss: bool = True,
            **kwargs,
    ) -> None:
        super().__init__(args=args, **kwargs)
        self.evaluator = evaluator
        self.eval_examples = eval_examples
        self.test_dataset = test_dataset
        self.test_examples = test_examples
        self.compute_metrics = self._compute_metrics
        self.ignore_pad_token_for_loss = ignore_pad_token_for_loss

    def train_all(self):
        self.origin_output_dir = self.args.output_dir
        if self.args.do_restart and self.is_world_process_zero():
            os.makedirs(os.path.join(self.origin_output_dir,'init'), exist_ok=True)
            self.model.save_all_adapters(os.path.join(self.origin_output_dir,'init'))
        torch.distributed.barrier()
        for cascade_step in range(self.args.max_cascade_steps):
            if self.args.do_restart and cascade_step>0:
                for adp in ['adapter1', 'adapter2']:
                    self.model.load_adapter(os.path.join(self.origin_output_dir,'init',adp))
            self.args.output_dir = os.path.join(self.origin_output_dir, 'cas_%d' % cascade_step)
            os.makedirs(self.args.output_dir, exist_ok=True)
            if cascade_step>0 or not self.args.start_from_first_castep:
                self.optimizer = None
                self.lr_scheduler = None
                self.train()
                self.save_state()
            else:
                max_steps = self.args.max_steps
                self.args.max_steps = 1
                state_dict = {k: v.clone() for k, v in self.model.state_dict().items()}
                self.train()
                self.args.max_steps = max_steps
                self.model.load_state_dict(state_dict, strict=False)
                del state_dict
            if cascade_step!=self.args.max_cascade_steps-1 and not self.args.do_origin:
                self.update_dataset(self.train_dataset, None, metric_key_prefix='cas_train', with_metric=False,
                                cascade_step=cascade_step)
            self.update_dataset(self.eval_dataset, self.eval_examples, metric_key_prefix='cas_eval', with_metric=True,
                                cascade_step=cascade_step)
            self.update_dataset(self.test_dataset, self.test_examples, metric_key_prefix='cas_test', with_metric=True,
                                cascade_step=cascade_step)
        self.args.output_dir = self.origin_output_dir
        self.save_state()


    def update_dataset(self, dataset, examples, metric_key_prefix, with_metric, cascade_step):
        if dataset is None:
            return
        output = self.predict(dataset, examples, metric_key_prefix=metric_key_prefix, with_metric=with_metric)
        if not self.args.do_origin:
            dataset.last_predictions = output.predictions
        if self.is_world_process_zero():
            torch.save(output.predictions, os.path.join(self.args.output_dir, f'{metric_key_prefix}_generation.pk'))
        if with_metric:
            metrics = output.metrics
            max_samples = len(dataset)
            metrics["samples"] = min(max_samples, len(dataset))
            split_name = f'{metric_key_prefix}_{cascade_step}'
            self.log_metrics(split_name, metrics)
            output_dir = self.args.output_dir
            self.args.output_dir = os.path.join(self.origin_output_dir, 'results')
            os.makedirs(self.args.output_dir, exist_ok=True)
            self.save_metrics(split_name, {split_name: metrics})
            self.args.output_dir = output_dir

    def evaluate(
            self,
            eval_dataset: Optional[Dataset] = None,
            ignore_keys: Optional[List[str]] = None,
            metric_key_prefix: str = "eval",
            max_length: Optional[int] = None,
            max_time: Optional[int] = None,
            num_beams: Optional[int] = None,
    ) -> Dict[str, float]:
        self._max_length = max_length if max_length is not None else self.args.generation_max_length
        self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
        self._max_time = max_time

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset

        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        try:
            output = self.evaluation_loop(
                eval_dataloader,
                description="Evaluation",
                prediction_loss_only=True,
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            self.compute_metrics = compute_metrics

        self.log(output.metrics)

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)

        self._memory_tracker.stop_and_update_metrics(output.metrics)

        return output.metrics

    def predict(
            self,
            test_dataset: Optional[Dataset],
            test_examples: Optional[Dataset] = None,
            ignore_keys: Optional[List[str]] = None,
            metric_key_prefix: str = "predict",
            max_length: Optional[int] = None,
            max_time: Optional[int] = None,
            num_beams: Optional[int] = None,
            with_metric=True,
    ) -> PredictionOutput:
        self._max_length = max_length if max_length is not None else self.args.generation_max_length
        self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
        self._max_time = max_time

        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        test_dataloader = self.get_test_dataloader(test_dataset)
        start_time = time.time()

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        try:
            output = self.evaluation_loop(
                test_dataloader,
                description="Prediction",
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            self.compute_metrics = compute_metrics

        if not with_metric:
            self._memory_tracker.stop_and_update_metrics(output.metrics)
            return output
        if self.compute_metrics is not None:
            eval_preds = self._post_process_function(
                test_examples, output.predictions, metric_key_prefix)
            output.metrics.update(self.compute_metrics(eval_preds, section="test"))

        output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))

        for key in list(output.metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                output.metrics[f"{metric_key_prefix}_{key}"] = output.metrics.pop(key)

        self.log(output.metrics)
        self._memory_tracker.stop_and_update_metrics(output.metrics)

        return output

    def prediction_step(
            self,
            model: nn.Module,
            inputs: Dict[str, Union[torch.Tensor, Any]],
            prediction_loss_only: bool,
            ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        inputs = self._prepare_inputs(inputs)

        # XXX: adapt synced_gpus for fairscale as well
        gen_kwargs = {
            "max_length": self._max_length if self._max_length is not None else self.model.config.max_length,
            "num_beams": self._num_beams if self._num_beams is not None else self.model.config.num_beams,
            "synced_gpus": True if is_deepspeed_zero3_enabled() else False,
            "no_repeat_ngram_size": 0,  # FIXME: hard coding the no_repeat_ngram_size
        }

        if "description_input_ids" in inputs:
            gen_kwargs["description_input_ids"] = inputs["description_input_ids"]
        if "description_attention_mask" in inputs:
            gen_kwargs["description_attention_mask"] = inputs["description_attention_mask"]
        if "task_ids" in inputs:
            gen_kwargs["task_ids"] = inputs["task_ids"]
        if "last_predictions" in inputs:
            gen_kwargs["last_predictions"]=inputs["last_predictions"]

        generated_tokens = self.model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            **gen_kwargs,
        )
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_kwargs["max_length"]:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

        return (None, generated_tokens, None)

    def _post_process_function(
            self, examples: Dataset, predictions: np.ndarray, stage: str
    ) -> EvalPrediction:
        # assert isinstance(examples, Dataset)

        predictions = self.tokenizer.batch_decode(predictions, skip_special_tokens=True)

        return EvalPrediction(predictions=predictions, items=[examples[idx] for idx in range(len(predictions))])

    def _compute_metrics(self, eval_prediction: EvalPrediction, section) -> dict:
        return self.evaluator.evaluate(eval_prediction.predictions, eval_prediction.items, section)
