# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
import random
# import tgt
import sys
from typing import Dict, List, Optional, Tuple
import numpy as np
from dataclasses import dataclass, field
from fairseq.dataclass.configs import FairseqDataclass
from fairseq.tasks import register_task
from fairseq.tasks.fairseq_task import FairseqTask
import torch
from fairseq import utils
import torch.nn.functional as F
from fairseq.data import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
import json
from pathlib import Path
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from ..utils.attn_mask import *

logger = logging.getLogger(__name__)

def map_speed(mark, real_min, real_mean, real_max):
    mark = max(min(mark, 2.0), 0.5)
    if mark <= 1:
        real_speed = real_min + (mark - 0.5) / (1 - 0.5) * (real_mean - real_min)
    else:
        real_speed = real_mean + (mark - 1) / (2 - 1) * (real_max - real_mean)
    return real_speed   

def load_audio(
    manifest_path, max_keep_sample_size, min_keep_sample_size
):
    n_long, n_short, error_num = 0, 0, 0
    item_zip_locs, lengthes, emos, promp_scalers = [[] for i in range(4)]
    with open(manifest_path, "r", encoding='utf-8') as f:
        for line in f.readlines():
            line = line.strip()
            item_zip_loc, \
                codec_lens_str, text_lens_str, \
                    prompt_codec_lens, prompt_text_lens, prompt_energy, \
                        emos_str, prompt_scalers_str = line.split("\t")
            length = 0
            for codec_len, text_len, prompt_codec_len, prompt_text_len, prompt_scaler, emo in zip(
                np.array(codec_lens_str.split("|"), dtype=float),
                np.array(text_lens_str.split("|"), dtype=float),
                np.array(prompt_codec_lens.split("|"), dtype=float),
                np.array(prompt_text_lens.split("|"), dtype=float),
                np.array(prompt_scalers_str.split("|"), dtype=float),
                np.array(emos_str.split("|"), dtype=str),
            ):
                length += (codec_len + text_len + prompt_text_len + prompt_codec_len*prompt_scaler)
            if length < min_keep_sample_size:
                n_short += 1
                continue
            elif length > max_keep_sample_size:
                n_long += 1
                continue
            else:
                item_zip_locs.append(item_zip_loc)
                emos.append(emos_str)
                promp_scalers.append(prompt_scalers_str)
                lengthes.append(length)
    logger.info(
        (
            f"max_keep={max_keep_sample_size}, min_keep={min_keep_sample_size}, "
            f"loaded {len(lengthes)}, skipped {n_short} short and {n_long} long,"
            f"longest-loaded={max(lengthes)}, shortest-loaded={min(lengthes)}"
        )
    )
    return item_zip_locs, emos, promp_scalers, lengthes

def load_prompt_infos(prompt_json):
    with open(prompt_json, "r", encoding="utf-8") as rf:
        prompt_infos = json.load(rf)
    path2info = {}
    for emo in prompt_infos.keys():
        for spk in prompt_infos[emo].keys():
            for item in prompt_infos[emo][spk]:
                path2info[item["filepath"]] = item
    return path2info

