import numpy as np
import torch as tr
from torch import nn
from scipy.stats import ttest_ind, sem
from tqdm import tqdm
import gc

def generate_mixing(dim=5, sparsity=2.0, scale=1.0, add_drift=True):
  all_coef = np.random.normal(size=(dim, dim))
  # "sparsity" is more like (1-sparsity). also, we divide by dim to make it commensurable
  mask = np.random.binomial(1, sparsity/dim, size=(dim, dim))
  # we don't want matrix to be symmetrical because that limits the structure
  coef = all_coef * mask
  eigval = np.real(np.linalg.eigvals(coef))
  drift = 1/np.sqrt(dim) if add_drift else 0
  # real part of eigenvalues should be nonnegative, to make sure we don't blow up
  coef -= (np.amin(eigval) - drift) * np.eye(dim)
  ## normalize the scale of the matrix
  #eigval_mag = np.abs(np.linalg.eigvals(coef))
  #coef /= np.amax(eigval_mag)
  # find the dimension that is most affected and put that first.
  # and put second the dimension with the strongest forward effect.
  backward_effects = np.sum(coef**2, axis=1)
  forward_effects = np.sum(coef**2, axis=0)
  indices = np.arange(dim)
  max_backward_idx = np.argmax(backward_effects)
  forward_effects[max_backward_idx] = 0 # don't want to pick the same dimension twice
  max_forward_idx = np.argmax(forward_effects)
  remaining_indices = indices[~np.isin(indices, [max_backward_idx, max_forward_idx])]
  indices = np.concatenate([[max_backward_idx], [max_forward_idx], remaining_indices])
  coef = coef[indices, :][:, indices]
  return coef * scale

# I may want to consider scaling the noise (inversely) by the number of dimensions, but that complicates operations.
def simulate_step(state, mixing, noise_scale, dt, nonlinearity):
  batch_size, dim = state.shape
  subsample = round(1/dt)
  for k in range(subsample):
    noise = noise_scale * np.random.normal(size=(batch_size, dim))
    drift = nonlinearity(state) @ mixing.T
    derivative = -drift + noise
    state += derivative * dt
  return state

def simulate_process(mixing, length, batch_size, noise_scale=0.1, dt=0.1, nonlinearity=(lambda x: x)):
  dim = mixing.shape[0]
  #if noise_scale is None:
  #  noise_scale = 1/np.sqrt(dim) # variance is additive
  sample = np.zeros((batch_size, dim, length))
  for t in tqdm(range(1, length)):
    state = sample[:, :, t-1].copy()
    simulate_step(state, mixing, noise_scale, dt, nonlinearity)
    sample[:, :, t] = state
  return sample

class LinearPredictor(nn.Module):
  def __init__(self, in_channels, time_interval, out_channels):
    super().__init__()
    self.model = nn.Linear(in_channels*time_interval, out_channels)
    self.time_interval = time_interval
    self.in_channels = in_channels
    self.out_channels = out_channels

  def forward(self, input):
    flat = input.reshape(input.shape[0], self.in_channels*self.time_interval)
    output = self.model(flat)
    return output


class MLP(nn.Module):
  def __init__(self, in_channels, time_interval, out_channels, hidden_channels, dropout=1e-4, n_layers=2):
    super().__init__()
    layers = []
    layers.append(nn.Linear(in_channels*time_interval, hidden_channels))
    layers.append(nn.SELU())
    layers.append(nn.Dropout(dropout))
    for _ in range(n_layers-1):
      layers.append(nn.Linear(hidden_channels, hidden_channels))
      layers.append(nn.SELU())
      layers.append(nn.Dropout(dropout))
    layers.append(nn.Linear(hidden_channels, out_channels))
    self.model = nn.Sequential(*layers)
    self.time_interval = time_interval
    self.in_channels = in_channels
    self.out_channels = out_channels

  def forward(self, input):
    flat = input.reshape(input.shape[0], self.in_channels*self.time_interval)
    output = self.model(flat)
    return output


