import torch
import os
import torch.optim as optim
import numpy as np
import config as cfg
from model import EnhancedCVAE_NoResidualNoSE
from train import train
from data_utils import get_dataloaders
import matplotlib.pyplot as plt

def plot_training_curves(train_losses, val_losses):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(np.log(train_losses), label='Log Train Loss')
    plt.plot(np.log(val_losses), label='Log Val Loss')
    plt.title('Log-scale Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Log Loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(cfg.RESULTS_DIR, 'loss_curves.png'))
    plt.close()

def main():
    # Set random seeds
    torch.manual_seed(cfg.SEED)
    np.random.seed(cfg.SEED)

    # Select device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # Load dataset
    print('Loading data...')
    input_frames = 10
    output_frames = 10
    batch_size = 128
    train_loader, test_loader = get_dataloaders(
        data_root=cfg.PENN_ACTION_PATH,
        batch_size=batch_size,
        input_frames=input_frames,
        output_frames=output_frames
    )
    print(f"Data loaded. Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

    # Initialize model
    print("Creating model...")
    input_dim = 13 * 2 * input_frames  # 13 joints, 2 coordinates per joint
    output_dim = 13 * 2 * output_frames
    cond_dim = input_dim  # Condition is the input sequence itself
    model = EnhancedCVAE_NoResidualNoSE(
        input_dim=input_dim,
        cond_dim=cond_dim,
        latent_dim=64,
        hidden_dim=256
    ).to(device)

    # Print model parameter count
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model created with {total_params:,} parameters")

    # Define optimizer and learning rate scheduler
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,
        T_mult=2,
        eta_min=1e-5
    )

    # Start training
    print("Starting training...")
    train_losses, val_losses = train(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        epochs=cfg.EPOCHS,
        kl_weight=cfg.KL_WEIGHT,
        save_path=cfg.CHECKPOINT_DIR,
        save_interval=cfg.SAVE_INTERVAL
    )

    # Save final model
    final_path = os.path.join(cfg.RESULTS_DIR, 'final_model.pth')
    torch.save({
        'epoch': cfg.EPOCHS,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_losses[-1],
        'val_loss': val_losses[-1]
    }, final_path)
    print(f'Training completed. Final model saved to {final_path}')

    # Generate test results
    print('Generating test results...')
    model.eval()
    with torch.no_grad():
        test_inputs, test_targets = next(iter(test_loader))
        test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)
        reconstructions, _, _ = model(test_targets, test_inputs)
        save_results(test_inputs, test_targets, reconstructions)
        plot_training_curves(train_losses, val_losses)

def save_results(inputs, targets, reconstructions, num_samples=50):
    import numpy as np
    import pickle
    import os
    os.makedirs(cfg.RESULTS_DIR, exist_ok=True)
    try:
        scaler_path = os.path.join(cfg.RESULTS_DIR, 'scaler.pkl')
        with open(scaler_path, 'rb') as f:
            scaler = pickle.load(f)
        def denormalize(data):
            original_shape = data.shape
            flattened = data.reshape(-1, original_shape[-1])
            denormalized = scaler.inverse_transform(flattened)
            return denormalized.reshape(original_shape)
        inputs = denormalize(inputs.cpu().numpy())
        targets = denormalize(targets.cpu().numpy())
        reconstructions = denormalize(reconstructions.cpu().numpy())
    except Exception as e:
        inputs = inputs.cpu().numpy()
        targets = targets.cpu().numpy()
        reconstructions = reconstructions.cpu().numpy()
    for i in range(min(num_samples, inputs.shape[0])):
        np.savez(
            os.path.join(cfg.RESULTS_DIR, f'sample_{i}.npz'),
            input_sequence=inputs[i],
            target_sequence=targets[i],
            reconstruction=reconstructions[i]
        )
    print(f'Saved {num_samples} sample results to {cfg.RESULTS_DIR}')

if __name__ == '__main__':
    main()
