import os
import time
import torch
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import logging

# Configuration & Model
from config.sdk.config import Config
from TLAD.sdk.model import TLADModel
from TLAD.sdk.utils import *

# Data Loader
from data_process.sdk.dataloader import get_data_loaders 

# Data Recording
from TLAD.sdk.data_record import SudokuDataRecorder

# ------------------------ Global Settings ------------------------
plt.switch_backend('Agg')  # Non-interactive backend 
ENERGY_RECORD_STEP = 2
SAMPLE_IDX_FOR_RECOED = 0
VISUALIZATION_FREQ= 5

# ------------------------ Main Training Pipeline ------------------------
def train():
    """
    Main TLAD training pipeline (two-phase).
    1. Phase 1: Perception module pre-training
    2. Phase 2: Thermodynamic reasoning training
    3. Post-training: Best model visualization for paper
    """
    # Initialize config and logging (exp ID for traceability)
    cfg = Config()
    logger, exp_id = setup_logging(cfg)
    logger.info("=== TLAD Training Pipeline Initialized ===")
    logger.info(f"Device: {cfg.DEVICE} | Batch Size: {cfg.BATCH_SIZE}")
    logger.info(f"Phase 1 Epochs: {cfg.EPOCHS_PHASE1} | Phase 2 Epochs: {cfg.EPOCHS_PHASE2}")
    
    # Data record for energy
    csv_dir = getattr(cfg, 'CSV_DIR', None)
    data_recorder = SudokuDataRecorder(csv_dir, exp_id)

    # Data loaders (GPU-optimized)
    train_loader, val_loader, test_loader = get_data_loaders(cfg)
    logger.info("Data loaders initialized (train/val/test)")
    
    # Model initialization
    model = TLADModel(cfg).to(cfg.DEVICE)
    if cfg.DEVICE == 'cuda':
        model = torch.nn.DataParallel(model)
        model = model.to(memory_format=torch.channels_last)
    logger.info("TLAD model initialized")
    
    # Metric tracking (val/test puzzle accuracy)
    phase1_metrics = {"val_puzzle": [], "test_puzzle": [], "epoch_times": []}
    phase2_metrics = {"val_puzzle": [], "test_puzzle": [], "epoch_times": []}
    
    # Best model tracking (for paper visualization)
    best_val_acc = 0.0
    best_val_epoch = 0
    best_val_phase = "Phase1"
    best_model_state = None
    best_optimizer_state = None
    best_metrics = None
    
    # ------------------------ Phase 1: Perception Module Pre-training ------------------------
    logger.info("\n=== Phase 1: Perception Module Pre-training ===")
    optimizer_phase1 = optim.Adam(model.parameters(), lr=cfg.LR_PHASE1)
    criterion = torch.nn.CrossEntropyLoss()
    
    for epoch in range(1, cfg.EPOCHS_PHASE1 + 1):
        epoch_start = time.perf_counter()
        
        # Training loop
        model.train()
        train_loop = tqdm(train_loader, desc=f"Phase1 E{epoch}")
        for batch in train_loop:
            puzzle = batch['puzzle'].to(cfg.DEVICE, non_blocking=True)
            solution = batch['solution'].to(cfg.DEVICE, non_blocking=True)
            
            logits = model(puzzle, mode='pretrain')
            loss = criterion(logits.view(-1, 9), (solution - 1).view(-1))
            
            optimizer_phase1.zero_grad()
            loss.backward()
            optimizer_phase1.step()
            
            train_loop.set_postfix(loss=f"{loss.item():.4f}")
        
        # Epoch timing
        epoch_time = time.perf_counter() - epoch_start
        phase1_metrics["epoch_times"].append(epoch_time)
        
        # Evaluation
        val_puzzle, val_cell = validate(model, val_loader, cfg.DEVICE, is_test=False)
        test_puzzle, test_cell = validate(model, test_loader, cfg.DEVICE, is_test=True)
        
        # Update metrics 
        phase1_metrics["val_puzzle"].append(val_puzzle * 100)
        phase1_metrics["test_puzzle"].append(test_puzzle * 100)
        
        # Update best model
        if val_puzzle > best_val_acc:
            best_val_acc = val_puzzle
            best_val_epoch = epoch
            best_val_phase = "Phase1"
            best_checkpoint_path = save_checkpoint(
                model=model, optimizer=optimizer_phase1, epoch=epoch,
                metrics={"val_puzzle": val_puzzle, "test_puzzle": test_puzzle},
                cfg=cfg, exp_id=exp_id, is_best=True
            )
        
        # Log results
        log_msg = (
            f"[Phase1 E{epoch}] "
            f"Val: Puzzle={val_puzzle:.1%}, Cell={val_cell:.1%} | "
            f"Test: Puzzle={test_puzzle:.1%}, Cell={test_cell:.1%} | "
            f"Time={epoch_time:.2f}s"
        )
        logger.info(log_msg)
        print(log_msg)
    
    # ------------------------ Phase 2: Thermodynamic Reasoning Training ------------------------
    logger.info("\n=== Phase 2: Thermodynamic Reasoning Training ===")
    # Layer-specific learning rates (free energy module has higher LR)
    optimizer_phase2 = optim.Adam([
        {'params': model.module.perception.parameters() if hasattr(model, 'module') else model.perception.parameters(), 'lr': cfg.LR_PHASE2},
        {'params': model.module.free_energy.parameters() if hasattr(model, 'module') else model.free_energy.parameters(), 'lr': cfg.LR_PHASE2 * 10},
        {'params': model.module.value_head.parameters() if hasattr(model, 'module') else model.value_head.parameters(), 'lr': cfg.LR_PHASE2},
        {'params': model.module.logit_head.parameters() if hasattr(model, 'module') else model.logit_head.parameters(), 'lr': cfg.LR_PHASE2},
    ], lr=cfg.LR_PHASE2)
    
    for epoch in range(1, cfg.EPOCHS_PHASE2 + 1):
        epoch_start = time.perf_counter()
        
        # Training loop
        model.train()
        train_loop = tqdm(train_loader, desc=f"Phase2 E{epoch}")
        for batch in train_loop:
            puzzle = batch['puzzle'].to(cfg.DEVICE, non_blocking=True)
            solution = batch['solution'].to(cfg.DEVICE, non_blocking=True)
            
            is_recording_batch = (epoch % ENERGY_RECORD_STEP == 0) and (batch == 0)
            outputs = model(puzzle, mode='reasoning', return_step_wise=is_recording_batch)

            loss_task = criterion(outputs['logits'].view(-1, 9), (solution - 1).view(-1))
            loss_energy = outputs['energy'][-1] * cfg.W_TASK
            total_loss = loss_task + loss_energy
            
            optimizer_phase2.zero_grad()
            total_loss.backward()
            optimizer_phase2.step()
            
            train_loop.set_postfix(loss=f"{total_loss.item():.4f}", energy=f"{outputs['energy'][-1]:.2f}")
        
            # record
            if is_recording_batch:
                with torch.no_grad():
                    sample_idx = SAMPLE_IDX_FOR_RECOED
                    step_energies = outputs['step_energy'] # List[float]
                    step_probs = outputs['step_probs']     # List[Tensor]
                    
                    for step_idx in range(len(step_energies)):
                        s_energy = step_energies[step_idx] 
                        s_prob = step_probs[step_idx][sample_idx] 
                        pred_grid = (s_prob.argmax(-1).cpu().numpy() + 1).reshape(9, 9)
                        err_tuple = calculate_sudoku_constraint_error(pred_grid)
                        data_recorder.append_step_energy(
                            epoch=epoch, 
                            batch_idx=batch, 
                            sample_idx=sample_idx,
                            step_idx=step_idx + 1, 
                            step_free_energy=s_energy,
                            constraint_error_tuple=err_tuple
                        )

        data_recorder.flush_all()

        # Epoch timing
        epoch_time = time.perf_counter() - epoch_start
        phase2_metrics["epoch_times"].append(epoch_time)
        
        # Evaluation
        val_puzzle, val_cell = validate(model, val_loader, cfg.DEVICE, is_test=False)
        test_puzzle, test_cell = validate(model, test_loader, cfg.DEVICE, is_test=True)
        
        # Update metrics
        phase2_metrics["val_puzzle"].append(val_puzzle * 100)
        phase2_metrics["test_puzzle"].append(test_puzzle * 100)
        
        # Save best model
        global_epoch = cfg.EPOCHS_PHASE1 + epoch
        if val_puzzle > best_val_acc:
            best_val_acc = val_puzzle
            best_val_epoch = global_epoch
            best_val_phase = "Phase2"
            best_checkpoint_path = save_checkpoint(
                model=model, optimizer=optimizer_phase2, epoch=global_epoch,
                metrics={"val_puzzle": val_puzzle, "test_puzzle": test_puzzle},
                cfg=cfg, exp_id=exp_id, is_best=True
            )
        
        # Log results
        log_msg = (
            f"[Phase2 E{epoch}] "
            f"Val: Puzzle={val_puzzle:.1%}, Cell={val_cell:.1%} | "
            f"Test: Puzzle={test_puzzle:.1%}, Cell={test_cell:.1%} | "
            f"Time={epoch_time:.2f}s"
        )
        logger.info(log_msg)
        print(log_msg)
        
        # Periodic visualization (reduce IO overhead)
        if epoch % VISUALIZATION_FREQ == 0:
            try:
                sample_batch = next(iter(val_loader))
                sample_puzzle = sample_batch['puzzle'][:1].to(cfg.DEVICE, non_blocking=True)
                
                with torch.no_grad():
                    vis_outputs = model(sample_puzzle, mode='reasoning')
                
                visualize_tlad_dynamics(sample_puzzle, vis_outputs, global_epoch, cfg, exp_id)
                logger.info(f"Visualization saved for Phase2 E{epoch} (Global E{global_epoch})")
            except Exception as e:
                logger.error(f"Visualization error: {str(e)}")
                print(f"Visualization error: {str(e)}")
    
    # ------------------------ Post-Training Analysis ------------------------
    # Plot accuracy curves (val/test)
    logger.info("\n=== Generating Accuracy Curves ===")
    plot_accuracy_curve(phase1_metrics["val_puzzle"], phase2_metrics["val_puzzle"], cfg, exp_id, curve_type="val")
    plot_accuracy_curve(phase1_metrics["test_puzzle"], phase2_metrics["test_puzzle"], cfg, exp_id, curve_type="test")
    
    # Training time metrics
    phase1_total = sum(phase1_metrics["epoch_times"])
    phase1_avg = phase1_total / len(phase1_metrics["epoch_times"]) if phase1_metrics["epoch_times"] else 0.0
    phase2_total = sum(phase2_metrics["epoch_times"])
    phase2_avg = phase2_total / len(phase2_metrics["epoch_times"]) if phase2_metrics["epoch_times"] else 0.0
    total_train_time = phase1_total + phase2_total
    
    # Best test accuracy
    all_test_accs = phase1_metrics["test_puzzle"] + phase2_metrics["test_puzzle"]
    best_test_acc = max(all_test_accs) if all_test_accs else 0.0
    best_test_epoch = all_test_accs.index(best_test_acc) + 1
    best_test_phase = "Phase1" if best_test_epoch <= len(phase1_metrics["test_puzzle"]) else "Phase2"
    best_test_phase_epoch = best_test_epoch if best_test_phase == "Phase1" else best_test_epoch - len(phase1_metrics["test_puzzle"])
    
    # Inference performance
    logger.info("\n=== Measuring Inference Performance ===")
    avg_inference_ms, throughput = measure_inference_time(model, test_loader, cfg.DEVICE)
    
    # Final summary
    final_summary = (
        "\n=== Final Experiment Summary ===\n"
        f"Total Training Time: {total_train_time:.2f}s (Phase1: {phase1_total:.2f}s | Phase2: {phase2_total:.2f}s)\n"
        f"Average Epoch Time: Phase1: {phase1_avg:.2f}s | Phase2: {phase2_avg:.2f}s\n"
        f"Best Validation Accuracy: {best_val_acc*100:.2f}% ({best_val_phase} E{best_val_epoch})\n"
        f"Best Test Accuracy: {best_test_acc:.2f}% ({best_test_phase} E{best_test_phase_epoch})\n"
        f"Inference Performance: {avg_inference_ms:.2f}ms/sample | {throughput:.2f} samples/sec\n"
        "=================================="
    )
    logger.info(final_summary)
    print(final_summary)

# ------------------------ Entry Point ------------------------
if __name__ == "__main__":
    # Fixed random seed
    seed_all(seed=0)
    # Launch training
    train()