import os
import sys
import traceback

# Debug information
print("🔍 DEBUG: Starting train.py")
print(f"🔍 DEBUG: Current working directory: {os.getcwd()}")
print(f"🔍 DEBUG: Python path: {sys.path[:3]}...")  # First 3 entries
print(f"🔍 DEBUG: train.py file location: {__file__}")

try:
    import torch
    import torch.nn as nn
    import matplotlib.pyplot as plt
    from tqdm import tqdm
    import json
    import time
    from datetime import datetime
    import numpy as np
    print("✅ DEBUG: Standard library imports successful")
except ImportError as e:
    print(f"❌ DEBUG: Standard library import failed: {e}")
    traceback.print_exc()
    sys.exit(1)

# Add current directory to Python path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
print(f"🔍 DEBUG: Added to path: {os.path.dirname(os.path.abspath(__file__))}")

# Try local imports with detailed error reporting
try:
    print("🔍 DEBUG: Attempting to import data.dl_getter...")
    from data.dl_getter import get_dataloader
    print("✅ DEBUG: data.dl_getter import successful")
except ImportError as e:
    print(f"❌ DEBUG: data.dl_getter import failed: {e}")
    print("🔍 DEBUG: Checking data directory...")
    data_dir = os.path.join(os.path.dirname(__file__), 'data')
    print(f"   Data directory exists: {os.path.exists(data_dir)}")
    if os.path.exists(data_dir):
        print(f"   Data directory contents: {os.listdir(data_dir)}")
    traceback.print_exc()
    sys.exit(1)

try:
    print("🔍 DEBUG: Attempting to import utils...")
    from utils.get_args import get_args
    from utils.accuracy import topk_accuracy, should_compute_top5
    print("✅ DEBUG: utils imports successful")
except ImportError as e:
    print(f"❌ DEBUG: utils import failed: {e}")
    print("🔍 DEBUG: Checking utils directory...")
    utils_dir = os.path.join(os.path.dirname(__file__), 'utils')
    print(f"   Utils directory exists: {os.path.exists(utils_dir)}")
    if os.path.exists(utils_dir):
        print(f"   Utils directory contents: {os.listdir(utils_dir)}")
    traceback.print_exc()
    sys.exit(1)

try:
    print("🔍 DEBUG: Attempting to import models...")
    from models.factory import create_pcn_from_backbone
    from models.normalization.unified_weight_norm import apply_unified_norm
    print("✅ DEBUG: models imports successful")
except ImportError as e:
    print(f"❌ DEBUG: models import failed: {e}")
    print("🔍 DEBUG: Checking models directory...")
    models_dir = os.path.join(os.path.dirname(__file__), 'models')
    print(f"   Models directory exists: {os.path.exists(models_dir)}")
    if os.path.exists(models_dir):
        print(f"   Models directory contents: {os.listdir(models_dir)}")
    traceback.print_exc()
    sys.exit(1)

print("✅ DEBUG: All imports successful, proceeding with train.py execution")

# c_shared not available in remote environment - create local fallback
def save_args_to_yaml(args, filepath):
    """Local fallback for save_args_to_yaml"""
    import yaml
    import os

    # If filepath is a directory, create a proper filename
    if os.path.isdir(filepath):
        filepath = os.path.join(filepath, 'args.yaml')

    # Ensure directory exists
    os.makedirs(os.path.dirname(filepath), exist_ok=True)

    args_dict = vars(args) if hasattr(args, '__dict__') else args
    with open(filepath, 'w') as f:
        yaml.dump(args_dict, f, default_flow_style=False)



def initialize_detailed_logging(args):
    """Initialize detailed experiment logging"""
    if not args.save_detailed_metrics:
        return None

    experiment_data = {
        'experiment_name': args.experiment_name or "exp_{}".format(datetime.now().strftime('%Y%m%d_%H%M%S')),
        'start_time': datetime.now().isoformat(),
        'args': {k: str(v) for k, v in vars(args).items()},  # Convert all to string for JSON serialization
        'epochs': [],
        'final_results': {}
    }

    return experiment_data

