import dataclasses
import datetime
import logging
import humanfriendly
from typing import Optional

import torch

from rau.tasks.common.data import Dataset
from rau.tasks.common.training_loop import (
    add_training_loop_arguments as common_add_training_loop_arguments,
    get_training_loop_kwargs as common_get_training_loop_kwargs,
    TrainingLoop,
    get_random_generator_and_seed,
    DictScoreAccumulator,
    OutOfCUDAMemoryError
)
from rau.tools.torch.model_interface import ModelInterface
from rau.tools.torch.saver import ModelSaver
from rau.tools.logging import Logger
from rau.training.early_stopping import UpdatesWithoutImprovement
from rau.tools.torch.profile import reset_memory_profiler, get_peak_memory
from rau.tools.ticker import TimedTicker

from .batching import group_into_batches
from .data import VocabularyContainer
from .model_interface import ModelInput

def add_training_loop_arguments(parser):
    group = common_add_training_loop_arguments(parser,
        max_tokens_per_batch_help=
        'The maximum number of tokens allowed per batch. This puts a limit on '
        'the number of elements included in a single batch tensor, including '
        'BOS, EOS, and padding tokens. If a single example exceeds the limit, '
        'it is not discarded, but included in a batch by itself.'
    )
    group.add_argument('--language-modeling-loss-coefficient', type=float, default=1.0)
    group.add_argument('--next-symbols-loss-coefficient', type=float, default=1.0)

def get_training_loop_kwargs(parser, args):
    result = common_get_training_loop_kwargs(parser, args)
    for name in [
        'language_modeling_loss_coefficient',
        'next_symbols_loss_coefficient'
    ]:
        result[name] = getattr(args, name)
    return result

Example = tuple[torch.Tensor, tuple[bool, Optional[torch.Tensor]]]
PreparedBatch = tuple[
    ModelInput,
    tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
]

