import json
import sys
import time
import traceback
from pathlib import Path
from typing import Dict
from tqdm.auto import tqdm
import numpy as np

import torch

from datasets import DataLoader, zero_one, z_score
from slayer_model.kernels import Hat, MorletWavelet, DecayingExponentialKernel, KERNEL_DICT
from slayer_model.network import SRMNetwork
from slayer_model.utils import torch_to_cupy
from slayer_model.utils.losses import mse
from utils.metrics import R2Accumulator, R2AccumulatorPaper, RSEAccumulator

torch.set_default_dtype(torch.float32)



dt = 1.
BATCH_SIZE_EVAL = 500
LOAD_N_BATCHES = int(sys.argv[2] if len(sys.argv) > 2 else 12)

def get_data_norm(norm_str):
    if norm_str=='z_score':
        return z_score
    elif norm_str=='zero_one':
        return zero_one
    else:
        raise ValueError(f"Data normaliser {norm_str} not found")


def run_experiment(config: Dict):
    results = {'config': config}
    results_dir = Path(__file__).resolve().parent / config['results_dir']
    name = config['name']
    for seed in config['data']['seeds']:
        print("-" * 25)
        print(f"Starting experiment: {name}/{seed}")
        print("-" * 25)
        results_seed = {}

        horizon = config['data']['horizon']
        dataset = config['data']['dataset']
        n_hidden = config['model']['n_hidden']
        sample_kernel = config['model']['sample_kernel']
        mu_u = config['model']['mu_u']
        xi = config['model']['xi']
        batch_size = config['train']['batch_size']
        initial_lr = config['train']['initial_lr']
        num_epochs = config['train']['num_epochs']
        patience = config['train']['patience']
        min_delta = config['train']['min_delta']
        final_lr = config['train']['final_lr']
        lambda_reg = config['train']['lambda_reg']
        size_init_batch = config['train']['size_init_batch']
        data_norm = get_data_norm(config['data']['data_norm'])

        data = DataLoader(dataset=dataset, prediction_horizon=horizon, shuffle=True, seed=seed, force_recreate=True,
                          normaliser=data_norm, reduced=False)
        torch.manual_seed(seed)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if not torch.cuda.is_available():
            raise RuntimeError("No GPU detected.")
        transfer_func = lambda arr: torch.tensor(arr).to(device=device)
        print("Starting initialization")
        x,y = data.get_first_batch(batch_size=size_init_batch, target="train", transfer_func=transfer_func)
        dim, T = data.get_dim_t()
        fit_idcs = np.arange(T - horizon, T, dtype=int)
    
        k_centers = torch.rand(size=(n_hidden,)) * (15 * dt)
        k_widths = torch.rand(size=(n_hidden,)) * (10 * dt) + 5 * dt
        q_widths = torch.rand(size=(n_hidden,)) * (10 * dt) + 5 * dt
        
        layers = [
            {
                'n_neurons': n_hidden,
                'n_inputs': dim,
                'phi_k': KERNEL_DICT[sample_kernel](),
                'phi_q': DecayingExponentialKernel(),
                'dt':dt,
                'k_centers': k_centers,
                'k_widths': k_widths,
                'q_widths': q_widths,
            },
            {
                'n_neurons': dim,
                'n_inputs': n_hidden,
                'phi_k': Hat(),
                'dt': dt,
            },
        ]
    
        net = SRMNetwork(layers)
        net = net.to(device)
        net.eval()

        # BEGIN SGD TRAINING

        torch.cuda.synchronize()
        start = time.perf_counter()

        with torch.no_grad():
            net.init_fluct_rg(x, mu_u, xi)

        del x
        del y

        torch.cuda.synchronize()
        end = time.perf_counter()
        print(f"Initialization time: {end-start:.3f}s")

        print("Finished initialization, starting training")

        # --- Optimizer + Cosine scheduler ---
        optimizer = torch.optim.Adam(net.parameters(), lr=initial_lr)
        # CosineAnnealingLR: lr_t = final_lr + 0.5*(lr0 - final_lr)*(1 + cos(pi * t / T_max))
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=final_lr)

        # --- Early stopping bookkeeping ---
        best_val_loss = float("inf")
        epochs_no_improve = 0
        best_epoch = 0
        best_state = None
        epoch = -1

        torch.cuda.synchronize()
        start = time.perf_counter()

        # Training Loop
        for epoch in range(num_epochs):
            net.train()
            train_loss = 0.0
            n_batches = data.get_n_batches(batch_size=batch_size, target="train")
            print(f"Epoch {epoch+1}/{num_epochs} — lr: {optimizer.param_groups[0]['lr']:.3e}")
    
            for x, y in tqdm(data.iterate(batch_size=batch_size, target="train", transfer_func=transfer_func, load_n_batches=LOAD_N_BATCHES), desc=f'Training ({epoch+1})', total=n_batches):
                # Forward pass (select fit indices)
                output = net(x)[:, :, fit_idcs]
    
                # Compute loss
                loss = mse(output, y) + lambda_reg * net.get_regularisation_weights(neuron_wise=False)
                train_loss += loss.item()
    
                # Backward + step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            # Average train loss
            avg_train_loss = train_loss / max(1, n_batches)
            print(f"Train Loss: {avg_train_loss:.6f}")
    
            # --- Validation ---
            net.eval()
            with torch.no_grad():
                val_loss_total = 0.0
                n_batches = data.get_n_batches(batch_size=BATCH_SIZE_EVAL, target="val")
                for x, y in data.iterate(batch_size=BATCH_SIZE_EVAL, target="val", transfer_func=transfer_func):
                    # Forward pass (select fit indices)
                    output = net(x)[:, :, fit_idcs]
                    # Compute loss
                    loss = mse(output, y)

                    val_loss_total += loss.item()

                # Average val loss
                val_loss = val_loss_total / max(1, n_batches)
                print(f"Val Loss: {val_loss:.6f}", flush=True)
    
            # --- Early stopping logic ---
            if val_loss + min_delta < best_val_loss:
                best_val_loss = val_loss
                epochs_no_improve = 0
                best_state = {k: v.cpu().clone() for k, v in net.state_dict().items()}
                best_epoch = epoch + 1
                print(f"Validation loss improved; saving best model (val_loss={best_val_loss:.6f}).")
            else:
                epochs_no_improve += 1
                print(f"No improvement for {epochs_no_improve} epoch(s).")
    
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs (no improvement for {patience} epochs).")
                break
    
            # Step scheduler at epoch end
            scheduler.step()

        torch.cuda.synchronize()
        end = time.perf_counter()

        # END SGD TRAINING
    
        # --- Restore best model weights (if any) ---
        if best_state is not None:
            net.load_state_dict(best_state)
            print(f"Best model restored (val_loss={best_val_loss:.6f}).")
        else:
            print("No improvement observed during training; final model kept as-is.")

        # Save model weights
        torch.save(net.state_dict(), results_dir / name / f"weights_{seed}.pt")

        # Validation evaluation
        r2a = R2Accumulator()
        rsea = RSEAccumulator()
        r2a_p = R2AccumulatorPaper(mean=data.get_mean("val"))

        data_val = data.iterate(batch_size=BATCH_SIZE_EVAL, target="val", transfer_func=transfer_func)

        for x_test, y_test in data_val:
            v_fit = net(x_test)[:, :, fit_idcs]
            y_test = torch_to_cupy(y_test)
            v_fit = torch_to_cupy(v_fit)
            r2a.accumulate(y_test, v_fit)
            rsea.accumulate(y_test, v_fit)
            r2a_p.accumulate(y_test, v_fit)

        r2_val = r2a.reduce()
        rse_val = rsea.reduce()
        r2_p_val = r2a_p.reduce()
        print(f"R2P val: {r2_p_val}")
        print(f"R2 val: {r2_val}")
        print(f"RSE val: {rse_val}")

        # Test evaluation
        r2a = R2Accumulator()
        rsea = RSEAccumulator()
        r2a_p = R2AccumulatorPaper(mean=data.get_mean("test"))
        
        data_test = data.iterate(batch_size=BATCH_SIZE_EVAL, target="test", transfer_func=transfer_func)
    
        for x_test, y_test in data_test:
            v_fit = net(x_test)[:, :, fit_idcs]
            y_test = torch_to_cupy(y_test)
            v_fit = torch_to_cupy(v_fit)
            r2a.accumulate(y_test, v_fit)
            rsea.accumulate(y_test, v_fit)
            r2a_p.accumulate(y_test, v_fit)

        r2_test = r2a.reduce()
        rse_test = rsea.reduce()
        r2_p_test = r2a_p.reduce()

        print(f"R2P test: {r2_p_test}")
        print(f"R2 test: {r2_test}")
        print(f"RSE test: {rse_test}")

        results_seed["epoch"] = epoch + 1
        results_seed["best_epoch"] = best_epoch
        results_seed["r2_test"] = float(r2_test.get())
        results_seed["rse_test"] =  float(rse_test.get())
        results_seed["r2_val"] = float(r2_val.get())
        results_seed["rse_val"] = float(rse_val.get())
        results_seed["r2p_val"] = float(r2_p_val.get())
        results_seed["r2p_test"] = float(r2_p_test.get())
        results_seed["time"] = end - start

        results[seed] = results_seed

    # Compute average over seeds
    averages = {}
    metrics = ["epoch", "best_epoch", "r2_test", "rse_test", "r2_val", "rse_val", "r2p_val", "r2p_test", "time"]
    for metric in metrics:
        values = [results[seed][metric] for seed in config['data']['seeds']]
        averages[metric] = sum(values) / len(values)
    results["averages"] = averages

    with open(results_dir / f"{name}.json", "w") as res_file:
            json.dump(results, res_file)


config_path = sys.argv[1]
with open(config_path, "r") as config_file:
    config = json.load(config_file)
try:
    run_experiment(config)
except Exception as e:
    traceback.print_exc()
