import json
import logging
import math
import os
import time

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel

try:
    import torch.distributed.nn
    from torch import distributed as dist

    has_distributed = True
except ImportError:
    has_distributed = False

try:
    import wandb
except ImportError:
    wandb = None

# from torchviz import make_dot

from floats.model import get_input_dtype
from floats.sinkhorn import sinkhorn_knopp
from float_train.distributed import is_master
from float_train.precison import get_autocast

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def backward(total_loss, scaler):
    if scaler is not None:
        scaler.scale(total_loss).backward()
    else:
        total_loss.backward()

def get_audiotext_metrics(audio_features, text_features):
    metrics = {}
    logits_per_audio = torch.cdist(audio_features, text_features, p=2)
    # audio_norm = (audio_features ** 2).sum(dim=1).unsqueeze(1)
    # text_norm = (text_features ** 2).sum(dim=1).unsqueeze(0)
    # cross = audio_features @ text_features.T
    # logits_per_audio = torch.sqrt(torch.clamp(audio_norm + text_norm - 2 * cross, min=1e-10))
    # logits_per_audio = logits_per_audio / (logits_per_audio.mean()+ 1e-6)
    logits_per_text = logits_per_audio.T

    a = torch.ones(len(audio_features))/len(audio_features)
    b = torch.ones(len(text_features))/len(text_features)
    sinkhorn_per_audio = sinkhorn_knopp(logits_per_audio, a, b, reg=0.03, numItermax=10)
    sinkhorn_per_text = sinkhorn_knopp(logits_per_text, b, a, reg=0.03, numItermax=10)
    logits = {"logits_audio_to_text": logits_per_audio, "logits_text_to_audio": logits_per_text,
              "sinkhorn_audio_to_text": sinkhorn_per_audio, "sinkhorn_text_to_audio":sinkhorn_per_text}
    ### 5 audio <-> 5 text
    if len(audio_features) == len(text_features):
        print("***************len(audio_features) == len(text_features)***************")
        for name, logit in logits.items():
            if 'sinkhorn' in name:
                _, top10_ranking = torch.topk(logit, k=10, dim=1, largest=True)
            else:
                _, top10_ranking = torch.topk(logit, k=10, dim=1, largest=False)
            batch_size = logit.size(0)

            num_audio_copies = 5  # 每个audio复制了5次

            # text features数量 = audio features数量
            assert batch_size % num_audio_copies == 0

            # build groundtruth
            if name in ["logits_audio_to_text", "sinkhorn_audio_to_text"]:
                # 🔵 audio → text
                # audio被复制了5次，但是文本是原本的
                group_ids = torch.arange(batch_size, device=logit.device) // num_audio_copies
                ground_truth = group_ids * num_audio_copies  # [N_audio]
                ground_truth_ids = ground_truth.unsqueeze(1) + torch.arange(num_audio_copies, device=logit.device).unsqueeze(0)
                # ground_truth_ids: (batch_size, 5)
            else:
                # 🔴 text → audio
                # 每5条caption来自同一个audio，audio也被复制了5次
                group_ids = torch.arange(batch_size, device=logit.device) // num_audio_copies
                ground_truth = group_ids * num_audio_copies
                ground_truth_ids = ground_truth.unsqueeze(1) + torch.arange(num_audio_copies, device=logit.device).unsqueeze(0)
                # ground_truth_ids: (batch_size, 5)

            # 找每个 query 的排名位置
            preds = (top10_ranking.unsqueeze(-1) == ground_truth_ids.unsqueeze(1)).any(dim=-1)
            for k in [1, 5, 10]:
                metrics[f"{name}_R@{k}"] = preds[:, :k].any(dim=1).float().mean().item() * 100
            first_match_rank = (preds.float().cumsum(dim=1) == 1).float().argmax(dim=1)

            metrics[f"{name}_mean_rank"] = first_match_rank.float().mean().item() + 1
            metrics[f"{name}_median_rank"] = torch.median(first_match_rank.float()).item() + 1
    else:
        assert len(text_features) == len(audio_features) * 5, "len(text_features) != len(audio_features) * 5"
        print("***************len(audio_features) != len(text_features)***************")
        for name, logit in logits.items():
            if 'sinkhorn' in name:
                _, top10_ranking = torch.topk(logit, k=10, dim=1, largest=True)
            else:
                _, top10_ranking = torch.topk(logit, k=10, dim=1, largest=False)

            if 'audio_to_text' in name:
                correct_caption_indices = torch.arange(len(audio_features), device=logit.device) * 5
                correct_caption_indices = correct_caption_indices.unsqueeze(1) + torch.arange(5, device=logit.device).unsqueeze(0)  # (num_audio, 5)
                preds = (top10_ranking.unsqueeze(1) == correct_caption_indices.unsqueeze(2)).any(dim=1)  # (num_audio, topk)
                
                for k in [1, 5, 10]:
                    metrics[f"{name}_R@{k}"] = preds[:, :k].any(dim=1).float().mean().item() * 100

                first_match_rank = (preds.float().cumsum(dim=1) == 1).float().argmax(dim=1)
                metrics[f"{name}_mean_rank"] = first_match_rank.float().mean().item() + 1
                metrics[f"{name}_median_rank"] = torch.median(first_match_rank.float()).item() + 1
            else: 
                # text -> audio
                correct_audio_indices = torch.arange(len(audio_features), device=logit.device).repeat_interleave(5)  # (num_caption,)

                preds = (top10_ranking == correct_audio_indices.unsqueeze(1))  # (num_caption, topk)

                for k in [1, 5, 10]:
                    metrics[f"{name}_R@{k}"] = preds[:, :k].any(dim=1).float().mean().item() * 100

                first_match_rank = (preds.float().cumsum(dim=1) == 1).float().argmax(dim=1)
                metrics[f"{name}_mean_rank"] = first_match_rank.float().mean().item() + 1
                metrics[f"{name}_median_rank"] = torch.median(first_match_rank.float()).item() + 1
        
    return metrics


