
import pandas as pd
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
from utils import extract_coordinates
import torch.nn.functional as F



def parse_point(s):
    s = s.strip().lstrip('(').rstrip(')')
    x, y = map(float, s.split(','))
    return np.array([x, y])

def calculate_angle(p1, p2, p3):
    v1 = p1 - p2
    v2 = p3 - p2
    angle_rad = np.arccos(np.clip(np.dot(v1, v2) / 
                   (np.linalg.norm(v1) * np.linalg.norm(v2)), -1.0, 1.0))
    angle_deg = np.degrees(angle_rad)
    return angle_deg

def visualize_and_calculate(csv_path, target_id, output_path):
    df = pd.read_csv(csv_path)
    row = df[df['ID'] == target_id]

    if row.empty:
        print(f"ID {target_id} not found in CSV.")
        return

    row = row.iloc[0]

    image_path = row['Path']
    image = cv2.imread(image_path)
    if image is None:
        print(f"Image not found at {image_path}")
        return
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    ps1 = parse_point(row['PS1'])
    ps2 = parse_point(row['PS2'])
    fh1 = parse_point(row['FH1'])

    plt.figure()
    plt.imshow(image_rgb)
    plt.scatter(*ps1, c='red', label='PS1')
    plt.scatter(*ps2, c='green', label='PS2')
    plt.scatter(*fh1, c='blue', label='FH1')
    plt.plot([ps1[0], ps2[0]], [ps1[1], ps2[1]], 'r-')
    plt.plot([ps1[0], fh1[0]], [ps1[1], fh1[1]], 'b-')
    plt.legend()
    plt.title(f"ID: {target_id}")
    vis_save_path = os.path.join(output_path, f"visualized_{target_id}.png")
    plt.savefig(vis_save_path)
    plt.close()

    angle = calculate_angle(ps2, ps1, fh1)
    print(f"Calculated angle: {angle:.2f} degrees")

    aop_value = float(row['AOP'])
    print(f"AOP in CSV: {aop_value:.2f} degrees")
    print(f"Difference: {abs(angle - aop_value):.2f} degrees")


def visualize_keypoints(model, val_loader, device, epoch, save_path, n_samples=4):
    model.eval()
    
    with torch.no_grad():
        try:
            batch = next(iter(val_loader))
        except StopIteration:
            print("Visualization Warning: Validation loader is empty, skipping visualization.")
            return

        num_returns = len(batch)
        if num_returns == 4:
            imgs_batch, _, landmarks_batch, original_imgs_tuple = batch
            use_original_img_for_display = True
        elif num_returns == 3:
            imgs_batch, _, landmarks_batch = batch
            use_original_img_for_display = False
        else:
            raise ValueError(
                f"Unsupported number of items in val_loader batch: {num_returns}. "
                "Expected 3 or 4."
            )

        imgs = imgs_batch[:n_samples].to(device)
        landmarks = landmarks_batch[:n_samples].to(device)
        
        outputs = model(imgs.float())
        preds = extract_coordinates(outputs)
        
    outputs_cpu = outputs.cpu().numpy()
    preds_cpu = preds.cpu().numpy()
    landmarks_cpu = landmarks.cpu().numpy()

    if use_original_img_for_display:
        imgs_to_display = list(original_imgs_tuple)[:n_samples]
    else:
        imgs_cpu = imgs.cpu().numpy()   
        imgs_transposed = np.transpose(imgs_cpu, (0, 2, 3, 1)) 
        imgs_to_display = np.clip(imgs_transposed, 0, 1)   

    fig, axes = plt.subplots(nrows=2, ncols=n_samples, figsize=(n_samples * 5, 10))
    if n_samples == 1:
        axes = np.array(axes).reshape(2, 1)

    for i in range(n_samples):
        ax_img = axes[0, i]
        img_display = imgs_to_display[i]
        ax_img.imshow(img_display)
        
        h, w = img_display.shape[:2]
        
        # 绘制真实关键点 (蓝色)
        ax_img.scatter(landmarks_cpu[i, 0::2] * w, landmarks_cpu[i, 1::2] * h, 
                       s=80, c='blue', marker='o', label='Ground Truth', alpha=0.8, edgecolors='w')
        # 绘制预测关键点 (红色)
        ax_img.scatter(preds_cpu[i, 0::2] * w, preds_cpu[i, 1::2] * h, 
                       s=80, c='red', marker='x', label='Prediction', alpha=0.8)
        
        ax_img.set_title(f"Sample {i+1}")
        ax_img.axis('off')

        ax_heatmap = axes[1, i]
        heatmap_agg = np.sum(outputs_cpu[i], axis=0)
        ax_heatmap.imshow(img_display, alpha=0.4) 
        ax_heatmap.imshow(heatmap_agg, cmap='jet', alpha=0.6) 
        ax_heatmap.set_title(f"Predicted Heatmap")
        ax_heatmap.axis('off')

    handles, labels = ax_img.get_legend_handles_labels()
    if handles:
        fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.05), ncol=2)

    plt.tight_layout(rect=[0, 0.05, 1, 1])
    
    os.makedirs(save_path, exist_ok=True)
    save_file = os.path.join(save_path, f"epoch_{epoch}_visualization.png")
    plt.savefig(save_file)
    plt.close(fig) 
    
    print(f"Visualization saved to {save_file}")

# if __name__ == '__main__':
    # visualize_and_calculate("/home/chenyu/MICCAI/dataset/Labeled.csv", 12, "./")



