# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import os
# import tgt
import sys
from typing import Dict, List, Optional, Tuple
import numpy as np
from dataclasses import dataclass, field
from fairseq.dataclass.configs import FairseqDataclass
from fairseq.tasks import register_task
from fairseq.tasks.fairseq_task import FairseqTask
import torch
from fairseq import utils
import torch.nn.functional as F
from fairseq.data import data_utils
from fairseq.data.fairseq_dataset import FairseqDataset
import json
from pathlib import Path
from tqdm import tqdm
from ..models.cosyvoice2.cli.frontend import CosyVoiceTextFrontEnd
import re
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')


# whether contain chinese character
def contains_chinese(text):
    return bool(chinese_char_pattern.search(text))

def replace_i(text):
    # 只替换独立的 " i "，确保前后是空格或标点
    return re.sub(r'(?<=\s)i(?=\s)', 'I', text)

logger = logging.getLogger(__name__)

def get_all_files(root, suffix):
    files = []
    for p in Path(root).iterdir():
        if str(p).endswith(f".{suffix}"):
            files.append(str(p))
        for s in p.rglob(f'*.{suffix}'):
            files.append(str(s))
    return list(set(files))

def load_audio(manifest_path, max_keep, min_keep, text_frontend):
    n_long, n_short, error_num = 0, 0, 0
    texts, codecs, align_texts, lengthes = [], [], [], []
    with open(manifest_path, "r", encoding='utf-8') as f:
        data = json.load(f)
        for key in tqdm(data.keys()):
            item = data[key]
            # id = key
            text = item["t"]
            if not contains_chinese(text):
                text = replace_i(text.capitalize())
            else:
                text = text.replace(" ", "，")
            if len(text) < 2:
                n_short += 1
                continue
            codec = item["c"]
            length = codec.count(" ")
            align = item["a"]
            length = int(length + len(text))
            if min_keep is not None and length < min_keep:
                n_short += 1
            elif max_keep is not None and length > max_keep:
                n_long += 1
            else:
                texts.append(text)
                codecs.append(codec)
                align_texts.append(align)
                lengthes.append(length)
    logger.info(
        (
            f"max_keep={max_keep}, min_keep={min_keep}, "
            f"loaded {len(lengthes)}, skipped {n_short} short and {n_long} long, not exist {error_num} "
            f"longest-loaded={max(lengthes)}, shortest-loaded={min(lengthes)}"
        )
    )
    # if wf is not None:
    #     wf.close()
    assert len(texts) == len(codecs)
    assert len(texts) == len(align_texts)
    assert len(texts) == len(lengthes)
    return texts, codecs, align_texts, lengthes

def load_audio_zip(manifest_path, max_keep, min_keep):
    n_long, n_short, error_num = 0, 0, 0
    utt_ids, item_locs, lengthes = [], [], []
    with open(manifest_path, "r", encoding='utf-8') as f:
        for line in tqdm(f.readlines()):
            utt_id, item_zip_loc, codec_len, text_len = line.strip().split("\t")
            length = int(codec_len) + int(text_len)
            if len(text_len) < 2:
                n_short += 1
                continue
            if min_keep is not None and length < min_keep:
                n_short += 1
            elif max_keep is not None and length > max_keep:
                n_long += 1
            else:
                utt_ids.append(utt_id)
                item_locs.append(item_zip_loc)
                lengthes.append(length)
    logger.info(
        (
            f"max_keep={max_keep}, min_keep={min_keep}, "
            f"loaded {len(lengthes)}, skipped {n_short} short and {n_long} long, not exist {error_num} "
            f"longest-loaded={max(lengthes)}, shortest-loaded={min(lengthes)}"
        )
    )
    # if wf is not None:
    #     wf.close()
    assert len(utt_ids) == len(item_locs)
    assert len(utt_ids) == len(lengthes)
    return utt_ids, item_locs, lengthes

