#!/usr/bin/env python3
"""Basic RandOpt for toy experiments"""

import argparse
from datetime import datetime
import json
import os
import random
from stat import FILE_ATTRIBUTE_INTEGRITY_STREAM

import numpy as np
import torch

import matplotlib.pyplot as plt
import copy

from toy_expts import datasets
from toy_expts import models
from toy_expts import pretrain
from toy_expts import posttrain
from toy_expts import eval as eval_module

def parse_args(argv=None):
    p = argparse.ArgumentParser(description="Toy Expt v3", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    
    # Datasets
    p.add_argument("--pretrain_dataset", type=str, default=None)
    p.add_argument("--posttrain_dataset", type=str, default=None)
    p.add_argument("--test_dataset", type=str, default=None)
    p.add_argument("--res_x", type=float, default=0.1)
    
    # Pretraining
    p.add_argument("--pretrain_bsz", type=int, default=256)
    p.add_argument("--posttrain_dataset_sz", type=int, default=1024)
    p.add_argument("--pretrain_iters", type=int, default=1000)
    p.add_argument("--pretraining_lr", type=float, default=0.001)
    p.add_argument("--test_bsz", type=int, default=256)
    
    # Model
    p.add_argument("--width", type=int, default=128)
    p.add_argument("--depth", type=int, default=5)
    p.add_argument("--ctx_sz", type=int, default=50)
    p.add_argument("--fut_sz", type=int, default=100)
    p.add_argument("--init_ctx_sz", type=float, default=10)
    p.add_argument("--pos_encoding_dim", type=int, default=0)
    
    # Post-training (RandOpt)
    #p.add_argument("--sigma_values", type=str, default="0.0001,0.0005,0.001,0.002,0.005,0.01")
    p.add_argument("--sigma", type=float, default=0.01)
    p.add_argument("--population_size", type=int, default=300)
    p.add_argument("--top_k", type=int, default=30)
    p.add_argument("--top_k_ratios", type=str, default=None)
    p.add_argument("--weighted_ensemble", action="store_true", default=False)
    p.add_argument("--ensemble_temperature", type=float, default=0.001)
    
    # Misc
    p.add_argument("--logging_dir", type=str, default="log")
    p.add_argument("--device", type=str, default="0")
    p.add_argument("--global_seed", type=int, default=42)
    p.add_argument("--skip_exp2", action="store_true")

    if argv is None:
        argv = []
    args = p.parse_args(argv)

    # Post-process args
    #args.sigma_list = [float(s) for s in args.sigma_values.split(",")]
    if args.top_k_ratios:
        ratios = [float(r) for r in args.top_k_ratios.split(",")]
        args.top_k_list = sorted(set(max(1, int(r * args.population_size)) for r in ratios), reverse=True)
        args.top_k = args.top_k_list[0]
    else:
        args.top_k_list = [args.top_k]

    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    return args


def set_seed(seed):
    """Set random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def setup_logging(args):
    """Setup logging directory and save args."""
    ts = datetime.now().strftime('%Y%m%d_%H%M%S')
    logging_dir = f"{args.logging_dir}"
    os.makedirs(logging_dir, exist_ok=True)
    
    # Save args (handle device serialization)
    args_dict = vars(args).copy()
    args_dict['device'] = str(args.device)
    with open(f"{logging_dir}/args.json", "w") as f:
        json.dump(args_dict, f, indent=4)
    
    return logging_dir


def create_model(args, dim_in):
    """Create and initialize a model."""
    model = models.Net(
        width=args.width, depth=args.depth, dim_in=dim_in,
        dim_out=1, device=args.device, pos_encoding_dim=args.pos_encoding_dim
    )
    model.init_weights()
    return model


def main(args):
    set_seed(args.global_seed)
    
    print(f"{'='*60}\nToy Expt 1\n{'='*60}")
    print(f"Population: {args.population_size} | Top-K: {args.top_k_list}")
    
    args.logging_dir = setup_logging(args)
    
    # Create and pretrain base model
    base_model = create_model(args, dim_in=args.ctx_sz + 1)
    if args.pretrain_iters > 0:
        base_model = pretrain.pretrain_base_model(base_model, args.pretrain_dataset, args)
    
    test_dataset = datasets.load_data(args.test_bsz, args.test_dataset, args)
    print(f"Base model loss on {args.pretrain_dataset}: {eval_module.eval_model(base_model, test_dataset, args):.4f}")
    
    # Plot test sample predictions for base model
    ctx_x, ctx_y, fut_x, fut_y = test_dataset
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    for i in range(3):

        #ii = np.random.randint(0, ctx_y[[i]].shape[1])
        #ctx_y[i,:ii] = 0 # variable length ctx by masking values

        y_preds = base_model.AR_rollout(ctx_y[[i]], args.fut_sz)

        mse_base = eval_module.compute_mse(y_preds[0].detach().cpu().numpy(), fut_y[i].unsqueeze(0).cpu().numpy())

        axs[i].scatter(ctx_x[i].cpu().numpy(), ctx_y[i].cpu().numpy(), label='ctx', c='b', s=10)
        axs[i].plot(fut_x[i].cpu().numpy(), fut_y[i].cpu().numpy(), label='gt fut', linestyle='-', color='k')
        axs[i].plot(fut_x[i].cpu().numpy(), y_preds[0].detach().cpu().numpy(), 'r-', linewidth=2.5, label=f'Base pred (MSE={mse_base:.4f})')
        axs[i].set_ylim([-2, 2])
        axs[i].legend()
    plt.show()
    
    # RandOpt perturbation experiment
    if not args.skip_exp2:
        print(f"\n{'='*60}\nRANDOPT EXPERIMENT\n{'='*60}")

        ctx_x, ctx_y, fut_x, fut_y = test_dataset
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        for i in range(3):
            axs[i].plot(fut_x[i].cpu().numpy(), fut_y[i].cpu().numpy(), label='gt fut', linestyle='-', color='k')

            axs[i].scatter(ctx_x[i].cpu().numpy(), ctx_y[i].cpu().numpy(), label='ctx', c='b', s=10)
        
            for j in range(10):
                perturbed = copy.deepcopy(base_model)
                perturbed.perturb_weights(j, args.sigma)

                y_preds = perturbed.AR_rollout(ctx_y[[i]], args.fut_sz)
                
                label = f'pred' if j == 0 else None
                axs[i].plot(fut_x[i].cpu().numpy(), y_preds[0].detach().cpu().numpy(), label=label, linestyle='-', linewidth=0.5, color='k', alpha=0.3)

            axs[i].set_ylim([-2, 2])
            axs[i].legend()
        plt.show()
        
    else:
        print(f"\n{'='*60}\nSkipping RandOpt plots (--skip_exp2)\n{'='*60}")
    
    # Ensemble evaluation
    print(f"\n{'='*60}\nENSEMBLE EVALUATION\n{'='*60}")
    print(f"Strategy: Select top-{args.top_k} models based on post-training train set MSE")
    
    # train ensemble
    #for sigma in args.sigma_list:
    sigma = args.sigma
    print(f"\n--- Ensemble with σ={sigma} ---")
    top_k_models = posttrain.RandOpt(base_model, args.posttrain_dataset, args, 
        n_samples=args.population_size, sigma=sigma, top_k=args.top_k
    )
    
    # Plot test sample predictions for top model
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    ctx_x, ctx_y, fut_x, fut_y = test_dataset
    for i in range(3):
        eval_module.plot_ensemble_prediction(base_model, top_k_models, sigma, ctx_x[i], ctx_y[i], fut_x[i], fut_y[i], args, axs[i], weighted=args.weighted_ensemble, temperature=args.ensemble_temperature)
    plt.show()
    plt.tight_layout()
    eval_module._save_fig(args, f"ensemble_comparison_sigma{args.sigma}")
    plt.show()
    plt.close()

    # combined plot
    fig, axs = plt.subplots(1, 3, figsize=(36,8))
    for i in range(3):

        base_preds = base_model.AR_rollout(ctx_y[[i]], args.fut_sz)[0]
        mse_base = eval_module.compute_mse(base_preds.detach().cpu().numpy(), fut_y[[i]].cpu().numpy())

        for j in range(100):
            perturbed = copy.deepcopy(base_model)
            perturbed.perturb_weights(j, args.sigma)

            y_preds = perturbed.AR_rollout(ctx_y[[i]], args.fut_sz)
                
            label = f'Perturbations' if j == 0 else None
            axs[i].plot(fut_x[i].cpu().numpy(), y_preds[0].detach().cpu().numpy(), label=label, linestyle='-', linewidth=1.0, color='k', alpha=0.2)

        top_k_preds_ = []
        for j, model in enumerate(top_k_models):
            perturbed = copy.deepcopy(base_model)
            perturbed.perturb_weights(model[0], sigma)
            
            y_preds = perturbed.AR_rollout(ctx_y[[i]], args.fut_sz)
            label = 'Top-K' if j == 0 else None
            axs[i].plot(fut_x[i].cpu().numpy(), y_preds[0].detach().cpu().numpy(), label=label, linestyle='-', linewidth=1.0, color='g', alpha=1.0)

            mse = eval_module.compute_mse(y_preds.detach().cpu().numpy(), fut_y[[i]].cpu().numpy())
            top_k_preds_.append((mse, y_preds[0].detach().cpu()))

        top_k_preds = np.array([m[1] for m in top_k_preds_])
        top_k_mses = np.array([m[0] for m in top_k_preds_])

        if args.weighted_ensemble and args.top_k > 1:
            neg_mses = -top_k_mses / args.ensemble_temperature
            weights = np.exp(neg_mses - np.max(neg_mses))
            weights /= weights.sum()
            ensemble_preds = np.sum(weights[:, None] * top_k_preds, axis=0)
        else:
            weights = np.ones(args.top_k) / args.top_k
            ensemble_preds = np.mean(top_k_preds, axis=0)

        mse_ens = eval_module.compute_mse(ensemble_preds, fut_y[[i]].cpu().numpy())

        axs[i].plot(ctx_x[i].cpu().numpy(), ctx_y[i].cpu().numpy(), 'b-', label='Context', linewidth=5.0)
        axs[i].plot(fut_x[i].cpu().numpy(), fut_y[i].cpu().numpy(), label='Ground truth', linewidth=5.0, linestyle='-', color='k')
        axs[i].plot(fut_x[i].cpu().numpy(), base_preds.detach().cpu().numpy(), 'r-', linewidth=5.0, label=f'Base model (MSE={mse_base:.2f})')
        axs[i].plot(fut_x[i].cpu().numpy(), ensemble_preds, 'g-', linewidth=5.0, label=f'Ensemble (MSE={mse_ens:.2f})', zorder=13)
        axs[i].set_ylim([-2, 2])
        axs[i].set_xticks([])
        axs[i].set_yticks([])
        axs[i].legend(ncol=2, fontsize=20, loc="upper left")
    plt.show()

    eval_module._save_fig(args, f"{args.pretrain_dataset}_{args.posttrain_dataset}_{args.test_dataset}")
    
    # Print summary
    #_print_summary(args, ensemble_results)


def _print_summary(args, results):
    """Print ensemble evaluation summary."""
    print(f"\n{'='*60}")

    print(f"(n={args.population_size}, top_k={args.top_k})\n{'='*60}")
    
    print(f"{'σ':<10} {'Pretrained':<14} {'Ensemble':<14} {'Improvement':<12} {'Better/Total':<12}")
    print("-" * 62)
    for sigma, res in results.items():
        imp = res['improvement_selected_pct']
        imp_str = f"+{imp:.1f}%" if imp > 0 else f"{imp:.1f}%"
        print(f"{sigma:<10} {res['mse_pretrained']:<14.6f} {res['mse_selected']:<14.6f} {imp_str:<12} {res['n_better_than_base']}/{args.population_size:<11}")
    
    best_sigma = max(results, key=lambda s: results[s]['improvement_selected_pct'])
    best_imp = results[best_sigma]['improvement_selected_pct']
    
    if best_imp > 0:
        print(f"\n★ Best σ={best_sigma}: +{best_imp:.2f}% improvement")
    else:
        print(f"\n⚠ No improvement. Best σ={best_sigma}: {best_imp:.2f}%")


if __name__ == "__main__":
    args = parse_args()
    main(args)