def log_epoch_metrics(experiment_data, epoch, tr_loss, tr_acc, te_loss, te_acc, elapsed_time, te_acc_top5=None):
    """Log epoch metrics to detailed experiment data"""
    if experiment_data is None:
        return

    epoch_data = {
        'epoch': epoch,
        'train_loss': float(tr_loss),
        'train_accuracy': float(tr_acc),
        'elapsed_time': elapsed_time,
        'timestamp': datetime.now().isoformat()
    }

    # Add test metrics if provided
    if te_loss is not None and te_acc is not None:
        epoch_data['test_loss'] = float(te_loss)
        epoch_data['test_accuracy'] = float(te_acc)
        if te_acc_top5 is not None:
            epoch_data['test_accuracy_top5'] = float(te_acc_top5)

    experiment_data['epochs'].append(epoch_data)

def save_detailed_metrics(experiment_data, args, milestone_data=None):
    """Save final detailed metrics to JSON file"""
    if experiment_data is None:
        return

    experiment_data['end_time'] = datetime.now().isoformat()
    if experiment_data['epochs']:
        experiment_data['final_results'] = {
            'final_train_accuracy': experiment_data['epochs'][-1]['train_accuracy'],
            'final_test_accuracy': experiment_data['epochs'][-1]['test_accuracy'],
            'best_test_accuracy': max([e['test_accuracy'] for e in experiment_data['epochs']]),
            'total_epochs_logged': len(experiment_data['epochs'])
        }

        # Add milestone accuracy data if provided
        if milestone_data:
            experiment_data['final_results'].update({
                'best_acc_until_20': milestone_data['best_acc_20'],
                'best_acc_until_50': milestone_data['best_acc_50'],
                'best_acc_until_100': milestone_data['best_acc_100'],
                'best_acc_overall': milestone_data['best_acc_overall'],
                'best_epoch_at_20': milestone_data['best_epoch_20'],
                'best_epoch_at_50': milestone_data['best_epoch_50'],
                'best_epoch_at_100': milestone_data['best_epoch_100'],
                'best_epoch_overall': milestone_data['best_epoch_overall']
            })

    filename = os.path.join(args.log_dir, 'detailed_metrics.json')
    with open(filename, 'w') as f:
        json.dump(experiment_data, f, indent=2)

    print("Detailed metrics saved to {}".format(filename))

def train_one_epoch(model, optimizer, train_loader, args, experiment_data=None, epoch=None):
    model.backbone_model.train()
    total_loss = 0
    correct = 0
    correct_forward = 0
    cnt = 0
    pbar = tqdm(train_loader)

    batch_history = []

    for batch_idx, (x, y) in enumerate(pbar):
        x, y = x.to(args.device), y.to(args.device)

        # PCN training step - now returns both initial and final predictions
        output_dict = model.forward(x, y, optimizer)  # Pass optimizer like parent branch

        # Get final converged prediction for main Acc
        pred_final = output_dict['pred'].argmax(dim=1)
        correct += (pred_final == y).sum().item()

        # Get initial feedforward prediction (before PCN convergence) for Acc_forward
        if 'initial_pred' in output_dict:
            pred_forward = output_dict['initial_pred'].argmax(dim=1)
        else:
            # Fallback: use final prediction as initial (for compatibility)
            pred_forward = pred_final
        correct_forward += (pred_forward == y).sum().item()

        loss = output_dict['loss']  # Extract loss from dict
        total_loss += loss.item() * x.size(0)
        cnt += x.size(0)

        # Reset solver memory for next batch
        if hasattr(model, 'reset_solver_memory'):
            model.reset_solver_memory()

        # Track batch-level metrics if detailed logging enabled
        if experiment_data is not None and hasattr(args, 'track_batch_metrics') and args.track_batch_metrics:
            batch_acc = (pred_forward == y).sum().item() / x.size(0)
            batch_history.append({
                'epoch': epoch,
                'batch': batch_idx,
                'loss': loss.item(),
                'accuracy': batch_acc,
            })

        pbar.set_description("Train Loss: {:.4f}, Acc: {:.4f}, Acc_forward: {:.4f}".format(
            total_loss/cnt, correct/cnt, correct_forward/cnt))

    # Store batch history if tracking enabled
    if experiment_data is not None and batch_history:
        if 'batch_history' not in experiment_data:
            experiment_data['batch_history'] = []
        experiment_data['batch_history'].extend(batch_history)

    return total_loss / cnt, correct / cnt

