import os
import sys
import copy
import json
import torch
import random
import argparse
import subprocess
import numpy as np
import soundfile as sf
import subprocess
import concurrent.futures

from vl_load.vita.model.vita_tts.decoder.decoder import LLM2TTSCodecAR
from vl_load.vita.model.vita_tts.decoder.ticodec.vqvae_tester import VqvaeTester

class llm2TTS():
    def __init__(self, model_path):
        self.model = self.get_model(model_path).cuda().to(
                                    torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
                                    )
        self.infer = self.model.infer

        self.codec_model = VqvaeTester(config_path=model_path + "/codec/model.json", 
                                        model_path=model_path + "/codec/final.pt",
                                        sample_rate=24000)
        self.codec_model = self.codec_model.cuda()
        self.codec_model.vqvae.generator.remove_weight_norm()
        self.codec_model.vqvae.encoder.remove_weight_norm()
        self.codec_model.eval()

    def get_model_conf(self, model_path):
        model_conf = model_path + "/decoder/model.json"
        with open(model_conf, "rb") as f:
            print('reading a config file from ' + model_conf)
            confs = json.load(f)
        # for asr, tts, mt
        idim, odim, args = confs
        return argparse.Namespace(**args)

    def get_model(self, model_path):
        args_load = self.get_model_conf(model_path)
        args_load = vars(args_load)
        args = argparse.Namespace(**args_load)
        odim = args.odim
        idim = args.idim
        model = LLM2TTSCodecAR(idim, odim, args)

        # Resume from a snapshot
        snapshot_dict = torch.load(model_path + "/decoder/final.pt", map_location=lambda storage, loc: storage)
        if 'model' in snapshot_dict.keys():
            resume_model_dict = snapshot_dict['model']
        else:
            resume_model_dict = snapshot_dict
        
        model_dict = model.state_dict()
        for key in model_dict.keys():
            if key in resume_model_dict.keys():
                if model_dict[key].shape == resume_model_dict[key].shape:
                    model_dict[key] = resume_model_dict[key]
                else:
                    print('Key {} has different shape, {} VS {}'.format(key, model_dict[key].shape, 
                                                                        resume_model_dict[key].shape))
            else:
                print('Key {} has not in resume model'.format(key))
        model.load_state_dict(model_dict)
        model.eval()
        return model
    
    def find_min_sum_index(self, buffer, syn, N, threshold):
        """
        Find the index with the minimum sum of a sliding window in the given audio segment 
        and perform operations based on this index.

        Parameters:
        - buffer (torch.Tensor): The buffer containing previously processed audio segments.
        - syn (torch.Tensor): The current audio segment to be processed.
        - N (int): The size of the sliding window used to calculate the sum.
        - threshold (float): Threshold value to determine whether to concatenate buffer and current segment or not.

        Returns:
        - tuple: A tuple containing the updated buffer and the processed audio segment.

        """
        arr = syn[0, 0, :]
        L = len(arr)
        mid = L // 2
        
        kernel = torch.ones(N).to(arr.device)
        window_sums = torch.nn.functional.conv1d(arr.abs().view(1, 1, -1), kernel.view(1, 1, -1), padding=0).squeeze()
        
        start_index = mid - (N // 2)
        min_sum, min_index = torch.min(window_sums[start_index:], dim=0)

        # get the start and end index of the window
        start_index = max(0, min_index.item() + start_index)
        end_index = min(L, min_index.item() + N + start_index)
        
        # calculate the real min_sum and min_index
        min_sum_real, min_index_real = torch.min(arr[start_index: end_index].abs(), dim=0)
        min_index = min_index_real.item() + start_index

        min_sum = min_sum / N
        syn_clone = syn.clone()

        if min_sum < threshold:
            syn = torch.cat([buffer.clone(), syn[:, :, :min_index]], dim=-1)
            buffer = syn_clone[:, :, min_index:]
        else:
            buffer = torch.cat([buffer, syn_clone], dim=-1)
            syn = None
        return buffer, syn

    def run(self, hidden, top_k, prefix, codec_chunk_size=40, codec_padding_size=10, 
            penalty_window_size=-1, penalty=1.1, N=2401, seg_threshold=0.01):
        """
        Run the speech decoder process.

        Parameters:
        - hidden (torch.Tensor): The output for embedding layer of the language model.
        - top_k (int): The number of top-k tokens to consider during inference.
        - prefix (str, optional): The hidden state from the language model.
        - codec_chunk_size (int, default=40): The size of each chunk to process in the codec model.
        - codec_padding_size (int, default=10): The amount of padding to add on each side of the codec chunk.
        - penalty_window_size (int, default=20): The window size for applying penalties during decoding.
        - penalty (float, default=1.1): The penalty factor.

        Yields:
        - torch.Tensor: Intermediate audio segments generated by the codec model.

        """
        codec_upsample_rate = 600
        left_padding = 0
        right_padding = codec_padding_size
        prefix = None
        buffer = torch.zeros([1, 1, 0]).to(hidden.device)
        with torch.no_grad():
            with torch.autocast(device_type="cuda", 
                    dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32):
                print("Starting TTS...")
                token = torch.full((1, 0), self.model.vocab_size, dtype=torch.long, device=hidden.device)
                for next_token_id in self.infer(hidden, top_k, prefix, penalty_window_size, penalty):
                    token = torch.cat([token, next_token_id], dim=-1)
                    if token.size(1) == left_padding + codec_chunk_size + right_padding:
                        syn = self.codec_model.vqvae(token.unsqueeze(-1), 
                                                     torch.tensor(self.codec_model.vqvae.h.global_tokens, 
                                                     device=token.device).unsqueeze(0).unsqueeze(0))
                        print("Codec Done")
                        syn = syn[:, :, left_padding * codec_upsample_rate: -right_padding * codec_upsample_rate]
                        left_padding = codec_padding_size
                        token = token[:, -(left_padding + right_padding):]
                        buffer, syn = self.find_min_sum_index(buffer, syn, N, seg_threshold)
                        if syn is not None:
                            yield syn
                if token.size(1) > 0:
                    print("Codec Done")
                    syn = self.codec_model.vqvae(token.unsqueeze(-1), 
                                                 torch.tensor(self.codec_model.vqvae.h.global_tokens, 
                                                 device=token.device).unsqueeze(0).unsqueeze(0))
                    syn = syn[:, :, left_padding * codec_upsample_rate:]
                    yield torch.cat([buffer, syn], dim=-1)
