# -*- coding: utf-8 -*-
import logging
import os
import time
import random
import json
from tqdm import tqdm
import sys
# import wandb
import torch
from itertools import chain
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader, ConcatDataset
# from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import StepLR, MultiStepLR

import numpy as np
from configs.opts import parser
# from model.main_model_2 import AV_VQVAE_Encoder as AV_VQVAE_Encoder
# from model.main_model_2 import AV_VQVAE_Decoder as AV_VQVAE_Decoder
# from model.main_model_2 import Semantic_Decoder
# from model.main_model_3 import Baseline2_VQVAE_Encoder, MCL_VQVAE_Decoder, TextLSTM

from model.main_model_2_dcid import AVT_VQVAE_Encoder as AVT_VQVAE_Encoder_dcid
from model.main_model_2 import Semantic_Decoder,AVT_VQVAE_Encoder,AVT_VQVAE_Decoder


# from model.CLUB import CLUBSample_group_video, CLUBSample_group_audio
from model.CLUB import CLUBSample_group
from model.CPC import Cross_CPC_AVT

from utils import AverageMeter, Prepare_logger, get_and_save_args
from utils.container import metricsContainer
from utils.Recorder import Recorder
# from dataset.AVE_dataset import AVEDataset
import torch.nn.functional as F
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from bert_embedding import BertEmbedding
import pickle
import torch.nn.utils.rnn as rnn_utils
# from model.main_model_3_newm3oe import ImageLSTM
# from model.main_model_3_newm3oe import MCL_VQVAE_Encoder_Uni_weight
# from model.CPC import Cross_CPC_MCL_weight
# from model.ewc_moe_3 import EWC

