from pathlib import Path
from os import makedirs
from os.path import dirname, join, abspath
import sys
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import yaml

# Add parent directory to path to find teacher_student package
sys.path.insert(0, dirname(dirname(abspath(__file__))))

from teacher_student.dataloader import TaskDataModule, get_dataset
from teacher_student.task_utils import update_tasks_mnist

ROOT_DIR = dirname(dirname(Path(__file__).resolve()))


def collect_all_data(data_module):
    """Collect all images from train, val, and test dataloaders"""
    all_images = []
    total_samples = 0
    
    # Collect from train dataloader
    for batch in data_module.train_dataloader():
        images, _ = batch
        all_images.append(images)
        total_samples += images.shape[0]
    
    # Collect from test dataloader
    for batch in data_module.test_dataloader():
        images, _ = batch
        all_images.append(images)
        total_samples += images.shape[0]
    
    # Concatenate all images
    all_data = torch.cat(all_images, dim=0)
    print(f"Collected {total_samples} samples with shape {all_data.shape}")
    
    return all_data


def compute_task_statistics(images):
    """
    Compute pixel-level mean and std across all images in a task
    
    Args:
        images: tensor of shape (N, C, H, W) where N is number of images
        
    Returns:
        mean_matrix: (H, W) mean at each pixel location
        std_matrix: (H, W) std at each pixel location
    """
    # Remove channel dimension if present (MNIST is grayscale)
    if len(images.shape) == 4:
        images = images.squeeze(1)  # Remove channel dimension
    
    # Compute mean and std across the batch dimension (dim=0)
    mean_matrix = torch.mean(images, dim=0)
    std_matrix = torch.std(images, dim=0)
    
    return mean_matrix, std_matrix


def frobenius_dot_product(matrix_a, matrix_b):
    """
    Calculate Frobenius inner product: <A, B>_F = trace(A^T * B) = sum(A * B)
    
    Args:
        matrix_a: First matrix (H, W)
        matrix_b: Second matrix (H, W)
        
    Returns:
        float: Frobenius dot product value
    """
    return torch.sum(matrix_a * matrix_b).item()




def visualize_mean_matrices(mean1, mean2, task1_labels, task2_labels, similarity_score, save_dir, control_tasks):
    """
    Create heatmap visualizations of the mean matrices
    """
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Task 1 mean matrix
    im1 = axes[0].imshow(mean1.numpy(), cmap='viridis', aspect='equal')
    axes[0].set_title(f'Task 1 Mean Matrix\nDigits: {task1_labels}')
    axes[0].set_xlabel('Pixel X')
    axes[0].set_ylabel('Pixel Y')
    plt.colorbar(im1, ax=axes[0])
    
    # Task 2 mean matrix
    im2 = axes[1].imshow(mean2.numpy(), cmap='viridis', aspect='equal')
    axes[1].set_title(f'Task 2 Mean Matrix\nDigits: {task2_labels}')
    axes[1].set_xlabel('Pixel X')
    axes[1].set_ylabel('Pixel Y')
    plt.colorbar(im2, ax=axes[1])
    
    # Difference matrix
    diff_matrix = (mean1 - mean2).numpy()
    im3 = axes[2].imshow(diff_matrix, cmap='RdBu_r', aspect='equal')
    axes[2].set_title(f'Difference Matrix\n(Task 1 - Task 2)')
    axes[2].set_xlabel('Pixel X')
    axes[2].set_ylabel('Pixel Y')
    plt.colorbar(im3, ax=axes[2])
    
    plt.suptitle(f'Task Similarity Analysis\nFrobenius Dot Product: {similarity_score:.4f}', 
                 fontsize=16, y=1.02)
    plt.tight_layout()
    plt.savefig(save_dir / f'task_similarity_{control_tasks}.png')
    
    # Additional statistics plot
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Histogram of pixel values in mean matrices
    axes[0].hist(mean1.numpy().flatten(), bins=50, alpha=0.7, label='Task 1', density=True)
    axes[0].hist(mean2.numpy().flatten(), bins=50, alpha=0.7, label='Task 2', density=True)
    axes[0].set_xlabel('Pixel Intensity')
    axes[0].set_ylabel('Density')
    axes[0].set_title('Distribution of Mean Pixel Intensities')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Scatter plot of corresponding pixels
    axes[1].scatter(mean1.numpy().flatten(), mean2.numpy().flatten(), alpha=0.6, s=1)
    axes[1].set_xlabel('Task 1 Mean Pixel Intensity')
    axes[1].set_ylabel('Task 2 Mean Pixel Intensity')
    axes[1].set_title('Pixel-wise Correlation Between Tasks')
    axes[1].plot([mean1.min(), mean1.max()], [mean1.min(), mean1.max()], 'r--', alpha=0.8)
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_dir / f'task_stats_{control_tasks}.png')
    
    
    