@dataclasses.dataclass
class RecognitionTrainingLoop(TrainingLoop[
    Example,
    PreparedBatch,
    VocabularyContainer
]):

    language_modeling_loss_coefficient: float
    next_symbols_loss_coefficient: float

    def run(self,
        saver: ModelSaver,
        model_interface: ModelInterface,
        training_data: list[Example],
        validation_data: list[Example],
        vocabulary: VocabularyContainer,
        console_logger: logging.Logger,
        event_logger: Logger
    ) -> None:
        """
        NOTE: When this function returns, the model's parameters will be those of
        the *last* epoch, not necessarily the *best* epoch.
        """
        device = model_interface.get_device(None)
        do_profile_memory = device.type == 'cuda'
        random_shuffling_generator, random_shuffling_seed = \
            get_random_generator_and_seed(self.random_shuffling_seed)
        console_logger.info(f'random shuffling seed: {random_shuffling_seed}')
        OptimizerClass = getattr(torch.optim, self.optimizer)
        optimizer = OptimizerClass(
            saver.model.parameters(),
            lr=self.initial_learning_rate
        )
        validation_metric = self.get_validation_metric_name()
        validation_metric_mode = self.get_validation_metric_mode()
        early_stopping = UpdatesWithoutImprovement(
            validation_metric_mode,
            patience=self.early_stopping_patience
        )
        if self.learning_rate_patience < 1:
            raise ValueError('learning rate patience must be at least 1')
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode=validation_metric_mode,
            # According to PyTorch, a patience of 0 means we reduce the LR as
            # soon as performance does not improve, and a patience of 1 means
            # we wait one checkpoint. We subtract 1 so that the patience means
            # the number of epochs without improvement before reducing the LR.
            patience=self.learning_rate_patience - 1,
            factor=self.learning_rate_decay_factor,
            threshold=0.0
        )
        console_logger.info(f'training examples: {len(training_data)}')
        num_validation_examples = len(validation_data)
        console_logger.info(f'validation examples: {num_validation_examples}')
        validation_batches = list(self.generate_batches(
            validation_data,
            self.max_tokens_per_batch
        ))
        console_logger.info(f'validation batches: {len(validation_batches)}')
        model_interface.on_before_process_pairs(
            saver,
            [training_data, validation_data]
        )
        del validation_data
        event_logger.log('start_training', dict(
            num_training_examples=len(training_data),
            num_validation_examples=num_validation_examples,
            num_validation_batches=len(validation_batches),
            max_epochs=self.max_epochs,
            random_shuffling_seed=random_shuffling_seed,
            optimizer=self.optimizer,
            initial_learning_rate=self.initial_learning_rate,
            label_smoothing_factor=self.label_smoothing_factor,
            early_stopping_patience=self.early_stopping_patience,
            learning_rate_patience=self.learning_rate_patience,
            learning_rate_decay_factor=self.learning_rate_decay_factor,
            gradient_clipping_threshold=self.gradient_clipping_threshold,
            examples_per_checkpoint=self.examples_per_checkpoint
        ))
        epoch_no = 0
        examples_since_checkpoint = 0
        checkpoint_no = 0
        best_validation_scores = None
        best_checkpoint_no = None
        best_epoch_no = None
        total_start_time = datetime.datetime.now()
        # Initial checkpoint before any training.
        console_logger.info(f'  checkpoint #{checkpoint_no}')
        validation_scores = self.evaluate(
            saver.model,
            model_interface,
            validation_batches
        )
        console_logger.info(f'    validation scores:')
        for key, value in validation_scores.items():
            console_logger.info(f'      {key}: {value:.2f}')
        validation_score = validation_scores[validation_metric]
        # Save the model parameters.
        is_best, _ = early_stopping.update(validation_score)
        console_logger.info('    saving parameters')
        saver.save()
        best_validation_scores = validation_scores
        best_checkpoint_no = checkpoint_no
        best_epoch_no = 0
        event_logger.log('checkpoint', dict(
            is_best=is_best,
            scores=validation_scores
        ))
        # Training
        for _ in range(self.max_epochs):
            epoch_start_time = datetime.datetime.now()
            console_logger.info(f'epoch #{epoch_no + 1}')
            random_shuffling_generator.shuffle(training_data)
            batches = list(self.generate_batches(
                training_data,
                self.max_tokens_per_batch
            ))
            random_shuffling_generator.shuffle(batches)
            epoch_loss = DictScoreAccumulator()
            if self.show_progress:
                progress_loss = DictScoreAccumulator()
                progress_num_examples = 0
                progress_start_time = datetime.datetime.now()
                ticker = TimedTicker(len(batches), 1)
            if do_profile_memory:
                reset_memory_profiler(device)
            should_stop = False
            for batch_no, batch in enumerate(batches):
                try:
                    loss_numerator, loss_denominator, loss_terms = self.run_parameter_update(
                        saver,
                        model_interface,
                        optimizer,
                        batch
                    )
                    loss_terms['loss'] = (loss_numerator, loss_denominator)
                    epoch_loss.update(loss_terms)
                    if self.show_progress:
                        progress_loss.update(loss_terms)
                except OutOfCUDAMemoryError as e:
                    self.handle_out_of_cuda_memory(
                        vocabulary,
                        batch,
                        e.info,
                        device,
                        console_logger,
                        event_logger
                    )
                    raise
                batch_size = len(batch)
                if self.show_progress:
                    progress_num_examples += batch_size
                    ticker.progress = batch_no + 1
                    if ticker.tick():
                        progress_loss_dict = progress_loss.get_value()
                        progress_loss = progress_loss_dict.pop('loss')
                        progress_duration = datetime.datetime.now() - progress_start_time
                        progress_examples_per_second = progress_num_examples / progress_duration.total_seconds()
                        progress_parts = [
                            f'{ticker.int_percent}%',
                            f'loss: {progress_loss:.2f}',
                            f'examples/s: {progress_examples_per_second:.2f}'
                        ]
                        for key, value in progress_loss_dict.items():
                            progress_parts.append(f'{key}: {value:.2f}')
                        console_logger.info(f'  {" | ".join(progress_parts)}')
                        progress_loss = DictScoreAccumulator()
                        progress_start_time = datetime.datetime.now()
                        progress_num_examples = 0
                examples_since_checkpoint += batch_size
                if examples_since_checkpoint >= self.examples_per_checkpoint:
                    console_logger.info(f'  checkpoint #{checkpoint_no + 1}')
                    validation_scores = self.evaluate(
                        saver.model,
                        model_interface,
                        validation_batches
                    )
                    console_logger.info(f'    validation scores:')
                    for key, value in validation_scores.items():
                        console_logger.info(f'      {key}: {value:.2f}')
                    validation_score = validation_scores[validation_metric]
                    # Update the learning rate.
                    lr_scheduler.step(validation_score)
                    # Show the current learning rate.
                    curr_learning_rate = lr_scheduler.get_last_lr()[0]
                    console_logger.info(f'    learning rate: {curr_learning_rate}')
                    # Decide whether to save the model parameters and whether to
                    # stop early.
                    is_best, should_stop = early_stopping.update(validation_score)
                    if is_best:
                        console_logger.info('    saving parameters')
                        saver.save()
                        best_validation_scores = validation_scores
                        best_checkpoint_no = checkpoint_no
                        best_epoch_no = epoch_no
                    event_logger.log('checkpoint', dict(
                        is_best=is_best,
                        scores=validation_scores
                    ))
                    # Reset the count of examples seen since the last checkpoint.
                    # If `examples_since_checkpoint` is not exactly equal to
                    # `self.examples_per_checkpoint` after `batch_size` is
                    # added to it, but is greater than it, include the extra
                    # examples in the updated count.
                    examples_since_checkpoint %= self.examples_per_checkpoint
                    checkpoint_no += 1
                    if should_stop:
                        console_logger.info('  stopping early')
                        break
            if should_stop:
                break
            epoch_loss_dict = epoch_loss.get_value()
            epoch_loss = epoch_loss_dict.pop('loss')
            epoch_duration = datetime.datetime.now() - epoch_start_time
            epoch_duration_seconds = epoch_duration.total_seconds()
            console_logger.info(f'  epoch loss: {epoch_loss:.2f}')
            if epoch_loss_dict:
                console_logger.info('  epoch scores:')
                for key, value in epoch_loss_dict.items():
                    console_logger.info(f'    {key}: {value:.2f}')
            console_logger.info(f'  epoch duration: {epoch_duration}')
            epoch_examples_per_second = len(training_data) / epoch_duration_seconds
            console_logger.info(f'  examples/s: {epoch_examples_per_second:.2f}')
            if do_profile_memory:
                peak_memory = get_peak_memory(device)
                console_logger.info(f'  peak CUDA memory: {humanfriendly.format_size(peak_memory)}')
            else:
                peak_memory = None
            event_logger.log('epoch', dict(
                loss=epoch_loss,
                scores=epoch_loss_dict,
                duration=epoch_duration_seconds,
                peak_memory=peak_memory,
                num_training_batches=len(batches)
            ))
            epoch_no += 1
        total_duration = datetime.datetime.now() - total_start_time
        # TODO Check for this ahead of time.
        if best_validation_scores is None:
            raise ValueError(
                'the maximum number of epochs has been reached, but no '
                'checkpoints have been made'
            )
        console_logger.info(f'best validation scores:')
        for key, value in best_validation_scores.items():
            console_logger.info(f'  {key}: {value:.2f}')
        console_logger.info(f'completed epochs: {epoch_no}')
        console_logger.info(f'best epoch: #{best_epoch_no+1}')
        console_logger.info(f'completed checkpoints: {checkpoint_no}')
        console_logger.info(f'best checkpoint: #{best_checkpoint_no+1}')
        console_logger.info(f'checkpoints since improvement: {early_stopping.updates_since_improvement}')
        console_logger.info(f'total training duration: {total_duration}')
        event_logger.log('train', dict(
            best_validation_scores=best_validation_scores,
            num_epochs=epoch_no,
            best_epoch=best_epoch_no,
            num_checkpoints=checkpoint_no,
            best_checkpoint=best_checkpoint_no,
            checkpoints_since_improvement=early_stopping.updates_since_improvement,
            duration=total_duration.total_seconds()
        ))

    def get_validation_metric_name(self):
        return 'recognition_cross_entropy'

    def get_validation_metric_mode(self):
        return 'min'

    def generate_batches(self, examples, max_tokens):
        return generate_batches(examples, max_tokens)

    def get_prepared_batch_info(self, prepared_batch):
        (
            model_input,
            (
                recognition_expected_tensor,
                language_modeling_expected_tensor,
                next_symbols_expected_tensor,
                next_symbols_padding_mask,
                positive_output_lengths
            )
        ) = prepared_batch
        return dict(
            input_size=tuple(model_input.input_sequence.size()),
            recognition_output_size=tuple(recognition_expected_tensor.size()),
            language_modeling_output_size=(
                tuple(language_modeling_expected_tensor.size())
                if language_modeling_expected_tensor is not None
                else None
            ),
            next_symbols_output_size=(
                tuple(next_symbols_expected_tensor.size())
                if next_symbols_expected_tensor is not None
                else None
            )
        )

    def log_failed_batch(self, vocabulary, batch, info, console_logger, event_logger):
        if info is not None:
            console_logger.info(f'  input size: {info.get("input_size")}')
            console_logger.info(f'  recognition output size: {info.get("recognition_output_size")}')
            console_logger.info(f'  language modeling output size: {info.get("language_modeling_output_size")}')
            console_logger.info(f'  next symbols output size: {info.get("next_symbols_output_size")}')
        tokens = sum(len(x[0]) for x in batch)
        console_logger.info(f'  tokens: {tokens}')
        lengths = [len(x[0]) for x in batch]
        console_logger.info(f"  sequence lengths: {lengths}")
        token_strs = [
            [vocabulary.input_vocab.to_string(a) for a in x[0]]
            for x in batch
        ]
        sequences_str = '\n'.join(' '.join(x) for x in token_strs)
        console_logger.info(f'  sequences:\n{sequences_str}')
        return dict(
            **info,
            examples=token_strs
        )

    def get_loss(self, model, model_interface, prepared_batch):
        loss_terms = get_loss_terms(
            model,
            model_interface,
            prepared_batch,
            numerator_reduction='none',
            denominator_reduction='sum',
            label_smoothing_factor=self.label_smoothing_factor,
            include_accuracy=False
        )
        # Assign coefficients to the loss terms.
        if 'language_modeling_cross_entropy' in loss_terms:
            loss_terms['language_modeling_cross_entropy'] += (self.language_modeling_loss_coefficient,)
        if 'next_symbols_cross_entropy' in loss_terms:
            loss_terms['next_symbols_cross_entropy'] += (self.next_symbols_loss_coefficient,)
        return loss_terms

    def evaluate_batch(self, model, model_interface, prepared_batch):
        result = get_loss_terms(
            model,
            model_interface,
            prepared_batch,
            numerator_reduction='sum',
            denominator_reduction='sum',
            label_smoothing_factor=0.0,
            include_accuracy=True
        )
        return { k : (n.item(), d) for k, (n, d) in result.items() }

