"""
Evaluation script for rule-based analysis of autoregressive trajectory
generation from a locally saved JUMP-DIFFUSION model.

This script loads a trained model from a local directory, evaluates its
ability to adhere to the known rules of the data generating process,
and saves the analysis and visualizations locally.

Usage:
    python local_autoregressive_evaluation.py --run_path /path/to/local/run
"""
import torch
import numpy as np
import yaml
import matplotlib.pyplot as plt
from pathlib import Path
from types import SimpleNamespace
import argparse
import random
import json

# Import necessary functions from the existing project files
from data import JumpDataLoader, DataConfig
from network import UNetModel
from alternative_network import UNet
# --- Helper Functions (copied from other scripts for stand-alone use) ---

def set_seed(seed: int):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def load_json_config(config_path: str) -> SimpleNamespace:
    """Loads a JSON config file and converts it to a SimpleNamespace."""
    with open(config_path, 'r') as f:
        config_dict = json.load(f)
    def dict_to_namespace(d):
        if not isinstance(d, dict):
            # Convert lists back to tuples where needed (e.g., attention_resolutions)
            if isinstance(d, list):
                return tuple(d)
            return d
        return SimpleNamespace(**{k: dict_to_namespace(v) for k, v in d.items()})
    return dict_to_namespace(config_dict)

def binarize_top_k(batch_tensor, k=9):
    """
    For each image in a batch, sets the k largest entries to 1 and the rest to 0.
    Handles both 4D [B, C, H, W] and 5D [B, T, C, H, W] tensors.
    """
    original_shape = batch_tensor.shape
    
    # If 5D, reshape to a batch of individual images to binarize each frame
    if batch_tensor.dim() == 5:
        B, T, C, H, W = original_shape
        batch_tensor = batch_tensor.reshape(B * T, C, H, W)

    batch_size = batch_tensor.shape[0]
    
    # Flatten spatial dimensions for topk
    reshaped_for_topk = batch_tensor.view(batch_size, -1)

    if torch.all(reshaped_for_topk == reshaped_for_topk[:, :1]):
        if len(original_shape) == 5:
            return batch_tensor.view(original_shape)
        return batch_tensor
        
    thresholds = torch.topk(reshaped_for_topk, k, dim=1).values[:, -1]
    thresholds_reshaped = thresholds.view(batch_size, 1, 1, 1)
    result = (batch_tensor >= thresholds_reshaped).float()

    # If the original input was 5D, reshape the result back
    if len(original_shape) == 5:
        result = result.view(original_shape)
        
    return result


def get_cube_center(binarized_frame):
    indices = binarized_frame.squeeze().nonzero(as_tuple=False)
    if indices.shape[0] == 0:
        return torch.tensor([-1., -1.], device=binarized_frame.device)
    center = indices.float().mean(dim=0)
    return torch.tensor([center[1], center[0]], device=binarized_frame.device)

def is_single_box(binarized_frame, cube_size=3):
    if binarized_frame.shape[0] != 1: return False
    expected_pixels = cube_size * cube_size
    if binarized_frame.sum() != expected_pixels: return False
    activated_pixels = binarized_frame.squeeze().nonzero(as_tuple=False)
    if activated_pixels.shape[0] != expected_pixels: return False
    min_coords = activated_pixels.min(dim=0).values
    max_coords = activated_pixels.max(dim=0).values
    if (max_coords[0] - min_coords[0] != cube_size - 1) or \
       (max_coords[1] - min_coords[1] != cube_size - 1):
        return False
    return True

