import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import argparse

# 设置设备为GPU
device = 'cuda:0' 

# 设置命令行参数
def parse_args():
    parser = argparse.ArgumentParser(description="Use MINE to estimate mutual information for trajectories.")
    parser.add_argument("--video_name", type=str, required=True, help="Name of the video folder (e.g., 'animal2_s').")
    parser.add_argument("--idxs", type=int, nargs='+', required=True, help="List of reference point indices (e.g., 1000 1500 2100).")
    parser.add_argument("--image_file", type=str, default="rgbs/rgb_00001.jpg", help="Path to the reference image (default: 'rgbs/rgb_00001.jpg').")
    parser.add_argument("--load_draw", type=bool, default=True, help="If True, load existing MI results instead of recalculating.")
    parser.add_argument("--method", type=str, choices=['MINE', 'MINDE', 'InfoNCE', 'KSG', 'InfoNet', 'InfoNet+', 'KNIFE', 'KSG-5'], required=True, help="Method to estimate mutual information.")
    return parser.parse_args()

# 读取命令行参数
args = parse_args()

# 路径设置
data_root_path = os.path.join("data/point_odyssey/val", args.video_name)
mi_save_dir = os.path.join("results/3dtrack", args.video_name)  # 根据视频名称生成保存路径
os.makedirs(mi_save_dir, exist_ok=True)
track_save_path = os.path.join(mi_save_dir)
track_file_path = os.path.join(track_save_path, "trajs_2d.npy")

# 确保轨迹文件存在
if not os.path.isfile(track_file_path):
    print(f"未找到 {track_file_path}，请检查路径")
    exit()

# 读取轨迹数据
track = np.load(track_file_path)
print(f"轨迹数据形状: {track.shape}")