def train_dynamics(sample, n_obs, time_interval=1,
                   model_type='linear', mlp_dim=256,
                   n_epochs=128, batch_size=4096, learning_rate=1e-3,
                   device='cpu'):
  data = tr.from_numpy(sample[:, :n_obs, :]).float().to(device)
  sim_batch_size, dim, length = data.shape
  train_length = int(0.9*length)
  test_length = length - train_length
  train_data = data[:, :, :train_length]
  test_data = data[:, :, train_length:]
  test_inputs = tr.cat([
      test_data[:, :n_obs, t-time_interval:t]
      for t in range(time_interval, test_length)
    ], dim=0)
  test_outputs = tr.cat([
      test_data[:, :n_obs, t] - test_data[:, :n_obs, t-1]
      for t in range(time_interval, test_length)
    ], dim=0)
  train_inputs = tr.cat([
      train_data[:, :n_obs, t-time_interval:t]
      for t in range(time_interval, train_length)
    ], dim=0)
  train_outputs = tr.cat([
      train_data[:, :n_obs, t] - train_data[:, :n_obs, t-1]
      for t in range(time_interval, train_length)
    ], dim=0)
  baseline_outputs = tr.cat([
      tr.zeros_like(train_data[:, :n_obs, t-1]) # just using past as baseline prediction
      for t in range(time_interval, train_length)
    ], dim=0)
  dim_scales = tr.std(data, dim=(0, 2))
  baseline_loss = np.sqrt(tr.mean(((baseline_outputs - train_outputs) / dim_scales)**2).item())
  if model_type == 'linear':
    model = LinearPredictor(in_channels=n_obs, time_interval=time_interval, out_channels=n_obs).to(device)
  elif model_type == 'mlp':
    model = MLP(in_channels=n_obs, time_interval=time_interval, out_channels=n_obs,
                hidden_channels=mlp_dim).to(device)
  model.train()
  optim = tr.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-8)
  progress = tqdm(range(n_epochs))
  train_losses, test_losses = [], []
  for epoch in progress:
    epoch_length = train_inputs.shape[0]
    indices = np.random.choice(range(epoch_length), size=epoch_length, replace=False)
    batches = tr.from_numpy(indices).to(device).split(batch_size)
    #if epoch == 0: print(f"Number of batches: {len(batches)}.")
    total_loss = []
    for batch in batches:
      input = train_inputs[batch, :, :]
      output = train_outputs[batch, :]
      prediction = model(input)
      loss = tr.mean(((prediction - output) / dim_scales)**2)
      loss.backward()
      optim.step()
      optim.zero_grad()
      total_loss.append(loss.item())
    test_prediction = model(test_inputs)
    test_error = (test_prediction - test_outputs) / dim_scales
    test_loss = np.sqrt(tr.mean(test_error**2).item()) / baseline_loss
    train_loss = np.sqrt(np.mean(total_loss)) / baseline_loss
    train_losses.append(train_loss) 
    test_losses.append(test_loss)
    test_noise = tr.mean((test_prediction - test_outputs)**2, dim=0
      ).sqrt().cpu().detach().numpy()
    progress.set_description(
      f"Train Loss: {train_loss:.5f} | Test Loss: {test_loss:.5f} | Noise: {np.mean(test_noise):.5f}" )
  model.eval().to('cpu')
  return model, train_losses, test_losses, test_noise