def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None):
    device = torch.device(args.device)
    model.eval()

    autocast = get_autocast(args.precision, device_type=device.type)
    input_dtype = get_input_dtype(args.precision)

    if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):
        print(f"[Rank {dist.get_rank()}] >>> Running VAL evaluate at epoch {epoch}")
        dataloader = data['val'].dataloader
        samples_per_val = dataloader.num_samples
        # print(f"[Rank {dist.get_rank()}] using sampler: {data['val'].dataloader.sampler}")

        all_audio_features = torch.empty((samples_per_val, 768), dtype=torch.float32, device=device)
        all_text_features = torch.empty((samples_per_val, 768), dtype=torch.float32, device=device)
        write_mask = torch.zeros(samples_per_val, dtype=torch.int32, device=device)

        with torch.inference_mode():
            for i, batch in enumerate(dataloader):
                audios, tokenized, audio_ids, indices = batch
                audios = audios.to(device=device, dtype=input_dtype, non_blocking=True)
                
                input_ids = tokenized['input_ids'].to(device=device, non_blocking=True)
                attention_mask = tokenized['attention_mask'].to(device=device, non_blocking=True)
                # print(f"[Rank {dist.get_rank()}] val indices = {indices}")
                with autocast():
                    model_out = model(audios, input_ids, attention_mask)
                    audio_features = model_out["audio_features"]
                    text_features = model_out["text_features"]
                    
                    all_audio_features[indices] = audio_features
                    all_text_features[indices] = text_features

                    # 标记写入位置
                    write_mask[indices] += 1
        
        # print(f"[Rank {dist.get_rank()}] write_mask = {write_mask}")
                                                           
        # 聚合写入掩码，检测冲突
        dist.all_reduce(write_mask, op=dist.ReduceOp.SUM)
        
        # 聚合写入掩码，检测冲突
        if torch.any(write_mask == 0):
            missing_ids = torch.nonzero(write_mask == 0).squeeze().tolist()
            raise RuntimeError(f"[VAL ERROR] Some indices were never written: {missing_ids}")

        # 聚合特征（只写过的位置会保留，其他为零）
        dist.all_reduce(all_audio_features, op=dist.ReduceOp.SUM)
        dist.all_reduce(all_text_features, op=dist.ReduceOp.SUM)    

        if torch.any(write_mask > 1):
            conflict_id = torch.where(write_mask>1)[0]
            print(f"[Rank {dist.get_rank()}] >>> val write_mask has conflict index {conflict_id}")
            # print(f"[after all_gather]{write_mask}")
            # 如果有重复的，就执行除以操作
            all_audio_features /= write_mask.unsqueeze(1)
            all_text_features /= write_mask.unsqueeze(1)

        metrics = {}
        if is_master(args):   

            val_metrics = get_audiotext_metrics(
                audio_features=all_audio_features[::5].cpu(),
                text_features=all_text_features.cpu()
            )
            # metrics.update(
            #     {**val_metrics, "epoch": epoch, "num_samples": num_samples}
            # )
            metrics.update(
                {f"val/{k}": v for k, v in val_metrics.items()}
            )
            metrics.update({"val/epoch": epoch, "val/num_samples": samples_per_val})

    if 'test' in data:
        print(f"[Rank {dist.get_rank()}] >>> Running TEST evaluate at epoch {epoch}")
        dataloader = data['test'].dataloader
        samples_per_test = dataloader.num_samples
        print(f"[Rank {dist.get_rank()}] using sampler: {data['test'].dataloader.sampler}")

        test_audio_features = torch.empty((samples_per_test, 768), dtype=torch.float32, device=device)
        test_text_features = torch.empty((samples_per_test, 768), dtype=torch.float32, device=device)
        
        test_write_mask = torch.zeros(samples_per_test, dtype=torch.int32, device=device)

        with torch.inference_mode():
            for i, batch in enumerate(dataloader):
                audios, tokenized, audio_ids, indices = batch
                print(f"[Rank {dist.get_rank()}] test indices = {indices}")
                audios = audios.to(device=device, dtype=input_dtype, non_blocking=True)
                
                input_ids = tokenized['input_ids'].to(device=device, non_blocking=True)
                attention_mask = tokenized['attention_mask'].to(device=device, non_blocking=True)

                with autocast():
                    model_out = model(audios, input_ids, attention_mask)
                    audio_features = model_out["audio_features"]
                    text_features = model_out["text_features"]
                    
                    test_audio_features[indices] = audio_features
                    test_text_features[indices] = text_features
                    # 标记写入位置
                    test_write_mask[indices] += 1

        print(f"[Rank {dist.get_rank()}] test_write_mask = {test_write_mask}")               
        # 聚合写入掩码，检测冲突
        dist.all_reduce(test_write_mask, op=dist.ReduceOp.SUM)
        print(f"[after all_gather]{test_write_mask}")

        if torch.any(test_write_mask == 0):
            missing_ids = torch.nonzero(test_write_mask == 0).squeeze().tolist()
            raise RuntimeError(f"[TEXT ERROR] Some indices were never written: {missing_ids}")

        # 聚合特征（只写过的位置会保留，其他为零）
        dist.all_reduce(test_audio_features, op=dist.ReduceOp.SUM)
        dist.all_reduce(test_text_features, op=dist.ReduceOp.SUM)    

        if torch.any(test_write_mask > 1):
            conflict_id = torch.where(test_write_mask>1)[0]
            print(f"[Rank {dist.get_rank()}] >>> test_write_mask has conflict index {conflict_id}")
            
            # 如果有重复的，就执行除以操作
            test_audio_features /= test_write_mask.unsqueeze(1)
            test_text_features /= test_write_mask.unsqueeze(1)
        
        if is_master(args):     
            test_metrics = get_audiotext_metrics(
                audio_features=test_audio_features[::5].cpu(),
                text_features=test_text_features.cpu()
            )
            metrics.update(
                {f"test/{k}": v for k, v in test_metrics.items()}
            )
            metrics.update({"test/epoch": epoch, "test/num_samples": samples_per_test})
    if not metrics:
        return metrics

    logging.info(
        f"Eval Epoch: {epoch} "
        + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
    )

    log_data = {name: val for name, val in metrics.items()}

    if args.save_logs:
        if tb_writer is not None:
            for name, val in log_data.items():
                tb_writer.add_scalar(name, val, epoch)

        with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
            f.write(json.dumps(metrics))
            f.write("\n")
                    
    if args.wandb:
        assert wandb is not None, 'Please install wandb.'
        if 'train' in data:
            dataloader = data['train'].dataloader
            num_batches_per_epoch = dataloader.num_batches // args.accum_freq
            step = num_batches_per_epoch * epoch
        else:
            step = None
        log_data['epoch'] = epoch
        wandb.log(log_data, step=step)

    return metrics  

