import json
import numpy as np
import os
import pathlib
import random
from scipy import stats
import sys
import torch

import config
from models import *
from tasks import *
import trainer

task = str(sys.argv[1])
noise_str = str(sys.argv[2])
seed = str(sys.argv[3])

random.seed(int(seed))
np.random.seed(int(seed))
torch.manual_seed(int(seed))
torch.cuda.manual_seed_all(int(seed))
torch.cuda.manual_seed(int(seed))
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

configs = {
    "spatial_navigation": ["noisy_unbiased", "noisy_biased", "no-noise_unbiased", "no-noise_biased"],
    "head_direction": ["noisy_unbiased", "no-noise_unbiased"]
}

epochs = {
    "spatial_navigation": "2500",
    "head_direction": "20000"
}

entropy = dict()

for config in configs[task]:
    if "no-noise" in config:
        path = f"{task}/{config}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        model.device = 'cuda'
        model.to('cuda')
        t = model.task
        model.sigma_rec *= np.sqrt(2)

        t_test = 2000

        noise = torch.zeros(500, t_test, 2)
        noise_init = 2.2 * torch.rand(500, 1, 2) - 1.1
        noise_init_pc = t.place_cells.get_activation(noise_init)

        h_noise = model(noise.cuda(), noise_init_pc.cuda())

        noise_xy = t.place_cells.get_nearest_cell_pos(h_noise[2].cpu()).cpu().detach().numpy()

        t_shift = 1000
        av_var = np.array([np.trace(np.cov(noise_xy[i,t_shift:].T)) for i in range(noise_xy.shape[0])]).mean()
        entropy[path] = av_var
        print(entropy[path])

        if seed == "0" and noise_str == "0.0001":
            np.save(f"offline_reactivation/{task}/entropy_{config}_noise_xy_{noise_str}_{seed}.npy", noise_xy)
            print("Saved")
    else:
        path = f"{task}/{config}_{noise_str}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        model.device = 'cuda'
        model.to('cuda')
        t = model.task
        model.sigma_rec *= np.sqrt(2)

        t_test = 2000
        noise = torch.zeros(500, t_test, 2)
        noise_init = 2.2 * torch.rand(500, 1, 2) - 1.1
        noise_init_pc = t.place_cells.get_activation(noise_init)

        h_noise = model(noise.cuda(), noise_init_pc.cuda())

        noise_xy = t.place_cells.get_nearest_cell_pos(h_noise[2].cpu()).cpu().detach().numpy()

        t_shift = 1000
        av_var = np.array([np.trace(np.cov(noise_xy[i,t_shift:].T)) for i in range(noise_xy.shape[0])]).mean()
        entropy[path] = av_var
        print(entropy[path])

        if seed == "0" and noise_str == "0.0001":
            np.save(f"offline_reactivation/{task}/entropy_{config}_noise_xy_{noise_str}_{seed}.npy", noise_xy)
            print("Saved")

json.dump(entropy, open(f"offline_reactivation/{task}/entropy_{noise_str}_{seed}.json", "w"), indent=4)