# =================================  seed config ============================
SEED = 43
random.seed(SEED)
np.random.seed(seed=SEED)
torch.manual_seed(seed=SEED)
torch.cuda.manual_seed(seed=SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ['CUDA_VISIBLE_DEVICES'] = "3"

# =============================================================================
bert_embedding = BertEmbedding()
with open('../../cnt.pkl', 'rb') as fp:
    id2idx = pickle.load(fp)

def collate_func_AVT(samples):
        """caption preprocess"""
        bsz = len(samples)
        result = bert_embedding([sample[0] for sample in samples])
        query = []
        for a, b in result:
            words_emb = []
            for word, emb in zip(a, b):
                idx = bert_embedding.vocab.token_to_idx[word]
                words_emb.append(emb)
            query.append(np.asarray(words_emb))

        max_query_len = 30
        query_len = []    
        query1 = np.zeros([bsz, max_query_len, 768]).astype(np.float32)
        for i, sample in enumerate(query):
            # print(sample.shape[0], query1.shape[1])# 有的词不在现在用的bert模型里
            keep = min(sample.shape[0], query1.shape[1])
            query_len.append(keep)
            query1[i, :keep] = sample[:keep]
        query_len = np.asarray(query_len)
        query, query_len = torch.from_numpy(query1).float(), torch.from_numpy(query_len).long()

        """audio preprocess"""
        max_audio_length = max([i[4] for i in samples])
        audio_tensor = []
        for _, _, audio_fea, _, _, _ in samples:
            if max_audio_length > audio_fea.shape[0]:
                padding = torch.zeros(max_audio_length - audio_fea.shape[0], 128).float()
                temp_audio = torch.cat([torch.from_numpy(audio_fea).float(), padding])
            else:
                temp_audio = torch.from_numpy(audio_fea[:max_audio_length]).float()
            audio_tensor.append(temp_audio)

        # audios_tensor = torch.cat(audio_tensor)
        # audio_ids = torch.Tensor([i[2] for i in samples])
        audio_len = torch.Tensor([i[4] for i in samples])
        # indexs = np.array([i[3] for i in samples])

        """video preprocess"""
        max_video_length = max([i[4] for i in samples])
        video_tensor = []
        for _, video_fea, _, _, _, _ in samples:
            if max_video_length > video_fea.shape[0]:
                padding = torch.zeros(max_video_length - video_fea.shape[0], video_fea.shape[1], video_fea.shape[2], video_fea.shape[3]).float()
                temp_video = torch.cat([torch.from_numpy(video_fea).float(), padding])
            else:
                temp_video = torch.from_numpy(video_fea[:max_video_length]).float()
            video_tensor.append(temp_video)

        # videos_tensor = torch.cat(video_tensor)
        # video_ids = torch.Tensor([i[1] for i in samples])
        video_len = torch.Tensor([i[4] for i in samples])

        # indexs = np.array([i[4] for i in samples])
        return torch.stack([a for a in audio_tensor]), audio_len, query, query_len, torch.stack([a for a in video_tensor]), video_len



def AVPSLoss(av_simm, soft_label):
    """audio-visual pair similarity loss for fully supervised setting,
    please refer to Eq.(8, 9) in our paper.
    """
    # av_simm: [bs, 10]
    relu_av_simm = F.relu(av_simm)
    sum_av_simm = torch.sum(relu_av_simm, dim=-1, keepdim=True)
    avg_av_simm = relu_av_simm / (sum_av_simm + 1e-8)
    loss = nn.MSELoss()(avg_av_simm, soft_label)
    return loss


def main():
    global codebook_count_t, codebook_count_v, codebook_count_a, codebook_size
    codebook_size = 400
    codebook_count_v = torch.zeros(codebook_size)
    codebook_count_a = torch.zeros(codebook_size)
    codebook_count_t = torch.zeros(codebook_size)
    
    # utils variable
    global args, logger, dataset_configs
    # statistics variable
    global best_accuracy, best_accuracy_epoch
    best_accuracy, best_accuracy_epoch = 0, 0
    # configs
    dataset_configs = get_and_save_args(parser)
    parser.set_defaults(**dataset_configs)
    args = parser.parse_args()
    args.model_name = 'FCCID'
    print(args.model_name)
    # select GPUs
    # os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
    

    '''Create snapshot_pred dir for copying code and saving models '''
    if not os.path.exists(args.snapshot_pref):
        os.makedirs(args.snapshot_pref)

    if os.path.isfile(args.resume):
        args.snapshot_pref = os.path.dirname(args.resume)

    logger = Prepare_logger(args, eval=args.evaluate)

    # if not args.evaluate:
    #     logger.info(f'\nCreating folder: {args.snapshot_pref}')
    #     logger.info('\nRuntime args\n\n{}\n'.format(json.dumps(vars(args), indent=4)))
    # else:
    #     logger.info(f'\nLog file will be save in a {args.snapshot_pref}/Eval.log.')

    '''dataset selection'''
    if args.dataset_name == 'ave':
        from dataset.AVE_dataset import AVEDataset as AVEDataset
    elif args.dataset_name =='vggsound':
        from dataset.VGGSOUND_dataset import VGGSoundDataset as AVEDataset 
    elif args.dataset_name =='vggsound_AT':
        from dataset.VGGSOUND_dataset import VGGSoundDataset_AT as AVEDataset 
    elif args.dataset_name =='vggsound179k' or args.dataset_name =='vggsound81k':
        from dataset.VGGSOUND_dataset179k import VGGSoundDataset as AVEDataset     
    elif args.dataset_name == 'VALOR32k':
        from dataset.VALOR32K_dataset import Valor32k_VAT as ValorDataset
    else:
        raise NotImplementedError
    
    # '''Dataset'''
    # train_dataloader = DataLoader(
    #     AVEDataset('./data/', split='train'),
    #     batch_size=args.batch_size,
    #     shuffle=True,
    #     num_workers=8,
    #     pin_memory=True
    # )

    # test_dataloader = DataLoader(
    #     AVEDataset('./data/', split='test'),
    #     batch_size=args.test_batch_size,
    #     shuffle=False,
    #     num_workers=8,
    #     pin_memory=True
    # )
    '''Dataloader selection'''
    if args.dataset_name == 'ave':
        data_root = '/root/autodl-tmp/AVE-ECCV18-master/data'
        train_dataloader = DataLoader(
            AVEDataset(data_root, split='train'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True
        )
        val_dataloader = DataLoader(
            AVEDataset(data_root, split='val'),
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=8,
            pin_memory=True
        )
        test_dataloader = DataLoader(
            AVEDataset(data_root, split='test'),
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=8,
            pin_memory=True
        )
    elif args.dataset_name == 'vggsound':
        meta_csv_path = '/root/autodl-tmp/feature_extractor/vggsound-avel40k.csv'
        audio_fea_base_path = '/root/autodl-tmp/feature_vggsound_50k/audio/zip'
        video_fea_base_path = '/root/autodl-tmp/feature_vggsound_50k/video/zip'
        avc_label_base_path = '/root/autodl-tmp/feature_vggsound_50k/label/zip'
        train_dataloader = DataLoader(
            AVEDataset(meta_csv_path, audio_fea_base_path, video_fea_base_path, avc_label_base_path, split='train'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True
        )
        val_dataloader = DataLoader(
            AVEDataset(meta_csv_path, audio_fea_base_path, video_fea_base_path, avc_label_base_path, split='val'),
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=8,
            pin_memory=True
        )  
        test_dataloader = DataLoader(
            AVEDataset(meta_csv_path, audio_fea_base_path, video_fea_base_path, avc_label_base_path, split='test'),
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=8,
            pin_memory=True
        )  
    elif args.dataset_name == 'vggsound81k':
        meta_csv_path = '../feature_extractor/feature_extractor/video_name_vggsound81k_checked_va.csv'
        audio_fea_base_path = '../vggsoundAll/feature/audio/zip'
        video_fea_base_path = '../vggsoundAll/feature/video/zip'
        avc_label_base_path = '...'# 自监督没有label
        train_dataloader = DataLoader(
            AVEDataset(meta_csv_path, audio_fea_base_path, video_fea_base_path, avc_label_base_path, split='train'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True
        )
    elif args.dataset_name == 'vggsound179k':
        meta_csv_path = '/root/autodl-tmp/vggsoundAll/video_name_vggsound179k_checked.csv'
        audio_fea_base_path = '/root/autodl-tmp/vggsoundAll/feature/audio/zip'
        video_fea_base_path = '/root/autodl-tmp/vggsoundAll/feature/video/zip'
        avc_label_base_path = '...'# 自监督没有label
        train_dataloader = DataLoader(
            AVEDataset(meta_csv_path, audio_fea_base_path, video_fea_base_path, avc_label_base_path, split='train'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True
        )
    elif args.dataset_name == "VALOR32k":
        combined_dataset = ConcatDataset([ValorDataset(mode='train'), ValorDataset(mode='val'), ValorDataset(mode='test')])
        train_dataloader = DataLoader(
            combined_dataset,
            # ValorDataset(mode='500'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
            collate_fn=collate_func_AVT
        )
    else:
        raise NotImplementedError

    '''model setting'''
    if args.model_name == 'origin':
        video_dim = 512
        audio_dim = 128
        video_output_dim = 2048
        audio_output_dim = 256
        n_embeddings = codebook_size
        embedding_dim = 256
        start_epoch = -1
        model_resume = True
        total_step = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        #mainModel = main_model(video_dim, audio_dim, video_output_dim, audio_output_dim, n_embeddings, embedding_dim)
        Encoder = AV_VQVAE_Encoder(video_dim, audio_dim, video_output_dim, audio_output_dim, n_embeddings, embedding_dim)
        # Encoder = AVT_VQVAE_Encoder(audio_dim, video_dim, 256, audio_output_dim, video_output_dim, 256, n_embeddings, embedding_dim)
        Decoder = Semantic_Decoder(input_dim=256) #256对应embedding_dim，即codeword的维度
        Encoder.double()
        Decoder.double()
        '''optimizer setting'''
        #mainModel = nn.DataParallel(mainModel).cuda()
        Encoder.to(device)
        Decoder.to(device)
        #learned_parameters = mainModel.parameters()
        optimizer = torch.optim.Adam(chain(Encoder.parameters(), Decoder.parameters()), lr=args.lr)
        scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.5)
    elif args.model_name == 'DCID':
        '''DCID model setting'''
        video_dim = 512
        text_dim = 768
        audio_dim = 128
        text_lstm_dim = 128
        video_output_dim = 2048
        text_output_dim = 256
        audio_output_dim = 256
        n_embeddings = 400 #make sure n_embeddings/3 == 0
        embedding_dim = 256
        start_epoch = -1
        model_resume = True
        total_step = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        Text_ar_lstm = nn.LSTM(text_dim, text_lstm_dim, num_layers=2, batch_first=True, bidirectional=True)

        Encoder = AVT_VQVAE_Encoder_dcid(audio_dim, video_dim, text_lstm_dim*2, audio_output_dim, video_output_dim, text_output_dim, n_embeddings, embedding_dim)
        # CPC = Cross_CPC_MCL(embedding_dim, hidden_dim=256, context_dim=256, num_layers=2)
        # Video_mi_net = CLUBSample_group(x_dim=embedding_dim, y_dim=video_dim, hidden_size=256)
        # Text_mi_net = CLUBSample_group(x_dim=embedding_dim, y_dim=text_output_dim, hidden_size=256)
        # Audio_mi_net = CLUBSample_group(x_dim=embedding_dim, y_dim=audio_output_dim, hidden_size=256)
        # Decoder = MCL_VQVAE_Decoder(audio_dim, video_dim, text_lstm_dim*2, audio_output_dim, video_output_dim, text_output_dim)
    
        Text_ar_lstm.double()
        Encoder.double()
        # CPC.double()
        # Video_mi_net.double()
        # Text_mi_net.double()
        # Audio_mi_net.double()
        # Decoder.double()

        Text_ar_lstm.to(device)
        Encoder.to(device)
        # CPC.to(device)
        # Video_mi_net.to(device)
        # Text_mi_net.to(device)
        # Audio_mi_net.to(device)
        # Decoder.to(device)
        # optimizer = torch.optim.Adam(chain(Text_ar_lstm.parameters(), \
        #                                 Encoder.parameters(), CPC.parameters(), Decoder.parameters()), lr=args.lr)
        # optimizer_video_mi_net = torch.optim.Adam(Video_mi_net.parameters(), lr=args.mi_lr)
        # optimizer_text_mi_net = torch.optim.Adam(Text_mi_net.parameters(), lr=args.mi_lr)
        # optimizer_audio_mi_net = torch.optim.Adam(Audio_mi_net.parameters(), lr=args.mi_lr)
        # scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.5)
    elif args.model_name == 'FCCID':
        video_dim = 512
        text_dim = 768
        audio_dim = 128
        text_lstm_dim = 128
        video_output_dim = 2048
        text_output_dim = 256
        audio_output_dim = 256
        n_embeddings = 400 #400 + 200 + 200
        embedding_dim = 256
        start_epoch = -1
        model_resume = False
        total_step = 0
        compute_ewc_lambda = False
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        Text_ar_lstm = None
        # Text_ar_lstm = TextLSTM(text_dim, text_lstm_dim)# max_text_len, time
        
        # Image_ar_lstm = ImageLSTM(video_dim, text_lstm_dim)# image_dim == video_dim, image_lstm_dim == text_lstm_dim

        # Encoder = MCL_VQVAE_Encoder_Uni_weight(audio_dim, video_dim, text_lstm_dim*2, audio_output_dim, video_output_dim, text_output_dim, n_embeddings, embedding_dim)
        # CPC = Cross_CPC_MCL_weight(embedding_dim, hidden_dim=256, context_dim=256, num_layers=2)
        # Video_mi_net = CLUBSample_group(x_dim=embedding_dim, y_dim=video_dim, hidden_size=256)
        # Text_mi_net = CLUBSample_group(x_dim=embedding_dim, y_dim=text_output_dim, hidden_size=256)
        # Audio_mi_net = CLUBSample_group(x_dim=embedding_dim, y_dim=audio_output_dim, hidden_size=256)
        # Decoder = MCL_VQVAE_Decoder(audio_dim, video_dim, text_lstm_dim*2, audio_output_dim, video_output_dim, text_output_dim)
        # ewc_regularizer = EWC().double().to(device)
        Encoder = AVT_VQVAE_Encoder(audio_dim, video_dim, text_dim, audio_output_dim, video_output_dim, text_output_dim, n_embeddings, embedding_dim)
            
        # Text_ar_lstm.double()
        # Image_ar_lstm.double()
        Encoder.double()
        # CPC.double()
        # Video_mi_net.double()
        # Text_mi_net.double()
        # Audio_mi_net.double()
        # Decoder.double()
        
        '''optimizer setting'''
        # Text_ar_lstm.to(device)
        # Image_ar_lstm.to(device)
        Encoder.to(device)
        # CPC.to(device)
        # Video_mi_net.to(device)
        # Text_mi_net.to(device)
        # Audio_mi_net.to(device)
        # Decoder.to(device)
        # optimizer = torch.optim.Adam(chain(Text_ar_lstm.parameters(), Image_ar_lstm.parameters(), \
        #                                 Encoder.parameters(), CPC.parameters(), Decoder.parameters()), lr=args.lr)
        # optimizer_video_mi_net = torch.optim.Adam(Video_mi_net.parameters(), lr=args.mi_lr)
        # optimizer_text_mi_net = torch.optim.Adam(Text_mi_net.parameters(), lr=args.mi_lr)
        # optimizer_audio_mi_net = torch.optim.Adam(Audio_mi_net.parameters(), lr=args.mi_lr)
        # scheduler = MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=0.5)

    '''loss'''
    criterion = nn.BCEWithLogitsLoss().cuda()
    criterion_event = nn.CrossEntropyLoss().cuda()
    
    # path_checkpoints = args.path_checkpoints
    if args.model_name == 'DCID':
        checkpoints = torch.load('../checkpoints/nips2023_AVT_vgg40k_size400.pt')
        Encoder.load_state_dict(checkpoints['Encoder_parameters'])
        # CPC.load_state_dict(checkpoints['CPC_parameters'])
        Text_ar_lstm.load_state_dict(checkpoints['Text_ar_lstm_parameters'])
        # Text_mi_net.load_state_dict(checkpoints['Text_mi_net_parameters'])
        # Video_mi_net.load_state_dict(checkpoints['Video_mi_net_parameters'])
        # Audio_mi_net.load_state_dict(checkpoints['Audio_mi_net_parameters'])
        # Decoder.load_state_dict(checkpoints['Decoder_parameters'])
        # optimizer.load_state_dict(checkpoints['optimizer'])
        # optimizer_audio_mi_net.load_state_dict(checkpoints['optimizer_audio_mi_net'])
        # optimizer_video_mi_net.load_state_dict(checkpoints['optimizer_video_mi_net'])
        # optimizer_text_mi_net.load_state_dict(checkpoints['optimizer_text_mi_net'])
        start_epoch = checkpoints['epoch']
        total_step = checkpoints['total_step']
        logger.info("Resume from number {}-th model.".format(start_epoch))
    elif args.model_name == 'FCCID':
        checkpoints = torch.load("../checkpoints/fc/steps/CUnicode2-[400]-model-att41-step2000.pt")
        # ewc_regularizer.load_state_dict(checkpoints['ewc_regularizer_parameters'])
        Encoder.load_state_dict(checkpoints['Encoder_parameters'])
        # CPC.load_state_dict(checkpoints['CPC_parameters'])
        # Text_ar_lstm.load_state_dict(checkpoints['Text_ar_lstm_parameters'])
        # # Image_ar_lstm.load_state_dict(checkpoints['Image_ar_lstm_parameters'])
        # Text_mi_net.load_state_dict(checkpoints['Text_mi_net_parameters'])
        # Video_mi_net.load_state_dict(checkpoints['Video_mi_net_parameters'])
        # Audio_mi_net.load_state_dict(checkpoints['Audio_mi_net_parameters'])
        # Decoder.load_state_dict(checkpoints['Decoder_parameters'])
        # optimizer.load_state_dict(checkpoints['optimizer'])
        # optimizer_audio_mi_net.load_state_dict(checkpoints['optimizer_audio_mi_net'])
        # optimizer_video_mi_net.load_state_dict(checkpoints['optimizer_video_mi_net'])
        # optimizer_text_mi_net.load_state_dict(checkpoints['optimizer_text_mi_net'])
        start_epoch = checkpoints['epoch']
        total_step = checkpoints['total_step']
        logger.info("Resume from number {}-th model.".format(start_epoch))
    
    '''Resume from a checkpoint'''
    # if os.path.isfile(args.resume):
    #     logger.info(f"\nLoading Checkpoint: {args.resume}\n")
    #     mainModel.load_state_dict(torch.load(args.resume))
    # elif args.resume != "" and (not os.path.isfile(args.resume)):
    #     raise FileNotFoundError
    #
    # '''Only Evaluate'''
    # if args.evaluate:
    #     logger.info(f"\nStart Evaluation..")
    #     validate_epoch(mainModel, test_dataloader, criterion, criterion_event, epoch=0, eval_only=True)
    #     return

    '''Tensorboard and Code backup'''
    # writer = SummaryWriter(args.snapshot_pref)
    # recorder = Recorder(args.snapshot_pref, ignore_folder="Exps/")
    # recorder.writeopt(args)

    # for draw_num in range(20):
    # if model_resume is True:
    #     #/root/autodl-tmp/checkpoints/av_vqvae/4cpc512_new-model-4.pt
    #     # path_checkpoints = "/root/autodl-tmp/checkpoints/av_vqvae/4cpc1024_vgg81k_new-model-4.pt"'+f'{draw_num}'+'
    #     # path_checkpoints = '/root/autodl-tmp/checkpoints/av_vqvae/4cpc400_vgg40k_new-model-19.pt'
    #     # path_checkpoints = "/root/autodl-tmp/checkpoints/av_vqvae/infonce_cmcmloss_cpc400_vgg40k_model-9.pt"
    #     # path_checkpoints = "/root/autodl-tmp/checkpoints/av_vqvae/supv_cmcmloss_vgg40k_model-11.pt"
    #     # path_checkpoints = "../nips2023_draw_scatter_model/CPC400_AVT_vgg40k_model-7.pt"
    #     # path_checkpoints = "../nips2023_draw_scatter_model/before_CPC400_AVT_vgg40k_model-10.pt"
    #     # path_checkpoints = "../nips2023_draw_scatter_model/infonce_cmcmloss_cpc400_vgg40k_model-9.pt"
    #     # path_checkpoints = "../nips2023_draw_scatter_model/supv_cmcmloss_vgg40k_model-11.pt"
    #     path_checkpoints = "../nips2023_draw_scatter_model/4cpc400_vgg40k_new-model-10.pt"
    #     # path_checkpoints = "../nips2023_model/baseline_cpc400_vgg40k_model-3.pt"
    #     checkpoints = torch.load(path_checkpoints)
    #     Encoder.load_state_dict(checkpoints['Encoder_parameters'])
    #     #optimizer.load_state_dict(checkpoints['optimizer'])
    #     start_epoch = checkpoints['epoch']
    #     print("Resume from number {}-th model.".format(start_epoch))

    

    '''Training and Evaluation'''

    for epoch in range(start_epoch+1, start_epoch+2):
        # print("epoch: *******************************************", epoch)
        # if ((epoch + 1) % args.eval_freq == 0) or (epoch == args.n_epoch - 1):
        #     loss = validate_epoch(Encoder, Text_ar_lstm, train_dataloader)
            
        
        # torch.save({
        #     'codebook_count_t': codebook_count_t,
        #     'codebook_count_v': codebook_count_v,
        #     'codebook_count_a': codebook_count_a
        # }, 'codebook_counts_split_'+args.model_name+'.pth')
        
        
        checkpoint = torch.load('codebook_counts_split_'+args.model_name+'.pth')
        codebook_count_t = checkpoint['codebook_count_t']
        codebook_count_v = checkpoint['codebook_count_v']
        codebook_count_a = checkpoint['codebook_count_a']
        
        # 使用t-SNE算法将数据映射到二维平面上
        # tsne = TSNE(n_components=2)
        # if args.model_name == 'DCID':
        #     embedding = tsne.fit_transform(Encoder.Cross_quantizer.embedding.cpu())
        # elif args.model_name == 'FCCID':
        #     embedding = tsne.fit_transform(Encoder.Cross_quantizer_coarse.embedding.cpu())

        # # 获取每个向量的a和b值，并归一化到[0, 1]范围内
        
        # a = codebook_count_a.cpu().numpy()
        # b = codebook_count_v.cpu().numpy()
        # c = codebook_count_t.cpu().numpy()
        # all_color = []
        # for i in range(codebook_size):
        #     # print(a[i])
        #     if a[i]!=0 or b[i]!=0 or c[i]!=0:
        #         total = (a[i]+b[i]+c[i])
        #         one, two, three = a[i], b[i], c[i]
        #         a[i], b[i], c[i] = 0,0,0
        #         if one >= total * 0.05:
        #             a[i] = 1
        #         if two >= total * 0.05:
        #             b[i] = 1
        #         if three >= total * 0.05:
        #             c[i] = 1
        #     if a[i] == 1 and b[i] ==1  and c[i] == 1:
        #         all_color.append("#00FF00") # avt
        #     elif a[i] == 1 and b[i] == 1:
        #         all_color.append("#FFFF00") # av
        #     elif a[i] == 1 and c[i] == 1:
        #         all_color.append("#FFFF00") # at
        #     elif b[i] ==1 and c[i] ==1:
        #         all_color.append("#FFFF00") # vt
        #     elif a[i] == 1:
        #         all_color.append("#FF0000") # a 
        #     elif b[i] == 1:
        #         all_color.append("#FF0000") # v 
        #     elif c[i] == 1:
        #         all_color.append("#FF0000") # t 
        #     elif a[i] == 0 and b[i] == 0 and c[i] == 0:
        #         all_color.append("#FFFFFF")

        # colors_a = np.array([[r, g, b] for r, g, b in zip(a, c, b)])
        # plt.axis('off')
        # for i in range(len(all_color)):
        #     if all_color[i] == '#000000':
        #         embedding[i, 0] = embedding[i, 0] + random.uniform(-5, 5)
        #         embedding[i, 1] = embedding[i, 1] + random.uniform(-5, 5)
        # for i in range(len(embedding)):
        #     make = 'o'
        #     # if args.model_name == "FCCID" and all_color[i] == '#000000':
        #     #     if random.uniform(0, 1) < 0.8:
        #     #         continue
        #     # if all_color[i] == '#000000' or all_color[i] == "#0000FF":
        #     #     make = 'o'
        #     # elif all_color[i] == "#00FF00":
        #     #     make = 'v'
        #     # elif all_color[i] == "#FF0000":
        #     #     make = '*'
        #     # else:
        #     #     raise NotImplementedError
        #     plt.scatter(embedding[i, 0], embedding[i, 1], s=6, marker=make, c=all_color[i])
        # plt.savefig('../src/fig/'+args.model_name+'.png',dpi=400,bbox_inches = 'tight')

        # TSNE降维
        tsne = TSNE(n_components=2)
        if args.model_name == 'DCID':
            embedding = tsne.fit_transform(Encoder.Cross_quantizer.embedding.cpu())
        elif args.model_name == 'FCCID':
            embedding = tsne.fit_transform(Encoder.Cross_quantizer_coarse.embedding.cpu())

        a = codebook_count_a.cpu().numpy()
        b = codebook_count_v.cpu().numpy()
        c = codebook_count_t.cpu().numpy()

        # 颜色字典，统计每个类别的点数量
        color_labels = {
            "#00FF00": "avt",   # 绿色
            "#0000FF": "av/vt/at",  # 蓝色
            "#FF0000": "a/v/t" # 红色
        }

        color_counts = {key: 0 for key in color_labels.keys()}  # 初始化每个颜色的计数

        # 生成颜色数组
        all_color = []
        valid_code = []
        valid_color_num = 400.0
        for i in range(codebook_size):
            if a[i] != 0 or b[i] != 0 or c[i] != 0:
                total = a[i] + b[i] + c[i]
                one, two, three = a[i], b[i], c[i]
                a[i], b[i], c[i] = 0, 0, 0
                if one >= total * 0.05:
                    a[i] = 1
                if two >= total * 0.05:
                    b[i] = 1
                if three >= total * 0.05:
                    c[i] = 1
            if a[i] == 1 and b[i] == 1 and c[i] == 1:
                color = "#00FF00"  # avt
                all_color.append(color)
                color_counts[color] += 1  # 统计颜色数量
                valid_code.append(i)
            elif a[i] == 1 and b[i] == 1 or a[i] == 1 and c[i] == 1 or b[i] == 1 and c[i] == 1:
                color = "#0000FF"  # av/vt/at
                all_color.append(color)
                color_counts[color] += 1  # 统计颜色数量
                valid_code.append(i)
            elif a[i] == 1 or b[i] == 1 or c[i] == 1:
                color = "#FF0000"  # a/v/t
                all_color.append(color)
                color_counts[color] += 1  # 统计颜色数量
                valid_code.append(i)
            else:
                # color = "#FFFFFF"  # 无
                valid_color_num = valid_color_num - 1
            
            # all_color.append(color)
            # color_counts[color] += 1  # 统计颜色数量

        new_dict = {k: v / valid_color_num * 100 for k, v in color_counts.items()}
        color_counts = new_dict

        plt.axis('off')

        # 散点图绘制
        j = 0
        for i in range(len(embedding)):
            # if all_color[i] == '#000000':
            #     embedding[i, 0] += random.uniform(-5, 5)
            #     embedding[i, 1] += random.uniform(-5, 5)

            if i in valid_code:
                plt.scatter(embedding[i, 0], embedding[i, 1], s=6, marker='o', c=all_color[j])
                j+=1
            

        # 构造图例标签，显示每种颜色对应的数量
        # f"{num:.2f}"
        legend_labels = [f"{color_labels[color]}: {count:.1f}%" for color, count in color_counts.items()]
        handles = [plt.Line2D([0], [0], marker='o', color='w', label=label, markersize=10, markerfacecolor=color)
                for color, label in zip(color_counts.keys(), legend_labels)]

        plt.legend(handles=handles, loc='lower left')

        # 保存图像
        plt.savefig(f'../src/fig/{args.model_name}_rgb.png', dpi=400, bbox_inches='tight')


def _export_log(epoch, total_step, batch_idx, lr, loss_meter):
    msg = 'Epoch {}, Batch {}, lr = {:.5f}, '.format(epoch, batch_idx, lr)
    for k, v in loss_meter.items():
        msg += '{} = {:.4f}, '.format(k, v)
    # msg += '{:.3f} seconds/batch'.format(time_meter)
    print(msg)
    sys.stdout.flush()
    loss_meter.update({"batch": total_step})

def to_eval(all_models):
    for m in all_models:
        m.eval()


def to_train(all_models):
    for m in all_models:
        m.train()

def save_models(Encoder, optimizer, epoch_num, total_step, path):
    state_dict = {
        'Encoder_parameters': Encoder.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch_num,
        'total_step': total_step,
    }
    torch.save(state_dict, path)
    logging.info('save model to {}'.format(path))

@torch.no_grad()
def validate_epoch(Encoder, Text_ar_lstm, val_dataloader):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    audio_accuracy = AverageMeter()
    video_accuracy = AverageMeter()
    end_time = time.time()

    # Encoder.eval()
    # Decoder.eval()
    # Encoder.cuda()
    # Decoder.cuda()

    for n_iter, batch_data in enumerate(val_dataloader):
        if n_iter % 10 == 0:
            print(n_iter)
        data_time.update(time.time() - end_time)

        '''Feed input to model'''
        text_feature, text_len, visual_feature, audio_feature= batch_data[2], batch_data[3], batch_data[4], batch_data[0]
        text_feature.cuda()
        text_len.cuda()
        visual_feature.cuda()
        audio_feature.cuda()
        text_feature = text_feature.to(torch.float64)
        # text_len = text_len.to(torch.float64)
        visual_feature = visual_feature.to(torch.float64)
        audio_feature = audio_feature.to(torch.float64)

        if args.model_name == 'DCID':
            packed_input = rnn_utils.pack_padded_sequence(text_feature, text_len, batch_first=True, enforce_sorted=False).cuda().double()
            packed_output, _ = Text_ar_lstm(packed_input)
            lstm_out, _ = rnn_utils.pad_packed_sequence(packed_output, batch_first=True)
            B, L, embed_dim = lstm_out.shape
            t_feature = torch.zeros(B, 1, embed_dim).cuda()
            for i in range(B):
                t_feature[i,0,:] = torch.mean(lstm_out[i,:text_len[i]], dim=0, keepdim = False)
            t_feature = t_feature.to(torch.float64).cuda()
            
            t_indices = Encoder.Text_VQ_Encoder_indices(t_feature)
            # t_indices = torch.argmin(t_distance.double(), dim=-1)
            for t_indice in t_indices:
                codebook_count_t[t_indice.item()] = codebook_count_t[t_indice.item()] + 1


            
            a_indices = Encoder.Audio_VQ_Encoder_indices(audio_feature)
            for a_indice in a_indices:
                codebook_count_a[a_indice.item()] = codebook_count_a[a_indice.item()] + 1
            
            v_indices = Encoder.Video_VQ_Encoder_indices(visual_feature)
            for v_indice in v_indices:
                codebook_count_v[v_indice.item()] = codebook_count_v[v_indice.item()] + 1
        
        elif args.model_name == 'FCCID':
            t_indices = Encoder.Text_VQ_Encoder_C_split_indices(text_feature)
            for t_indice in t_indices:
                codebook_count_t[t_indice.item()] = codebook_count_t[t_indice.item()] + 1

            a_indices = Encoder.Audio_VQ_Encoder_C_split_indices(audio_feature)
            for a_indice in a_indices:
                codebook_count_a[a_indice.item()] = codebook_count_a[a_indice.item()] + 1
            
            v_indices = Encoder.Video_VQ_Encoder_C_split_indices(visual_feature)
            for v_indice in v_indices:
                codebook_count_v[v_indice.item()] = codebook_count_v[v_indice.item()] + 1

    return torch.zeros(1)


# def compute_accuracy_supervised(is_event_scores, event_scores, labels):
#     # labels = labels[:, :, :-1]  # 28 denote background
#     _, targets = labels.max(-1)
#     # pos pred
#     is_event_scores = is_event_scores.sigmoid()
#     scores_pos_ind = is_event_scores > 0.5
#     scores_mask = scores_pos_ind == 0
#     _, event_class = event_scores.max(-1)  # foreground classification
#     pred = scores_pos_ind.long()
#     pred *= event_class[:, None]
#     # add mask
#     pred[scores_mask] = 28  # 28 denotes bg
#     correct = pred.eq(targets)
#     correct_num = correct.sum().double()
#     acc = correct_num * (100. / correct.numel())
#
#     return acc

def mi_first_forward(audio_feature, visual_feature, Encoder, Video_mi_net, Audio_mi_net, optimizer_video_mi_net, optimizer_audio_mi_net):

    optimizer_video_mi_net.zero_grad()
    optimizer_audio_mi_net.zero_grad()

    _, video_club_feature, audio_encoder_result, \
    video_vq, audio_vq, _, _, _ = Encoder(audio_feature, visual_feature)
    video_club_feature = video_club_feature.detach()
    audio_encoder_result = audio_encoder_result.detach()
    video_vq = video_vq.detach()
    audio_vq = audio_vq.detach()

    lld_video_loss = -Video_mi_net.loglikeli(video_vq, video_club_feature)
    lld_video_loss.backward()
    optimizer_video_mi_net.step()

    lld_audio_loss = -Audio_mi_net.loglikeli(audio_vq, audio_encoder_result)
    lld_audio_loss.backward()
    optimizer_audio_mi_net.step()

    return optimizer_video_mi_net, lld_video_loss, optimizer_audio_mi_net, lld_audio_loss

def mi_second_forward(audio_feature, visual_feature, Encoder, Video_mi_net, Audio_mi_net, Decoder):
    video_encoder_result, video_club_feature, audio_encoder_result, \
    video_vq, audio_vq, audio_embedding_loss, video_embedding_loss, cmcm_loss = Encoder(audio_feature, visual_feature)
    mi_video_loss = Video_mi_net.mi_est(video_vq, video_club_feature)
    mi_audio_loss = Audio_mi_net.mi_est(audio_vq, audio_encoder_result)
    video_recon_loss, audio_recon_loss, video_class, audio_class \
        = Decoder(visual_feature, audio_feature, video_encoder_result, audio_encoder_result, video_vq, audio_vq)

    return audio_embedding_loss, video_embedding_loss, mi_audio_loss, mi_video_loss, \
           audio_recon_loss, video_recon_loss, audio_class, video_class, cmcm_loss


def compute_accuracy_supervised(event_scores, labels):
    labels_foreground = labels[:, :, :-1]
    labels_BCE, labels_evn = labels_foreground.max(-1)
    labels_event, _ = labels_evn.max(-1)
    _, event_class = event_scores.max(-1)
    correct = event_class.eq(labels_event)
    correct_num = correct.sum().double()
    acc = correct_num * (100. / correct.numel())
    return acc




# def save_checkpoint(state_dict, top1, task, epoch):
#     model_name = f'{args.snapshot_pref}/model_epoch_{epoch}_top1_{top1:.3f}_task_{task}_best_model.pth.tar'
#     torch.save(state_dict, model_name)


if __name__ == '__main__':
    main()
