import json
import sys
import traceback
from pathlib import Path
from typing import Dict
from tqdm.auto import tqdm
import numpy as np

import torch

from datasets import DataLoader, zero_one
from slayer_model.kernels import Hat, MorletWavelet, DecayingExponentialKernel, KERNEL_DICT
from slayer_model.network import SRMNetwork
from slayer_model.utils import torch_to_cupy
from slayer_model.utils.losses import mse
from utils.metrics import R2Accumulator, R2AccumulatorPaper, RSEAccumulator

torch.set_default_dtype(torch.float32)

results_dir = Path(__file__).resolve().parent / "results"
results_dir_sswim = Path(__file__).resolve().parent.parent / "sswim" / "results"

dt = 1.
BATCH_SIZE_EVAL = 500
LOAD_N_BATCHES = int(sys.argv[2] if len(sys.argv) > 2 else 20)


def run_experiment(config: Dict):
    results = {'config': config}
    name = config['name']
    for seed in config['data']['seeds']:
        print("-" * 25)
        print(f"Starting experiment: {name}/{seed}")
        print("-" * 25)
        results_seed = {}

        horizon = config['data']['horizon']
        dataset = config['data']['dataset']
        n_hidden = config['model']['n_hidden']
        sample_kernel = config['model']['sample_kernel']
        batch_size = config['train']['batch_size']
        initial_lr = config['train']['initial_lr']
        num_epochs = config['train']['num_epochs']
        patience = config['train']['patience']
        min_delta = config['train']['min_delta']
        final_lr = config['train']['final_lr']

        data = DataLoader(dataset=dataset, prediction_horizon=horizon, shuffle=True, seed=seed, force_recreate=True, normaliser=zero_one)
        torch.manual_seed(seed)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if not torch.cuda.is_available():
            print("WARNING: No GPU detected.")
        transfer_func = lambda arr: torch.tensor(arr).to(device=device)
        dim, T = data.get_dim_t()
        fit_idcs = np.arange(T - horizon, T, dtype=int)
        
        layers = [
            {
                'n_neurons': n_hidden,
                'n_inputs': dim,
                'phi_k': KERNEL_DICT[sample_kernel](),
                'phi_q': DecayingExponentialKernel(),
                'dt':dt,
            },
            {
                'n_neurons': dim,
                'n_inputs': n_hidden,
                'phi_k': Hat(),
                'dt': dt,
            },
        ]
    
        net = SRMNetwork(layers)

        print("Loading weights")
        net.load_from_sswim(sample_layer=results_dir_sswim / name / f"sample_{seed}.npz", fit_layer=results_dir_sswim / name / f"fit_{seed}.npz")
        regs = torch.from_numpy(np.load(results_dir_sswim / name / f"best_regs_{seed}.npz")["best_regs"]).to(device)

        net = net.to(device)
        net.eval()


        print("Finished initialization, starting training")
    
        # --- Optimizer + Cosine scheduler ---
        optimizer = torch.optim.Adam(net.parameters(), lr=initial_lr)
        # CosineAnnealingLR: lr_t = final_lr + 0.5*(lr0 - final_lr)*(1 + cos(pi * t / T_max))
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=final_lr)
    
        # --- Early stopping bookkeeping ---
        best_val_loss = float("inf")
        epochs_no_improve = 0
        best_epoch = 0
        best_state = None
        epoch = -1

        # Training Loop
        for epoch in range(num_epochs):
            net.train()
            train_loss = 0.0
            n_batches = data.get_n_batches(batch_size=batch_size, target="train")
            print(f"Epoch {epoch+1}/{num_epochs} — lr: {optimizer.param_groups[0]['lr']:.3e}")
    
            for x, y in tqdm(data.iterate(batch_size=batch_size, target="train", transfer_func=transfer_func, load_n_batches=LOAD_N_BATCHES), desc=f'Training ({epoch+1})', total=n_batches):
                # Forward pass (select fit indices)
                output = net(x)[:, :, fit_idcs]
    
                # Compute loss
                loss = mse(output, y) + (regs * net.get_regularisation_weights(neuron_wise=True)).sum()
                train_loss += loss.item()
    
                # Backward + step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            # Average train loss
            avg_train_loss = train_loss / max(1, n_batches)
            print(f"Train Loss: {avg_train_loss:.6f}")
    
            # --- Validation ---
            net.eval()
            with torch.no_grad():
                val_loss_total = 0.0
                n_batches = data.get_n_batches(batch_size=BATCH_SIZE_EVAL, target="val")
                for x, y in data.iterate(batch_size=BATCH_SIZE_EVAL, target="val", transfer_func=transfer_func):
                    # Forward pass (select fit indices)
                    output = net(x)[:, :, fit_idcs]
                    # Compute loss
                    loss = mse(output, y)

                    val_loss_total += loss.item()

                # Average val loss
                val_loss = val_loss_total / max(1, n_batches)
                print(f"Val Loss: {val_loss:.6f}", flush=True)
    
            # --- Early stopping logic ---
            if val_loss + min_delta < best_val_loss:
                best_val_loss = val_loss
                epochs_no_improve = 0
                best_state = {k: v.cpu().clone() for k, v in net.state_dict().items()}
                best_epoch = epoch + 1
                print(f"Validation loss improved; saving best model (val_loss={best_val_loss:.6f}).")
            else:
                epochs_no_improve += 1
                print(f"No improvement for {epochs_no_improve} epoch(s).")
    
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs (no improvement for {patience} epochs).")
                break
    
            # Step scheduler at epoch end
            scheduler.step()
    
        # --- Restore best model weights (if any) ---
        if best_state is not None:
            net.load_state_dict(best_state)
            print(f"Best model restored (val_loss={best_val_loss:.6f}).")
        else:
            print("No improvement observed during training; final model kept as-is.")

        # Save model weights
        torch.save(net.state_dict(), results_dir / name / f"weights_{seed}.pt")

        # Validation evaluation
        r2a = R2Accumulator()
        rsea = RSEAccumulator()
        r2a_p = R2AccumulatorPaper(mean=data.get_mean("val"))

        data_val = data.iterate(batch_size=BATCH_SIZE_EVAL, target="val", transfer_func=transfer_func)

        for x_test, y_test in data_val:
            v_fit = net(x_test)[:, :, fit_idcs]
            y_test = torch_to_cupy(y_test)
            v_fit = torch_to_cupy(v_fit)
            r2a.accumulate(y_test, v_fit)
            rsea.accumulate(y_test, v_fit)
            r2a_p.accumulate(y_test, v_fit)

        r2_val = r2a.reduce()
        rse_val = rsea.reduce()
        r2_p_val = r2a_p.reduce()
        print(f"R2P val: {r2_p_val}")
        print(f"R2 val: {r2_val}")
        print(f"RSE val: {rse_val}")

        # Test evaluation
        r2a = R2Accumulator()
        rsea = RSEAccumulator()
        r2a_p = R2AccumulatorPaper(mean=data.get_mean("test"))
        
        data_test = data.iterate(batch_size=BATCH_SIZE_EVAL, target="test", transfer_func=transfer_func)
    
        for x_test, y_test in data_test:
            v_fit = net(x_test)[:, :, fit_idcs]
            y_test = torch_to_cupy(y_test)
            v_fit = torch_to_cupy(v_fit)
            r2a.accumulate(y_test, v_fit)
            rsea.accumulate(y_test, v_fit)
            r2a_p.accumulate(y_test, v_fit)

        r2_test = r2a.reduce()
        rse_test = rsea.reduce()
        r2_p_test = r2a_p.reduce()
        print(f"R2P test: {r2_p_test}")
        print(f"R2 test: {r2_test}")
        print(f"RSE test: {rse_test}")

        results_seed["epoch"] = epoch + 1
        results_seed["best_epoch"] = best_epoch
        results_seed["r2_test"] = float(r2_test.get())
        results_seed["rse_test"] =  float(rse_test.get())
        results_seed["r2_val"] = float(r2_val.get())
        results_seed["rse_val"] = float(rse_val.get())
        results_seed["r2p_val"] = float(r2_p_val.get())
        results_seed["r2p_test"] = float(r2_p_test.get())

        results[seed] = results_seed

    # Compute average over seeds
    averages = {}
    metrics = ["epoch", "best_epoch", "r2_test", "rse_test", "r2_val", "rse_val", "r2p_val", "r2p_test"]
    for metric in metrics:
        values = [results[seed][metric] for seed in config['data']['seeds']]
        averages[metric] = sum(values) / len(values)
    results["averages"] = averages

    with open(results_dir / f"{name}.json", "w") as res_file:
            json.dump(results, res_file)


config_path = sys.argv[1]
with open(config_path, "r") as config_file:
    config = json.load(config_file)
try:
    run_experiment(config)
except Exception as e:
    traceback.print_exc()
