"""
Manually change the auto correlation matrix in the cache to the covariance matrix.
"""

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import sys
sys.path.append("google-research/")


from jax import vmap
from jax.interpreters import xla
from jax import lax
from jax.random import normal, PRNGKey, split
import numpy as onp
from jax import numpy as jnp
from bnn_hmc.utils import checkpoint_utils
from tqdm import tqdm
import itertools
import pickle

cache_dir = '/home/ooo123/data/cov_caches'

def get_cov_train_cache(match_type, chain_id, sample_id):
  if match_type == 'channel_wise_sep_cov':
    match_type = 'channel_wise_sep'
  if match_type == 'spatial_sep_cov':
    match_type = 'spatial_sep'
  # cache_dir = '/home/ooo123_321321/scratch/cov_caches'
  ckpt_dir = os.path.join(cache_dir, f"{match_type}__{chain_id}__{sample_id}.pkl")
  state_dict = checkpoint_utils.load_checkpoint(ckpt_dir)
  return state_dict

# match_type = 'spatial_sep_cov'
# match_type = 'feature'
match_type = 'spatial_joint'

print(f'Converting {match_type} to {match_type}_cov')

for chain_id, sample_id in tqdm(itertools.product(range(3), range(260))):
  state = get_cov_train_cache(match_type, chain_id, sample_id)
  if 'sep' in match_type:
    for k, v in state.items():
      counter = v['counter']
      cor_H = v['cov_H'] / counter
      mu_H = v['mu_H']
      cor_W = v['cov_W'] / counter
      mu_W = v['mu_W']
      v['cov_H'] = cor_H - jnp.outer(mu_H, mu_H)
      v['cov_W'] = cor_W - jnp.outer(mu_W, mu_W)
      state[k] = v
      break
  else:
    state = dict(state)
    for k, v in state.items():
      v = dict(v)
      counter = v['counter']
      cor = v['cov'] / counter
      mu = v['mu']
      v['cov'] = cor - jnp.outer(mu, mu)
      state[k] = v
  with open(os.path.join(cache_dir, f"{match_type}_cov__{chain_id}__{sample_id}.pkl"), "wb") as f:
      pickle.dump(state, f)