import json
import sys
import traceback
from pathlib import Path
from typing import Dict
import time

from BACKEND import mem_summary, cp, to_gpu, to_cpu, np
from datasets import DataLoader, zero_one, z_score
from model.kernels import Hat, DecayingExponentialKernel, KERNEL_DICT
from model.layer import SRMLayerFixed
from model.utils import MeanStd, Fluct
from model.utils.enums import ReturnVs
from utils.eval import eval_set
from utils.metrics import sample_pairs
from utils.metrics.calibrate_sampling import calibrate_sampling_entropy
from utils.metrics.sampling_metrics import FourierMag, CosineDistance, PairwiseL2, FourierAngle, BandedFourierL2

dt = 1.
num_spikes = 100
TQDM_PROGRESS = False


def get_normaliser(norm):
    if norm[0] == 'F':
        return Fluct(c=norm[1])
    else:
        return MeanStd(mu=norm[1], sig=norm[2])

def get_metric(metrics, x, y, OL, num_samples, min_norm):
    metrics_bank = [PairwiseL2(), CosineDistance(), FourierMag(), FourierAngle(), BandedFourierL2(cp.arange(OL / 2 + 1)),
               BandedFourierL2(OL - cp.arange(OL / 2))]
    def find(metric_str):
        for metric in metrics_bank:
            if metric_str == metric.__str__():
                return metric
        if metric_str == 'BandedFourierHigh':
            return BandedFourierL2(OL - cp.arange(OL / 2))
        elif metric_str == 'BandedFourierLow':
            return BandedFourierL2(cp.arange(OL / 2 + 1))
        else:
            raise ValueError(f"Metric {metric_str} not found")

    if metrics == 'best':
        return calibrate_sampling_entropy(X=x, Y=y, metrics=metrics_bank, num_samples=num_samples, min_norm=min_norm)
    else:
        d_in_str, d_out_str = metrics
        d_in = find(d_in_str)
        d_out = find(d_out_str)
        return d_in, d_out


def get_data_norm(norm_str):
    if norm_str=='z_score':
        return z_score
    elif norm_str=='zero_one':
        return zero_one
    else:
        raise ValueError(f"Data normaliser {norm_str} not found")