def check_and_load_path(temp_list, wf=None, err_set=None):
    items = []
    for name in temp_list:
        if wf is not None:
            try:
                items.append(np.load(name))
            except Exception as e:
                print(f"{name}", file=wf)
                return items, False
        else:
            if name in err_set:
                return items, False
    return items, True

def get_alignment(data, tot_len, tokenizer_func, sil_token, next_token):
    # with open(align_path, "r", encoding="utf-8") as f:
    #     data = json.load(f)
    aligned_text = (torch.ones(tot_len) * sil_token).to(torch.int32)
    first_text = True
    for char_item in data:
        if char_item["end_token"] > tot_len:
            char_item["end_token"] = tot_len
        char = char_item["char"]
        if char == "<SIL>":
            char = '<|silence|>'
        elif char == "i":
            char = "I"
        else:
            if first_text:
                char = char
                first_text = False
            else:
                if char.isalpha():
                    char = " "+char
        char = tokenizer_func(char)
        if len(char) != 1:
            temp = tokenizer_func(" "+char_item["char"].strip())
            if len(temp) == 1:
                char = temp
            else:
                # if len(char) >= 2:
                #     print()
                # 处理字符长度大于 1 的情况（核心逻辑）
                span_len = char_item["end_token"] - char_item["start_token"]
                if span_len / len(char) >= 2:
                    temp = (torch.ones(span_len) * char[0]).to(torch.int32)
                    start = 0
                    chunk = int(span_len // len(char))
                    for i in range(len(char)):
                        if i != len(char) - 1:
                            temp[start:start+chunk-1] = char[i]
                            start = start + chunk
                        else:
                            temp[start:span_len-1] = char[i]
                    char = temp
                else:
                    char = char[0] 
        
        start_idx = char_item["start_token"] # 第几个，并不是脚标
        end_idx = char_item["end_token"]
        if end_idx > tot_len:
            logger.info(
                f"{str(data)} \t end_idx:{end_idx} tot_len:{tot_len}"
            )
            end_idx = tot_len
        aligned_text[start_idx:end_idx] = char
        aligned_text[end_idx-1] = next_token
    aligned_text[-1] = next_token
    
    boundary = torch.where(aligned_text == next_token, torch.tensor(1.0), torch.tensor(0.0))
    prev = torch.roll(aligned_text, shifts=1)
    prev[0] = sil_token  # 第一个元素特殊处理，防止访问越界
    # 替换4为前一位
    aligned_text = torch.where(aligned_text == next_token, prev, aligned_text)
    return aligned_text.to(torch.long), boundary.to(torch.long)

class TPDataset(FairseqDataset):
    def __init__(
        self,
        manifest_path: str,
        max_keep_sample_size: Optional[int] = None,
        min_keep_sample_size: Optional[int] = None,
        max_sample_size = None,
        shuffle: bool = True,
        max_valid_sample_size = -1,
        text_frontend=None,
    ):
        if manifest_path.find("valid") != -1 or manifest_path.find("test") != -1 or manifest_path.find("dev") != -1:
            max_keep_sample_size = max_valid_sample_size
        self.shuffle = shuffle
        self.audio_pad_idx = 0
        self.text_frontend = text_frontend
        self.silence_token = self.text_frontend._extract_text_token('<|silence|>')[0]
        self.next_token = self.text_frontend._extract_text_token('<|nexttext|>')[0]
        self.text_pad_idx = self.text_frontend._extract_text_token('<|endoftext|>')[0]
        self.max_sample_size = max_sample_size
        if manifest_path.endswith(".json"):
            self.texts, self.codecs, self.align_texts, self.lengthes = load_audio(
                manifest_path, max_keep_sample_size, min_keep_sample_size, text_frontend=self.text_frontend
            )
        else:
            self.utt_ids, self.item_locs, self.lengthes = load_audio_zip(
                manifest_path, max_keep_sample_size, min_keep_sample_size, 
            )
            self.texts = None
            
    def __getitem__(self, index):
        if self.texts is not None:
            text = self.texts[index]
            codec = self.codecs[index]
            align = self.align_texts[index]
        else:
            zip_path, offset, file_size = self.item_locs[index].split(":")
            offset, file_size = int(offset), int(file_size)
            with open(zip_path, "rb") as f:
                f.seek(offset)
                byte_data = f.read(file_size)
                assert len(byte_data) > 1
                json_data = json.loads(byte_data.decode("utf-8"))
                text = json_data["t"]
                if not contains_chinese(text):
                    text = replace_i(text.capitalize())
                else:
                    text = text.replace(" ", "，")
                codec = json_data["c"]
                align = json_data["a"]
        text = self.text_frontend.norm_and_extract_token(text).flatten().to(torch.long)
        codec = torch.LongTensor(np.array(codec.split(" "), dtype=int))
        aligned_text, boundary = get_alignment(
            align, len(codec), 
            self.text_frontend._extract_text_token, self.silence_token, self.next_token
        )
        return_dict = {
            "text": text,
            "codec": codec,
            "aligned_text": aligned_text,
            "boundary": boundary,
            "id": index
        }
        return return_dict
    
    def __len__(self):
        return len(self.lengthes)

    def collater(self, samples):
        samples = [s for s in samples if s["id"] is not None]
        if len(samples) == 0:
            return {}

        texts = [s["text"] for s in samples]
        text_lens = torch.LongTensor([temp.size(0) for temp in texts])
        codecs = [s["codec"] for s in samples]
        codec_lens = torch.LongTensor([temp.size(0) for temp in codecs])
        aligned_texts = [s["aligned_text"] for s in samples]
        boundarys = [s["boundary"] for s in samples]
        aligned_text_lens = torch.LongTensor([temp.size(0) for temp in aligned_texts])
        
        texts = data_utils.collate_tokens(
            texts,
            self.text_pad_idx,
            self.text_pad_idx,
            move_eos_to_beginning=False,
            pad_to_length=None,
        )
        aligned_texts = data_utils.collate_tokens(
            aligned_texts,
            self.text_pad_idx,
            self.text_pad_idx,
            move_eos_to_beginning=False,
            pad_to_length=None,
        )
        boundarys = data_utils.collate_tokens(
            boundarys,
            self.text_pad_idx,
            self.text_pad_idx,
            move_eos_to_beginning=False,
            pad_to_length=None,
        )
        codecs = data_utils.collate_tokens(
            codecs,
            self.audio_pad_idx,
            self.audio_pad_idx,
            move_eos_to_beginning=False,
            pad_to_length=None,
        )
        batch = {
            "net_input": {
                "codecs": codecs,
                "aligned_texts": aligned_texts,
                "texts": texts,
                "boundarys": boundarys,
                "codec_lens": codec_lens,
                "aligned_text_lens": aligned_text_lens,
                "text_lens": text_lens,
            }
        }
        return batch
    
    def num_tokens(self, index):
        return self.size(index)

    def size(self, index):
        return min(self.lengthes[index], self.max_sample_size)

    def ordered_indices(self):
        if self.shuffle:
            order = [np.random.permutation(len(self))]
        else:
            order = [np.arange(len(self))]
        order.append(self.lengthes)
        # return np.lexsort(order)[::-1] # 长的优先
        return np.lexsort(order) # 短的优先
        
    def collate_frames(self, frames, fbank_sizes=None, heights=None) -> torch.Tensor:
        if fbank_sizes is None:
            fbank_sizes = [item.size(0) for item in frames]
        max_len = max(fbank_sizes)
        if heights is not None:
            max_height = max(heights)
        else:
            max_height = max([item.size(-1) for item in frames])
        out = frames[0].new_zeros((len(frames), max_len, max_height))
        for i, v in enumerate(frames):
            out[i, :min(max_len, v.size(0)), :min(max_height, v.size(1))] = v[:, :min(max_height, v.size(1))]
        return out
    
    def collate_frame_tokens(self, frames, padding) -> torch.Tensor:
        lengthes = [item.size(0) for item in frames]
        max_len = max(lengthes)
        heights = [item.size(-1) for item in frames]
        max_height = max(heights)
        out = frames[0].new_ones((len(frames), max_len, max_height)) * padding
        for i, v in enumerate(frames):
            out[i, :min(max_len, v.size(0)), :min(max_height, v.size(1))] = v[:min(max_len, v.size(0)), :min(max_height, v.size(1))]
        return out, torch.LongTensor(lengthes), torch.LongTensor(heights)
    
def make_pad_mask(lengths: List[int], max_len: int = 0):
    """Make mask containing indices of padded part.

    See description of make_non_pad_mask.

    Args:
        lengths (List[int]): Batch of lengths (B,).
    Returns:
        np.ndarray: Mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    # lengths:(B,)
    batch_size = int(len(lengths))
    max_len = max_len if max_len > 0 else max(lengths)
    # np.arange(0, max_len): [0,1,2,...,max_len-1]
    # seq_range (1, max_len)
    seq_range = np.expand_dims(np.arange(0, max_len), 0)
    # seq_range_expand (B,max_len)
    seq_range_expand = np.tile(seq_range, (batch_size, 1))
    # (B,1)
    seq_length_expand = np.expand_dims(lengths, -1)
    # (B,max_len)
    mask = seq_range_expand >= seq_length_expand
    return torch.BoolTensor(mask)

def create_cross_attention_mask(query_lens, key_lens):
    """
    创建 Cross-Attention Mask，用于屏蔽 query 和 key 之间的无效交互
    query_lens: (batch_size,) 每个样本的 query 长度 decoder
    key_lens: (batch_size,) 每个样本的 key/value 长度 encoder
    """
    batch_size = len(query_lens)
    max_query_len = max(query_lens)  # Query 维度（decoder 输入）
    max_key_len = max(key_lens)  # Key/Value 维度（encoder 输出）

    query_mask = torch.arange(max_query_len).expand(batch_size, max_query_len) < torch.tensor(query_lens).unsqueeze(1)
    key_mask = torch.arange(max_key_len).expand(batch_size, max_key_len) < torch.tensor(key_lens).unsqueeze(1)

    # 生成 (batch_size, max_query_len, max_key_len) 形状的 cross-attention mask
    cross_mask = query_mask.unsqueeze(2) & key_mask.unsqueeze(1)
    return cross_mask

def fill_mask_tokens(tokens, tgt_len, mask_id):
    filled_matrix = (torch.ones(len(tokens), tgt_len) * mask_id).to(torch.long)
    for idx, token in enumerate(tokens):
        filled_matrix[idx][-len(token):] = token
    return filled_matrix

@dataclass
class CosyVoice2TPConfig(FairseqDataclass):
    data: str = field(default="", metadata={"help": "path to data directory"})
    max_keep_size: Optional[int] = field(
        default=None,
        metadata={"help": "exclude sample longer than this"},
    )
    max_sample_size: Optional[int] = field(
        default=None,
        metadata={"help": "max sample size to crop to for batching"},
    )
    max_valid_sample_size: Optional[int] = field(
        default=None,
        metadata={"help": "max sample size to crop to for batching"},
    )
    min_sample_size: Optional[int] = field(
        default=None,
        metadata={"help": "min sample size to crop to for batching"},
    )
    pretrained_home : str = field(default="", metadata={"help": "path to data directory"})
    data_type : str = field(default="json", metadata={"help": "path to data directory"})
    pretrained_tp_model : str = field(default="", metadata={"help": "path to data directory"})
    
    
@register_task("cosyvoice2_tp", dataclass=CosyVoice2TPConfig)
class CosyVoice2TPTask(FairseqTask):
    cfg: CosyVoice2TPConfig
    def __init__(
        self,
        cfg: CosyVoice2TPConfig,
    ) -> None:
        super().__init__(cfg)
        logger.info(f"current directory is {os.getcwd()}")
        logger.info(f"Config {cfg}")
        self.cfg = cfg
        if not os.path.exists(self.cfg.pretrained_home):
            self.cfg.pretrained_home = os.environ.get("COSYVOICE2HOME")
        
        self.text_frontend = CosyVoiceTextFrontEnd()
        self.text_frontend.add_special_tokens({'additional_special_tokens':['<|silence|>', '<|nexttext|>']})
        self.silence_token = self.text_frontend._extract_text_token('<|silence|>')[0]
        self.next_token = self.text_frontend._extract_text_token('<|nexttext|>')[0]
        self.text_pad_idx = self.text_frontend._extract_text_token('<|endoftext|>')[0]
        logger.info(f"text tokenizer size:{len(self.text_frontend.tokenizer.tokenizer)}, silence token: {self.silence_token}, next token: {self.next_token}")
        
    @property
    def source_dictionary(self):
        return None

    @property
    def target_dictionary(self):
        return None

    @property
    def dictionaries(self):
        return [None]

    @classmethod
    def setup_task(
        cls, cfg: CosyVoice2TPConfig, **kwargs
    ):
        return cls(cfg)

    def load_dataset(self, split: str, **kwargs) -> None:
        manifest = f"{self.cfg.data}/{split}.{self.cfg.data_type}"
        
        if "valid" in split or "test" in split or "dev" in split:
            max_keep = self.cfg.max_valid_sample_size
        else:
            max_keep = self.cfg.max_sample_size

        self.datasets[split] = TPDataset(
            manifest,
            max_keep_sample_size=max_keep,
            min_keep_sample_size=self.cfg.min_sample_size,
            max_sample_size=self.cfg.max_sample_size,
            max_valid_sample_size=self.cfg.max_valid_sample_size,
            text_frontend=self.text_frontend
        )
        logger.info(f"{split} dataloader is length of {len(self.datasets[split])}")
        

    def max_positions(self) -> Tuple[int, int]:
        return (sys.maxsize, sys.maxsize)

    def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array:
        return indices

    def build_model(self, cfg, from_checkpoint=False):
        from fairseq import models, quantization_utils
        model = models.build_model(cfg, self)
        model = load_pretrained_models(model, os.path.join(self.cfg.pretrained_home, "llm.pt"))
        if os.path.exists(self.cfg.pretrained_tp_model):
            model = load_pretrained_tp_models(model, self.cfg.pretrained_tp_model)
        return model

from copy import deepcopy
def load_pretrained_models(model, pretrained_checkpoint):
    state = torch.load(pretrained_checkpoint)
    model_state = model.state_dict()
    cur_model_keys = deepcopy(list(model_state.keys()))
    pretrained_dict = {}
    for k, v in state.items():
        if k in model_state.keys():
            if v.size() == model_state[k].size():
                pretrained_dict[k] = v.to(model_state[k].dtype)
                cur_model_keys.remove(k)
            else:
                pretrained_dict[k] = model_state[k].to(model_state[k].dtype)
                print(k, v.size(), model_state[k].size(), "size error")
    print("loaded: "+str(pretrained_dict.keys()))
    print("not loaded: "+str(cur_model_keys))
    model_state.update(pretrained_dict)
    model.load_state_dict(model_state)
    return model

def load_pretrained_tp_models(model, pretrained_checkpoint):
    state = torch.load(pretrained_checkpoint)["model"]
    model_state = model.state_dict()
    cur_model_keys = deepcopy(list(model_state.keys()))
    pretrained_dict = {}
    for k, v in state.items():
        if k in model_state.keys():
            if v.size() == model_state[k].size():
                pretrained_dict[k] = v.to(model_state[k].dtype)
                cur_model_keys.remove(k)
            else:
                pretrained_dict[k] = model_state[k].to(model_state[k].dtype)
                print(k, v.size(), model_state[k].size(), "size error")
    print("loaded tp: "+str(pretrained_dict.keys()))
    print("not loaded tp: "+str(cur_model_keys))
    model_state.update(pretrained_dict)
    model.load_state_dict(model_state)
    return model