import matplotlib

matplotlib.use('Agg')

from utils import audio


from tasks.base_task import data_loader

import os
from tqdm import tqdm
from utils.hparams import hparams
import utils.plot as plot
from utils.plot import np_now, save_spectrogram, weight_to_figure_with_mask
from modules.fastspeech.video_tts import VideoTts
from tasks.tts.videotts.dataset import get_tts_dataloader, get_test_dataloader
from tasks.tts.transformer_tts import TransformerTtsTask
import torch
import torch.optim
import torch.utils.data
import torch.nn.functional as F
import utils
import torch.distributions
import numpy as np
from pathlib import Path
from utils.dsp import save_wav, reconstruct_waveform
from utils.ffmpeg_utils import combine_audio_video
from vocoders.base_vocoder import get_vocoder_cls


class VideoTtsTask(TransformerTtsTask):
    def __init__(self):
        super(VideoTtsTask, self).__init__()
        self.mse_loss_fn = torch.nn.MSELoss()
        mel_losses = hparams['mel_loss'].split("|")
        self.loss_and_lambda = {}
        for i, l in enumerate(mel_losses):
            if l == '':
                continue
            if ':' in l:
                l, lbd = l.split(":")
                lbd = float(lbd)
            else:
                lbd = 1.0
            self.loss_and_lambda[l] = lbd
        print("| Mel losses:", self.loss_and_lambda)
        self.vocoder = None

    def build_tts_model(self):
        self.model = VideoTts()

    def build_model(self):
        self.build_tts_model()
        if hparams['load_ckpt'] != '':
            self.load_ckpt(hparams['load_ckpt'], strict=True)
        utils.print_arch(self.model)
        return self.model

    def _training_step(self, sample, batch_idx, _):
        loss_output = self.run_model(self.model, sample, is_training=True)
        total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
        loss_output['batch_size'] = sample['tokens'].shape[0]  # sample['txt_tokens'].size()[0]
        return total_loss, loss_output

    def validation_step(self, sample, batch_idx):
        outputs = {}
        outputs['losses'] = {}
        outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True)
        outputs['total_loss'] = sum(outputs['losses'].values())
        outputs['nsamples'] = sample['tokens'].shape[0]
        outputs = utils.tensors_to_scalars(outputs)
        return outputs

    def _validation_end(self, outputs):
        all_losses_meter = {
            'total_loss': utils.AvgrageMeter(),
        }
        for output in outputs:
            n = output['nsamples']
            for k, v in output['losses'].items():
                if k not in all_losses_meter:
                    all_losses_meter[k] = utils.AvgrageMeter()
                all_losses_meter[k].update(v, n)
            all_losses_meter['total_loss'].update(output['total_loss'], n)
        return {k: round(v.avg, 4) for k, v in all_losses_meter.items()}

    def run_model(self, model: VideoTts, sample, return_output=False, is_training=False):
        tokens = sample['tokens']
        token_lens = sample['token_lens']
        mel = sample.get('mels')  # (B, T, H)
        ids = sample['ids']
        mel_lens = sample.get('mel_lens')
        imgs = sample.get('imgs')
        img_lens = sample.get('img_lens')
        f0s = sample.get('f0s')
        uvs = sample.get('uvs')
        pitches = sample.get('pitches')
        ctc_tokens = sample.get('ctc_tokens')
        spk_img = sample.get('spk_img')
        output = model(tokens, imgs, token_lens, img_lens, pitches, mel_lens, infer=False,
                       spk_img=spk_img, f0=f0s, uv=uvs)

        losses = {}

        mel_loss = self.add_mel_loss(output['mel_out'], mel, losses)

        if hparams['use_align_branch']:
            self.add_mel_loss(output['align_mel_out'], mel, losses, postfix='-align-branch',
                              local_lbd=hparams['align_branch_factor'])

        if hparams['use_diag_loss']:
            losses['diag_loss'] = output['diag_loss'] * hparams['diag_loss_factor']

        if output['entropy_loss'] is not None:
            losses['entropy_loss'] = output['entropy_loss'] * hparams['entropy_loss_factor']

        diag_mask = output['diagonal_mask']
        mel_out = output['mel_out']  # (B, T, H)
        attention = output['attn_map']
        save_dir = Path(os.path.join(hparams['work_dir'], 'model_outs', 'train' if is_training else 'val'))
        os.makedirs(save_dir / 'tts_attention', exist_ok=True)
        os.makedirs(save_dir / 'tts_mel_plot', exist_ok=True)

        step = self.global_step

        if hparams['validate']:
            assert not is_training, f"is_training = {is_training}"
            save_dir = Path(os.path.join(hparams['work_dir'], 'model_outs', 'validate'))
            for idx in tqdm(range(tokens.shape[0])):
                att_id = ids[idx]
                save_name = f'{att_id}_{step}_{hparams["vocoder"]}'
                weight_to_figure_with_mask(attention[idx], diag_mask[idx], plots=None,
                                           path=save_dir / 'tts_attention' / f'{save_name}.png')
                save_spectrogram(mel_out[idx], save_dir / 'tts_mel_plot' / save_name)
                if step > 10_000:
                    self.mel2wav(mel_out[idx][:mel_lens[idx]].T, save_dir / 'tts_wav' / f'{save_name}.wav')
                    ep_id, cut_id = att_id.rsplit('-', 1)
                    video_path = Path(hparams['sent_video_path']) / f'{ep_id}' / f'{cut_id}.mp4'
                    combine_audio_video(video_path, save_dir / 'tts_wav' / f'{save_name}.wav',
                                        save_dir / 'tts_mp4' / f'{save_name}.mp4')

        else:
            att_id = self.train_att_id if is_training else self.val_att_id
            if att_id in ids:
                idx = ids.index(att_id)
                save_name = f'{att_id}_{step}_{hparams["vocoder"]}'
                weight_to_figure_with_mask(attention[idx], diag_mask[idx], plots=None,
                                           path=save_dir / 'tts_attention' / f'{save_name}.png')
                save_spectrogram(mel_out[idx], save_dir / 'tts_mel_plot' / save_name)
                if step > 10_000:
                    ep_id, cut_id = att_id.rsplit('-', 1)
                    video_path = Path(hparams['sent_video_path']) / f'{ep_id}' / f'{cut_id}.mp4'
                    self.mel2wav(mel_out[idx][:mel_lens[idx]].T, save_dir / 'tts_wav' / f'{save_name}.wav')
                    combine_audio_video(video_path, save_dir / 'tts_wav' / f'{save_name}.wav',
                                        save_dir / 'tts_mp4' / f'{save_name}.mp4')

        if hparams['use_pitch_embed']:
            self.add_pitch_loss(output, sample, losses)
        if hparams['use_energy_embed']:
            self.add_energy_loss(output['energy_pred'], sample['energy'], losses)

        if not return_output:
            return losses
        else:
            return losses, output

    ############
    # Dataloader
    ############
    @data_loader
    def train_dataloader(self):
        data_dl, self.train_att_id = get_tts_dataloader(hparams, prefix='train', use_ddp=self.trainer.use_ddp)
        return data_dl

    @data_loader
    def val_dataloader(self):
        data_dl, self.val_att_id = get_tts_dataloader(hparams, prefix='val', use_ddp=self.trainer.use_ddp)
        return data_dl

    @data_loader
    def test_dataloader(self):
        if not hparams['use_gt_mel_eval']:
            return get_test_dataloader(hparams)
        else:
            data_dl, self.test_att_id = get_tts_dataloader(hparams, prefix=hparams['testset_prefix'],
                                                           use_ddp=self.trainer.use_ddp)
            return data_dl

    #############
    # mel -> wave
    ##############
    def mel2wav(self, m, save_path=None):
        """
        m shape: (80, n)
        """
        if isinstance(m, torch.Tensor):
            m = np_now(m)

        if hparams['mel_gen_mode'] == 'pwg':
            if self.vocoder is None:
                self.vocoder = get_vocoder_cls(hparams)()
            m = m.T
            # spec2wav need mel: [T, 80]
            wav = self.vocoder.spec2wav(m)
        elif hparams['mel_gen_mode'] == 'wavernn':
            m = (m + 4) / 8
            m = np.clip(m, 0, 1)
            # (n_mels, n)
            wav = reconstruct_waveform(m, n_iter=hparams['gl_iters'])

        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            if hparams['vocoder'] == 'pwg':
                audio.save_wav(wav, save_path, hparams['audio_sample_rate'],
                               norm=hparams['out_wav_norm'])
            else:
                save_wav(wav, save_path)

        return wav

    ############
    # losses
    ############
    def add_mel_loss(self, mel_out, target, losses, postfix='', mel_mix_loss=None, local_lbd=1.0):
        all_mel_loss = 0
        nonpadding = target.abs().sum(-1).ne(0).float()
        for loss_name, lbd in self.loss_and_lambda.items():
            if 'l1' == loss_name:
                l = self.l1_loss(mel_out, target)
            elif 'mse' == loss_name:
                l = self.mse_loss(mel_out, target)

            losses[f'{loss_name}{postfix}'] = l * lbd * local_lbd
            all_mel_loss += losses[f'{loss_name}{postfix}']
        return all_mel_loss


    def l1_loss(self, decoder_output, target):
        # decoder_output : B x T x n_mel
        # target : B x T x n_mel
        l1_loss = F.l1_loss(decoder_output, target, reduction='none')
        weights = self.weights_nonzero_speech(target)
        l1_loss = (l1_loss * weights).sum() / weights.sum()
        return l1_loss

    def mse_loss(self, decoder_output, target):
        # decoder_output : B x T x n_mel
        # target : B x T x n_mel
        assert decoder_output.shape == target.shape
        mse_loss = F.mse_loss(decoder_output, target, reduction='none')
        weights = self.weights_nonzero_speech(target)
        mse_loss = (mse_loss * weights).sum() / weights.sum()
        return mse_loss

    def add_pitch_loss(self, output, sample, losses):
        mels = sample['mels']  # [B, T, H]
        f0 = sample['f0s']
        uv = sample['uvs']
        nonpadding = mels.abs().sum(-1, keepdim=False).ne(0).float()
        if hparams['pitch_type'] == 'frame':
            self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding)

    def add_f0_loss(self, p_pred, f0, uv, losses, nonpadding, postfix=''):
        assert p_pred[..., 0].shape == f0.shape
        if hparams['use_uv']:
            assert p_pred[..., 1].shape == uv.shape
            losses[f'uv{postfix}'] = (F.binary_cross_entropy_with_logits(
                p_pred[:, :, 1], uv, reduction='none') * nonpadding).sum() \
                                     / nonpadding.sum() * hparams['lambda_uv']
            nonpadding = nonpadding * (uv == 0).float()

        f0_pred = p_pred[:, :, 0]
        pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss
        losses[f'f0{postfix}'] = (pitch_loss_fn(f0_pred, f0, reduction='none') * nonpadding).sum() \
                                 / nonpadding.sum() * hparams['lambda_f0']

    def add_energy_loss(self, energy_pred, energy, losses):
        nonpadding = (energy != 0).float()
        loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum()
        loss = loss * hparams['lambda_energy']
        losses['e'] = loss

    ############
    # infer
    ############
    def test_start(self):
        self.vocoder = get_vocoder_cls(hparams)()

    def test_end(self, outputs):
        pass

    def test_step(self, sample, batch_idx):
        """
        tokens = sample['tokens']
        token_lens = sample['token_lens']
        mel = sample.get('mels')  # (B, T, H)
        ids = sample['ids']
        mel_lens = sample.get('mel_lens')
        imgs = sample.get('imgs')
        img_lens = sample.get('img_lens')
        f0s = sample.get('f0s')
        uvs = sample.get('uvs')
        pitches = sample.get('pitches')
        ctc_tokens = sample.get('ctc_tokens')
        spk_img = sample.get('spk_img')
        """
        tokens = sample['tokens']
        token_lens = sample['token_lens']
        token_id = sample.get('token_id')
        video_id = sample.get('video_id')
        imgs = sample.get('imgs')
        img_lens = sample.get('img_lens')
        video_path = sample.get('video_path')
        spk_img = sample.get('spk_img')

        if not hparams['use_gt_mel_eval']:
            output = self.model(tokens, imgs, token_lens, img_lens, None, None, infer=False, spk_img=spk_img)

            diag_mask = output['diagonal_mask']
            mel_out = output['mel_out']  # (B, T, H)
            attention = output['attn_map']
            save_dir = Path(os.path.join(hparams['work_dir'], 'model_outs', 'test'))
            os.makedirs(save_dir / 'tts_attention', exist_ok=True)
            os.makedirs(save_dir / 'tts_mel_plot', exist_ok=True)

            step = self.global_step
            assert tokens.shape[0] == 1
            for idx in range(tokens.shape[0]):
                sample_id = f'text-{token_id}-video-{video_id}' + ('-MAS' if hparams['use_mas'] else '')
                weight_to_figure_with_mask(attention[idx], diag_mask[idx], plots=None,
                                           path=save_dir / 'tts_attention' / f'{sample_id}_{step}.png')
                if hparams['use_mas'] and hparams.get('mas_path') is not None:
                    weight_to_figure_with_mask(hparams['mas_path'][idx], diag_mask[idx], plots=None,
                                               path=save_dir / 'tts_attention' / f'{sample_id}_MASpath_{step}.png')
                plot.weight_to_npy(attention[idx], save_dir / 'tts_attention_npy' / f'{sample_id}_{step}.npy')
                save_spectrogram(mel_out[idx], save_dir / 'tts_mel_plot' / f'{sample_id}_{step}')
                self.mel2wav(mel_out[idx].T, save_dir / 'tts_wav' / f'{sample_id}_{step}_{hparams["vocoder"]}.wav')
                combine_audio_video(video_path, save_dir / 'tts_wav' / f'{sample_id}_{step}_{hparams["vocoder"]}.wav',
                                    save_dir / 'tts_mp4' / f'{sample_id}_{step}_{hparams["vocoder"]}.mp4')
        else:
            gt_mels = sample.get('mels')
            ids = sample.get('ids')
            assert gt_mels is not None and gt_mels.shape[0] == 1
            save_dir = Path(os.path.join(hparams['work_dir'], 'model_outs', 'eval'))
            id = ids[0]
            ep_id, cut_id = id.rsplit('-', 1)
            video_path = Path(hparams['sent_video_path']) / f'{ep_id}' / f'{cut_id}.mp4'
            wav_save_path = save_dir / 'gtmel_wav' / f'gtmel_{id}_{hparams["vocoder"]}.wav'
            self.mel2wav(gt_mels[0].T, wav_save_path)
            combine_audio_video(video_path, wav_save_path,
                                save_dir / 'gtmel_mp4' / f'gtmel_{id}_{hparams["vocoder"]}.mp4')

            # combine 16k hz gt recording with original video
            preprocessed_root = hparams['preprocessed_root']
            clip_root = os.path.join(preprocessed_root, ep_id, cut_id)
            gt_audio_path = os.path.join(clip_root, 'audio.wav')
            combine_audio_video(video_path, gt_audio_path,
                                save_dir / 'gt_mp4' / f'gt_{id}.mp4')

    def after_infer(self, predictions, sil_start_frame=0):
        pass


if __name__ == '__main__':
    import pickle
    from utils.hparams import set_hparams

    set_hparams()
    model0 = VideoTtsTask()
    pickle.dumps(model0)
    print('Pickle Successfully!')