def eval_one_epoch(model, test_loader, args):
    model.backbone_model.eval()
    total_loss = 0
    correct = 0
    correct_top5 = 0
    cnt = 0
    compute_top5 = should_compute_top5(args.dataset)

    with torch.no_grad():
        pbar = tqdm(test_loader)
        for x, y in pbar:
            x, y = x.to(args.device), y.to(args.device)

            # Use predict_forward for evaluation (no PCN iterations)
            output_dict = model.predict_forward(x)

            if compute_top5:
                acc1, acc5 = topk_accuracy(output_dict['pred'], y, topk=(1, 5))
                correct += acc1.item() * x.size(0) / 100.0
                correct_top5 += acc5.item() * x.size(0) / 100.0
            else:
                preds = output_dict['pred'].argmax(dim=1)
                correct += (preds == y).sum().item()

            cnt += x.size(0)

    if compute_top5:
        return total_loss / cnt, correct / cnt, correct_top5 / cnt
    else:
        return total_loss / cnt, correct / cnt

def main():
    """Main training function"""
    args = get_args()

    # Save args to yaml for experiment tracking
    save_args_to_yaml(args, args.log_dir)

    # Initialize detailed logging
    experiment_data = initialize_detailed_logging(args)

    # Get dataloaders
    dataloader_dict = get_dataloader(args)
    train_loader = dataloader_dict['train_loader']
    test_loader = dataloader_dict['test_loader']

    # Create PCN model using args directly to preserve all settings
    from models.factory import create_pcn_with_external_backbone
    model = create_pcn_with_external_backbone(
        backbone_name=args.backbone,
        model_type='pcn_jacobi',
        solver_type=args.solver_type,
        use_meta_pc=(args.energy_option == 'meta_pc'),
        num_classes=args.num_classes,
        T=args.T,
        eta=args.eta,
        norm_type=getattr(args, 'norm_type', 'none'),
        # Pass loop_scheduler settings from args
        loop_scheduler=getattr(args, 'loop_scheduler', None),
        update_latent_rule=getattr(args, 'update_latent_rule', 'block_sweep_gs')
    )

    model.backbone_model.to(args.device)
    print("Model: {} | Solver: {} | Backbone modules: {} | Norm: {}".format(
        model.config.model_type, model.config.solver_type,
        len(model.backbone_model.backbone_module_list), args.norm_type))

    # Optimizer
    optimizer = torch.optim.AdamW(model.backbone_model.parameters(),
                                  lr=args.lr, weight_decay=args.weight_decay)

    # Training loop with configurable logging interval
    train_losses, test_losses = [], []
    train_accs, test_accs = [], []
    start_time = time.time()

    # Track best accuracies at different milestones
    best_acc_20 = 0.0
    best_acc_50 = 0.0
    best_acc_100 = 0.0
    best_acc_overall = 0.0
    best_epoch_20 = 0
    best_epoch_50 = 0
    best_epoch_100 = 0
    best_epoch_overall = 0

    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()

        # Always train, but log/evaluate based on interval
        tr_loss, tr_acc = train_one_epoch(model, optimizer, train_loader, args, experiment_data, epoch)

        # DEPLOYMENT MODE: Always evaluate every epoch for accurate tracking
        # Original: should_log = (epoch % args.log_interval == 0) or (epoch == 1) or (epoch == args.epochs)
        should_log = True  # Force evaluation every epoch for deployment

        if should_log:
            compute_top5 = should_compute_top5(args.dataset)
            if compute_top5:
                te_loss, te_acc, te_acc_top5 = eval_one_epoch(model, test_loader, args)
            else:
                te_loss, te_acc = eval_one_epoch(model, test_loader, args)
                te_acc_top5 = None
            elapsed_time = time.time() - epoch_start_time

            # Store results
            train_losses.append(tr_loss)
            test_losses.append(te_loss)
            train_accs.append(tr_acc)
            test_accs.append(te_acc)

            # Update best accuracy milestones
            if te_acc > best_acc_overall:
                best_acc_overall = te_acc
                best_epoch_overall = epoch

            if epoch <= 20 and te_acc > best_acc_20:
                best_acc_20 = te_acc
                best_epoch_20 = epoch

            if epoch <= 50 and te_acc > best_acc_50:
                best_acc_50 = te_acc
                best_epoch_50 = epoch

            if epoch <= 100 and te_acc > best_acc_100:
                best_acc_100 = te_acc
                best_epoch_100 = epoch

            # Log epoch metrics
            log_epoch_metrics(experiment_data, epoch, tr_loss, tr_acc, te_loss, te_acc, elapsed_time, te_acc_top5)

            # Log to file
            with open(os.path.join(args.log_dir, 'metrics.log'), 'a') as f:
                if compute_top5:
                    f.write("{}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n".format(
                        epoch, tr_loss, tr_acc, te_loss, te_acc, te_acc_top5))
                else:
                    f.write("{}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\n".format(
                        epoch, tr_loss, tr_acc, te_loss, te_acc))

            # Enhanced output with best accuracy info
            if compute_top5:
                output_str = "Epoch {:03d} | Train Loss: {:.4f}, Acc: {:.4f} | Test Loss: {:.4f}, Acc: {:.4f}, Top5: {:.4f}".format(
                    epoch, tr_loss, tr_acc, te_loss, te_acc, te_acc_top5)
            else:
                output_str = "Epoch {:03d} | Train Loss: {:.4f}, Acc: {:.4f} | Test Loss: {:.4f}, Acc: {:.4f}".format(
                    epoch, tr_loss, tr_acc, te_loss, te_acc)

            # Add best accuracy milestone info
            if epoch <= 20:
                output_str += " | Best@20: {:.4f}".format(best_acc_20)
            elif epoch <= 50:
                output_str += " | Best@50: {:.4f}".format(best_acc_50)
            elif epoch <= 100:
                output_str += " | Best@100: {:.4f}".format(best_acc_100)

            output_str += " | Time: {:.1f}s".format(elapsed_time)
            print(output_str)

            # Plot and save (only if not train_only mode)
            if not args.train_only and train_losses:
                epochs_logged = list(range(1, len(train_losses) + 1))

                plt.figure()
                plt.plot(epochs_logged, train_losses, label='train')
                plt.plot(epochs_logged, test_losses, label='test')
                plt.title('Loss')
                plt.xlabel('Logged Epoch')
                plt.ylabel('Loss')
                plt.legend()
                plt.savefig(os.path.join(args.log_dir, 'loss.png'))
                plt.close()

                plt.figure()
                plt.plot(epochs_logged, train_accs, label='train')
                plt.plot(epochs_logged, test_accs, label='test')
                plt.title('Accuracy')
                plt.xlabel('Logged Epoch')
                plt.ylabel('Accuracy')
                plt.legend()
                plt.savefig(os.path.join(args.log_dir, 'acc.png'))
                plt.close()
        else:
            elapsed_time = time.time() - epoch_start_time
            # Log training-only metrics
            log_epoch_metrics(experiment_data, epoch, tr_loss, tr_acc, None, None, elapsed_time)
            print("Epoch {:03d} | Train Loss: {:.4f}, Acc: {:.4f} | (eval skipped)".format(
                epoch, tr_loss, tr_acc))

        # Save model every epoch
        torch.save(model.backbone_model.state_dict(), os.path.join(args.log_dir, 'model.pth'))

    # Create milestone data dictionary
    milestone_data = {
        'best_acc_20': best_acc_20,
        'best_acc_50': best_acc_50,
        'best_acc_100': best_acc_100,
        'best_acc_overall': best_acc_overall,
        'best_epoch_20': best_epoch_20,
        'best_epoch_50': best_epoch_50,
        'best_epoch_100': best_epoch_100,
        'best_epoch_overall': best_epoch_overall
    }

    # Save detailed metrics with milestone data
    save_detailed_metrics(experiment_data, args, milestone_data)

    total_time = time.time() - start_time
    print("\nTraining completed in {:.1f}s ({} epochs)".format(total_time, args.epochs))
    if train_accs:
        print("Final performance - Train: {:.4f}, Test: {:.4f}".format(train_accs[-1], test_accs[-1]))

    # Mark experiment as completed successfully
    print("EXPERIMENT_COMPLETED_SUCCESSFULLY")


if __name__ == '__main__':
    main()
