# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import multiprocessing

import flashy
import torch
import omegaconf
from torch import nn

from .. import models, quantization
from . import builders
from .compression import CompressionSolver
from transformers import HubertModel,  Wav2Vec2FeatureExtractor
from ..data.audio_utils import convert_audio

from ..utils.samples.manager import SampleManager
from ..utils.utils import get_pool_executor
from .compression import evaluate_audio_reconstruction

logger = logging.getLogger(__name__)


def fast_forward(wav: torch.Tensor, target_length: int) -> torch.Tensor:
    # wav: (B, C, T)
    # 音声データを間引きして早送り
    num_samples = wav.size(-1)
    indices = torch.linspace(0, num_samples - 1, target_length).long()
    wav = wav[..., indices]
    return wav


class SpeechTokenizerSolver(CompressionSolver):
    def __init__(self, cfg: omegaconf.DictConfig):
        super().__init__(cfg)
        self.other_losses = nn.ModuleDict()
        self.other_loss_weights = dict()
        for loss_name, weight in self.cfg.losses_other.items():
            if "penalty" in loss_name:
                self.other_loss_weights[loss_name] = weight
            elif weight > 0:
                self.other_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
                self.other_loss_weights[loss_name] = weight
            else:
                self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)

    def build_model(self):
        """Instantiate model and optimizer."""
        # Model and optimizer
        self.model = models.builders.get_compression_model(self.cfg).to(self.device)
        self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
        self.lr_scheduler = builders.get_lr_scheduler(
            self.optimizer, self.cfg.schedule, self.total_updates
        )
        self.register_stateful("model", "optimizer", "lr_scheduler")
        self.register_best_state("model")
        self.register_ema("model")
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(self.cfg.semantic_model.path)
        # encoder.pos_conv_embed.conv.weight_g に関する warning は無視して良い. huggingface のバグ
        self.semantic_model = HubertModel.from_pretrained(self.cfg.semantic_model.path).eval().to(self.device)

    @torch.no_grad()
    def get_semantic_features(self, x):
        semantic_frame_rate = self.cfg.semantic_model.frame_rate  # 50
        semantic_sample_rate = self.cfg.semantic_model.sample_rate  # 16000
        semantic_hop_length = semantic_sample_rate // semantic_frame_rate  # 320
        assert x.size(-1) * self.model.frame_rate % semantic_frame_rate == 0
        target_length = x.size(-1) * self.model.frame_rate // semantic_frame_rate
        # Hubert はなぜか hop_length 分追加してあげないといけない
        target_length += semantic_hop_length * self.model.sample_rate // semantic_sample_rate

        x = fast_forward(x, target_length)
        wav = convert_audio(x, self.model.sample_rate, semantic_sample_rate, 1)[:, 0]
        wav = [w.cpu().numpy() for w in wav]
        input_values = self.feature_extractor(wav, sampling_rate=semantic_sample_rate, return_tensors="pt").input_values
        output = self.semantic_model(input_values.to(self.device), output_hidden_states=True)
        if self.cfg.semantic_model.target_layer == 'avg':
            feature = torch.mean(torch.stack(output.hidden_states), axis=0)
        else:
            feature = output.hidden_states[self.cfg.semantic_model.target_layer]
        return feature

    def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
        """Perform one training or valid step on a given batch."""
        x = batch.to(self.device)
        y = x.clone()

        semantic_target = self.get_semantic_features(y)

        qres, semantic_output = self.model(x)
        assert isinstance(qres, quantization.QuantizedResult)
        y_pred = qres.x
        # Log bandwidth in kb/s
        metrics["bandwidth"] = qres.bandwidth.mean()

        if self.is_training:
            d_losses: dict = {}
            if (
                len(self.adv_losses) > 0
                and torch.rand(1, generator=self.rng).item()
                <= 1 / self.cfg.adversarial.every
            ):
                for adv_name, adversary in self.adv_losses.items():
                    disc_loss = adversary.train_adv(y_pred, y)
                    d_losses[f"d_{adv_name}"] = disc_loss
                metrics["d_loss"] = torch.sum(torch.stack(list(d_losses.values())))
            metrics.update(d_losses)

        balanced_losses: dict = {}
        other_losses: dict = {}

        # penalty from quantization
        if qres.penalty is not None and qres.penalty.requires_grad:
            other_losses["penalty"] = qres.penalty  # penalty term from the quantizer

        # adversarial losses
        for adv_name, adversary in self.adv_losses.items():
            adv_loss, feat_loss = adversary(y_pred, y)
            balanced_losses[f"adv_{adv_name}"] = adv_loss
            balanced_losses[f"feat_{adv_name}"] = feat_loss

        # auxiliary losses
        for loss_name, criterion in self.aux_losses.items():
            loss = criterion(y_pred, y)
            balanced_losses[loss_name] = loss

        # other losses
        for loss_name, criterion in self.other_losses.items():
            if "semantic" in loss_name:
                loss = criterion(semantic_output, semantic_target)
                # target が異なるので other losses
                other_losses[loss_name] = loss
            else:
                loss = criterion(y_pred, y)
                other_losses[loss_name] = loss

        # weighted losses
        metrics.update(balanced_losses)
        metrics.update(other_losses)
        metrics.update(qres.metrics)

        if self.is_training:
            # backprop losses that are not handled by balancer
            other_loss = torch.tensor(0.0, device=self.device)
            for key in other_losses:
                other_loss += self.other_loss_weights[key] * other_losses[key]
            if other_loss.requires_grad:
                other_loss.backward(retain_graph=True)
                ratio1 = sum(
                    p.grad.data.norm(p=2).pow(2)
                    for p in self.model.parameters()
                    if p.grad is not None
                )
                assert isinstance(ratio1, torch.Tensor)
                metrics["ratio1"] = ratio1.sqrt()

            # balancer losses backward, returns effective training loss
            # with effective weights at the current batch.
            metrics["g_loss"] = self.balancer.backward(balanced_losses, y_pred)
            # add metrics corresponding to weight ratios
            metrics.update(self.balancer.metrics)
            ratio2 = sum(
                p.grad.data.norm(p=2).pow(2)
                for p in self.model.parameters()
                if p.grad is not None
            )
            assert isinstance(ratio2, torch.Tensor)
            metrics["ratio2"] = ratio2.sqrt()

            # optim
            flashy.distrib.sync_model(self.model)
            if self.cfg.optim.max_norm:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.cfg.optim.max_norm
                )
            self.optimizer.step()
            if self.lr_scheduler:
                self.lr_scheduler.step()
                metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
            self.optimizer.zero_grad()

        # informative losses only
        info_losses: dict = {}
        with torch.no_grad():
            for loss_name, criterion in self.info_losses.items():
                if "semantic" in loss_name:
                    loss = criterion(semantic_output, semantic_target)
                    info_losses[loss_name] = loss
                else:
                    loss = criterion(y_pred, y)
                    info_losses[loss_name] = loss

        metrics.update(info_losses)

        # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups
        adv_losses = [
            loss for loss_name, loss in metrics.items() if loss_name.startswith("adv")
        ]
        if len(adv_losses) > 0:
            metrics["adv"] = torch.sum(torch.stack(adv_losses))
        feat_losses = [
            loss for loss_name, loss in metrics.items() if loss_name.startswith("feat")
        ]
        if len(feat_losses) > 0:
            metrics["feat"] = torch.sum(torch.stack(feat_losses))

        return metrics

    def evaluate(self):
        """Evaluate stage. Runs audio reconstruction evaluation."""
        self.model.eval()
        evaluate_stage_name = str(self.current_stage)

        loader = self.dataloaders["evaluate"]
        updates = len(loader)
        lp = self.log_progress(
            f"{evaluate_stage_name} inference",
            loader,
            total=updates,
            updates=self.log_updates,
        )
        average = flashy.averager()

        pendings = []
        ctx = multiprocessing.get_context("spawn")
        with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
            for idx, batch in enumerate(lp):
                x = batch.to(self.device)
                with torch.no_grad():
                    qres, _ = self.model(x)

                y_pred = qres.x.cpu()
                y = batch.cpu()  # should already be on CPU but just in case
                pendings.append(
                    pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg)
                )

            metrics_lp = self.log_progress(
                f"{evaluate_stage_name} metrics", pendings, updates=self.log_updates
            )
            for pending in metrics_lp:
                metrics = pending.result()
                metrics = average(metrics)

        metrics = flashy.distrib.average_metrics(metrics, len(loader))
        return metrics

    def generate(self):
        """Generate stage."""
        self.model.eval()
        sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
        generate_stage_name = str(self.current_stage)

        loader = self.dataloaders["generate"]
        updates = len(loader)
        lp = self.log_progress(
            generate_stage_name, loader, total=updates, updates=self.log_updates
        )

        for batch in lp:
            reference, _ = batch
            reference = reference.to(self.device)
            with torch.no_grad():
                qres, _ = self.model(reference)
            assert isinstance(qres, quantization.QuantizedResult)

            reference = reference.cpu()
            estimate = qres.x.cpu()
            sample_manager.add_samples(
                estimate, self.epoch, ground_truth_wavs=reference
            )

        flashy.distrib.barrier()