# MPPI following TD-MPC (https://arxiv.org/pdf/2203.04955)
# note that hidden confounding doesn't only come from the hidden dimensions,
# but also from the frequency mismatch between the control and the finer-grained process
def control_process(model, mixing, n_obs, length, batch_size,
                    log_gamma=0.0, sensitivity_model='novel',
                    noise_scale=0.1, dt=0.1, nonlinearity=(lambda x: x),
                    time_interval=1, horizon=16, temperature=0.01, n_iterations=5,
                    n_trajectories=512, n_reward_trajectories=64, n_top_trajectories=32, 
                    model_noise=None, warmup=16, seed=None):
  if seed is not None:
    np.random.seed(seed)
  if model_noise is None:
    model_noise = np.ones(n_obs) * noise_scale
  assert sensitivity_model in ['novel', 'cmsm', 'empirical', 'pass']
  warmup += time_interval
  length += warmup # makes argument more intuitive
  gamma = np.exp(log_gamma)
  dim = mixing.shape[0]
  states = np.zeros((batch_size, dim, length))
  for t in range(1, warmup):
    state = states[:, :, t-1].copy()
    simulate_step(state, mixing, noise_scale, dt, nonlinearity)
    states[:, :, t] = state
  trajectories = np.zeros((batch_size, n_trajectories, n_reward_trajectories, n_obs, time_interval+horizon))
  actions = np.zeros((batch_size, 2, length-time_interval))
  progress = tqdm(range(warmup, length))
  for t in progress:
    if sensitivity_model != 'pass':
      ## PLANNING
      weight_means = []
      gamma_vars = []
      trajectories[..., :time_interval] = states[:, None, None, :n_obs, t-time_interval:t]
      plan_mean, plan_var = None, None
      for i in range(n_iterations):
        for k in range(horizon):
          if plan_mean is not None:
            # for the second n_trajectories dimension, we repeat the actions for conditional sampling
            noise = np.sqrt(plan_var[:, None, k]) * np.random.normal(size=(batch_size, n_trajectories))
            trajectories[:, :, :, 1, time_interval+k-1] = (plan_mean[:, None, k] + noise)[:, :, None]
          input = trajectories[..., k:time_interval+k].reshape(batch_size*n_trajectories*n_reward_trajectories, n_obs, time_interval)
          output = model(tr.from_numpy(input).float())
          prediction = output.detach().numpy().reshape(batch_size, n_trajectories, n_reward_trajectories, n_obs)
          innovation = model_noise * np.random.normal(size=(batch_size, n_trajectories, n_reward_trajectories, n_obs))
          trajectories[..., time_interval+k] = trajectories[..., time_interval+k-1] + prediction + innovation
        future_outcomes = trajectories[:, :, :, 0, time_interval:] # try to minimize the first dimension
        # actions are identical along the second n_trajectories dimension
        future_actions = trajectories[:, :, 0, 1, time_interval:] # control the second dimension
        if plan_mean is None:
          first_iter_actions = future_actions.copy()
        action_differences = future_actions[:, :, None, :] - first_iter_actions[:, None, :, :]
        action_scale = np.std(first_iter_actions, axis=(1,2))[:, None, None]
        action_distances = np.linalg.norm(action_differences, ord=2, axis=-1) / action_scale
        if sensitivity_model == 'novel':
          lower_gamma_expectation = np.maximum(1e-3,
            np.mean(np.power(gamma, -action_distances), axis=-1) )
          upper_gamma_expectation = np.minimum(1e+3,
            np.mean(np.power(gamma, +action_distances), axis=-1) )
        elif sensitivity_model == 'cmsm':
          lower_gamma_expectation = (1/gamma) * np.ones((batch_size, n_trajectories))
          upper_gamma_expectation = gamma * np.ones((batch_size, n_trajectories))
        rewards = np.mean(future_outcomes**2, axis=-1) # actually negative rewards. (batch_size x n_trajectories x n_trajectories)
        if sensitivity_model == 'empirical':
          # in this alternative formulation (baseline), log_gamma represents the quantile we should take
          reward_bounds = np.quantile(rewards, log_gamma, axis=-1)
          gamma_vars.append(0)
        else:
          if log_gamma > 0:
            quantile_threshold = (upper_gamma_expectation - 1) / (upper_gamma_expectation - lower_gamma_expectation)
            #reward_quantile = np.array([ [
            #  np.quantile(rewards[b, j, :], quantile_threshold[b, j])
            #  for j in range(n_trajectories) ] for b in range(batch_size) ])
            # faster version incoming, since np.quantile can't batch along second argument
            sorted_rewards = np.sort(rewards, axis=-1)
            quantile_index = np.round(quantile_threshold * (n_reward_trajectories-1)).astype(int)
            reward_quantile = np.take_along_axis(sorted_rewards, quantile_index[..., None], axis=-1)[..., 0]
          else:
            reward_quantile = np.zeros((batch_size, n_trajectories))
          assert reward_quantile.shape == (batch_size, n_trajectories)
          gamma_weights = np.where( rewards >= reward_quantile[:, :, None],
            upper_gamma_expectation[:, :, None], lower_gamma_expectation[:, :, None] )
          # weights should be close to 1 but they aren't always, due to numerical instability and discretization
          gamma_weights /= np.mean(gamma_weights, axis=-1)[..., None]
          gamma_vars.append( np.var(gamma_weights, axis=-1) )
          reward_bounds = np.mean(rewards * gamma_weights, axis=-1)
        top_trajectory_indices = np.argsort(reward_bounds, axis=1)[:, :n_top_trajectories]
        top_rewards = np.take_along_axis(reward_bounds, top_trajectory_indices, axis=1)
        top_trajectories = np.take_along_axis(trajectories, top_trajectory_indices[..., None, None, None], axis=1)
        top_actions = top_trajectories[:, :, 0, 1, time_interval:]
        norm_rewards = top_rewards / np.mean(top_rewards, axis=-1)[..., None]
        weights = np.exp(-norm_rewards / temperature) # we normalize for temperature, but NOT gamma
        weights /= np.mean(weights, axis=1)[:, None]
        # adjust temperature to target a certain weight variance? hold off on this additional complexity
        weight_means.append(np.var(weights))
        plan_mean = np.mean(weights[..., None] * top_actions, axis=1)
        plan_var = np.mean(weights[..., None] * (top_actions - plan_mean[:, None, :])**2, axis=1)
      ## ACTING
      actions[:, 0, t-time_interval] = plan_mean[:, 0]
      actions[:, 1, t-time_interval] = np.sqrt(plan_var[:, 0])
      # I use stochastic action for planning, but ultimately I just pick the mean
      action = plan_mean[:, 0] #+ np.sqrt(plan_var[:, 0]) * np.random.normal(size=(batch_size))
      states[:, 1, t-1] = action # for now, we don't hold the action for the whole tick interval
      progress.set_description(
        f"Weight uncertainty {np.sqrt(np.mean(weight_means)):.5f} | "
        + f"Action spread {np.std(plan_mean):.3f} | "
        + f"Action uncertainty {np.sqrt(np.mean(plan_var)):.3f} | "
        + f"Partial identification uncertainty {np.sqrt(np.mean(gamma_vars)):.3f}" )
    state = states[:, :, t-1].copy()
    simulate_step(state, mixing, noise_scale, dt, nonlinearity)
    states[:, :, t] = state
  return states, actions
