import json
import numpy as np
import os
import pathlib
import random
from scipy import stats
import sys
import torch

import config
from models import *
from tasks import *
import trainer

task = str(sys.argv[1])
noise_str = str(sys.argv[2])
seed = str(sys.argv[3])

random.seed(int(seed))
np.random.seed(int(seed))
torch.manual_seed(int(seed))
torch.cuda.manual_seed_all(int(seed))
torch.cuda.manual_seed(int(seed))
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

configs = {
    "spatial_navigation": ["noisy_unbiased", "noisy_biased", "no-noise_unbiased", "no-noise_biased"],
    "head_direction": ["noisy_unbiased", "no-noise_unbiased"]
}

epochs = {
    "spatial_navigation": "2500",
    "head_direction": "20000"
}

kdes = {"wake": dict(), "sleep": dict()}

for config in configs[task]:
    if "no-noise" in config:
        path = f"{task}/{config}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        t = model.task
        model.device = 'cpu'
        model.to('cpu')

        test_data = t.get_test_batch()

        h_data = model(test_data["data"], test_data["init_state"])

        data_xy = t.place_cells.get_nearest_cell_pos(h_data[2]).cpu().detach().numpy()

        wake = stats.gaussian_kde(data_xy.reshape(-1,2).T)

        kdes["wake"][path] = wake
        print("Wake done")
    else:
        path = f"{task}/{config}_{noise_str}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        t = model.task
        model.device = 'cpu'
        model.to('cpu')

        test_data = t.get_test_batch()
        t_test = 1000

        h_data = model(test_data["data"], test_data["init_state"])

        data_xy = t.place_cells.get_nearest_cell_pos(h_data[2]).cpu().detach().numpy()

        wake = stats.gaussian_kde(data_xy.reshape(-1,2).T)
        print("Wake done")

        kdes["wake"][path] = wake

r_t = spatial_navigation.SpatialNavigation(
    box_width=2.2,
    box_height=2.2,
    border_region=0.03,
    border_slow_factor=0.25,
    init_pos='uniform',
    biased=False,
    biased_ratio=1.0,
    drift_const=0.05,
    anchor_point=np.array([0, 0]),
    dt=0.02,
    sigma=11.52,
    b=0.26 * np.pi,
    mu=0,
    use_place_cells=True,
    place_cells_num=512,
    place_cells_sigma=0.2,
    place_cells_surround_scale=2,
    place_cells_dog=False,
    sequence_length=100,
    batch_size=200
)

random_nn_net = rnn.RNN(
    task=r_t,
    n_in=2,
    n_rec=512,
    n_out=512,
    n_init=512,
    sigma_in=0,
    sigma_rec=0,
    sigma_out=0,
    dt=0.2,
    tau=1,
    feedback_freq=0,
    bias=False,
    activation_fn="relu",
    device="cpu"
)
random_nn_net.n_init = 512
random_nn_net.eval()

t_test = 1000
noise = torch.zeros(200, t_test, 2)
noise_init = 2.2 * torch.rand(200, 1, 2) - 1.1
noise_init_pc = r_t.place_cells.get_activation(noise_init)

h_noise_r_nn = random_nn_net(noise, noise_init_pc)
noise_r_nn_xy = r_t.place_cells.get_nearest_cell_pos(h_noise_r_nn[2]).cpu().detach().numpy()  

sleep_r_nn = stats.gaussian_kde(noise_r_nn_xy.reshape(-1,2).T)
kdes["sleep"]["random_no-noise"] = sleep_r_nn

print("Sleep random no-noise done")

random_n_net = rnn.RNN(
    task=r_t,
    n_in=2,
    n_rec=512,
    n_out=512,
    n_init=512,
    sigma_in=0,
    sigma_rec=np.sqrt(eval(noise_str)),
    sigma_out=0,
    dt=0.2,
    tau=1,
    feedback_freq=0,
    bias=False,
    activation_fn="relu",
    device="cpu"
)
random_n_net.n_init = 512
random_n_net.eval()

t_test = 1000
noise = torch.zeros(200, t_test, 2)
noise_init = 2.2 * torch.rand(200, 1, 2) - 1.1
noise_init_pc = r_t.place_cells.get_activation(noise_init)

random_n_net.sigma_rec *= np.sqrt(2)
h_noise_r_n = random_n_net(noise, noise_init_pc)
noise_r_n_xy = r_t.place_cells.get_nearest_cell_pos(h_noise_r_n[2]).cpu().detach().numpy()  

sleep_r_n = stats.gaussian_kde(noise_r_n_xy.reshape(-1,2).T)
kdes["sleep"]["random_noisy"] = sleep_r_n

print("Sleep random noisy done")

print("Computed KDEs")

kde_kls = {key: dict() for key in kdes["sleep"].keys()}

for sk, sleep in kdes["sleep"].items():
    n = 2500
    print("Sleep:", sk)
    points = sleep.resample(n)
    s_pdf = sleep.pdf(points)
    for wk, wake in kdes["wake"].items():
        print("Wake:", wk)
        w_pdf = wake.pdf(points)
        kde_kls[sk][wk] = np.log(s_pdf / w_pdf).mean()
        print(kde_kls[sk][wk])

print(kde_kls)
json.dump(kde_kls, open(f"offline_reactivation/{task}/randomnets/kl_{noise_str}_{seed}.json", "w"), indent=4)
