# ------------------------------------------
# Diffsound
# written by Dongchao Yang
# ------------------------------------------

import torch
import math
from torch import nn
from nn_ss.sound_synthesis2.utils.misc import instantiate_from_config
import time
import numpy as np
from PIL import Image
import os
from hydra.utils import instantiate
from torch.cuda.amp import autocast
import random


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class ControlSpeech(nn.Module):
    def __init__(
            self,
            *,
            n_q=4,
            content_info={'key': 'wav_token'},
            condition_info={'key': 'text_dpe_adapted'},
            learnable_cf=False,
            diffusion_config,
    ):
        super().__init__()
        self.n_q = n_q
        self.content_info = content_info
        self.condition_info = condition_info
        self.guidance_scale = 1.0  # we donot use the classifier guidance in this stage
        self.transformer = instantiate_from_config(diffusion_config)
        self.truncation_forward = False
        self.mask_id = 1024  # the last token reprent MASK

    def parameters(self, recurse=True, name=None):
        if name is None or name == 'none':
            return super().parameters(recurse=recurse)
        else:
            names = name.split('+')
            params = []
            for n in names:
                try:  # the parameters() method is not overwritten for some classes
                    params += getattr(self, name).parameters(recurse=recurse, name=name)
                except:
                    params += getattr(self, name).parameters(recurse=recurse)
            return params

    @property
    def device(self):
        return self.transformer.device

    def get_ema_model(self):
        return self.transformer

    def p_sample_with_truncation(self, func, sample_type):
        truncation_rate = float(sample_type.replace('q', ''))

        def wrapper(*args, **kwards):
            out = func(*args, **kwards)
            import random
            if random.random() < truncation_rate:
                out = func(out, args[1], args[2], **kwards)
            return out

        return wrapper

    def predict_start_with_truncation(self, func, sample_type):
        if sample_type[-1] == 'p':
            truncation_k = int(sample_type[:-1].replace('top', ''))
            content_codec = self.content_codec
            save_path = self.this_save_path

            def wrapper(*args, **kwards):
                out = func(*args, **kwards)
                val, ind = out.topk(k=truncation_k, dim=1)
                probs = torch.full_like(out, -70)
                probs.scatter_(1, ind, val)
                return probs

            return wrapper
        elif sample_type[-1] == 'r':
            truncation_r = float(sample_type[:-1].replace('top', ''))

            def wrapper(*args, **kwards):
                out = func(*args, **kwards)
                # notice for different batches, out are same, we do it on out[0]
                temp, indices = torch.sort(out, 1, descending=True)
                temp1 = torch.exp(temp)
                temp2 = temp1.cumsum(dim=1)
                temp3 = temp2 < truncation_r
                new_temp = torch.full_like(temp3[:, 0:1, :], True)
                temp6 = torch.cat((new_temp, temp3), dim=1)
                temp3 = temp6[:, :-1, :]
                temp4 = temp3.gather(1, indices.argsort(1))
                temp5 = temp4.float() * out + (1 - temp4.float()) * (-70)
                probs = temp5
                return probs

            return wrapper
        else:
            print("wrong sample type")

    @torch.no_grad()
    def generate_content_tmp(
            self,
            batch,
            condition=None,
            filter_ratio=0.0,
            temperature=1.0,
            content_ratio=0.0,
            return_rec=False,
            replicate=1,
            return_att_weight=False,
            sample_type="top0.85r"):
        self.eval()
    
        out = {}
        
        trans_out = self.transformer.predict_infer(batch=batch)
        out['token_pred'] = trans_out
        # out['prompt'] = prompt
        return out


    @torch.no_grad()
    def infer_one(self, batch):
        """
        推理主入口
        """
        output = self.generate_content_tmp(batch)
        # mel_pre = {}
        # mel_pre['mel_pre'] = output['content']
        return output  # return mel

    def forward(
            self,
            batch,
            name='none',
            **kwargs
    ):
        # print('3input ', input)
        output = self.transformer(batch, **kwargs)  # 信息处理直接交给transformer
        # print('output ',output)
        # assert 1==2
        return output
