import os
import torch
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader
from model import EnhancedCVAE_NoResidualNoSE
from data_utils import H36MDataset
from train import train, validate, save_checkpoint
import config as cfg
import matplotlib.pyplot as plt
from loss_utils import loss_function

def plot_training_curves(train_losses, val_losses):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(np.log(train_losses), label='Log Train Loss')
    plt.plot(np.log(val_losses), label='Log Val Loss')
    plt.title('Log-scale Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Log Loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(cfg.RESULTS_DIR, 'loss_curves.png'))
    plt.close()

def main():
    torch.manual_seed(cfg.SEED)
    np.random.seed(cfg.SEED)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    print('Loading data...')
    train_dataset = H36MDataset(data_root=cfg.DATA_ROOT, input_frames=cfg.INPUT_FRAMES, output_frames=cfg.OUTPUT_FRAMES, train=True)
    test_dataset = H36MDataset(data_root=cfg.DATA_ROOT, input_frames=cfg.INPUT_FRAMES, output_frames=cfg.OUTPUT_FRAMES, train=False)
    train_loader = DataLoader(train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=cfg.BATCH_SIZE, shuffle=False)
    print(f'Data loaded. Train batches: {len(train_loader)}, Test batches: {len(test_loader)}')
    print('Creating model...')
    model = EnhancedCVAE_NoResidualNoSE(input_dim=cfg.TARGET_DIM, cond_dim=cfg.CONDITION_DIM, latent_dim=cfg.LATENT_DIM, hidden_dim=cfg.HIDDEN_DIM).to(device)
    param_count = sum(p.numel() for p in model.parameters())
    print(f'Model created with {param_count:,} parameters')
    optimizer = optim.Adam(model.parameters(), lr=cfg.LEARNING_RATE, weight_decay=cfg.WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-5)
    os.makedirs(cfg.CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(cfg.RESULTS_DIR, exist_ok=True)
    print('Starting training...')
    train_losses, val_losses = train(model=model, train_loader=train_loader, test_loader=test_loader, optimizer=optimizer, scheduler=scheduler,
                                     device=device, epochs=cfg.EPOCHS, kl_weight=cfg.KL_WEIGHT, save_path=cfg.CHECKPOINT_DIR,
                                     save_interval=cfg.SAVE_INTERVAL)
    final_path = os.path.join(cfg.CHECKPOINT_DIR, 'final_model.pth')
    torch.save({'epoch': cfg.EPOCHS, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_losses[-1], 'val_loss': val_losses[-1]}, final_path)
    print(f'Training completed. Final model saved to {final_path}')
    print('Generating test results...')
    model.eval()
    with torch.no_grad():
        test_inputs, test_targets = next(iter(test_loader))
        test_inputs, test_targets = test_inputs.to(device), test_targets.to(device)
        reconstructions, _, _ = model(test_targets, test_inputs)
        save_results(test_inputs, test_targets, reconstructions)
        plot_training_curves(train_losses, val_losses)

def save_results(inputs, targets, reconstructions, num_samples=50):
    import numpy as np
    import pickle
    os.makedirs(cfg.RESULTS_DIR, exist_ok=True)
    try:
        scaler_path = os.path.join(cfg.RESULTS_DIR, 'scaler.pkl')
        with open(scaler_path, 'rb') as f:
            scaler = pickle.load(f)
        def denormalize(data):
            original_shape = data.shape
            flattened = data.reshape(-1, original_shape[-1])
            denormalized = scaler.inverse_transform(flattened)
            return denormalized.reshape(original_shape)
        inputs = denormalize(inputs.cpu().numpy())
        targets = denormalize(targets.cpu().numpy())
        reconstructions = denormalize(reconstructions.cpu().numpy())
    except Exception:
        inputs = inputs.cpu().numpy()
        targets = targets.cpu().numpy()
        reconstructions = reconstructions.cpu().numpy()
    for i in range(min(num_samples, inputs.shape[0])):
        np.savez(os.path.join(cfg.RESULTS_DIR, f'sample_{i}.npz'), input_sequence=inputs[i], target_sequence=targets[i], reconstruction=reconstructions[i])
    print(f'Saved {num_samples} sample results to {cfg.RESULTS_DIR}')

def evaluate_model(model, test_loader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            reconstructions, mu, log_var = model(targets, inputs)
            loss, _, _, _ = loss_function(reconstructions, targets, mu, log_var, cfg.KL_WEIGHT)
            total_loss += loss.item()
    avg_loss = total_loss / len(test_loader)
    print(f'Test Loss: {avg_loss:.4f}')
    return avg_loss

def generate_samples(model, conditions, device, num_samples=1):
    model.eval()
    results = []
    with torch.no_grad():
        for _ in range(num_samples):
            batch_size = conditions.size(0)
            z = torch.randn(batch_size, cfg.LATENT_DIM).to(device)
            generated = model.decode(z, conditions)
            results.append(generated)
    return torch.stack(results, dim=1)

if __name__ == "__main__":
    main()
