import platform
import sys
import time
import numpy as np
import torch
import random
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from pprint import PrettyPrinter
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tools.utils import setup_seed, AverageMeter, a2t_ot_bilinear, t2a_ot_bilinear, a2t_and_t2a
from models.ASE_model_Float import ASE_mxh_ot
from data_handling.DataLoader import get_dataloader2
from models.BERT_Config import MODELS
import ot

tokenizer =  MODELS["bert-base-uncased"][1].from_pretrained("/home/xxxx/bert-base-uncased")

def train(config):
    # setup seed for reproducibility
    setup_seed(config.training.seed)

    # set up logger
    exp_name = config.exp_name

    folder_name = '{}_data_{}_noise{}_loss{}_eps{}_m{}_lr_{}_semi{}_'.format(exp_name, config.dataset,
                                             config.training.noise_p,
                                             config.training.loss,
                                             config.training.epsilon,
                                             config.training.m,
                                             config.training.lr, 
                                             config.training.semi_ratio)

    log_output_dir = Path('exp-outputs', folder_name, 'logging')
    model_output_dir = Path('exp-outputs', folder_name, 'models')
    log_output_dir.mkdir(parents=True, exist_ok=True)
    model_output_dir.mkdir(parents=True, exist_ok=True)

    logger.remove()
    
    logger.add(sys.stdout, format='{time: YYYY-MM-DD at HH:mm:ss} | {message}', level='INFO',
               filter=lambda record: record['extra']['indent'] == 1)
    logger.add(log_output_dir.joinpath('output.txt'), format='{time: YYYY-MM-DD at HH:mm:ss} | {message}', level='INFO',
               filter=lambda record: record['extra']['indent'] == 1)

    main_logger = logger.bind(indent=1)

    # setup TensorBoard
    writer = SummaryWriter(log_dir=str(log_output_dir) + '/tensorboard')

    # print training settings
    printer = PrettyPrinter()
    main_logger.info('Training setting:\n'
                     f'{printer.pformat(config)}')

    # set up model
    device, device_name = ('cuda',
                           torch.cuda.get_device_name(torch.cuda.current_device())) \
        if torch.cuda.is_available() else ('cpu', platform.processor())
    main_logger.info(f'Process on {device_name}')

    model = ASE_mxh_ot(config)

    model = model.to(device)
    

    # separate backbone lr:
    if config.training.separate_backbone_lr:
        backbone_ids = [id(item) for item in model.backbone_params()]
        other_params = [param for param in model.parameters() if id(param) not in backbone_ids]
        model_params = [
            {'params': other_params},
            {'params': model.backbone_params(), 'lr': config.training.backbone_lr}
        ]
    else:
        model_params = model.parameters()
    
    optimizer = torch.optim.Adam(params=model_params, lr=config.training.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
    # set up data loaders
    train_loader = get_dataloader2('train', config, config.dataset)
    val_loader = get_dataloader2('val', config, config.dataset)
    test_loader = get_dataloader2('test', config, config.dataset)

    main_logger.info(f'Size of training set: {len(train_loader.dataset)}, size of batches: {len(train_loader)}')
    main_logger.info(f'Size of validation set: {len(val_loader.dataset)}, size of batches: {len(val_loader)}')
    main_logger.info(f'Size of test set: {len(test_loader.dataset)}, size of batches: {len(test_loader)}')

    ep = 1

    # resume from a checkpoint
    if config.training.resume:
        checkpoint = torch.load(config.path.resume_model)
        model.load_state_dict(checkpoint['model'])
        ep = checkpoint['epoch']
        print(f"resume from {config.path.resume_model}")

    # training loop
    recall_sum_a2t = []
    recall_sum_t2a = []

    for epoch in range(ep, config.training.epochs + 1):
        main_logger.info(f'Training for epoch [{epoch}]')

        epoch_loss = AverageMeter()
        start_time = time.time()
        model.train()

        for batch_id, batch_data in tqdm(enumerate(train_loader),total=len(train_loader)):
            optimizer.zero_grad()
            audios, captions, audio_ids, _ = batch_data
            # move data to GPU
            audios = audios.to(device)
            audio_ids = audio_ids.to(device)

            tokenized = tokenizer(captions, add_special_tokens=True,padding=True, return_tensors='pt')
            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            sk, audio_embed, caption_embed = model(audios, input_ids, attention_mask)
            # (M, sk), _, _ = model(audios, input_ids, attention_mask)
            
            if config.training.semi_ratio > 0.0:
                # semi_loss
                gt = torch.eye(sk.size(0))
                mask = torch.randn(sk.size(0)) > config.training.semi_ratio
                gt_semi = gt * mask.unsqueeze(0).float() 
                gt_semi = gt_semi.to(device)
                s_loss = torch.sum(-gt_semi[mask] * torch.log(sk[mask]))
            else: 
                # sup_loss
                gt = torch.eye(sk.size(0)).to(device)
                s_loss = torch.sum(-gt * torch.log(sk))
            
            # ot_regularization
            Feat_M = torch.cdist(audio_embed.T, caption_embed.T, p=2)
            Feat_M = Feat_M / Feat_M.max()
            Feat_M_cpu = Feat_M.detach().cpu()
            # # 均匀分布初始化
            Feat_a = torch.ones(Feat_M.size(0), device="cpu") / Feat_M.size(0)
            Feat_b = torch.ones(Feat_M.size(1), device="cpu")  / Feat_M.size(1)                
            # # 特征范数的初始化 --> 特征维度的L2范数越大，可能包含的信息越多，应该给予更高的重要性
            # Feat_a = torch.norm(audio_embed, p=2, dim=0).to("cpu")
            # Feat_b = torch.norm(caption_embed, p=2, dim=0).to("cpu")
            # Feat_a = Feat_a / Feat_a.sum()
            # Feat_b = Feat_b / Feat_b.sum()
            # # 特征方差的非均匀初始化--> 方差大的特征维度可能包含更多判别信息
            # Feat_a = torch.var(audio_embed, dim=0).to("cpu")  # 每个音频特征维度的方差
            # Feat_b = torch.var(caption_embed, dim=0).to("cpu")  # 每个文本特征维度的方差
            # Feat_a = (Feat_a + 1e-6) / Feat_a.sum()
            # Feat_b = (Feat_b + 1e-6) / Feat_b.sum()
            pi = ot.unbalanced.sinkhorn_knopp_unbalanced(Feat_a, 
                                Feat_b, Feat_M_cpu, 0.03, reg_m=config.training.float_reg_m)
            # pi = ot.sinkhorn(Feat_a, Feat_b, Feat_M_cpu, 0.03)
            transfer_loss = torch.sum(pi.to(Feat_M.device) * Feat_M)
            import numpy as np
            np.save("pi.npy", pi.detach().cpu().numpy())
            exit()
            loss = s_loss + config.training.float_lambda * transfer_loss
            # loss = s_loss
            loss.backward()
            optimizer.step()

            
            epoch_loss.update(loss.cpu().item())
        
        writer.add_scalar('train/loss', epoch_loss.avg, epoch)

        elapsed_time = time.time() - start_time

        main_logger.info(f'Training statistics:\tloss for epoch [{epoch}]: {epoch_loss.avg:.3f},'
                         f'\ttime: {elapsed_time:.1f}, lr: {scheduler.get_last_lr()[0]:.6f}.')

        # validation loop, validation after each epoch
        main_logger.info("Validating...")
        r_sum_a2t, r_sum_t2a = validate(val_loader, model, device, writer, epoch)

        recall_sum_a2t.append(r_sum_a2t)
        recall_sum_t2a.append(r_sum_t2a)
        
        # save model
        if r_sum_a2t >= max(recall_sum_a2t):
            main_logger.info('Model saved.')
            torch.save({
                'model': model.state_dict(),
                'optimizer': model.state_dict(),
                'epoch': epoch,
            }, str(model_output_dir) + '/a2t_best_model.pth')
        if r_sum_t2a >= max(recall_sum_t2a):
            main_logger.info('Model saved.')
            torch.save({
                'model': model.state_dict (),
                'optimizer': model.state_dict(),
                'epoch': epoch,
            }, str(model_output_dir) + '/t2a_best_model.pth')

        scheduler.step()
    
    # Training done, evaluate on evaluation set
    main_logger.info('-'*90)
    main_logger.info('Training done. Start evaluating.')

    best_checkpoint_t2a = torch.load(str(model_output_dir) + '/t2a_best_model.pth')
    model.load_state_dict(best_checkpoint_t2a['model'])
    best_epoch_t2a = best_checkpoint_t2a['epoch']
    main_logger.info(f'Best checkpoint (Caption-to-audio) occurred in {best_epoch_t2a} th epoch.')
    validate_t2a(test_loader, model, device)
    main_logger.info('Evaluation done.')
    writer.close()

    best_checkpoint_a2t = torch.load(str(model_output_dir) + '/a2t_best_model.pth')
    model.load_state_dict(best_checkpoint_a2t['model'])
    best_epoch_a2t = best_checkpoint_a2t['epoch']
    main_logger.info(f'Best checkpoint (Audio-to-caption) occurred in {best_epoch_a2t} th epoch.')
    validate_a2t(test_loader, model, device)

def validate(data_loader, model, device, writer, epoch):
    val_logger = logger.bind(indent=1)
    model.eval()
    t2a_metrics = {"r1":0, "r5":0, "r10":0, "mean":0, "median":0}
    a2t_metrics = {"r1":0, "r5":0, "r10":0, "mean":0, "median":0}  
    a2t_t2a_metrics = {"r1":0, "r5":0, "r10":0, "mean":0, "median":0} 
    with torch.no_grad():
        # numpy array to keep all embeddings in the dataset
        audio_embs, cap_embs = None, None
        for i, batch_data in tqdm(enumerate(data_loader), total=len(data_loader)):
            audios, captions, _, indexs = batch_data
            audios = audios.to(device)

            tokenized = tokenizer(captions, add_special_tokens=True,padding=True, return_tensors='pt')
            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            
            _, audio_embeds, caption_embeds = model(audios, input_ids, attention_mask)

            if audio_embs is None:
                audio_embs = np.zeros((len(data_loader.dataset), audio_embeds.size(1)), dtype=np.float32)
                cap_embs = np.zeros((len(data_loader.dataset), caption_embeds.size(1)), dtype=np.float32)

            audio_embs[indexs] = audio_embeds.cpu().numpy()
            cap_embs[indexs] = caption_embeds.cpu().numpy()
            
        r1, r5, r10, r50, medr, meanr  = t2a_ot_bilinear(audio_embs, cap_embs)
        r_sum_t2a = r1 +r5 + r10
        t2a_metrics['r1'] += r1
        t2a_metrics['r5'] += r5
        t2a_metrics['r10'] += r10
        t2a_metrics['median'] += medr
        t2a_metrics['mean'] += meanr
        writer.add_scalar("valid/r1_t2a", r1, epoch)

        r1_a, r5_a, r10_a, r50_a, medr_a, meanr_a = a2t_ot_bilinear(audio_embs, cap_embs)
        r_sum_a2t = r1_a + r5_a + r10_a

        a2t_metrics['r1'] += r1_a
        a2t_metrics['r5'] += r5_a
        a2t_metrics['r10'] += r10_a
        a2t_metrics['median'] += medr_a
        a2t_metrics['mean'] += meanr_a
        writer.add_scalar("valid/r1_at2", r1_a, epoch)

        # evaluate audio_to_text_and_text_to_audio retrieval
        r1_at, r5_at, r10_at, r50_at, medr_at, meanr_at = a2t_and_t2a(audio_embs, cap_embs,)
        a2t_t2a_metrics['r1'] += r1_at
        a2t_t2a_metrics['r5'] += r5_at
        a2t_t2a_metrics['r10'] += r10_at
        a2t_t2a_metrics['median'] += medr_at
        a2t_t2a_metrics['mean'] += meanr_at

        val_logger.info('Audio to caption and Caption to Audio: r1: {:.4f}, r5: {:.4f}, '
                        'r10: {:.4f}, medr: {:.4f}, meanr: {:.4f}'.format(
                        a2t_t2a_metrics['r1'], a2t_t2a_metrics['r5'], a2t_t2a_metrics['r10'], a2t_t2a_metrics['median'], a2t_t2a_metrics['mean']))
        
        val_logger.info('Audio to caption: r1: {:.4f}, r5: {:.4f}, '
                        'r10: {:.4f}, medr: {:.4f}, meanr: {:.4f}'.format(
                        a2t_metrics['r1'], a2t_metrics['r5'], a2t_metrics['r10'], a2t_metrics['median'], a2t_metrics['mean']))

        val_logger.info('Caption to audio: r1: {:.4f}, r5: {:.4f}, '
                        'r10: {:.4f}, medr: {:.4f}, meanr: {:.4f}'.format(
                         t2a_metrics['r1'], t2a_metrics['r5'], t2a_metrics['r10'], t2a_metrics['median'], t2a_metrics['mean']))
        return r_sum_a2t, r_sum_t2a
    

def validate_a2t(data_loader, model, device):
    val_logger = logger.bind(indent=1)
    model.eval()
    a2t_metrics = {"r1":0, "r5":0, "r10":0, "mean":0, "median":0}
    

    with torch.no_grad():
        # numpy array to keep all embeddings in the dataset
        audio_embs, cap_embs = None, None
        # M = torch.diag(L)

        for i, batch_data in tqdm(enumerate(data_loader), total=len(data_loader)):
            audios, captions, audio_ids, indexs = batch_data
            audios = audios.to(device)

            tokenized = tokenizer(captions, add_special_tokens=True,padding=True, return_tensors='pt')
            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            _, audio_embeds, caption_embeds = model(audios, input_ids, attention_mask)
            # audio_embeds, caption_embeds = model(audios, input_ids, attention_mask)

            if audio_embs is None:
                audio_embs = np.zeros((len(data_loader.dataset), audio_embeds.size(1)), dtype=np.float32)
                cap_embs = np.zeros((len(data_loader.dataset), caption_embeds.size(1)), dtype=np.float32)

            audio_embs[indexs] = audio_embeds.cpu().numpy()
            cap_embs[indexs] = caption_embeds.cpu().numpy()

        # evaluate audio to text retrieval
        r1_a, r5_a, r10_a, r50_a, medr_a, meanr_a = a2t_ot_bilinear(audio_embs, cap_embs)
        a2t_metrics['r1'] += r1_a
        a2t_metrics['r5'] += r5_a
        a2t_metrics['r10'] += r10_a
        a2t_metrics['median'] += medr_a
        a2t_metrics['mean'] += meanr_a

        
        val_logger.info('Audio to caption: r1: {:.4f}, r5: {:.4f}, '
                        'r10: {:.4f}, medr: {:.4f}, meanr: {:.4f}'.format(
                        a2t_metrics['r1'], a2t_metrics['r5'], a2t_metrics['r10'], a2t_metrics['median'], a2t_metrics['mean']))

        
def validate_t2a(data_loader, model, device):
    val_logger = logger.bind(indent=1)
    model.eval()
    t2a_metrics = {"r1":0, "r5":0, "r10":0, "mean":0, "median":0}
    a2t_t2a_metrics = {"r1":0, "r5":0, "r10":0, "mean":0, "median":0}
    with torch.no_grad():
        # numpy array to keep all embeddings in the dataset
        audio_embs, cap_embs = None, None
        
        model = model.to(device)
        for i, batch_data in tqdm(enumerate(data_loader), total=len(data_loader)):
            audios, captions, audio_ids, indexs = batch_data
            audios = audios.to(device)

            tokenized = tokenizer(captions, add_special_tokens=True,padding=True, return_tensors='pt')
            input_ids = tokenized['input_ids'].to(device)
            attention_mask = tokenized['attention_mask'].to(device)
            _, audio_embeds, caption_embeds = model(audios, input_ids, attention_mask)
            # audio_embeds, caption_embeds = model(audios, input_ids, attention_mask)

            if audio_embs is None:
                audio_embs = np.zeros((len(data_loader.dataset), audio_embeds.size(1)), dtype=np.float32)
                cap_embs = np.zeros((len(data_loader.dataset), caption_embeds.size(1)), dtype=np.float32)

            audio_embs[indexs] = audio_embeds.cpu().numpy()
            cap_embs[indexs] = caption_embeds.cpu().numpy()

        # evaluate text to audio retrieval
        r1, r5, r10, r50, medr, meanr = t2a_ot_bilinear(audio_embs, cap_embs)
        t2a_metrics['r1'] += r1
        t2a_metrics['r5'] += r5
        t2a_metrics['r10'] += r10
        t2a_metrics['median'] += medr
        t2a_metrics['mean'] += meanr
       
        val_logger.info('Caption to audio: r1: {:.4f}, r5: {:.4f}, '
                        'r10: {:.4f}, medr: {:.4f}, meanr: {:.4f}'.format(
                         t2a_metrics['r1'], t2a_metrics['r5'], t2a_metrics['r10'], t2a_metrics['median'], t2a_metrics['mean']))
        
        # evaluate audio_to_text_and_text_to_audio retrieval
        r1_at, r5_at, r10_at, r50_at, medr_at, meanr_at = a2t_and_t2a(audio_embs, cap_embs)
        a2t_t2a_metrics['r1'] += r1_at
        a2t_t2a_metrics['r5'] += r5_at
        a2t_t2a_metrics['r10'] += r10_at
        a2t_t2a_metrics['median'] += medr_at
        a2t_t2a_metrics['mean'] += meanr_at

        val_logger.info('Audio to caption and Caption to Audio: r1: {:.4f}, r5: {:.4f}, '
                        'r10: {:.4f}, medr: {:.4f}, meanr: {:.4f}'.format(
                        a2t_t2a_metrics['r1'], a2t_t2a_metrics['r5'], a2t_t2a_metrics['r10'], a2t_t2a_metrics['median'], a2t_t2a_metrics['mean']))

