import os
import re
import time
import torch
import torch.nn as nn
from tqdm import tqdm
import logging
import random
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F


# -------------------------- Logging Setup --------------------------
def setup_logging(cfg):
    base_exp_id = (
        f"maze_s{cfg.EBA_STEPS}"
        f"_d{cfg.HIDDEN_DIM}"
        f"_l{int(cfg.LAMBDA_INIT)}"
    )
    exp_id = f"{base_exp_id}"
    log_dir = os.path.join(cfg.LOG_DIR)
    os.makedirs(log_dir, exist_ok=True)
    
    root = logging.getLogger()
    if root.handlers:
        for handler in root.handlers:
            root.removeHandler(handler)
            
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler(os.path.join(log_dir, f"train_{exp_id}.log"), encoding="utf-8"),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger("TLAD-Maze"), exp_id


# -------------------------- Checkpoint Save/Load --------------------------
def save_best_checkpoint(model, optimizer, epoch, metrics, cfg, exp_id):
    checkpoint_dir = cfg.SAVE_DIR
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    raw_model = model.module if isinstance(model, nn.DataParallel) else model
    
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": raw_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict() if optimizer else None,
        "metrics": metrics,
        "config_name": cfg.__class__.__name__
    }
    
    save_path = os.path.join(checkpoint_dir, f"{exp_id}_best.pth")
    torch.save(checkpoint, save_path)
    
    logger = logging.getLogger("TLAD-Maze")    
    return save_path

def load_checkpoint(model, load_path, device):
    if not os.path.exists(load_path):
        raise FileNotFoundError(f"Checkpoint not found: {load_path}")
    
    checkpoint = torch.load(load_path, map_location=device)
    state_dict = checkpoint["model_state_dict"]
    
    pattern = re.compile(r'^module\.')
    cleaned_state = {pattern.sub('', k): v for k, v in state_dict.items()}
    
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(cleaned_state)
    else:
        model.load_state_dict(cleaned_state)
        
    return checkpoint.get("epoch", 0), checkpoint.get("metrics", {})


# -------------------------- Maze-Specific Evaluation --------------------------
def validate_maze(model, loader, device, is_test=False):
    model.eval()
    correct_pixels = 0
    total_pixels = 0
    correct_mazes = 0
    total_mazes = 0
    
    desc = "Testing" if is_test else "Validating"
    
    with torch.no_grad():
        for batch in tqdm(loader, desc=desc, leave=False):
            p = batch['puzzle'].to(device, non_blocking=True)
            s = batch['solution'].to(device, non_blocking=True)
            
            out = model(p, mode='reasoning')
            preds = out['logits'].argmax(dim=-1)
            
            correct_pixels += (preds == s).sum().item()
            total_pixels += s.numel()
            
            maze_match = (preds == s).all(dim=1).sum().item()
            correct_mazes += maze_match
            total_mazes += p.size(0)
            
    return correct_pixels / total_pixels, correct_mazes / total_mazes


# ===================== Visualization: Same as Phase 1 =====================
def plot_maze_2sample(model, loader, cfg, phase_name, epoch, save_dir):
    torch.cuda.empty_cache()
    model.eval()
    COLOR_MAP = {
        "wall": [0.2, 0.2, 0.2],
        "passage": [1.0, 1.0, 1.0],
        "true_path": [0.2, 0.7, 0.9],
        "pred_path": [1.0, 1.0, 0.0],
        "start": "green",
        "end": "red"
    }
    N = cfg.GRID_SIZE
    sample_count = 0
    
    with torch.no_grad():
        for batch in loader:
            puzzle = batch['puzzle'].to(cfg.DEVICE)
            solution = batch['solution'].to(cfg.DEVICE)
            
            if phase_name.startswith("P2"):
                outputs = model(puzzle, mode='reasoning')
                logits = outputs['logits']
            else:
                logits = model(puzzle, mode='pretrain')
                
            pred_probs = F.softmax(logits, dim=-1)[..., 1]
            pred_path = (pred_probs > 0.5).float()

            for b_idx in range(min(puzzle.shape[0], 2)):
                if sample_count >= 2:
                    break
                sample_count += 1

                m_grid = puzzle[b_idx].cpu().numpy().reshape(N, N)
                t_grid = solution[b_idx].cpu().numpy().reshape(N, N)
                p_grid = pred_path[b_idx].cpu().numpy().reshape(N, N)
                p_grid[m_grid == 1] = 0

                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8), facecolor=COLOR_MAP["passage"])

                def draw_maze(ax, title, path_grid=None, path_color=None):
                    img = np.full((N, N, 3), COLOR_MAP["passage"], dtype=np.float32)
                    img[m_grid == 1] = COLOR_MAP["wall"]
                    if path_grid is not None and path_color is not None:
                        img[path_grid == 1] = path_color

                    ax.imshow(img, origin='upper', interpolation='nearest')

                    start_r, start_c = np.where(m_grid == 2)
                    end_r, end_c = np.where(m_grid == 3)
                    if len(start_r) > 0:
                        ax.scatter(start_c[0], start_r[0], color=COLOR_MAP["start"], s=250, edgecolor='black', zorder=100)
                    if len(end_r) > 0:
                        ax.scatter(end_c[0], end_r[0], color=COLOR_MAP["end"], s=250, edgecolor='black', zorder=100)

                    ax.set_title(title, color='black', fontsize=16, pad=10)
                    ax.axis('on')
                    ax.tick_params(color='black', labelcolor='black')
                    for spine in ax.spines.values():
                        spine.set_color('black')
                        spine.set_linewidth(1)

                draw_maze(ax1, "Left: True Solution (Light Blue)", t_grid, COLOR_MAP["true_path"])
                draw_maze(ax2, "Right: Prediction Only (Yellow)", p_grid, COLOR_MAP["pred_path"])

                plt.tight_layout()
                save_path = os.path.join(save_dir, f"{phase_name}_E{epoch}_S{sample_count}.png")
                plt.savefig(save_path, facecolor=COLOR_MAP["passage"], bbox_inches='tight', dpi=100)
                plt.close()

    model.train()