def evaluate_single(model, data, epoch, args, tb_writer=None, tokenizer=None):
    device = torch.device(args.device)
    model.eval()
    autocast = get_autocast(args.precision, device_type=device.type)
    input_dtype = get_input_dtype(args.precision)
    metrics = {}
    if not is_master(args):
        return metrics

    for split in ['val', 'test']:
        if split not in data:
            continue
        if split == 'val' and not (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):
            continue
        
        dataloader = data[split].dataloader
        samples = dataloader.num_samples

        all_audio_features = torch.empty((samples, 768), dtype=torch.float32, device=device)
        all_text_features = torch.empty((samples, 768), dtype=torch.float32, device=device)
        
        with torch.inference_mode():
            for batch in dataloader:
                audios, tokenized, audio_ids, indices = batch
                audios = audios.to(device=device, dtype=input_dtype, non_blocking=True)
                input_ids = tokenized['input_ids'].to(device=device, non_blocking=True)
                attention_mask = tokenized['attention_mask'].to(device=device, non_blocking=True)

                with autocast():
                    model_out = model(audios, input_ids, attention_mask)
                    audio_features = model_out["audio_features"]
                    text_features = model_out["text_features"]

                all_audio_features[indices] = audio_features
                all_text_features[indices] = text_features
        
        split_metrics = get_audiotext_metrics(
            audio_features=all_audio_features[::5].cpu(),
            text_features=all_text_features.cpu()
        )
        metrics.update({f"{split}/{k}": v for k, v in split_metrics.items()})
        metrics.update({f"{split}/epoch": epoch, f"{split}/num_samples": samples})
    if not metrics:
        return metrics

    logging.info(
        f"Eval Epoch: {epoch} "
        + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
    )

    log_data = {name: val for name, val in metrics.items()}

    if args.save_logs:
        if tb_writer is not None:
            for name, val in log_data.items():
                tb_writer.add_scalar(name, val, epoch)

        with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
            f.write(json.dumps(metrics))
            f.write("\n")
                    
    if args.wandb:
        assert wandb is not None, 'Please install wandb.'
        if 'train' in data:
            dataloader = data['train'].dataloader
            num_batches_per_epoch = dataloader.num_batches // args.accum_freq
            step = num_batches_per_epoch * epoch
        else:
            step = None
        log_data['epoch'] = epoch
        wandb.log(log_data, step=step)

    return metrics



