import json
import numpy as np
import os
import pathlib
import random
from scipy import stats
import sys
import torch

import config
from models import *
from tasks import *
import trainer

task = str(sys.argv[1])
noise_str = str(sys.argv[2])
seed = str(sys.argv[3])

random.seed(int(seed))
np.random.seed(int(seed))
torch.manual_seed(int(seed))
torch.cuda.manual_seed_all(int(seed))
torch.cuda.manual_seed(int(seed))
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

configs = {
    "spatial_navigation": ["noisy_unbiased", "noisy_biased", "no-noise_unbiased", "no-noise_biased"],
    "head_direction": ["noisy_unbiased", "no-noise_unbiased"]
}

epochs = {
    "spatial_navigation": "2500",
    "head_direction": "20000"
}

kdes = {"wake": dict(), "sleep": dict()}

for config in configs[task]:
    if "no-noise" in config:
        path = f"{task}/{config}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        t = model.task
        model.device = 'cpu'
        model.to('cpu')

        test_data = t.get_test_batch()
        t_test = 1000
        noise = torch.zeros(200, t_test, 1)
        noise_init = torch.pi * 2 * torch.rand(200, 1, 1) - torch.pi
        noise_init_pc = t.hd_cells.get_activation(noise_init)

        h_data = model(test_data["data"], test_data["init_state"])
        h_noise = model(noise, noise_init_pc)
        model.sigma_rec = np.sqrt(2 * eval(noise_str))
        h_noise_noisy = model(noise, noise_init_pc)

        data_xy = t.hd_cells.decode_hd(h_data[2]).cpu().detach().numpy()
        noise_xy = t.hd_cells.decode_hd(h_noise[2]).cpu().detach().numpy()
        noise_noisy_xy = t.hd_cells.decode_hd(h_noise_noisy[2]).cpu().detach().numpy()

        wake = stats.gaussian_kde(data_xy.reshape(-1,1).T)
        print("Wake done")
        sleep = stats.gaussian_kde(noise_xy.reshape(-1,1).T)
        print("Sleep done")
        sleep_noisy = stats.gaussian_kde(noise_noisy_xy.reshape(-1,1).T)
        print("Sleep noisy done")

        kdes["wake"][path] = wake
        kdes["sleep"][path] = sleep
        kdes["sleep"]["noisy_"+path] = sleep_noisy
    else:
        path = f"{task}/{config}_{noise_str}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        t = model.task
        model.device = 'cpu'
        model.to('cpu')

        test_data = t.get_test_batch()
        t_test = 1000
        noise = torch.zeros(200, t_test, 1)
        noise_init = torch.pi * 2 * torch.rand(200, 1, 1) - torch.pi
        noise_init_pc = t.hd_cells.get_activation(noise_init)

        h_data = model(test_data["data"], test_data["init_state"])
        model.sigma_rec *= np.sqrt(2)
        h_noise = model(noise, noise_init_pc)

        data_xy = t.hd_cells.decode_hd(h_data[2]).cpu().detach().numpy()
        noise_xy = t.hd_cells.decode_hd(h_noise[2]).cpu().detach().numpy()

        wake = stats.gaussian_kde(data_xy.reshape(-1,1).T)
        print("Wake done")
        sleep = stats.gaussian_kde(noise_xy.reshape(-1,1).T)
        print("Sleep done")

        kdes["wake"][path] = wake
        kdes["sleep"][path] = sleep

print("Computed KDEs")

kde_kls = {key: dict() for key in kdes["sleep"].keys()}

for sk, sleep in kdes["sleep"].items():
    n = 2500
    print("Sleep:", sk)
    points = sleep.resample(n)
    s_pdf = sleep.pdf(points)
    for wk, wake in kdes["wake"].items():
        print("Wake:", wk)
        w_pdf = wake.pdf(points)
        kde_kls[sk][wk] = np.log(s_pdf / w_pdf).mean()
        print(kde_kls[sk][wk])


kde_kls["uniform"] = dict()
u_s = stats.uniform(loc=-np.pi,scale=2*np.pi)
points = u_s.rvs((n,1)).T
u_s_pdf = u_s.pdf(points)
for wk, wake in kdes["wake"].items():
    w_pdf = wake.pdf(points)
    kde_kls["uniform"][wk] = np.log((u_s_pdf / w_pdf).clip(1e-7,1e7)).mean()

print(kde_kls)
json.dump(kde_kls, open(f"offline_reactivation/{task}/kl_{noise_str}_{seed}.json", "w"), indent=4)