class CELSDSDataset(FairseqDataset):
    def __init__(
        self,
        manifest_path: str,
        max_keep_sample_size: Optional[int] = None,
        min_keep_sample_size: Optional[int] = None,
        max_sample_size: Optional[int] = None,
        shuffle: bool = True,
        max_valid_sample_size = -1,
        prompt_json="",
        attention_bias="",
        attention_range="",
        multi_ab=False
    ):
        self.item_zip_locs, self.emos, self.promp_scalers, self.lengthes = load_audio(
            manifest_path, max_keep_sample_size, min_keep_sample_size, 
        )
        self.shuffle = shuffle
        self.max_sample_size = max_sample_size
        self.text_pad_idx = 50257
        self.audio_pad_idx = 0
        self.prompt_info = load_prompt_infos(prompt_json)
        self.emo_dict = {
            "sad":0, 
            "happy":1, 
            "angry":2, 
            "surprised":3, 
            "neutral":4
        }
        self.attention_bias = attention_bias
        if attention_range != "":
            self.attention_range = list(map(float, attention_range.split("_")))
        self.atten_bias_func = {
            "0": lower_triangle,
            "1": tgt_st_paired_emo_and_all_tt,
            "2": all_paired_emo,
            "3": all_st_paired_emo,
            "4": tgt_st_paired_emo,
            "5": st_paired_emo_tgt_st_all_tt,
            "6": all_st_paired_emo_and_all_tt,
        }
        self.multi_ab = multi_ab
        
    def load_json_fromzip(self, info):
        zip_path, offset, file_size = info.split(":")
        offset = int(offset)
        file_size = int(file_size)
        with open(zip_path, "rb") as f:
            f.seek(offset)
            byte_data = f.read(file_size)
        # JSON 解析
        json_data = json.loads(byte_data.decode("utf-8"))
        return json_data
    
    def get_attn_bias(self, 
                    text_side_prompt_lens, 
                    text_side_tgt_lens, 
                    speech_side_prompt_lens, 
                    speech_side_speech_lens, 
                    value=1, top=5, low=1):
        """
        Args:
            text_side_prompt (_type_): N
            text_side_tgt (_type_): N
            speech_side_prompt (_type_): N
            speech_side_speech (_type_): N
            value (int, optional): _description_. Defaults to 1.
            top (int, optional): _description_. Defaults to 5.
            low (int, optional): _description_. Defaults to 1.
        """
        text_side_prompt = []
        text_side_tgt = []
        speech_side_prompt = []
        speech_side_speech = []
        for i, (a, b, c, d) in enumerate(zip(
            text_side_prompt_lens, 
            text_side_tgt_lens, 
            speech_side_prompt_lens, 
            speech_side_speech_lens
        )):
            if i == 0:
                a = a + 1 # sos
            text_side_prompt.append((f"t_0{i}", a+1))
            text_side_tgt.append((f"t_1{i}", b+1)) # +1 for emotion
            if i == 0:
                c = c + 1 # bos
            speech_side_prompt.append((f"s_0{i}", c))
            speech_side_speech.append((f"s_1{i}", d))

        attn_bias_dict = {}
        for func_key in self.attention_bias.split("_"):
            attn_bias = torch.tensor(self.atten_bias_func[func_key](
                text_side_prompt, 
                text_side_tgt, 
                speech_side_prompt, 
                speech_side_speech, 
                top=float(self.attention_range[1]), 
                low=float(self.attention_range[0])
            ))
            attn_bias_dict[func_key] = attn_bias
        return attn_bias_dict

    def __getitem__(self, index):
        json_items = self.load_json_fromzip(self.item_zip_locs[index])
        scalers = np.array(self.promp_scalers[index].split("|"), dtype=float)
        tgt_codec = []
        tgt_text = []
        prompt_codec = []
        prompt_text = []
        tgt_sp = []
        tgt_energy = []
        tgt_emos = []
        tgt_emos_prompt = []
        emos = []
        prompt_text_lens = []
        text_lens = []
        prompt_codec_lens = []
        tgt_codec_lens = []
        
        for json_item, scaler, emo in zip(json_items, scalers, self.emos[index].split("|")):
            tgt_codec.append(torch.tensor(
                np.array(json_item['st'].split(" "), dtype=int)
            ).to(torch.long))
            tgt_codec_lens.append(tgt_codec[-1].size(0))
            # tgt_text.append(torch.tensor(self.emo_dict[emo], dtype=torch.long))
            tgt_text.append(torch.tensor(
                np.array(json_item['tt'].split(" "), dtype=int)
            ).to(torch.long))
            text_lens.append(tgt_text[-1].size(0))
            
            # 最好判断一下refined
            # tgt_sp.append(float(json_item.get("speed", 1.0)))
            # tgt_energy.append(float(json_item.get("energy", 1.0)))
            
            prompt_info = self.prompt_info[json_item['prompt_wav']]
            prompt_codec.append(resample_by_stride(torch.tensor(
                np.array(prompt_info['speech_token'].split(" "), dtype=int)
            ), scaler).to(torch.long))
            prompt_codec_lens.append(prompt_codec[-1].size(0))
            prompt_text.append(torch.tensor(
                np.array(prompt_info['text_token'].split(" "), dtype=int)
            ).to(torch.long))
            prompt_text_lens.append(prompt_text[-1].size(0))
            
            tgt_emos_prompt.append(torch.tensor(
                [self.emo_dict[emo]]*prompt_codec[-1].size(0),
                dtype=torch.long
            ).to(torch.long))
            
            tgt_emos.append(torch.tensor(
                [self.emo_dict[emo]]*tgt_codec[-1].size(0),
                dtype=torch.long
            ).to(torch.long))
            emos.append(torch.tensor(
                [self.emo_dict[emo]],
                dtype=torch.long
            ).to(torch.long))
        
        
        prompt_codec = torch.cat(prompt_codec, dim=0)
        tgt_codec = torch.cat(tgt_codec, dim=0)
        
        prompt_text = pad_sequence(prompt_text, batch_first=True, padding_value=self.text_pad_idx)
        tgt_text = pad_sequence(tgt_text, batch_first=True, padding_value=self.text_pad_idx)
        tgt_emos = torch.cat(tgt_emos, dim=0)
        tgt_emos_prompt = torch.cat(tgt_emos_prompt, dim=0)
        
        emos = torch.cat(emos, dim=0)
        
        return_dict = {
            "id": index, 
            "tgt_codec": tgt_codec, 
            "tgt_text": tgt_text, 
            "prompt_codec": prompt_codec, 
            "prompt_text": prompt_text, 
            "tgt_emos": tgt_emos, 
            "tgt_emos_prompt": tgt_emos_prompt, 
            "emos": emos,
            "prompt_text_lens": torch.tensor(prompt_text_lens).to(torch.long),
            "text_lens": torch.tensor(text_lens).to(torch.long),
            "prompt_codec_lens": torch.tensor(prompt_codec_lens).to(torch.long),
            "tgt_codec_lens": torch.tensor(tgt_codec_lens).to(torch.long),
            "length": self.lengthes[index]
        }

        if self.attention_bias != "":
            attn_bias = self.get_attn_bias(
                        prompt_text_lens, 
                        text_lens, 
                        prompt_codec_lens, 
                        tgt_codec_lens
                    )
            return_dict["attn_bias"] = attn_bias

        return return_dict
    
    def __len__(self):
        return len(self.lengthes)

    def collater(self, samples):
        # target = max(sizes) -> random_crop not used
        # target = max_sample_size -> random_crop used for long
        samples = [s for s in samples if s["id"] is not None]
        if len(samples) == 0:
            return {}
        
        lengths = [s["length"] for s in samples]
        max_len = max(lengths)
        min_len = min(lengths)

        filter_out = 0
        # length > 0.75 * max	1.33
        # length > 0.80 * max	1.25
        # 0.85  1.176
        # 0.9	1.11
        # 0.95	1.05
        # 3. 如果长度差异过大，动态过滤短样本（比如比例大于2.0）
        if max_len / (min_len + 1e-5) > 1.11:  # 你可以改成3.0、1.5视任务需求
            filtered = []
            for s in samples:
                if s["length"] > 0.9 * max_len:  # 保留较长的样本
                    filtered.append(s)
                else:
                    filter_out += 1
            samples = filtered
        if filter_out != 0:
            logger.info(f"filter out num: {filter_out}, ori lengthes: {str(lengths)}, threshold: {0.9 * max_len}")

        ids = [s["id"] for s in samples]
        tgt_codecs = [s["tgt_codec"] for s in samples]
        tgt_texts = [s["tgt_text"] for s in samples]
        prompt_codecs = [s["prompt_codec"] for s in samples]
        prompt_texts = [s["prompt_text"] for s in samples]
        tgt_emos = [s["tgt_emos"] for s in samples]
        tgt_emos_prompt = [s["tgt_emos_prompt"] for s in samples]
        emos = [s["emos"] for s in samples]
        prompt_text_lens = [s["prompt_text_lens"] for s in samples]
        tgt_text_lens = [s["text_lens"] for s in samples]
        prompt_codec_lens = [s["prompt_codec_lens"] for s in samples]
        tgt_codec_lens = [s["tgt_codec_lens"] for s in samples]
        
        tgt_codecs_lens = torch.tensor([temp.size(0) for temp in tgt_codecs]).to(torch.long)
        prompt_codecs_lens = torch.tensor([temp.size(0) for temp in prompt_codecs]).to(torch.long)
        
        tgt_codecs = data_utils.collate_tokens(
                tgt_codecs,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
        prompt_codecs = data_utils.collate_tokens(
                prompt_codecs,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
        
        tgt_text_lens = data_utils.collate_tokens(
                tgt_text_lens,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
        prompt_text_lens = data_utils.collate_tokens(
                prompt_text_lens,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
        
        prompt_codec_lens = data_utils.collate_tokens(
                prompt_codec_lens,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
        tgt_codec_lens = data_utils.collate_tokens(
                tgt_codec_lens,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
       
        tgt_texts, tgt_iter_num, _ = self.collate_frame_tokens(
            tgt_texts, self.text_pad_idx
        )
        prompt_texts, prompt_iter_num, _ = self.collate_frame_tokens(
            prompt_texts, self.text_pad_idx
        )
        # print(prompt_texts.max())
        tgt_emos = data_utils.collate_tokens(
                tgt_emos,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
        tgt_emos_prompt = data_utils.collate_tokens(
                tgt_emos_prompt,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
        emos = data_utils.collate_tokens(
                emos,
                self.audio_pad_idx,
                self.audio_pad_idx,
                move_eos_to_beginning=False,
                pad_to_length=None,
            )
        net_input = {
            "tgt_codecs": tgt_codecs,
            "prompt_codecs": prompt_codecs,
            "tgt_texts": tgt_texts,
            "prompt_texts": prompt_texts,
            "tgt_emos": tgt_emos,
            "tgt_emos_prompt": tgt_emos_prompt,
            "tgt_text_lens": tgt_text_lens,
            "prompt_text_lens": prompt_text_lens,
            "iter_num": tgt_iter_num,
            "tgt_codecs_lens": tgt_codecs_lens,
            "prompt_codecs_lens": prompt_codecs_lens,
            "emos": emos,
            "prompt_each_codec_lens": prompt_codec_lens,
            "tgt_each_codec_lens": tgt_codec_lens,
        }

        if "attn_bias" in samples[0].keys():
            random_key = random.choice(list(samples[0]["attn_bias"].keys()))
            attn_bias = [i["attn_bias"][random_key] for i in samples]
            attn_bias = self.collate_frames(attn_bias)
            net_input["attn_bias"] = attn_bias
            
            if self.multi_ab:
                all_bias = []
                # causal_bias = None
                for key in range(7):
                    key = str(key)
                    # if key == "0": # causal
                    #     causal_bias = self.collate_frames([i["attn_bias"][key] for i in samples])
                    #     causal_bias = causal_bias / causal_bias.max()
                    # else:
                    all_bias.append(self.collate_frames([i["attn_bias"][key] for i in samples]))
                all_bias = torch.stack(all_bias, 1)
                net_input["all_bias"] = all_bias
                # net_input["causal_bias"] = causal_bias
        
        batch = {
            "id": torch.tensor(ids),
            "net_input": net_input,
        }
        return batch
    
    def num_tokens(self, index):
        return self.size(index)

    def size(self, index):
        return self.lengthes[index]

    def ordered_indices(self):
        # if self.shuffle:
        #     order = [np.random.permutation(len(self))]
        # else:
        order = [np.arange(len(self))]
        order.append(self.lengthes)
        return np.lexsort(order)[::-1] # 长的优先
        # return np.lexsort(order) # 短的优先
        
    def collate_frames(self, frames, fbank_sizes=None, heights=None) -> torch.Tensor:
        if fbank_sizes is None:
            fbank_sizes = [item.size(0) for item in frames]
        max_len = max(fbank_sizes)
        if heights is not None:
            max_height = max(heights)
        else:
            max_height = max([item.size(-1) for item in frames])
        out = frames[0].new_zeros((len(frames), max_len, max_height))
        for i, v in enumerate(frames):
            out[i, :min(max_len, v.size(0)), :min(max_height, v.size(1))] = v[:, :min(max_height, v.size(1))]
        return out
    
    def collate_frame_tokens(self, frames, padding) -> torch.Tensor:
        lengthes = [item.size(0) for item in frames]
        max_len = max(lengthes)
        heights = [item.size(-1) for item in frames]
        max_height = max(heights)
        out = frames[0].new_ones((len(frames), max_len, max_height)) * padding
        for i, v in enumerate(frames):
            out[i, :min(max_len, v.size(0)), :min(max_height, v.size(1))] = v[:min(max_len, v.size(0)), :min(max_height, v.size(1))]
        return out, torch.tensor(lengthes), torch.tensor(heights)
    
def make_pad_mask(lengths: List[int], max_len: int = 0):
    """Make mask containing indices of padded part.

    See description of make_non_pad_mask.

    Args:
        lengths (List[int]): Batch of lengths (B,).
    Returns:
        np.ndarray: Mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    # lengths:(B,)
    batch_size = int(len(lengths))
    max_len = max_len if max_len > 0 else max(lengths)
    # np.arange(0, max_len): [0,1,2,...,max_len-1]
    # seq_range (1, max_len)
    seq_range = np.expand_dims(np.arange(0, max_len), 0)
    # seq_range_expand (B,max_len)
    seq_range_expand = np.tile(seq_range, (batch_size, 1))
    # (B,1)
    seq_length_expand = np.expand_dims(lengths, -1)
    # (B,max_len)
    mask = seq_range_expand >= seq_length_expand
    return torch.BoolTensor(mask)

@dataclass
class CosyvoiceConfig(FairseqDataclass):
    data: str = field(default="", metadata={"help": "path to data directory"})
    max_keep_size: Optional[int] = field(
        default=None,
        metadata={"help": "exclude sample longer than this"},
    )
    max_sample_size: Optional[int] = field(
        default=None,
        metadata={"help": "max sample size to crop to for batching"},
    )
    max_valid_sample_size: Optional[int] = field(
        default=None,
        metadata={"help": "max sample size to crop to for batching"},
    )
    min_sample_size: Optional[int] = field(
        default=None,
        metadata={"help": "min sample size to crop to for batching"},
    )
    pretrained_home : str = field(default="", metadata={"help": "path to data directory"})
    processed_home : str = field(default="", metadata={"help": "path to data directory"})
    audio_home : str = field(default="", metadata={"help": "path to data directory"})
    bert_pretrained_dir : str = field(default="", metadata={"help": "path to data directory"})
    persona_feature_home : str = field(default="", metadata={"help": "path to data directory"})
    persona_json_path : str = field(default="", metadata={"help": "path to data directory"})
    contextual_num : int = field(default=2, metadata={"help": "path to data directory"})
    instruct_bert_home : str = field(default=2, metadata={"help": "path to data directory"})
    only_refined_text : bool = field(default=False, metadata={"help": "path to data directory"})
    train_fm : bool = field(default=False, metadata={"help": "path to data directory"})
    pretrained_frameemo : str = field(default="", metadata={"help": "path to data directory"})
    speechcraft_origin_infos : str = field(default="", metadata={"help": "path to data directory"})
    speechcraft_descriptions : str = field(default="", metadata={"help": "path to data directory"})
    speechcraft_emo_embs_home : str = field(default="", metadata={"help": "path to data directory"})
    prompt_json : str = field(default="", metadata={"help": "path to data directory"})
    use_no_denoised : bool = field(default=False, metadata={"help": "path to data directory"})
    attention_bias : str = field(default="", metadata={"help": "path to data directory"})
    attention_range: str = field(default="", metadata={"help": "path to data directory"})
    multi_ab: bool = field(default=False, metadata={"help": "path to data directory"})
    
@register_task("cosyvoice_emosft_dataset", dataclass=CosyvoiceConfig)
class CosyvoiceSFTTask(FairseqTask):
    cfg: CosyvoiceConfig
    def __init__(
        self,
        cfg: CosyvoiceConfig,
    ) -> None:
        super().__init__(cfg)
        logger.info(f"current directory is {os.getcwd()}")
        logger.info(f"Config {cfg}")
        self.cfg = cfg

    @property
    def source_dictionary(self):
        return None

    @property
    def target_dictionary(self):
        return None

    @property
    def dictionaries(self):
        return [None]

    @classmethod
    def setup_task(
        cls, cfg: CosyvoiceConfig, **kwargs
    ):
        return cls(cfg)

    def load_dataset(self, split: str, **kwargs) -> None:
        manifest = f"{self.cfg.data}/{split}.tsv"
        
        if "valid" in split or "test" in split or "dev" in split:
            max_keep = self.cfg.max_valid_sample_size
        else:
            max_keep = self.cfg.max_sample_size

        self.datasets[split] = CELSDSDataset(
            manifest,
            max_keep_sample_size=max_keep,
            min_keep_sample_size=self.cfg.min_sample_size,
            max_sample_size=self.cfg.max_sample_size,
            max_valid_sample_size=self.cfg.max_valid_sample_size,
            prompt_json=self.cfg.prompt_json,
            attention_bias=self.cfg.attention_bias,
            attention_range=self.cfg.attention_range,
            multi_ab=self.cfg.multi_ab
        )
        logger.info(f"{split} dataloader is length of {len(self.datasets[split])}")

    def max_positions(self) -> Tuple[int, int]:
        return (sys.maxsize, sys.maxsize)

    def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
        return indices

    def build_model(self, cfg, from_checkpoint=False):
        from fairseq import models, quantization_utils
        model = models.build_model(cfg, self, from_checkpoint)
        if os.path.exists(os.path.join(self.cfg.pretrained_home, "llm.pt")):
            model = load_pretrained_models(model, os.path.join(self.cfg.pretrained_home, "llm.pt"))
        if os.path.exists(self.cfg.pretrained_frameemo):
            model = load_pretrained_fairseq_models(model, self.cfg.pretrained_frameemo)
        return model

from copy import deepcopy
def load_pretrained_models(model, pretrained_checkpoint):
    state = torch.load(pretrained_checkpoint)
    model_state = model.state_dict()
    cur_model_keys = deepcopy(list(model_state.keys()))
    pretrained_dict = {}
    for k, v in state.items():
        if k in model_state.keys():
            if v.size() == model_state[k].size():
                pretrained_dict[k] = v.to(model_state[k].dtype)
                cur_model_keys.remove(k)
            else:
                pretrained_dict[k] = model_state[k].to(model_state[k].dtype)
                print(k, v.size(), model_state[k].size(), "size error")
    print("loaded: "+str(pretrained_dict.keys()))
    print("not loaded: "+str(cur_model_keys))
    model_state.update(pretrained_dict)
    model.load_state_dict(model_state)
    return model

def load_pretrained_fairseq_models(model, pretrained_checkpoint):
    state = torch.load(pretrained_checkpoint)["model"]
    model_state = model.state_dict()
    cur_model_keys = deepcopy(list(model_state.keys()))
    pretrained_dict = {}
    for k, v in state.items():
        if k in model_state.keys():
            if v.size() == model_state[k].size():
                pretrained_dict[k] = v.to(model_state[k].dtype)
                cur_model_keys.remove(k)
            else:
                pretrained_dict[k] = model_state[k].to(model_state[k].dtype)
                print(k, v.size(), model_state[k].size(), "size error")
    print("fairseq loaded: "+str(pretrained_dict.keys()))
    print("fairseq not loaded: "+str(cur_model_keys))
    model_state.update(pretrained_dict)
    model.load_state_dict(model_state)
    return model

def resample_by_stride(tensor, scale):
    length = len(tensor)
    new_length = max(1, int(length * float(scale)))  # 计算新的长度
    indices = torch.linspace(0, length - 1, new_length).round().long()  # 均匀采样索引
    return tensor[indices]