import os
import torch
import logging
from multiHeadAttn import PlainMultiHeadAttention, ImageBindMultiheadAttention

def create_logger(args):
    # 创建一个日志记录器
    logger = logging.getLogger('my_logger')
    logger.setLevel(logging.INFO) # 设置记录器的级别为INFO
    logger.propagate = False # 阻止消息传递给更高级别的日志记录器


    # 创建一个处理器，将日志消息写入文件
    log_file = os.path.join(args.savedir, 'log_file.log')
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO) # 设置文件处理器的级别为INFO

    # 创建一个处理器，将日志消息打印到控制台
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO) # 设置控制台处理器的级别为INFO

    # 创建一个格式化程序，并将其添加到处理器
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    console_handler.setFormatter(formatter)

    # 将处理器添加到记录器
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    # 记录一些消息
    logger.info('This message will go to both the console and the log file.') # 使用info方法记录消息
    return logger


def get_uni_enc_ckpt(dataset_name, model_name):
    assert dataset_name in ['AVE', 'ks', 'CREMA-D']
    assert model_name in ['resnet', 'vit']
    pass
    
def freeze_encoders(MMmodel):
    for k, v in MMmodel.audio_model.named_parameters():
        if 'clf' not in k:
            v.requires_grad = False
    for k, v in MMmodel.video_model.named_parameters():
        if 'clf' not in k:
            v.requires_grad = False

def load_uni_enc_ckpt(args, model):
    audio_path, video_path = get_uni_enc_ckpt(args.dataset_name, args.model)
    audio_ckpt = torch.load(audio_path)
    video_ckpt = torch.load(video_path)
    model.audio_model.load_state_dict(audio_ckpt)
    model.video_model.load_state_dict(video_ckpt)

def change_multihead(model, mode='A'):
    assert mode in ['A', 'V']
    if mode == 'A':
        for module in model.blocks:
            new_module = ImageBindMultiheadAttention(add_bias_kv=True)
            new_module.set_parameters(module.attn)
            module.attn = new_module
    else:
        for module in model.transformer.resblocks:
            new_module = PlainMultiHeadAttention()
            new_module.set_parameters(module.attn)
            module.attn = new_module

