import json
import numpy as np
import os
import pathlib
import random
from scipy import stats
import sys
import torch

import config
from models import *
from tasks import *
import trainer

task = str(sys.argv[1])
noise_str = str(sys.argv[2])
seed = str(sys.argv[3])

random.seed(int(seed))
np.random.seed(int(seed))
torch.manual_seed(int(seed))
torch.cuda.manual_seed_all(int(seed))
torch.cuda.manual_seed(int(seed))
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

configs = {
    "spatial_navigation": ["noisy_unbiased", "noisy_biased", "no-noise_unbiased", "no-noise_biased"],
    "head_direction": ["noisy_unbiased", "no-noise_unbiased"]
}

epochs = {
    "spatial_navigation": "2500",
    "head_direction": "20000"
}

kdes = {"wake": dict(), "sleep": dict()}

for config in configs[task]:
    if "no-noise" in config:
        path = f"{task}/{config}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        t = model.task
        model.device = 'cpu'
        model.to('cpu')

        test_data = t.get_test_batch()

        h_data = model(test_data["data"], test_data["init_state"])
        data_xy = t.hd_cells.decode_hd(h_data[2]).cpu().detach().numpy()

        wake = stats.gaussian_kde(data_xy.reshape(-1,1).T)

        kdes["wake"][path] = wake

        print("Wake done")

        r_t = head_direction.HeadDirection(
            dimensionality='2D',
            init_hd='uniform',
            biased=False,
            drift_const=0.05,
            anchor_angle=0,
            dt=0.02,
            sigma=11.52,
            mu=0,
            use_hd_cells=True,
            hd_cells_num=512,
            hd_cells_angular_spread=np.pi/6,
            sequence_length=100,
            batch_size=200
        )
        random_nn_net = rnn.RNN(
            task=r_t,
            n_in=1,
            n_rec=128,
            n_out=512,
            n_init=512,
            sigma_in=0,
            sigma_rec=0,
            sigma_out=0,
            dt=0.5,
            tau=1,
            feedback_freq=0,
            bias=False,
            activation_fn="relu",
            device="cpu"
        )
        random_nn_net.n_init = 512
        random_nn_net.eval()

        r_nn_test_data = r_t.get_test_batch()

        h_r_nn = random_nn_net(r_nn_test_data["data"], r_nn_test_data["init_state"])
        r_nn_xy = r_t.hd_cells.decode_hd(h_r_nn[2]).cpu().detach().numpy()

        wake_r_nn = stats.gaussian_kde(r_nn_xy.reshape(-1,1).T)
        kdes["wake"]["random_no-noise"] = wake_r_nn

        print("Wake random no-noise done")

        t_test = 1000
        noise = torch.zeros(200, t_test, 1)
        noise_init = torch.pi * 2 * torch.rand(200, 1, 1) - torch.pi
        noise_init_pc = r_t.hd_cells.get_activation(noise_init)

        h_noise_r_nn = random_nn_net(noise, noise_init_pc)
        noise_r_nn_xy = r_t.hd_cells.decode_hd(h_noise_r_nn[2]).cpu().detach().numpy()  

        sleep_r_nn = stats.gaussian_kde(noise_r_nn_xy.reshape(-1,1).T)
        kdes["sleep"]["random_no-noise"] = sleep_r_nn

        print("Sleep random no-noise done")
    else:
        path = f"{task}/{config}_{noise_str}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        t = model.task
        model.device = 'cpu'
        model.to('cpu')

        test_data = t.get_test_batch()

        h_data = model(test_data["data"], test_data["init_state"])
        data_xy = t.hd_cells.decode_hd(h_data[2]).cpu().detach().numpy()

        wake = stats.gaussian_kde(data_xy.reshape(-1,1).T)

        kdes["wake"][path] = wake

        print("Wake done")

        r_t = head_direction.HeadDirection(
            dimensionality='2D',
            init_hd='uniform',
            biased=False,
            drift_const=0.05,
            anchor_angle=0,
            dt=0.02,
            sigma=11.52,
            mu=0,
            use_hd_cells=True,
            hd_cells_num=512,
            hd_cells_angular_spread=np.pi/6,
            sequence_length=100,
            batch_size=200
        )
        random_n_net = rnn.RNN(
            task=r_t,
            n_in=1,
            n_rec=128,
            n_out=512,
            n_init=512,
            sigma_in=0,
            sigma_rec=model.sigma_rec,
            sigma_out=0,
            dt=0.5,
            tau=1,
            feedback_freq=0,
            bias=False,
            activation_fn="relu",
            device="cpu"
        )
        random_n_net.n_init = 512
        random_n_net.eval()

        r_n_test_data = r_t.get_test_batch()

        h_r_n = random_n_net(r_n_test_data["data"], r_n_test_data["init_state"])
        r_n_xy = r_t.hd_cells.decode_hd(h_r_n[2]).cpu().detach().numpy()

        wake_r_n = stats.gaussian_kde(r_n_xy.reshape(-1,1).T)
        kdes["wake"]["random_noisy"] = wake_r_n

        print("Wake random noisy done")

        t_test = 1000
        noise = torch.zeros(200, t_test, 1)
        noise_init = torch.pi * 2 * torch.rand(200, 1, 1) - torch.pi
        noise_init_pc = r_t.hd_cells.get_activation(noise_init)

        random_n_net.sigma_rec *= np.sqrt(2)
        h_noise_r_n = random_n_net(noise, noise_init_pc)
        noise_r_n_xy = r_t.hd_cells.decode_hd(h_noise_r_n[2]).cpu().detach().numpy()  

        sleep_r_n = stats.gaussian_kde(noise_r_n_xy.reshape(-1,1).T)
        kdes["sleep"]["random_noisy"] = sleep_r_n

        print("Sleep random noisy done")

print("Computed KDEs")

kde_kls = {key: dict() for key in kdes["sleep"].keys()}

for sk, sleep in kdes["sleep"].items():
    n = 2500
    print("Sleep:", sk)
    points = sleep.resample(n)
    s_pdf = sleep.pdf(points)
    for wk, wake in kdes["wake"].items():
        print("Wake:", wk)
        w_pdf = wake.pdf(points)
        kde_kls[sk][wk] = np.log(s_pdf / w_pdf).mean()
        print(kde_kls[sk][wk])

print(kde_kls)
json.dump(kde_kls, open(f"offline_reactivation/{task}/randomnets/kl_{noise_str}_{seed}.json", "w"), indent=4)
