import numpy as np
import torch as tr
from torch import nn
from scipy.stats import ttest_ind, ttest_rel, ttest_1samp, sem, pearsonr
from tqdm import tqdm
from matplotlib import pyplot as plt
plt.ion()
plt.style.use("../style.mplstyle")
from cmcrameri import cm
import gc
import pickle as pkl
from ou_process import *

dim = 16
n_obs = 8
time_interval = 4
nonlinearity = lambda x: 2/(1 + np.exp(-x)) - 1 #lambda x: x**3

np.random.seed(0)
tr.manual_seed(0)

mixings = []
samples = []
while len(mixings) < 256:
  mixing = generate_mixing(dim)
  short_sample = simulate_process(mixing, 1024, 1, nonlinearity=nonlinearity)
  spread = np.std(short_sample)
  if np.isnan(spread) or spread > 1 or spread < 0.01:
    print(f"SKIPPING process with spread {spread:.3f}.")
    continue
  short_diffs = np.diff(short_sample, axis=-1)
  short_pasts = short_diffs[..., :-1]
  short_futures = short_diffs[..., 1:]
  short_action_correlation = pearsonr(short_pasts[0, n_obs:, :], short_futures[0, [1], :], axis=-1).statistic
  short_reward_correlation = pearsonr(short_pasts[0, n_obs:, :], short_futures[0, [0], :], axis=-1).statistic
  short_correlation = np.sqrt(np.amax(np.abs(short_action_correlation * short_reward_correlation)))
  if short_correlation < 0.33:
    print(f"SKIPPING process with spread {spread:.3f} and total confounding {short_correlation:.3f}.")
    continue
  print(f"ADDING #{len(mixings)+1} process with spread {spread:.3f} and total confounding {short_correlation:.3f}.")
  sample = simulate_process(mixing, 1024, 1024, nonlinearity=nonlinearity)
  mixings.append(mixing)
  samples.append(sample)


models_mlp = []
trains_mlp, tests_mlp = [], []
noises_mlp = []
models_linear = []
trains_linear, tests_linear = [], []
noises_linear = []
for i, sample in enumerate(samples):
  print(f"Learning sample {i+1} of {len(mixings)}.")
  np.random.seed(0)
  tr.manual_seed(0)
  model_mlp, train_mlp, test_mlp, noise_mlp = train_dynamics(sample, n_obs, time_interval, model_type='mlp', learning_rate=2e-5, n_epochs=512, device='mps')
  models_mlp.append(model_mlp)
  trains_mlp.append(train_mlp)
  tests_mlp.append(test_mlp)
  noises_mlp.append(noise_mlp)
  model_linear, train_linear, test_linear, noise_linear = train_dynamics(sample, n_obs, time_interval, model_type='linear', learning_rate=1e-3, n_epochs=32, device='mps')
  models_linear.append(model_linear)
  trains_linear.append(train_linear)
  tests_linear.append(test_linear)
  noises_linear.append(noise_linear)


all_samples = []
all_states_control = []
all_states_naive = []
all_states_novel = []
all_states_cmsm = []
all_states_empirical = []
length = 16; batch_size = 1; n_batches = 1 
for i, (model, mixing, noise) in enumerate(zip(models_mlp, mixings, noises_mlp)):
  print(f"Controlling sample {i+1} of {len(mixings)}.")
  all_samples.append(samples[i])
  states_control = np.concatenate([ control_process(model, mixing,
    n_obs, length=length, batch_size=batch_size, sensitivity_model='pass',
    model_noise=None, nonlinearity=nonlinearity, time_interval=time_interval,
    seed=i*n_batches+j)[0]
    for j in range(n_batches) ], axis=0)
  all_states_control.append(states_control)
  states_naive = np.concatenate([ control_process(model, mixing,
    n_obs, length=length, batch_size=batch_size, log_gamma=0.,
    model_noise=None, nonlinearity=nonlinearity, time_interval=time_interval,
    seed=i*n_batches+j)[0]
    for j in range(n_batches) ], axis=0)
  all_states_naive.append(states_naive)
  all_states_novel.append([])
  all_states_cmsm.append([])
  all_states_empirical.append([])
  for log_gamma in [0.05, 0.1, 0.2, 0.3, 0.4]:
    states_novel = np.concatenate([ control_process(model, mixing,
      n_obs, length=length, batch_size=batch_size, log_gamma=log_gamma,
      sensitivity_model='novel', model_noise=None, nonlinearity=nonlinearity,
      time_interval=time_interval, seed=i*n_batches+j)[0]
      for j in range(n_batches) ], axis=0)
    all_states_novel[-1].append(states_novel)
  for log_gamma in [0.25, 0.5, 1.0, 1.5, 2.0]:
    states_cmsm = np.concatenate([ control_process(model, mixing,
      n_obs, length=length, batch_size=batch_size, log_gamma=log_gamma,
      sensitivity_model='cmsm', model_noise=None, nonlinearity=nonlinearity,
      time_interval=time_interval, seed=i*n_batches+j)[0]
      for j in range(n_batches) ], axis=0)
    all_states_cmsm[-1].append(states_cmsm)
  for quantile in [0.125, 0.250, 0.500]:
    states_empirical = np.concatenate([ control_process(model, mixing,
      n_obs, length=length, batch_size=batch_size, log_gamma=quantile,
      sensitivity_model='empirical', model_noise=None, nonlinearity=nonlinearity,
      time_interval=time_interval, seed=i*n_batches+j)[0]
      for j in range(n_batches) ], axis=0)
    all_states_empirical[-1].append(states_empirical)


def evaluate_results(states):
  return np.array([
    np.amax([np.mean(naive[:,0,:]**2 - novel[:,0,:]**2) for novel in novels] + [0]) / np.mean(naive[:,0,:]**2)
    for (naive, novels) in zip(all_states_naive, states) ])

def evaluate_best_calibration(states):
  return np.array([
    np.argmin([np.mean(novel[:,0,:]**2) for novel in novels])
    for novels in states ])

res_novel = evaluate_results(all_states_novel)
res_cmsm = evaluate_results(all_states_cmsm)
res_empirical = evaluate_results(all_states_empirical)

# these are to test if our grids of gammas are well-positioned. seems like they are, surprisingly so.
cal_novel = evaluate_best_calibration(all_states_novel)
cal_cmsm = evaluate_best_calibration(all_states_cmsm)
cal_empirical = evaluate_best_calibration(all_states_empirical)


def persist(filename):
  data = {
    'mixings': mixings,
    'models_mlp': [model.state_dict() for model in models_mlp],
    'trains_mlp': trains_mlp,
    'tests_mlp': tests_mlp,
    'noises_mlp': noises_mlp,
    'models_linear': [model.state_dict() for model in models_linear],
    'trains_linear': trains_linear,
    'tests_linear': tests_linear,
    'noises_linear': noises_linear,
    'samples': all_samples,
    'states_control': all_states_control,
    'states_naive': all_states_naive,
    'states_novel': all_states_novel,
    'states_cmsm': all_states_cmsm,
    'states_empirical': all_states_empirical,
    'res_novel': res_novel,
    'res_cmsm': res_cmsm,
    'res_empirical': res_empirical,
    'cal_novel': cal_novel,
    'cal_cmsm': cal_cmsm,
    'cal_empirical': cal_empirical,
  }
  with open(filename, 'wb') as f:
    pkl.dump(data, f)