def generate_batches(examples, max_tokens):
    return group_into_batches(examples, lambda b, n: b * n <= max_tokens)

def get_loss_terms(
    model,
    model_interface,
    prepared_batch,
    numerator_reduction,
    denominator_reduction,
    label_smoothing_factor,
    include_accuracy
):
    """
    :param numerator_reduction: This can be none or sum. If none, then all
        numerators are returned as a 1-D tensor of values, with one value per
        example. If sum, then they are returned as 0-D tensors with a single
        value, as if the none version had been summed.
    :param denominator_reduction: This can be none or sum. If none, then all
        denominators are returned as a 1-D tensor of values, with one value per
        example. A value of `None` is equivalent to a tensor of all 1's. If
        sum, then they are returned as floats or ints, as if the none version
        had been summed.
    """
    (
        model_input,
        (
            expected_recognition_output,
            expected_language_modeling_output,
            expected_next_symbols_output,
            next_symbols_padding_mask,
            positive_output_lengths
        )
    ) = prepared_batch
    (
        recognition_logits,
        language_modeling_logits,
        next_symbols_logits
    ) = model_interface.get_logits(
        model,
        model_input
    )
    result = {}
    # Compute the recognition loss using binary cross-entropy.
    recognition_loss = torch.nn.functional.binary_cross_entropy_with_logits(
        recognition_logits,
        expected_recognition_output,
        reduction=numerator_reduction
    )
    match denominator_reduction:
        case 'none':
            num_examples_denominator = None
        case 'sum':
            num_examples_denominator = len(recognition_logits)
        case _:
            raise ValueError
    result['recognition_cross_entropy'] = (
        recognition_loss,
        num_examples_denominator
    )
    if include_accuracy:
        # The model accepts iff the logit is >= 0 (the probability is >= 0.5).
        recognition_predictions = recognition_logits >= 0.0
        recognition_accuracy = recognition_predictions == expected_recognition_output
        match numerator_reduction:
            case 'none':
                pass
            case 'sum':
                recognition_accuracy = torch.sum(recognition_accuracy)
            case _:
                raise ValueError
        result['recognition_accuracy'] = (
            recognition_accuracy,
            num_examples_denominator
        )
    if language_modeling_logits is not None:
        # Compute the language modeling loss using cross-entropy.
        pad_index = model_interface.output_padding_index
        language_modeling_ce = torch.nn.functional.cross_entropy(
            language_modeling_logits.permute(0, 2, 1),
            expected_language_modeling_output,
            ignore_index=pad_index,
            reduction='none',
            label_smoothing=label_smoothing_factor
        )
        # Average over timesteps.
        language_modeling_mean_ce = torch.sum(language_modeling_ce, dim=1) / positive_output_lengths
        match numerator_reduction:
            case 'sum':
                language_modeling_loss = torch.sum(language_modeling_mean_ce)
            case 'none':
                language_modeling_loss = language_modeling_mean_ce
            case _:
                raise ValueError
        match denominator_reduction:
            case 'none':
                num_positive_denominator = None
            case 'sum':
                num_positive_denominator = len(expected_language_modeling_output)
            case _:
                raise ValueError
        result['language_modeling_cross_entropy'] = (
            language_modeling_loss,
            num_positive_denominator
        )
    if next_symbols_logits is not None:
        # Compute the valid symbols loss using binary cross-entropy.
        pad_index = model_interface.output_padding_index
        next_symbols_unmasked_ce = torch.nn.functional.binary_cross_entropy_with_logits(
            next_symbols_logits,
            expected_next_symbols_output,
            reduction='none'
        )
        # Average over alphabet symbols and mask out padding positions.
        next_symbols_alphabet_mean_ce = torch.mean(
            next_symbols_unmasked_ce,
            dim=2
        ) * next_symbols_padding_mask
        # Average over timesteps.
        next_symbols_mean_ce = torch.sum(next_symbols_alphabet_mean_ce, dim=1) / positive_output_lengths
        match numerator_reduction:
            case 'sum':
                next_symbols_loss = torch.sum(next_symbols_mean_ce)
            case 'none':
                next_symbols_loss = next_symbols_mean_ce
            case _:
                raise ValueError
        match denominator_reduction:
            case 'none':
                num_positive_denominator = None
            case 'sum':
                num_positive_denominator = len(expected_next_symbols_output)
            case _:
                raise ValueError
        result['next_symbols_cross_entropy'] = (
            next_symbols_loss,
            num_positive_denominator
        )
        if include_accuracy:
            next_symbols_predictions = next_symbols_logits >= 0.0
            next_symbols_accuracy = (
                (next_symbols_predictions == expected_next_symbols_output) *
                next_symbols_padding_mask[:, :, None]
            )
            next_symbols_set_accuracy = torch.all(next_symbols_accuracy, dim=2)
            next_symbols_string_accuracy = torch.all(next_symbols_set_accuracy, dim=1)
            match numerator_reduction:
                case 'none':
                    next_symbols_symbol_accuracy = torch.sum(next_symbols_accuracy, dim=(1, 2))
                    next_symbols_set_accuracy = torch.sum(next_symbols_set_accuracy, dim=1)
                case 'sum':
                    next_symbols_symbol_accuracy = torch.sum(next_symbols_accuracy)
                    next_symbols_set_accuracy = torch.sum(next_symbols_set_accuracy)
                    next_symbols_string_accuracy = torch.sum(next_symbols_string_accuracy)
                case _:
                    raise ValueError
            vocab_size = next_symbols_logits.size(2)
            match denominator_reduction:
                case 'none':
                    num_next_symbols_denominator = positive_output_lengths * vocab_size
                    num_next_symbols_sets_denominator = positive_output_lengths
                case 'sum':
                    num_next_symbols_sets_denominator = torch.sum(positive_output_lengths).item()
                    num_next_symbols_denominator = num_next_symbols_sets_denominator * vocab_size
                case _:
                    raise ValueError
            result['next_symbols_symbol_accuracy'] = (
                next_symbols_symbol_accuracy,
                num_next_symbols_denominator
            )
            result['next_symbols_set_accuracy'] = (
                next_symbols_set_accuracy,
                num_next_symbols_sets_denominator
            )
            result['next_symbols_string_accuracy'] = (
                next_symbols_string_accuracy,
                num_positive_denominator
            )
    return result