# 读取第一帧图像
image_path = os.path.join(data_root_path, args.image_file)
if not os.path.isfile(image_path):
    print(f"未找到 {image_path}，请检查路径")
    exit()
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# 为每个参考点重新绘制图像
for idx in args.idxs:
    # 加载互信息结果
    mi_file_path = os.path.join(mi_save_dir, f"{args.video_name}_{idx}_{args.method}.npy")
    location = track[0]  # [num_points, 2]
    point = location[idx]  # 参考点的坐标
    if args.load_draw:
        if os.path.isfile(mi_file_path):
            print(f"加载已有的互信息结果 {mi_file_path}，跳过估计过程。")
            results = np.load(mi_file_path)
            print(f"互信息结果形状 for point {idx}: {results.shape}")
            print(np.sum(results > 10), np.sum(results < -1))
        else:
            print(f"未找到 {mi_file_path}，请确保已计算并保存互信息结果。跳过点 {idx}")
            continue
    else:
        # 如果load_draw为False，重新计算互信息
        if os.path.isfile(mi_file_path):
            print(f"已有互信息估计结果 {mi_file_path}，跳过估计过程。")
            continue
        
        print(f"正在计算点 {idx} 的互信息..., 使用方法 {args.method}")
        results = []
        # 为每个点计算互信息

        if args.method == 'MINE':
            from estimators.MINE import MINE
            class Hyperparams(object):
                def __init__(self):
                    self.critic = 'neural'  # 使用神经网络作为critic
                    self.lr = 5e-4  # 学习率
                    self.bs = 500  # 批量大小
                    self.n_bridges = 4  # 桥接数量
                    self.wd = 1e-5  # 权重衰减

            # 超参数初始化
            hyperparams = Hyperparams()

            # MINE的critic和编码器架构
            d = 4  # 每个点的维度（x, y）
            architecture_critic = [d, 500, 500, 500, 1]  # MINE中的critic架构
            architecture_encode = [d // 2, 200, 200, d // 2]  # MINE中的编码器架构

        elif args.method == 'MINDE':
            from estimators.MINDE import MINDE
            class Hyperparams(object):
                def __init__(self): 
                    self.critic = 'neural'                # ('neural', 'quadratic')
                    self.lr = 5e-4
                    self.bs = 500
                    self.n_bridges = 4
                    self.wd = 1e-5
                    
            hyperparams=Hyperparams()
            d = 4
            architecture_critic = [d, 500, 500, 500, 1]
            architecture_encode = [d//2, 200, 200, d//2]
            hyperparams.t_patience = 500
            hyperparams.dim = d//2
            hyperparams.device = device
            hyperparams.importance_sampling = True

        elif args.method == 'InfoNCE':
            from estimators.InfoNCE import InfoNCE
            class Hyperparams(object):
                def __init__(self):
                    self.lr = 5e-4
                    self.bs = 500  # 批量大小
                    self.n_neg = 4  # 负样本数
                    self.encode_x = False  # 是否编码x
                    self.encode_y = False  # 是否编码y
                    self.critic = 'neural'  # ('neural', 'quadratic')
                    self.max_iteration = 1500  # 最大训练轮数
            hyperparams=Hyperparams()
            d = 4
            architecture_critic = [d, 500, 500, 500, 1]
            architecture_encoder_x = [d // 2, 200, 200, d // 2]  # 假设编码器维度
            architecture_encoder_y = [d // 2, 200, 200, d // 2]  # 假设编码器维度

        elif args.method == 'KSG':
            from estimators.KSG import KSG
            class Hyperparams(object):
                def __init__(self):
                    self.k_neighbors = 40  # k-NN的邻居数
                    self.tree_type = 'kd_tree'  # 树的类型
                    self.tree_kwargs = {}  # 树的额外参数
            
            hyperparams = Hyperparams()
            
        elif args.method == 'KSG-5':
            from estimators.KSG import KSG
            class Hyperparams(object):
                def __init__(self):
                    self.k_neighbors = 5  # k-NN的邻居数
                    self.tree_type = 'kd_tree'  # 树的类型
                    self.tree_kwargs = {}  # 树的额外参数
            
            hyperparams = Hyperparams()
            

        elif args.method == 'InfoNet':
            from InfoNet_V1.infer import load_model, estimate_mi, compute_smi_mean
            config_path = "InfoNet_V1/configs/config.yaml"
            ckpt_path = "InfoNet_V1/saved/model_5000_32_1000-720--0.16.pt"
            infonet_model = load_model(config_path, ckpt_path)
            infonet_model.eval()

        elif args.method == 'InfoNet+':
            from InfoNet import estimate_mi_xy, load_model_from_checkpoint
            import hydra
            config_path = 'config'
            config_name = 'cfg_1-5d_new'
            checkpoint_path = 'ckpt_path'
            device = 'cuda' if torch.cuda.is_available() else 'cpu'

            with hydra.initialize(config_path=config_path, version_base='1.1'):
                cfg = hydra.compose(config_name=config_name)
            model = load_model_from_checkpoint(
                checkpoint_path=checkpoint_path,
                config_path=config_path,
                config_name=config_name,
                device=device
            )
            print("Model loaded successfully")
            max_dim = 5  # 最大维度，与模型配置一致
            softrank_reg = 1e-3  # 正则化参数

        for i in range(track.shape[1]):
            if i == idx:
                results.append(0.0)  # 自身互信息为0
                continue
            
            # 准备数据
            X = torch.tensor(track[:, i, :], dtype=torch.float32, device=device)  # [frames, 2]
            Y = torch.tensor(track[:, idx, :], dtype=torch.float32, device=device)  # [frames, 2]
            
            if args.method == 'MINE':
                # 初始化MINE估计器
                estimator = MINE(None, None, architecture_critic, hyperparams)
                estimator.to(device)
                # 使用MINE估计互信息
                estimator.learn(X, Y)  # 训练MINE模型
                mi_est = estimator.MI(X, Y)  # 获取估计的互信息
                print('est MI:', mi_est)
                results.append(mi_est)  # 保存互信息值

            elif args.method == 'MINDE':
                estimator = MINDE(None, None, None, hyperparams)
                estimator.to(device)
                estimator.learn(X, Y)
                mi_est = estimator.MI(X, Y)
                print('est MI:', mi_est)
                results.append(mi_est)

            elif args.method == 'KSG':
                ksg_estimator = KSG(k_neighbors=hyperparams.k_neighbors, tree_type=hyperparams.tree_type, tree_kwargs=hyperparams.tree_kwargs)
                mi_est = ksg_estimator(X.cpu().numpy(), Y.cpu().numpy(), std=False)
                print('est MI:', mi_est)
                results.append(mi_est)

            elif args.method == 'KSG-5':
                ksg_estimator = KSG(k_neighbors=hyperparams.k_neighbors, tree_type=hyperparams.tree_type, tree_kwargs=hyperparams.tree_kwargs)
                mi_est = ksg_estimator(X.cpu().numpy(), Y.cpu().numpy(), std=False)
                print('est MI:', mi_est)
                results.append(mi_est)
            
            elif args.method == 'InfoNCE':
                estimator = InfoNCE(architecture_encoder_x, architecture_encoder_y, architecture_critic, hyperparams)
                estimator.to(device)
                estimator.learn(X, Y)
                mi_est = estimator.MI(X, Y)
                print('est MI:', mi_est)
                results.append(mi_est)

            elif args.method == 'KNIFE':
                from KNIFE import estimate_mi_knife
                mi_est = estimate_mi_knife(X, Y, train_steps=200)
                print('est MI:', mi_est)
                results.append(mi_est)

            elif args.method == 'InfoNet':
                mi_est = compute_smi_mean(X.cpu().numpy(), Y.cpu().numpy(), infonet_model, proj_num=32, seq_len=X.shape[0], batchsize=8)
                print('est MI:', mi_est)
                results.append(mi_est.item())

            elif args.method == 'InfoNet+':
                X = X.unsqueeze(0)  # [1, frames, 2]
                Y = Y.unsqueeze(0)  # [1, frames, 2]
                mi_est = estimate_mi_xy(X.cpu(), Y.cpu(), model, max_dim, softrank_reg=softrank_reg)
                print('est MI:', mi_est)
                results.append(mi_est.item())

    results = np.array(results)
    # 保存互信息结果
    np.save(mi_file_path, results)
    print(f"互信息计算完成，并保存至 {mi_file_path}")

    # 创建图像
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    fig, ax = plt.subplots(figsize=(12, 7), dpi=300)

    # 显示背景图像
    ax.imshow(image)

    # 绘制散点图（MI值作为颜色）
    scatter = ax.scatter(
        location[:, 0],
        location[:, 1],
        c=results,
        cmap='viridis',
        s=50,
        alpha=0.8,
        edgecolors='white',
        linewidth=0.5
    )

    # 绘制参考点（更加突出）
    ax.scatter(
        point[0],
        point[1],
        c='red',
        s=400,  # 参考点更大
        marker='*',
        label='Reference Point',
        edgecolors='white',
        linewidth=1.5
    )

    # 使用 make_axes_locatable 让 colorbar 高度与主图一致
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)  # size控制宽度，pad控制间距

    # 添加colorbar
    cbar = plt.colorbar(scatter, cax=cax)
    cbar.set_label('Mutual Information', fontsize=12)
    cbar.ax.tick_params(labelsize=10)

    # 去除坐标轴
    ax.axis('off')

    # 添加图例
    ax.legend(
        loc='upper right',
        fontsize=12,
        frameon=True,
        facecolor='white',
        edgecolor='black'
    )

    # 保存图像
    plot_save_dir = os.path.join("results/3dtrack", 'plots')
    os.makedirs(plot_save_dir, exist_ok=True)
    save_path = os.path.join(plot_save_dir, f"{args.method}_{args.video_name}_mi_{idx}.png")
    plt.savefig(save_path, bbox_inches='tight', dpi=300)

    print(f"图片已保存至 {save_path}")
        
        # 关闭图像以释放内存
    plt.close()
