#!/usr/bin/env python3
"""
Post-training code for RandOpt toy experiments
"""

import copy
import time
import numpy as np
import torch
from toy_expts_v4 import datasets
from toy_expts_v4 import eval as eval_module
from toy_expts_v4.models import positional_encoding


def RandOpt(base_model, posttrain_dataset, args, n_samples=10, sigma=0.01, top_k=5, weighted=True, temperature=1.0):
    """Post-train model with RandOpt on the post-training train set."""
    print(f"\n{'='*60}\nPOST-TRAINING MODEL\n{'='*60}")
    
    t0 = time.time()

    dataset = datasets.load_data(args.posttrain_dataset_sz, posttrain_dataset, args)

    # Sample N perturbed models and evaluate on the post-train dataset
    ctx_x, ctx_y, fut_x, fut_y = dataset
    model_scores = []  # (seed, mse)
    
    for seed in range(n_samples):
        perturbed = copy.deepcopy(base_model)
        perturbed.perturb_weights(seed, sigma)
        perturbed.eval()

        with torch.no_grad():
            y_preds = perturbed.AR_rollout(ctx_y, args.fut_sz)
            mse = eval_module.compute_mse(y_preds.cpu().numpy(), fut_y.cpu().numpy())

        model_scores.append((seed, mse))
        del perturbed  # Free memory immediately

    # Select top-K based on MSE
    model_scores.sort(key=lambda x: x[1])
    top_k_models = model_scores[:top_k]

    print(f"Completed in {time.time() - t0:.2f}s")

    return top_k_models
