import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import logging
import matplotlib.pyplot as plt
import numpy as np

# ------------------------ Global Settings ------------------------
plt.switch_backend('Agg')
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# ------------------------ Import Modules ------------------------
from config.maze.config_maze5 import Config 
from TLAD.maze.model import TLADModel
from TLAD.maze.utils import *
from data_process.maze.dataloader import get_maze_loaders 


# ===================== Loss: Phase 1 Authenticity =====================
class Phase1AuthenticityLoss(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.grid_size = cfg.GRID_SIZE  
        self.dice_weight = cfg.DICE
        self.degree_weight = cfg.DEGREE

    def compute_soft_degree(self, pred_probs):
        B, L = pred_probs.shape
        H = W = self.grid_size
        x = pred_probs.view(B, 1, H, W)
        kernel = torch.tensor([[[[0, 1, 0],
                                 [1, 0, 1],
                                 [0, 1, 0]]]], dtype=x.dtype, device=x.device)
        return F.conv2d(x, kernel, padding=1).squeeze(1)

    def forward(self, logits, targets, inputs):
        device = logits.device
        B, L = inputs.shape
        H = W = self.grid_size

        mask_class = (inputs != 1)
        if mask_class.sum() == 0:
            zero_loss = torch.tensor(0.0, device=device, requires_grad=True)
            self.metrics = {'L_ce': 0.0, 'L_dice': 0.0, 'L_degree': 0.0, 'total': 0.0}
            return zero_loss

        masked_logits = logits[mask_class]
        masked_targets = targets[mask_class]
        ce_loss = F.cross_entropy(masked_logits, masked_targets.long(), reduction='mean')

        probs = F.softmax(masked_logits, dim=-1)
        pred_path = probs[:, 1]
        target_path = (masked_targets == 1).float()
        inter = (pred_path * target_path).sum()
        union = pred_path.sum() + target_path.sum()
        dice_loss = 1.0 - (2.0 * inter + 1.0) / (union + 1.0)

        pred_probs_full = F.softmax(logits, dim=-1)[..., 1]
        target_grid = targets.view(B, H, W)
        input_grid = inputs.view(B, H, W)

        soft_degree = self.compute_soft_degree(pred_probs_full)
        target_degree = torch.zeros_like(soft_degree)
        is_start = (input_grid == 2)
        is_end = (input_grid == 3)
        is_path = (target_grid == 1)
        target_degree[is_start | is_end] = 1.0
        target_degree[is_path & ~(is_start | is_end)] = 2.0

        mask_degree = (target_grid == 1) & (input_grid != 1)
        if mask_degree.sum() > 0:
            degree_loss = torch.abs(soft_degree - target_degree)[mask_degree].mean()
        else:
            degree_loss = torch.tensor(0.0, device=device)

        total_loss = ce_loss + self.dice_weight * dice_loss + self.degree_weight * degree_loss
        self.metrics = {
            'L_ce': ce_loss.item(),
            'L_dice': dice_loss.item(),
            'L_degree': degree_loss.item(),
            'total': total_loss.item()
        }
        return total_loss


# ===================== Main Training Function =====================
def train(s):
    cfg = Config()    
    for d in [cfg.VIS_DIR, cfg.SAVE_DIR, cfg.LOG_DIR]:
        os.makedirs(d, exist_ok=True)
    logger, exp_id = setup_logging(cfg)
    
    train_loader, val_loader, test_loader = get_maze_loaders(cfg)
    model = TLADModel(cfg).to(cfg.DEVICE)
    
    best_val_maze = -1.0
    best_val_phase = None
    best_val_epoch = None

    best_test_maze = -1.0
    best_test_phase = None
    best_test_epoch = None

    # ------------------------ PHASE 1 ------------------------
    logger.info(">>> Starting PHASE 1: Authenticity Learning")
    criterion_p1 = Phase1AuthenticityLoss(cfg).to(cfg.DEVICE)
    optimizer1 = optim.Adam(model.parameters(), lr=cfg.LR_PHASE1)
    best_maze_acc = 0.0

    loss_history_p1 = {
        'epoch': [],
        'L_ce': [], 'L_dice': [], 'L_degree': [], 'total': [],
        'val_pixel': [], 'val_maze': []
    }
    
    phase1_epoch_times = []

    for epoch in range(1, cfg.EPOCHS_PHASE1 + 1):
        epoch_start = time.perf_counter()
        
        model.train()
        pbar = tqdm(train_loader, desc=f"P1-E{epoch}", ncols=110)
        
        total_ce = total_dice = total_degree = total_loss = 0.0
        num_batches = 0
        
        for batch in pbar:
            puzzle = batch['puzzle'].to(cfg.DEVICE)
            solution = batch['solution'].to(cfg.DEVICE)
            logits = model(puzzle, mode='pretrain')
            loss = criterion_p1(logits, solution, puzzle)
            
            optimizer1.zero_grad()
            loss.backward()
            optimizer1.step()
            
            metrics = criterion_p1.metrics
            total_ce += metrics['L_ce']
            total_dice += metrics['L_dice']
            total_degree += metrics['L_degree']
            total_loss += metrics['total']
            num_batches += 1
            
            pbar.set_postfix(
                Loss=f"{metrics['total']:.4f}",
                CE=f"{metrics['L_ce']:.3f}",
                Dice=f"{metrics['L_dice']:.3f}",
                Deg=f"{metrics['L_degree']:.3f}"
            )

        epoch_time = time.perf_counter() - epoch_start
        phase1_epoch_times.append(epoch_time)
        
        avg_ce = total_ce / num_batches
        avg_dice = total_dice / num_batches
        avg_degree = total_degree / num_batches
        avg_total = total_loss / num_batches

        val_pixel_acc, val_maze_acc = validate_maze(model, val_loader, cfg.DEVICE)
        test_pixel_acc, test_maze_acc = validate_maze(model, test_loader, cfg.DEVICE)
        
        # >>> 新增：更新全局最佳 val & test <<<
        current_phase = "Phase1"
        current_epoch = epoch

        if val_maze_acc > best_val_maze:
            best_val_maze = val_maze_acc
            best_val_phase = current_phase
            best_val_epoch = current_epoch

        if test_maze_acc > best_test_maze:
            best_test_maze = test_maze_acc
            best_test_phase = current_phase
            best_test_epoch = current_epoch

        logger.info(
            f"P1-E{epoch} | "
            f"Train Loss: {avg_total:.4f} (CE={avg_ce:.3f}, Dice={avg_dice:.3f}, Degree={avg_degree:.3f}) | "
            f"Val Pixel: {val_pixel_acc:.2%}, Val Maze: {val_maze_acc:.2%} | "
            f"Test Pixel: {test_pixel_acc:.2%}, Test Maze: {test_maze_acc:.2%} | "
            f"Time: {epoch_time:.2f}s"
        )
        
        # plot_maze_2sample(model, val_loader, cfg, "P1", epoch, cfg.VIS_DIR)
        
        loss_history_p1['epoch'].append(epoch)
        loss_history_p1['L_ce'].append(avg_ce)
        loss_history_p1['L_dice'].append(avg_dice)
        loss_history_p1['L_degree'].append(avg_degree)
        loss_history_p1['total'].append(avg_total)
        loss_history_p1['val_pixel'].append(val_pixel_acc)
        loss_history_p1['val_maze'].append(val_maze_acc)
        
        if val_maze_acc >= best_maze_acc:
            best_maze_acc = val_maze_acc
            save_best_checkpoint(
                model, optimizer1, epoch, 
                {
                    "val_maze_acc": val_maze_acc,
                    "test_maze_acc_at_val_best": test_maze_acc,
                    "val_pixel_acc": val_pixel_acc,
                    "test_pixel_acc": test_pixel_acc
                }, 
                cfg, exp_id
            )

    plot_loss_components(loss_history_p1, os.path.join(cfg.VIS_DIR, "P1_loss_components.png"))
    logger.info(">>> PHASE 1 completed.")

    # ------------------------ PHASE 2 ------------------------
    logger.info("\n>>> Starting PHASE 2: Thermodynamic Reasoning Training")
    
    criterion_p2 = Phase1AuthenticityLoss(cfg).to(cfg.DEVICE)
    
    optimizer2 = optim.Adam([
        {'params': model.perception.parameters(), 'lr': cfg.LR_PHASE2},
        {'params': model.stem.parameters(), 'lr': cfg.LR_PHASE2},
        {'params': model.value_head.parameters(), 'lr': cfg.LR_PHASE2},
        {'params': model.logit_head.parameters(), 'lr': cfg.LR_PHASE2},
        {'params': model.free_energy.parameters(), 'lr': cfg.LR_PHASE2 * 10},
    ], lr=cfg.LR_PHASE2)
    
    best_val_maze_acc_phase2 = 0.0
    phase2_epoch_times = []

    for epoch in range(1, cfg.EPOCHS_PHASE2 + 1):
        epoch_start = time.perf_counter()
        
        model.train()
        pbar = tqdm(train_loader, desc=f"P2-E{epoch}", ncols=110)
        total_loss_epoch = 0.0
        
        for batch in pbar:
            puzzle = batch['puzzle'].to(cfg.DEVICE)
            solution = batch['solution'].to(cfg.DEVICE)
            
            outputs = model(puzzle, mode='reasoning')
            logits = outputs['logits']
            last_energy = outputs['energy'][-1] if outputs['energy'] else 0.0
            
            loss_task = criterion_p2(logits, solution, puzzle)
            total_loss = loss_task + last_energy * cfg.W_TASK 
            
            optimizer2.zero_grad()
            total_loss.backward()
            optimizer2.step()
            
            pbar.set_postfix(
                Loss=f"{total_loss.item():.4f}",
                Task=f"{loss_task.item():.4f}",
                Energy=f"{last_energy:.2f}",
                L_wall=f"{model.free_energy.lambda_wall.item():.1f}"
            )
            total_loss_epoch += total_loss.item()
        
        epoch_time = time.perf_counter() - epoch_start
        phase2_epoch_times.append(epoch_time)
        
        avg_loss = total_loss_epoch / len(train_loader)
        val_pixel, val_maze = validate_maze(model, val_loader, cfg.DEVICE)
        test_pixel, test_maze = validate_maze(model, test_loader, cfg.DEVICE)
        
        current_phase = "Phase2"
        current_epoch = epoch

        if val_maze > best_val_maze:
            best_val_maze = val_maze
            best_val_phase = current_phase
            best_val_epoch = current_epoch

        if test_maze > best_test_maze:
            best_test_maze = test_maze
            best_test_phase = current_phase
            best_test_epoch = current_epoch

        logger.info(
            f"P2-E{epoch} | "
            f"Val Pixel: {val_pixel:.2%}, Val Maze: {val_maze:.2%} | "
            f"Test Pixel: {test_pixel:.2%}, Test Maze: {test_maze:.2%} | "
            f"Avg Loss: {avg_loss:.4f} | Time: {epoch_time:.2f}s"
        )
        
        # visualization
        # plot_maze_2sample(model, val_loader, cfg, "P2", epoch, cfg.VIS_DIR)
        
        if val_maze >= best_val_maze_acc_phase2:
            best_val_maze_acc_phase2 = val_maze
            save_best_checkpoint(
                model, optimizer2, epoch + cfg.EPOCHS_PHASE1,
                {
                    "val_maze_acc": val_maze,
                    "test_maze_acc_at_val_best": test_maze,
                    "val_pixel_acc": val_pixel,
                    "test_pixel_acc": test_pixel,
                    "phase": "Phase2"
                },
                cfg, exp_id + "_phase2"
            )

    # ------------------------ Final Summary ------------------------
    phase1_total = sum(phase1_epoch_times)
    phase1_avg = phase1_total / len(phase1_epoch_times) if phase1_epoch_times else 0.0
    phase2_total = sum(phase2_epoch_times)
    phase2_avg = phase2_total / len(phase2_epoch_times) if phase2_epoch_times else 0.0
    total_train_time = phase1_total + phase2_total
    
    logger.info("\n=== Measuring Inference Performance ===")
    avg_inference_ms, throughput = measure_inference_time(model, test_loader, cfg.DEVICE)
    
    final_summary = (
        "\n=== Final Experiment Summary ===\n"
        f">>> SEED: {s}"
        f"Total Training Time: {total_train_time:.2f}s (Phase1: {phase1_total:.2f}s | Phase2: {phase2_total:.2f}s)\n"
        f"Average Epoch Time: Phase1: {phase1_avg:.2f}s | Phase2: {phase2_avg:.2f}s\n"
        f"Inference Performance: {avg_inference_ms:.2f}ms/sample | {throughput:.2f} samples/sec\n\n"
        f"Best Validation Maze Accuracy:\n"
        f"{best_val_maze*100:.2f}% achieved in {best_val_phase}, Epoch {best_val_epoch}\n\n"
        f"Best Test Maze Accuracy:\n"
        f"{best_test_maze*100:.2f}% achieved in {best_test_phase}, Epoch {best_test_epoch}\n"
        "=================================="
    )
    logger.info(final_summary)
    print(final_summary)

    logger.info("\n=== TLAD Maze Training Complete ===")


if __name__ == "__main__":
    def seed_all(seed):
        import random
        import os
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        os.environ['PYTHONHASHSEED'] = str(seed)
    s = 42  # 0 1234 42 （2026 228）
    seed_all(s)
    train(s)