# ===================== Plot Loss Components =====================
def plot_loss_components(history, save_path):
    epochs = history['epoch']
    plt.figure(figsize=(14, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['total'], label='Total Loss', linewidth=2, color='black')
    plt.plot(epochs, history['L_ce'], label='CE Loss', linestyle='--')
    plt.plot(epochs, history['L_dice'], label='Dice Loss', linestyle='--')
    plt.plot(epochs, history['L_degree'], label='Degree Loss', linestyle='--')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Phase 1: Training Loss Components')
    plt.legend()
    plt.grid(True, linestyle=':', alpha=0.7)

    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['val_pixel'], label='Val Pixel Acc', marker='o')
    plt.plot(epochs, history['val_maze'], label='Val Maze Acc', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Phase 1: Validation Accuracy')
    plt.legend()
    plt.grid(True, linestyle=':', alpha=0.7)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()


# ===================== Measure Inference Time =====================
def measure_inference_time(model, test_loader, device, num_warmup=5, num_measure=20):
    model.eval()
    total_time = 0.0
    total_samples = 0
    
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if i >= num_warmup:
                break
            puzzle = batch['puzzle'].to(device, non_blocking=True)
            model(puzzle, mode='reasoning')
            if device == 'cuda':
                torch.cuda.synchronize()
    
    with torch.no_grad():
        loop = tqdm(test_loader, desc="Measuring inference time")
        for i, batch in enumerate(loop):
            if i >= num_measure:
                break
            puzzle = batch['puzzle'].to(device, non_blocking=True)
            batch_size = puzzle.shape[0]
            total_samples += batch_size
            
            if device == 'cuda':
                torch.cuda.synchronize()
            start = time.perf_counter()
            model(puzzle, mode='reasoning')
            if device == 'cuda':
                torch.cuda.synchronize()
            end = time.perf_counter()
            
            total_time += (end - start)
    
    avg_time_per_sample = (total_time / total_samples) * 1000
    throughput = total_samples / total_time
    return avg_time_per_sample, throughput


# def get_valid_path_region(inputs: torch.Tensor, grid_size: int) -> torch.Tensor:
#     B, L = inputs.shape
#     N = int(grid_size) 
#     assert L == N * N, f"Expected {N*N}, got {L}"
    
#     device = inputs.device
#     grid = inputs.view(B, N, N)
#     is_wall = (grid == 1)
#     is_start = (grid == 2)
#     is_end = (grid == 3)
#     traversable = ~is_wall

#     # Kernel for 4-connectivity
#     kernel = torch.tensor([[[[0, 1, 0],
#                              [1, 0, 1],
#                              [0, 1, 0]]]], dtype=torch.float32, device=device)

#     # === Step 1: Compute dist from start ===
#     dist_S = torch.full((B, N, N), float('inf'), device=device)
#     dist_S[is_start] = 0.0

#     # Iterative distance propagation (BFS via convolution)
#     max_steps = int(0.5 * N * N)
#     for _ in range(max_steps):
#         padded = F.pad(dist_S.unsqueeze(1), (1,1,1,1), value=float('inf'))
#         neighbors = torch.cat([
#             padded[:, :, :-2, 1:-1],
#             padded[:, :, 2:, 1:-1],
#             padded[:, :, 1:-1, :-2],
#             padded[:, :, 1:-1, 2:],
#         ], dim=1)
#         min_neighbor = torch.min(neighbors, dim=1).values
#         new_dist = torch.minimum(dist_S, min_neighbor + 1.0)
#         updated = torch.where(traversable, new_dist, torch.tensor(float('inf'), device=device))
#         if torch.equal(dist_S, updated):
#             break
#         dist_S = updated

#     # === Step 2: Compute dist from end ===
#     dist_E = torch.full((B, N, N), float('inf'), device=device)
#     dist_E[is_end] = 0.0

#     for _ in range(max_steps):
#         padded = F.pad(dist_E.unsqueeze(1), (1,1,1,1), value=float('inf'))
#         neighbors = torch.cat([
#             padded[:, :, :-2, 1:-1],
#             padded[:, :, 2:, 1:-1],
#             padded[:, :, 1:-1, :-2],
#             padded[:, :, 1:-1, 2:],
#         ], dim=1)
#         min_neighbor = torch.min(neighbors, dim=1).values
#         new_dist = torch.minimum(dist_E, min_neighbor + 1.0)
#         updated = torch.where(traversable, new_dist, torch.tensor(float('inf'), device=device))
#         if torch.equal(dist_E, updated):
#             break
#         dist_E = updated

#     # === Step 3: Extract path via distance sum ===
#     # Total shortest distance from S to E
#     total_dist = dist_S[is_end]  # (num_end_points,) — but we need per-batch

#     # Since each maze has exactly one end, reshape
#     total_dist = total_dist.view(B, -1).min(dim=1).values  # (B,), handle potential multi-end safely

#     # Broadcast to (B, N, N)
#     total_dist = total_dist.view(B, 1, 1)

#     # Path condition: dist_S + dist_E == total_dist
#     on_path_2d = (torch.abs(dist_S + dist_E - total_dist) < 1e-3) & traversable

#     return on_path_2d.view(B, L)


__precomputed_valid_map = {}

def register_precomputed_valid_regions(puzzles: torch.Tensor, valid_regions: torch.Tensor):
    """
    Internal use only. Register precomputed valid regions for acceleration.
    """
    for i in range(puzzles.size(0)):
        key = puzzles[i].cpu().numpy().tobytes()
        __precomputed_valid_map[key] = valid_regions[i]


def clear_precomputed_valid_regions():
    __precomputed_valid_map.clear()


def get_valid_path_region(inputs: torch.Tensor, grid_size: int) -> torch.Tensor:
    B, L = inputs.shape
    N = int(grid_size)
    assert L == N * N, f"Expected {N*N}, got {L}"
    try:
        results = []
        for b in range(B):
            key = inputs[b].cpu().numpy().tobytes()
            if key in __precomputed_valid_map:
                results.append(__precomputed_valid_map[key])
            else:
                raise KeyError
        return torch.stack(results, dim=0).to(inputs.device)
    except (KeyError, RuntimeError):
        pass

    device = inputs.device
    grid = inputs.view(B, N, N)
    is_wall = (grid == 1)
    is_start = (grid == 2)
    is_end = (grid == 3)
    traversable = ~is_wall

    inf_val = float('inf')
    dist_S = torch.full((B, N, N), inf_val, device=device)
    dist_S[is_start] = 0.0

    max_steps = int(0.5 * N * N)
    for _ in range(max_steps):
        padded = F.pad(dist_S.unsqueeze(1), (1,1,1,1), value=inf_val)
        neighbors = torch.cat([
            padded[:, :, :-2, 1:-1],
            padded[:, :, 2:, 1:-1],
            padded[:, :, 1:-1, :-2],
            padded[:, :, 1:-1, 2:],
        ], dim=1)
        min_neighbor = torch.min(neighbors, dim=1).values
        new_dist = torch.minimum(dist_S, min_neighbor + 1.0)
        updated = torch.where(traversable, new_dist, torch.tensor(inf_val, device=device))
        if torch.equal(dist_S, updated):
            break
        dist_S = updated

    dist_E = torch.full((B, N, N), inf_val, device=device)
    dist_E[is_end] = 0.0
    for _ in range(max_steps):
        padded = F.pad(dist_E.unsqueeze(1), (1,1,1,1), value=inf_val)
        neighbors = torch.cat([
            padded[:, :, :-2, 1:-1],
            padded[:, :, 2:, 1:-1],
            padded[:, :, 1:-1, :-2],
            padded[:, :, 1:-1, 2:],
        ], dim=1)
        min_neighbor = torch.min(neighbors, dim=1).values
        new_dist = torch.minimum(dist_E, min_neighbor + 1.0)
        updated = torch.where(traversable, new_dist, torch.tensor(inf_val, device=device))
        if torch.equal(dist_E, updated):
            break
        dist_E = updated

    total_dist = (dist_S + dist_E)[is_end].view(B, -1).min(dim=1).values
    total_dist = total_dist.view(B, 1, 1)
    on_path_2d = (torch.abs(dist_S + dist_E - total_dist) < 1e-3) & traversable
    return on_path_2d.view(B, L)