def train_one_epoch(model, data, loss, epoch, optimizer, scaler, scheduler, args, tb_writer=None):
    device = torch.device(args.device)
    autocast = get_autocast(args.precision, device_type=device.type)
    input_dtype = get_input_dtype(args.precision)

    model.train() 
    data['train'].set_epoch(epoch)
    dataloader = data['train'].dataloader
    num_batches_per_epoch = dataloader.num_batches // args.accum_freq
    sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))

    if args.accum_freq > 1:
        accum_audios, accum_input_ids, accum_attention_masks,  accum_features = [], [],[], {}

    losses_m = {}
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    end = time.time()         
    for i, batch in enumerate(dataloader):
        i_accum = i // args.accum_freq
        step = num_batches_per_epoch * epoch + i_accum

        # if not args.skip_scheduler:
            # scheduler(step)

        audios, tokenized, audio_ids, indices = batch
        audios = audios.to(device=device, dtype=input_dtype, non_blocking=True)
                
        input_ids = tokenized['input_ids'].to(device=device, non_blocking=True)
        attention_mask = tokenized['attention_mask'].to(device=device, non_blocking=True)

        data_time_m.update(time.time() - end)
        optimizer.zero_grad()
        if args.accum_freq == 1:
            with autocast():
                model_out = model(audios, input_ids, attention_mask)
                losses = loss(**model_out, output_dict=True)
                total_loss = sum(losses.values())

            backward(total_loss, scaler)
            
        else:
            # First, cache the features without any gradient tracking.\
            with torch.no_grad():
                with autocast():
                    model_out = model(audios, input_ids, attention_mask)
                    
                    for key, val in model_out.items():
                        if key in accum_features:
                            accum_features[key].append(val)
                        else:
                            accum_features[key] = [val]
                
                accum_audios.append(audios)
                accum_input_ids.append(input_ids)
                accum_attention_masks.append(attention_mask)

            if ((i + 1) % args.accum_freq) > 0: 
                continue 
            # Now, ready to take gradients for the last accum_freq batches.
            # Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
            # Call backwards each time, but only step optimizer at the end.
            optimizer.zero_grad()
            for j in range(args.accum_freq):
                audios = accum_audios[j]
                input_ids = accum_input_ids[j]
                attention_mask = accum_attention_masks[j]
                with autocast():
                    model_out = model(audios, input_ids, attention_mask)
                    inputs = {}
                    for key, val in accum_features.items():
                        accumulated = accum_features[key]
                        inputs[key] = torch.cat(accumulated[:j] + [model_out[key]] + accumulated[j + 1:])

                    losses = loss(**inputs, **inputs_no_accum, output_dict=True)
                    del inputs
                    del inputs_no_accum
                    total_loss = sum(losses.values())
                    losses["loss"] = total_loss  
                backward(total_loss, scaler)
        
        
        if scaler is not None:
            if args.horovod:
                optimizer.synchronize()
                scaler.unscale_(optimizer)
                if args.grad_clip_norm is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
                with optimizer.skip_synchronize():
                    scaler.step(optimizer)
            else:
                if args.grad_clip_norm is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
                scaler.step(optimizer)
            scaler.update()
        else:
            if args.grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
            optimizer.step()

        # reset gradient accum, if enabled
        if args.accum_freq > 1:
            accum_audios, accum_input_ids, accum_attention_masks,  accum_features = [], [],[], {}

        batch_time_m.update(time.time() - end)
        end = time.time()
        batch_count = i_accum + 1
        if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
            # 在 Master 进程（Rank 0）上打印当前训练状态，包括：
            batch_size = len(audios)
            num_samples = batch_count * batch_size * args.accum_freq * args.world_size
            samples_per_epoch = dataloader.num_samples
            percent_complete = 100.0 * batch_count / num_batches_per_epoch

            for key, val in losses.items():
                if key not in losses_m:
                    losses_m[key] = AverageMeter()
                losses_m[key].update(val.item(), batch_size)

            loss_log = " ".join(
                [
                    f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" 
                    for loss_name, loss_m in losses_m.items()
                ]
            )

            samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val
            samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val
            logging.info(
                f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
                f"Data (t): {data_time_m.avg:.3f} "
                f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
                f"LR: {optimizer.param_groups[0]['lr']:5f} " + loss_log
                # f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log
            )

            # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
            log_data = {
                "data_time": data_time_m.val,
                "batch_time": batch_time_m.val,
                "samples_per_second": samples_per_second,
                "samples_per_second_per_gpu": samples_per_second_per_gpu,
                # "scale": logit_scale_scalar,
                "lr": optimizer.param_groups[0]["lr"]
            }            
            log_data.update({name:val.val for name,val in losses_m.items()})

            log_data = {"train/" + name: val for name, val in log_data.items()}

            if tb_writer is not None:
                for name, val in log_data.items():
                    tb_writer.add_scalar(name, val, step)
            
            if args.wandb:
                assert wandb is not None, 'Please install wandb.'
                log_data['step'] = step  # for backwards compatibility
                wandb.log(log_data, step=step)
            
            # resetting batch / data time meters per log window
            batch_time_m.reset()
            data_time_m.reset()