def calculate_gt_jump_ratio_vectorized(test_set, image_size, device):
    """Calculates the empirical jump ratio of the ground truth data efficiently."""
    N, T, _, H, W = test_set.shape
    
    y_coords = torch.arange(H, device=device, dtype=torch.float32).view(1, 1, H, 1)
    test_set_squeezed = test_set.squeeze(2)
    num_pixels = test_set_squeezed.sum(dim=[2, 3]).clamp(min=1e-6)
    sum_y = (test_set_squeezed * y_coords).sum(dim=[2, 3])
    all_centers_y = sum_y / num_pixels
    
    y_t, y_t_minus_1 = all_centers_y[:, 1:], all_centers_y[:, :-1]
    is_jump = (torch.abs(y_t - y_t_minus_1) > 2).float()
    
    gt_total_steps = N * (T - 1)
    gt_jumps = is_jump.sum().item()
    
    return (gt_jumps / gt_total_steps) * 100 if gt_total_steps > 0 else 0


def autoregressive_sampler(net, initial_frames, mem_len, config, device):
    net.eval()
    num_samples = initial_frames.shape[0]
    image_size = config.data.image_size
    disc_steps = config.model.disc_steps
    
    full_gen_traj = torch.zeros(
        num_samples, config.data.no_timesteps, 1, image_size, image_size, device=device
    )
    full_gen_traj_noisy = full_gen_traj.clone()
    full_gen_traj[:, :mem_len] = initial_frames
    full_gen_traj_noisy[:,:mem_len] = initial_frames
    current_memory = binarize_top_k(initial_frames, k=9)

    stats = {
        'x_direction_correct': 0, 'x_direction_total': 0,
        'determinable_wrong_jumps': 0, 'total_generated_jumps': 0,
        'well_formed_cubes': 0,
    }
    
    with torch.no_grad():
        for t_idx in range(mem_len, config.data.no_timesteps):
            current_memory = current_memory.contiguous()
            x = current_memory[:, -1]
            mem_reshaped = current_memory.view(num_samples, mem_len, image_size, image_size)
            
            t1_val, t2_val = t_idx - 1, t_idx
            
            for i in range(disc_steps):
                t_val = t1_val + (i / disc_steps) * (t2_val - t1_val)
                t = t_val * torch.ones(num_samples, device=device)
                t2 = t2_val * torch.ones(num_samples, device=device)
                h = 1 / disc_steps
                
                time_in = torch.cat([t2, t], dim=0)
                inpu = torch.cat([x, mem_reshaped], dim=1)
                
                out = net(inpu, time_in)
                lambda_net_raw, mean_net, sig_net_raw = out.split(1, dim=1)
                lambda_net = torch.exp(lambda_net_raw)
                sig_net = torch.exp(sig_net_raw)
                
                z = mean_net + sig_net * torch.randn_like(mean_net)
                rt = torch.exp(-lambda_net * h).clamp(0, 1)
                m = torch.bernoulli(1 - rt)
                x = (1 - m) * x + m * z

            next_frame = x
            full_gen_traj_noisy[:,t_idx]=next_frame
            next_frame_binarized = binarize_top_k(next_frame, k=9)
            full_gen_traj[:, t_idx] = next_frame_binarized
            # --- Batched statistics calculation would be complex here, ---
            # --- so we loop over the batch for clarity. ---
            for sample_idx in range(num_samples):
                sample_memory = current_memory[sample_idx:sample_idx+1]
                sample_next_frame = next_frame_binarized[sample_idx:sample_idx+1]
                
                if is_single_box(sample_next_frame, cube_size=config.data.cube_size):
                    stats['well_formed_cubes'] += 1

                    pos_t_minus_2 = get_cube_center(sample_memory[:, -2])
                    pos_t_minus_1 = get_cube_center(sample_memory[:, -1])
                    pos_t_new = get_cube_center(sample_next_frame)

                    if pos_t_minus_2[0] != -1 and pos_t_minus_1[0] != -1:
                        expected_x_dir = torch.sign(pos_t_minus_1[0] - pos_t_minus_2[0])
                        if pos_t_minus_1[0] < 2: expected_x_dir = 1
                        elif pos_t_minus_1[0] > image_size  - config.data.cube_size: expected_x_dir = -1
                        actual_x_dir = torch.sign(pos_t_new[0] - pos_t_minus_1[0])
                        if expected_x_dir != 0:
                            stats['x_direction_total'] += 1
                            if actual_x_dir == expected_x_dir: stats['x_direction_correct'] += 1

                    is_new_move_a_jump = abs(pos_t_new[1] - pos_t_minus_1[1]) > 2
                    if is_new_move_a_jump:
                        stats['total_generated_jumps'] += 1
                        new_jump_dir = torch.sign(pos_t_new[1] - pos_t_minus_1[1])
                        last_jump_dir_in_memory, last_jump_y_pos = None, -1
                        for j in range(mem_len - 1, 0, -1):
                            p1 = get_cube_center(sample_memory[:, j-1]); p2 = get_cube_center(sample_memory[:, j])
                            if abs(p2[1] - p1[1]) > 2:
                                last_jump_dir_in_memory, last_jump_y_pos = torch.sign(p2[1] - p1[1]), p2[1]
                                break
                        if last_jump_dir_in_memory is not None:
                            expected_jump_dir = last_jump_dir_in_memory
                            if last_jump_y_pos < 4: expected_jump_dir = 1
                            elif last_jump_y_pos > image_size - 2 - config.data.cube_size: expected_jump_dir = -1
                            if new_jump_dir != expected_jump_dir: stats['determinable_wrong_jumps'] += 1

            current_memory = torch.cat([current_memory[:, 1:], next_frame_binarized.unsqueeze(1)], dim=1)

    return full_gen_traj, stats,full_gen_traj_noisy

