#!/usr/bin/env python3
"""
Evaluation and plotting code for basic RandOpt for toy experiments
"""

import copy
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from toy_expts_v4.models import positional_encoding
from toy_expts_v4 import datasets

X_RANGE = np.linspace(-10, 10, 100)


def _save_fig(args, name, timestamp=None):
    """Save figure to logging directory."""
    ts = timestamp or datetime.now().strftime('%Y%m%d_%H%M%S')
    plt.savefig(f"{args.logging_dir}/{name}_{ts}.png", dpi=150)
    plt.savefig(f"{args.logging_dir}/{name}_{ts}.pdf")


def _set_ylim(ax, y_true, margin=0.2):
    """Set y-axis limits with margin."""
    y_min, y_max = y_true.min(), y_true.max()
    y_margin = (y_max - y_min) * margin
    ax.set_ylim([y_min - y_margin, y_max + y_margin])


def eval_model(model, dataset, args):
    """Evaluate model on a dataset."""
    ctx_x, ctx_y, fut_x, fut_y = dataset
        
    model.eval()
    with torch.no_grad():
        loss = model.compute_loss(ctx_y, fut_y[:,[0]])
    
    return loss


def compute_mse(y_pred, y_true):
    """Compute mean squared error."""
    return ((y_pred - y_true) ** 2).mean()


def plot_ensemble_prediction(base_model, top_k_models, sigma, ctx_x, ctx_y, fut_x, fut_y, args, ax, weighted=True, temperature=0.001):
    """Plot ensemble prediction."""

    top_k_preds_ = []
    for i, model in enumerate(top_k_models):
        perturbed = copy.deepcopy(base_model)
        perturbed.perturb_weights(model[0], sigma)
        
        y_preds = perturbed.AR_rollout(ctx_y.unsqueeze(0), args.fut_sz)
        label = 'top k preds' if i == 0 else None
        ax.plot(fut_x.cpu().numpy(), y_preds[0].detach().cpu().numpy(), label=label, linestyle='-', linewidth=0.5, color='g', alpha=0.3)

        mse = compute_mse(y_preds.detach().cpu().numpy(), fut_y.unsqueeze(0).cpu().numpy())
        top_k_preds_.append((mse, y_preds[0].detach().cpu()))

    top_k_preds = np.array([m[1] for m in top_k_preds_])
    top_k_mses = np.array([m[0] for m in top_k_preds_])

    if weighted and args.top_k > 1:
        neg_mses = -top_k_mses / temperature
        weights = np.exp(neg_mses - np.max(neg_mses))
        weights /= weights.sum()
        ensemble_preds = np.sum(weights[:, None] * top_k_preds, axis=0)
    else:
        weights = np.ones(args.top_k) / args.top_k
        ensemble_preds = np.mean(top_k_preds, axis=0)

    mse_ens = compute_mse(ensemble_preds, fut_y.unsqueeze(0).cpu().numpy())

    ax.scatter(ctx_x.cpu().numpy(), ctx_y.cpu().numpy(), label='ctx', c='b', s=10)
    ax.plot(fut_x.cpu().numpy(), fut_y.cpu().numpy(), label='gt fut', linestyle='-', color='k')
    ax.plot(fut_x.cpu().numpy(), ensemble_preds, 'g-', linewidth=2.5, label=f'Ensemble (MSE={mse_ens:.4f})', zorder=13)
    ax.set_ylim([-2, 2])
    ax.legend()


def _print_ensemble_results(args, sigma, n_samples, top_k, n_test, mse_pre, mse_ens, imp):
    """Print ensemble comparison results."""
    print(f"\n{'='*60}")
    header = "ENSEMBLE RESULTS"
    print(f"{header}, σ={sigma}, n={n_samples}, top_k={top_k}):")
    #print(f"  Selection: {args.ctx_sz} context points | Test: {n_test} non-context points")
    print(f"  Base MSE: {mse_pre:.6f}")
    
    sign = '+' if imp > 0 else ''
    print(f"\n  Test MSE: Pretrained={mse_pre:.6f}, Ensemble={mse_ens:.6f} ({sign}{imp:.2f}%)")
    
    status = "✓" if imp > 0 else "✗"
    result = "BETTER" if imp > 0 else "WORSE"
    print(f"  {status} Ensemble is {abs(imp):.2f}% {result} than pretrained")
    print(f"{'='*60}")