def run_trial(name, results_dir, seed, config):
    results_seed = {}

    # Magic call to ensure CuPy is initialised before any CuPy code runs
    cp.linalg.svd(cp.eye(4))
    _ = cp.linalg.eigh(cp.eye(1))

    H = config['data']['horizon']
    dataset = config['data']['dataset']
    n_hidden = config['model']['n_hidden']
    sample_kernel = config['model']['sample_kernel']
    normaliser = get_normaliser(config["model"]["normaliser"])
    weight_objective = config['model']['objective']
    batch_size = config['train']['batch_size']
    size_init_batch = config['train']['size_init_batch']
    reg = config['train']['reg'] if 'reg' in config['train'] else 0.
    metrics = config['train']['metrics']
    data_norm = get_data_norm(config['data']['data_norm'])

    if reg > 0:
        n_reg_search = 0
    else:
        n_reg_search = 32

    data = DataLoader(dataset=dataset, prediction_horizon=H, shuffle=True, seed=seed, force_recreate=True,
                      normaliser=data_norm, reduced=True)
    x, y = data.get_first_batch(batch_size=size_init_batch, target="train", transfer_func=to_cpu)
    # Variables setup
    N, dim, _ = x.shape
    OL, T = data.get_ol_t()
    fit_idcs = cp.arange(OL, T, dtype=cp.int32)
    tmax = T * dt
    ts = cp.arange(0, tmax, dt, dtype=cp.float32)

    x = to_gpu(x)
    y = to_gpu(y)

    sig_min = 5
    sig_max = 50
    n_distinct = 10

    # Network setup
    k_sig = (cp.arange(n_hidden) % n_distinct) / (n_distinct - 1) * (sig_max - sig_min) + sig_min
    k_tau = cp.linspace(0, OL / 2 if OL > H else H + OL / 2, n_hidden, dtype=cp.float32)
    q_sig = cp.full_like(k_sig, fill_value=sig_min)

    sample_layer = SRMLayerFixed(
        n_neurons=n_hidden,
        n_inputs=dim,
        in_weights=None,
        phi_k=KERNEL_DICT[sample_kernel](),
        phi_q=DecayingExponentialKernel(),
        seed=seed,
        k_delays=k_tau * dt,
        k_widths=k_sig * dt,
        q_widths=q_sig * dt,
        debug=False
    )
    fit_layer = SRMLayerFixed(
        n_neurons=dim,
        n_inputs=n_hidden,
        out_weights=None,
        phi_k=Hat(),
        seed=seed,
        debug=False
    )

    # BEGIN S-SWIM

    cp.cuda.Stream.null.synchronize()  # Ensure GPU is idle before starting
    start = time.perf_counter()

    # Sampling Metric
    d_in, d_out = get_metric(metrics=metrics, x=x, y=y, OL=OL, num_samples=1000, min_norm=1e-3)

    # Sample pairs for weight computation
    pair_idcs = sample_pairs(X=x, Y=y, k=n_hidden, d_in=d_in, d_out=d_out, min_norm=1e-3, plot=False, demean_x=True,
                             demean_y=True, top_k=False, seed=seed)

    # Sampling Layer weights
    sample_layer.init_swim(s_in=x, pair_idcs=pair_idcs, t_max=tmax, dt=dt, max_it=200, eig_tol=1e-6,
                           norm_sign=False, objective=weight_objective, solver='cupy', compute_batch_size=500)
    sample_layer.normalise_weights(s_in=x, t_max=tmax, dt=dt, fit_out_weights=True, normaliser=normaliser,
                                   slice_size=500, silence_correction=True)

    # Fit Layer Temporal params
    trains_sample = sample_layer.compute_full_trains(x, ts, return_vs=ReturnVs.NO, num_spikes=num_spikes, pad=25, progress=TQDM_PROGRESS)
    x = None
    fit_layer.tau_correlation(trains_sample, f=y, t_max=tmax, dt=dt, demean_f=True, OL=OL, plot=False,
                              flip=False)
    fit_layer.sig_search(s_in=trains_sample, f=y, dt=dt, eval_ts=ts, fit_idcs=fit_idcs, n_sig=30, n_threads=32,
                         progress=TQDM_PROGRESS, plot=False, alpha=2)

    y = None
    trains_sample = None

    # Fit Layer weights
    best_regs = fit_layer.fit_l2_minibatch(data=data, pre=sample_layer, eval_ts=ts, fit_idcs=fit_idcs, batch_size=batch_size,
                               n_reg_search=n_reg_search, reg_min=-5, reg_max=0.5, reg=reg, n_threads=32, neuron_wise_best_reg=True,
                               num_spikes=num_spikes, pad=50, progress=TQDM_PROGRESS)

    cp.cuda.Stream.null.synchronize()  # Wait until GPU has finished
    end = time.perf_counter()

    # END S-SWIM

    mem_summary(free=True, vardict=locals())

    # Saving weights
    sample_layer.save_params(results_dir / name / f"sample_{seed}.npz")
    fit_layer.save_params(results_dir / name / f"fit_{seed}.npz")
    np.savez(results_dir / name / f"best_regs_{seed}.npz", best_regs=best_regs.get())

    # Eval
    print("Eval")
    data_val = data.iterate(batch_size=batch_size, target="val")

    r2_val, rse_val = eval_set(data_val, sample_layer=sample_layer, fit_layer=fit_layer, ts=ts, fit_idcs=fit_idcs,
                               num_spikes=num_spikes,
                               pad=50)

    print(f"R2 val: {r2_val}")
    print(f"RSE val: {rse_val}")

    data_test = data.iterate(batch_size=batch_size, target="test")

    r2_test, rse_test = eval_set(data_test, sample_layer=sample_layer, fit_layer=fit_layer, ts=ts, fit_idcs=fit_idcs,
                                 num_spikes=num_spikes,
                                 pad=50)

    print(f"R2 test: {r2_test}")
    print(f"RSE test: {rse_test}")

    results_seed["r2_test"] = float(r2_test.get())
    results_seed["rse_test"] = float(rse_test.get())
    results_seed["r2_val"] = float(r2_val.get())
    results_seed["rse_val"] = float(rse_val.get())
    results_seed["d_in"] = d_in.__str__()
    results_seed["d_out"] = d_out.__str__()
    results_seed["time"] = end - start
    return results_seed


def run_experiment(config: Dict):
    results = {'config': config}
    results_dir = Path(__file__).resolve().parent / config['results_dir']
    name = config['name']

    for seed in config['data']['seeds']:
        print("-" * 25)
        print(f"Starting experiment: {name}/{seed}")
        print("-" * 25)

        results_seed = run_trial(name, results_dir, seed, config)

        results[seed] = results_seed


    # Compute average over seeds
    averages = {}
    metrics = ["r2_test", "rse_test", "r2_val", "rse_val", "time"]
    for metric in metrics:
        values = [results[seed][metric] for seed in config['data']['seeds'] if cp.isfinite(results[seed][metric])]
        averages[metric] = sum(values) / len(values)
    results["averages"] = averages

    with open(results_dir / f"{name}.json", "w") as res_file:
            json.dump(results, res_file)


config_path = sys.argv[1]
with open(config_path, "r") as config_file:
    config = json.load(config_file)
try:
    run_experiment(config)
except Exception as e:
    traceback.print_exc()
