import json
import numpy as np
import os
import pathlib
import random
from scipy import stats
import sys
import torch

import config
from models import *
from tasks import *
import trainer

task = str(sys.argv[1])
noise_str = str(sys.argv[2])
seed = str(sys.argv[3])

random.seed(int(seed))
np.random.seed(int(seed))
torch.manual_seed(int(seed))
torch.cuda.manual_seed_all(int(seed))
torch.cuda.manual_seed(int(seed))
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

configs = {
    "spatial_navigation": ["noisy_unbiased", "noisy_biased", "no-noise_unbiased", "no-noise_biased"],
    "head_direction": ["noisy_unbiased", "no-noise_unbiased"]
}

epochs = {
    "spatial_navigation": "2500",
    "head_direction": "20000"
}

entropy = dict()

for config in configs[task]:
    if "no-noise" in config:
        path = f"{task}/{config}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        model.device = 'cuda'
        model.to('cuda')
        t = model.task
        model.sigma_rec *= np.sqrt(2)

        t_test = 1000
        noise = torch.zeros(500, t_test, 1)
        noise_init = 2 * np.pi * torch.rand(500, 1, 1) - np.pi
        noise_init_pc = t.hd_cells.get_activation(noise_init)

        h_noise = model(noise.cuda(), noise_init_pc.cuda())

        noise_xy = t.hd_cells.decode_hd(h_noise[2].cpu()).cpu().detach().numpy()

        t_shift = 500
        av_var = np.array([np.cov(noise_xy[i,t_shift:].T) for i in range(noise_xy.shape[0])]).mean()
        entropy[path] = av_var
        print(entropy[path])
    else:
        path = f"{task}/{config}_{noise_str}_"
        if "_biased" in config:
            path += "1.0_0.05_"
        path += seed
        print(path)
        model = torch.load(f"offline_reactivation/{path}/model_{epochs[task]}.pt")
        model.n_init = 512
        model.eval()
        model.device = 'cuda'
        model.to('cuda')
        t = model.task
        model.sigma_rec *= np.sqrt(2)

        t_test = 1000
        noise = torch.zeros(500, t_test, 1)
        noise_init = 2 * np.pi * torch.rand(500, 1, 1) - np.pi
        noise_init_pc = t.hd_cells.get_activation(noise_init)

        h_noise = model(noise.cuda(), noise_init_pc.cuda())

        noise_xy = t.hd_cells.decode_hd(h_noise[2].cpu()).cpu().detach().numpy()

        t_shift = 500
        av_var = np.array([np.cov(noise_xy[i,t_shift:].T) for i in range(noise_xy.shape[0])]).mean()
        entropy[path] = av_var
        print(entropy[path])

json.dump(entropy, open(f"offline_reactivation/{task}/entropy_{noise_str}_{seed}.json", "w"), indent=4)