def calculate_task_similarity_with_viz(dm_task1, dm_task2, task1_labels, task2_labels, save_dir, control_tasks):
    """
    Calculate similarity between two tasks and create visualizations
    """
    print("Collecting data from Task 1...")
    task1_images = collect_all_data(dm_task1)
    
    print("Collecting data from Task 2...")
    task2_images = collect_all_data(dm_task2)
    
    print("Computing pixel-level statistics...")
    task1_mean, task1_std = compute_task_statistics(task1_images)
    task2_mean, task2_std = compute_task_statistics(task2_images)
    
    # Calculate Frobenius dot product between mean matrices
    similarity = frobenius_dot_product(task1_mean, task2_mean)
    
    print(f"\nResults:")
    print(f"Task 1 ({task1_labels}) - Mean matrix shape: {task1_mean.shape}")
    print(f"Task 2 ({task2_labels}) - Mean matrix shape: {task2_mean.shape}")
    print(f"Frobenius dot product similarity: {similarity:.6f}")
    
    # Additional statistics
    print(f"\nAdditional Statistics:")
    print(f"Task 1 mean intensity: {task1_mean.mean().item():.4f} ± {task1_mean.std().item():.4f}")
    print(f"Task 2 mean intensity: {task2_mean.mean().item():.4f} ± {task2_mean.std().item():.4f}")
    
    # Create visualizations
    visualize_mean_matrices(task1_mean, task2_mean, task1_labels, task2_labels, similarity, save_dir, control_tasks)
    
    return similarity, task1_mean, task1_std, task2_mean, task2_std





def main():
    data_dir = Path(join(ROOT_DIR,"/data/"))
    save_dir = Path(join(ROOT_DIR,"experiments/evaluations/task_similarity/MNIST"))
    makedirs(save_dir, exist_ok=True)
    
    seed = 42
    dataset_label = "MNIST"
    bs = 64
    cluster = False
    
    result = {}
    
    for control_tasks in ["half", "round", "top", "equal"]:
        
        task1_labels, task2_labels = update_tasks_mnist(control_tasks)

        dm_task1 = TaskDataModule(
                dataset_class = get_dataset(dataset_label),
                batch_size = bs,
                selected_labels = task1_labels,
                data_dir = data_dir,
                seed = seed,
                num_workers = 1 if cluster else 20
                )


        dm_task2 = TaskDataModule(
                dataset_class = get_dataset(dataset_label),
                batch_size = bs,
                selected_labels = task2_labels,
                data_dir = data_dir,
                seed = seed,
                num_workers = 1 if cluster else 20
                )

        dm_task1.setup()
        dm_task2.setup()


        # Calculate task similarity with visualization
        similarity_score, mean1, std1, mean2, std2 = calculate_task_similarity_with_viz(
            dm_task1, dm_task2, task1_labels, task2_labels, save_dir, control_tasks
        )

        result.update({
            control_tasks: {
                'Frobneus_dot_product': similarity_score,
                'task1_labels': task1_labels,
                'task2_labels': task2_labels,
                # 'task1_mean_matrix': mean1.numpy(),
                # 'task2_mean_matrix': mean2.numpy(),
                # 'task1_std_matrix': std1.numpy(),
                # 'task2_std_matrix': std2.numpy()
            }
        })
    
    yaml.dump(result, open(save_dir / "task_similarity_results.yaml", "w"))

if __name__ == "__main__":
    main()
