#!/usr/bin/env python3
# coding: utf-8
import math
import torch
import torch.nn as nn

import torch.nn.functional as F
from torch import Tensor
from torch.nn.parameter import Parameter
from models.AudioEncoder import Cnn10, ResNet38, Cnn14
from models.TextEncoder import BertEncoder
from models.BERT_Config import MODELS

import ot

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class AudioEnc(nn.Module):
    def __init__(self, config):
        super().__init__()

        if config.cnn_encoder.model == 'Cnn10':
            self.audio_enc = Cnn10(config)
        elif config.cnn_encoder.model == 'ResNet38':
            self.audio_enc = ResNet38(config)
        elif config.cnn_encoder.model == 'Cnn14':
            self.audio_enc = Cnn14(config)
        else:
            raise NotImplementedError('No such audio encoder network.')
        
        if config.cnn_encoder.pretrained:
            # loading pretrained CNN weights
            pretrained_cnn = torch.load('pretrained_models/audio_encoder/{}.pth'.
                                        format(config.cnn_encoder.model))['model']
            dict_new = self.audio_enc.state_dict().copy()
            trained_list = [i for i in pretrained_cnn.keys()
                            if not ('fc' in i or i.startswith('spec') or i.startswith('logmel'))] # 检查 i 是否以 'xx' 开头，是，不保留
            for i in range(len(trained_list)):
                dict_new[trained_list[i]] = pretrained_cnn[trained_list[i]]
            self.audio_enc.load_state_dict(dict_new)
        if config.training.freeze:
            for name, param in self.audio_enc.named_parameters():
                param.requires_grad = False
            
    def forward(self, inputs):
        audio_encoded = self.audio_enc(inputs)
        return audio_encoded



class ASE_mxh_ot(nn.Module):
    def __init__(self, config):
        super(ASE_mxh_ot, self).__init__()
        self.epsilon = config.training.epsilon
        self.l2 = config.training.l2
        joint_embed = config.joint_embed

        self.audio_enc = AudioEnc(config)
        if config.cnn_encoder.model == 'Cnn10':
            self.audio_linear = nn.Sequential(
                nn.Linear(512, joint_embed),
                nn.ReLU(),
                nn.Linear(joint_embed, joint_embed)
            )
        elif config.cnn_encoder.model == 'ResNet38' or config.cnn_encoder.model == 'Cnn14':
            self.audio_linear = nn.Sequential(
                nn.Linear(2048, joint_embed * 2),
                nn.ReLU(),
                nn.Linear(joint_embed * 2, joint_embed)
            )
            
        if config.text_encoder == 'bert':
            self.text_enc = BertEncoder(config)
            bert_type = config.bert_encoder.type
            self.text_linear = nn.Sequential(
                nn.Linear(MODELS[bert_type][2], joint_embed * 2),
                nn.ReLU(),
                nn.Linear(joint_embed * 2, joint_embed)
            )
        
        self.L = nn.Sequential(
            nn.Linear(2048, 2048), 
            nn.ReLU(),
            nn.Linear(2048, 768)
        )
        
    def backbone_params(self):
        """
        Returns the parameters of cnn_encoder and text_enc for separate optimization.
        """
        params = []
        if hasattr(self, 'audio_enc'):
            params += list(self.audio_enc.parameters())
        if hasattr(self, 'text_enc'):
            params += list(self.text_enc.parameters())     
        return params
    
    def encode_audio(self, audios):
        return self.audio_enc(audios)
    
    def encode_text(self, input_ids, attention_mask):
        return self.text_enc(input_ids, attention_mask)

    def forward(self, audios, input_ids, attention_mask):
        if audios == None:
            audio_encoded = None
            audio_embed = None
        else:
            audio_encoded = self.encode_audio(audios)
            audio_embed = self.L(audio_encoded)
            
            audio_embed = F.normalize(audio_embed, p=2, dim=-1)
        
        if input_ids == None:
            input_ids = None
            caption_embed = None
        else: 
            caption_encoded = self.encode_text(input_ids, attention_mask)
            # caption_embed = self.text_linear(caption_encoded)
            
            caption_embed = F.normalize(caption_encoded, p=2, dim=-1)
        
        
        batch_size = audio_embed.size(0)
        a = torch.ones(batch_size)/batch_size
        b = torch.ones(batch_size)/batch_size
        a = a.to(audio_embed.device)
        b = b.to(audio_embed.device)
        
        M = torch.cdist(audio_embed, caption_embed, p=2)
        M = M / M.max()

        SK = ot.sinkhorn(a, b, M, reg=self.epsilon, 
                         numItermax=10)

        return SK, audio_embed, caption_embed