import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import pickle
from sklearn.metrics import mean_squared_error
from data_utils import H36MDataset
import config as cfg

def load_model(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch']
    print(f"Loaded model from epoch {epoch}")
    return model, epoch

def denormalize_sequence(sequence, scaler_path):
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)
    if isinstance(sequence, torch.Tensor):
        sequence = sequence.detach().cpu().numpy()
    original_shape = sequence.shape
    flattened = sequence.reshape(-1, original_shape[-1] if len(original_shape) > 1 else 2 * cfg.NUM_JOINTS)
    denormalized = scaler.inverse_transform(flattened)
    return denormalized.reshape(original_shape)

def visualize_skeleton(joints_2d, frame_idx=0, ax=None, color='blue', alpha=1.0, label=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    connections = [
        (0, 1), (1, 2), (2, 3),
        (0, 4), (4, 5), (5, 6),
        (0, 7), (7, 8), (8, 9),
        (0, 10), (10, 11), (11, 12),
        (0, 13), (13, 14), (14, 15),
    ]
    joints = joints_2d[frame_idx] if len(joints_2d.shape) == 3 else joints_2d
    joints = joints.reshape(-1, 2)
    for start, end in connections:
        ax.plot([joints[start, 0], joints[end, 0]],
                [joints[start, 1], joints[end, 1]],
                color=color, alpha=alpha, marker='o', linewidth=2)
    if label:
        ax.plot([], [], color=color, label=label)
    ax.set_aspect('equal')
    ax.invert_yaxis()
    return ax

def visualize_prediction(input_seq, true_seq, pred_seq, sample_idx=0, save_path=None):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    visualize_skeleton(input_seq[sample_idx].reshape(-1, cfg.NUM_JOINTS, 2),
                       frame_idx=-1, ax=axes[0], color='blue', label='Input')
    axes[0].set_title('Last Input Frame')
    visualize_skeleton(true_seq[sample_idx].reshape(-1, cfg.NUM_JOINTS, 2),
                       frame_idx=-1, ax=axes[1], color='green', label='Ground Truth')
    axes[1].set_title('Last Ground Truth Frame')
    visualize_skeleton(pred_seq[sample_idx].reshape(-1, cfg.NUM_JOINTS, 2),
                       frame_idx=-1, ax=axes[2], color='red', label='Prediction')
    axes[2].set_title('Last Predicted Frame')
    for ax in axes:
        ax.legend()
        ax.grid(True)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        print(f"Visualization saved to {save_path}")
    plt.show()

def evaluate_model(model, test_loader, device, results_dir, scaler_path):
    model.eval()
    all_inputs = []
    all_predictions = []
    all_targets = []
    mse_values = []
    with torch.no_grad():
        for input_seq, target_seq in test_loader:
            input_seq, target_seq = input_seq.to(device), target_seq.to(device)
            recon_batch, _, _ = model(target_seq, input_seq)
            mse = torch.mean((recon_batch - target_seq) ** 2).item()
            mse_values.append(mse)
            all_inputs.append(input_seq.cpu().numpy())
            all_predictions.append(recon_batch.cpu().numpy())
            all_targets.append(target_seq.cpu().numpy())
    avg_mse = np.mean(mse_values)
    print(f"Average MSE: {avg_mse:.4f}")
    os.makedirs(os.path.join(results_dir, 'visualizations'), exist_ok=True)
    all_inputs = np.vstack(all_inputs)
    all_predictions = np.vstack(all_predictions)
    all_targets = np.vstack(all_targets)
    all_inputs_denorm = denormalize_sequence(all_inputs, scaler_path)
    all_predictions_denorm = denormalize_sequence(all_predictions, scaler_path)
    all_targets_denorm = denormalize_sequence(all_targets, scaler_path)
    for i in range(min(5, len(all_inputs))):
        save_path = os.path.join(results_dir, 'visualizations', f'sample_{i}.png')
        visualize_prediction(all_inputs_denorm, all_targets_denorm, all_predictions_denorm,
                             sample_idx=i, save_path=save_path)
    with open(os.path.join(results_dir, 'evaluation_metrics.txt'), 'w') as f:
        f.write(f"Average MSE: {avg_mse:.4f}\n")
    return avg_mse, all_inputs_denorm, all_targets_denorm, all_predictions_denorm

def generate_samples(model, input_seq, num_samples=5, device='cpu', scaler_path=None):
    model.eval()
    if not isinstance(input_seq, torch.Tensor):
        input_seq = torch.FloatTensor(input_seq)
    input_seq = input_seq.to(device)
    if len(input_seq.shape) == 1:
        input_seq = input_seq.unsqueeze(0)
    samples = []
    with torch.no_grad():
        for _ in range(num_samples):
            z = torch.randn(input_seq.size(0), cfg.LATENT_DIM).to(device)
            sample = model.decode(z, input_seq)
            samples.append(sample.cpu().numpy())
    samples = np.array(samples)
    if scaler_path:
        denormalized_samples = []
        for sample in samples:
            denormalized_sample = denormalize_sequence(sample, scaler_path)
            denormalized_samples.append(denormalized_sample)
        samples = np.array(denormalized_samples)
    return samples
