import copy
import logging
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
from functools import partial

from .AudioEncoder import Cnn10, Cnn14, ResNet38
from .TextEncoder import BertEncoder
from dataclasses import dataclass
from typing import Optional


@dataclass
class WavCfg:
    sr: int = 32000               # 采样率
    window_size: int = 1024      # STFT窗口大小
    hop_length: int = 320        # 帧移
    mel_bins: int = 64           # 梅尔滤波器数



@dataclass
class AudioCfg:
    arch: str = "ResNet38"               # 支持 "Cnn10", "Cnn14", "ResNet38"
    wav: WavCfg = WavCfg()            # 声音信号参数
    freeze: bool = False              # 是否冻结 encoder 参数
    dropout: float = 0.2
    cnn_pretrained: bool=True
    spec_augmentation: bool = True
    
@dataclass
class TextCfg:
    model_type: str="bert-base-uncased"
    freeze: bool = False
    dropout: float = 0.2


def get_cast_dtype(precision: str):
    cast_dtype = None
    if precision == 'bf16':
        cast_dtype = torch.bfloat16
    elif precision == 'fp16':
        cast_dtype = torch.float16
    return cast_dtype

def get_input_dtype(precision: str):
    input_dtype = None
    if precision in ('bf16', 'pure_bf16'):
        input_dtype = torch.bfloat16
    elif precision in ('fp16', 'pure_fp16'):
        input_dtype = torch.float16
    return input_dtype

def _build_audio_tower(
        embed_dim: int,
        audio_cfg: AudioCfg,
        cast_dtype: Optional[torch.dtype] = None
):
    if isinstance(audio_cfg, dict):
        audio_cfg = AudioCfg(**audio_cfg)
    model_type = audio_cfg.arch
    if model_type == "Cnn10":
        model = Cnn10(audio_cfg)
        audio_head = nn.Sequential(
            model,
            nn.Linear(512, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
            )
    elif model_type == "Cnn14":
        model = Cnn14(audio_cfg)
        audio_head = nn.Sequential(
            model,
            nn.Linear(2048, embed_dim * 2),
            nn.ReLU(),
            nn.Linear(embed_dim * 2, embed_dim)
            )
    elif model_type == "ResNet38":
        model = ResNet38(audio_cfg)
        if audio_cfg.cnn_pretrained:
            # loading pretrained CNN weights
            pretrained_cnn = torch.load('float_distributed/pretrained_models/audio_encoder/ResNet38.pth')['model']
            dict_new = model.state_dict().copy()
            trained_list = [i for i in pretrained_cnn.keys()
                            if not ('fc' in i or i.startswith('spec') or i.startswith('logmel'))] # 检查 i 是否以 'xx' 开头，是，不保留
            for i in range(len(trained_list)):
                dict_new[trained_list[i]] = pretrained_cnn[trained_list[i]]
            model.load_state_dict(dict_new)

        audio_head = nn.Sequential(
            model,
            nn.Linear(2048, embed_dim * 2),
            nn.ReLU(),
            nn.Linear(embed_dim * 2, 768)
            )
    else:
        raise ValueError(f"Unsupported audio encoder: {model_type}")

    if cast_dtype is not None:
        audio_head = audio_head.to(dtype=cast_dtype)
    
    return audio_head

def _build_text_tower(
        embed_dim: int,
        text_cfg: TextCfg,
        cast_dtype: Optional[torch.dtype] = None,
    ):
    if isinstance(text_cfg, dict):
        text_cfg = TextCfg(**text_cfg)
    text = BertEncoder(bert_type=text_cfg.model_type, 
                       dropout=text_cfg.dropout,
                       freeze=text_cfg.freeze)
    if cast_dtype is not None:
        text = text.to(dtype=cast_dtype)

    return text



class AudioTextModel(nn.Module):
    output_dict: torch.jit.Final[bool]
    
    def __init__(self,
                 embed_dim: int,
                 audio_cfg: AudioCfg,
                 text_cfg: TextCfg,
                 cast_dtype: Optional[torch.dtype] = None,
                 output_dict: bool = False,
                 ):
        super().__init__()
        self.output_dict = output_dict

        self.audio = _build_audio_tower(embed_dim, audio_cfg, cast_dtype)
        self.text = _build_text_tower(embed_dim, text_cfg, cast_dtype)
    
    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.audio.set_grad_checkpointing(enable)
    
    def encode_audio(self, audio, normalize: bool=False):
        features = self.audio(audio)
        return F.normalize(features, dim=-1) if normalize else features
    
    def encode_text(self, input_ids, attention_mask, normalize: bool=False):
        features = self.text(input_ids, attention_mask)
        return F.normalize(features, dim=-1) if normalize else features
    
    def forward(self, 
                audio: Optional[torch.Tensor]=None,
                input_ids: Optional[torch.Tensor]=None,
                attention_mask: Optional[torch.Tensor]=None,
    ):
        audio_features = self.encode_audio(audio, normalize=True) if audio is not None else None
        text_features = self.encode_text(input_ids, attention_mask, normalize=True) if input_ids is not None else None
        
        if self.output_dict:
            out_dict = {
                "audio_features": audio_features,
                "text_features": text_features,
            }
            return out_dict
        return audio_features, text_features