def visualize_autoregressive_trajectories(gen_traj, gt_traj, idx, output_dir):
    timesteps = gen_traj.shape[1]
    fig, axes = plt.subplots(2, timesteps, figsize=(timesteps * 1.5, 3.5))
    fig.suptitle(f"Jump Autoregressive Generation vs. GT (Sample {idx})", fontsize=16)
    #axes[0, 0].set_ylabel("Ground Truth", fontsize=12)
    for t in range(timesteps):
        axes[0, t].imshow(gt_traj[0, t, 0].cpu(), cmap="gray"); axes[0, t].axis("off")
    #axes[1, 0].set_ylabel("Generated", fontsize=12)
    for t in range(timesteps):
        axes[1, t].imshow(gen_traj[0, t, 0].cpu(), cmap="gray"); axes[1, t].axis("off")
    plt.tight_layout(rect=[0, 0, 1, 1.1])
    fig_path = output_dir / f"jump_autoregressive_comparison_{idx}.png"
    fig.savefig(fig_path)
    print(f"Saved visualization to: {fig_path}")
    plt.close(fig)

def main(run_path: str):
    run_path = Path(run_path)
    if not run_path.exists():
        print(f"ERROR: Provided run path does not exist: {run_path}")
        return
        
    print(f"--- Starting rule-based autoregressive evaluation for local run: {run_path.name} ---")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    output_dir = run_path / "autoregressive_evaluation"
    output_dir.mkdir(parents=True, exist_ok=True)

    print("Loading configuration from run directory...")
    config_path = run_path / "config.json"
    if not config_path.exists():
        print(f"ERROR: config.json not found in {run_path}")
        return
    config = load_json_config(str(config_path))
    set_seed(config.seed)

    print("Loading test dataset (full trajectories)...")
    test_data_path = Path("data") / f"test_data_size{config.data.image_size}_jump_prob0.3.pt"
    if not test_data_path.exists():
        print(f"ERROR: Test data not found at {test_data_path}.")
        return
    test_set = torch.load(test_data_path).to(device)

    print("Loading model from local file...")
    model_path = run_path / "best_model.pt"
    if not model_path.exists():
        print(f"ERROR: best_model.pt not found in {run_path}")
        return
        
    if config.model.network == "oai":
        net = UNetModel(
        image_size=config.data.image_size,
        in_channels=config.model.memory_length + 1,
        model_channels=config.model.model_channels,
        out_channels=3,
        num_res_blocks=config.model.num_res_blocks,
        attention_resolutions=tuple(config.model.attention_resolutions),
        channel_mult=tuple(config.model.channel_mult),
        dropout=config.training.dropout,
        use_scale_shift_norm=True
    ).to(device)
    elif config.model.network == "simple":
        net = UNet(in_channels=config.model.memory_length + 1, 
            out_channels=3, 
            base_channels=config.model.unet_base_channels
        ).to(device)
    net.load_state_dict(torch.load(model_path, map_location=device))
    print("Model loaded successfully.")
    
    inferred_mem_len = config.model.memory_length

    print("Calculating ground truth jump ratio...")
    gt_jump_ratio = calculate_gt_jump_ratio_vectorized(test_set, config.data.image_size, device)

    num_samples_to_analyze = 512 
    print(f"Generating and analyzing {num_samples_to_analyze} autoregressive samples in a single batch...")
    
    # --- Vectorized Sampling ---
    # Select a batch of random starting trajectories
    rand_indices = torch.randperm(test_set.shape[0])[:num_samples_to_analyze]
    gt_trajectories_batch = test_set[rand_indices]
    initial_frames_batch = gt_trajectories_batch[:, :inferred_mem_len]
    
    # Generate all trajectories in one go
    generated_trajectories_batch, total_stats, gen_noisy = autoregressive_sampler(
        net, initial_frames_batch, inferred_mem_len, config, device
    )
    
    # Visualize the first 10 samples from the generated batch
    for i in range(min(10, num_samples_to_analyze)):
        visualize_autoregressive_trajectories(
            generated_trajectories_batch[i:i+1], 
            gen_noisy[i:i+1], 
            i, 
            output_dir
        )

    report_lines = ["--- Jump-Diffusion Autoregressive Rule-Following Statistics ---"]
    gen_total_steps = num_samples_to_analyze * (config.data.no_timesteps - inferred_mem_len)
    if gen_total_steps > 0:
        well_formed_ratio = (total_stats['well_formed_cubes'] / gen_total_steps) * 100
        report_lines.append(f"Well-Formed Cube Ratio: {well_formed_ratio:.2f}%")
    gen_jump_ratio = (total_stats['total_generated_jumps'] / gen_total_steps) * 100 if gen_total_steps > 0 else 0
    report_lines.append(f"Generated Y-Jump Ratio: {gen_jump_ratio:.2f}% (Ground Truth is ~{gt_jump_ratio:.2f}%)")
    if total_stats['x_direction_total'] > 0:
        x_acc = (total_stats['x_direction_correct'] / total_stats['x_direction_total']) * 100
        report_lines.append(f"X-Direction Accuracy (based on memory): {x_acc:.2f}%")
    else:
        report_lines.append("X-Direction Accuracy: Not enough motion in memory to evaluate.")
    if total_stats['total_generated_jumps'] > 0:
        wrong_jump_ratio = (total_stats['determinable_wrong_jumps'] / total_stats['total_generated_jumps']) * 100
        report_lines.append(f"Determinable Wrong Jumps (of all generated jumps): {wrong_jump_ratio:.2f}%")
    else:
        report_lines.append("Determinable Wrong Jumps: No jumps were generated.")
    report_lines.append("------------------------------------------------")

    report = "\n".join(report_lines)
    print("\n" + report)
    stats_file_path = output_dir / "jump_autoregressive_stats.txt"
    with open(stats_file_path, 'w') as f: f.write(report)
    print(f"\nSaved statistics to: {stats_file_path}")
    print("\nAutoregressive evaluation for Jump-Diffusion model complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate a locally saved Jump-Diffusion model.")
    parser.add_argument("--run_path", type=str, required=True, help="Path to the local run directory (e.g., local_runs/jump_model/...).")
    args = parser.parse_args()
    main(args.run_path)


