'''
Code for Invariant Risk Minimization (IRM) on Colored MNIST.
'''


import argparse
import numpy as np
import torch
from torchvision import datasets
from torch import nn, optim, autograd
import torch.nn.functional as F


parser = argparse.ArgumentParser(description='Colored MNIST')
parser.add_argument('--l2_regularizer_weight', type=float,default=0.001)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--n_restarts', type=int, default=10)
parser.add_argument('--penalty_anneal_iters', type=int, default=100)
parser.add_argument('--penalty_weight', type=float, default=10000.0)
parser.add_argument('--steps', type=int, default=1001)
parser.add_argument('--grayscale_model', action='store_true')
parser.add_argument('--train_len', type=int, default=2000)
flags = parser.parse_args()

print('Flags:')
for k,v in sorted(vars(flags).items()):
  print("\t{}: {}".format(k, v))

final_train_accs = []
final_test_accs = []
for restart in range(flags.n_restarts):
  print("Restart", restart)

  # Load MNIST, make train/val splits, and shuffle train set examples
  
  train_dataset_end = 40000
  mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
  mnist_val = (mnist.data[train_dataset_end:], mnist.targets[train_dataset_end:])
  # Subsample some of the training examples
  train_inds = np.linspace(0, train_dataset_end, flags.train_len, dtype=int)
  mnist.data = mnist.data[train_inds]
  mnist.targets = mnist.targets[train_inds]
  mnist_train = (mnist.data[:train_dataset_end], mnist.targets[:train_dataset_end])

  rng_state = np.random.get_state()
  np.random.shuffle(mnist_train[0].numpy())
  np.random.set_state(rng_state)
  np.random.shuffle(mnist_train[1].numpy())

  # Build environments
  def make_environment(images, labels, e):
    def torch_bernoulli(p, size):
      return (torch.rand(size) < p).float()
    def torch_xor(a, b):
      return (a-b).abs() # Assumes both inputs are either 0 or 1
    # 2x subsample for computational convenience
    images = images.reshape((-1, 28, 28))[:, ::2, ::2]
    # Assign a binary label based on the digit; flip label with probability 0.25
    labels = (labels < 5).float()
    labels = torch_xor(labels, torch_bernoulli(0.25, len(labels)))
    # Assign a color based on the label; flip the color with probability e
    colors = torch_xor(labels, torch_bernoulli(e, len(labels)))
    # Apply the color to the image by zeroing out the other color channel
    images = torch.stack([images, images], dim=1)
    images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0
    # Add a third channel of zeros for blue
    images = torch.cat([images, torch.zeros_like(images[:, :1])], dim=1)
    return {
      'images': (images.float() / 255.).cuda(),
      'labels': labels[:, None].cuda()
    }

  envs = [
    make_environment(mnist_train[0][::2], mnist_train[1][::2], 0.1),
    make_environment(mnist_train[0][1::2], mnist_train[1][1::2], 0.4),
    make_environment(mnist_val[0], mnist_val[1], 0.9)
  ]
  
  # Define convnet
  # Note: very slightly different from DRM since IRM subsamples image inputs
  # Also, no feature normalization here since IRM does not use it
  class ConvNet(nn.Module):
    def __init__(self, feature_size):
      super(ConvNet, self).__init__()
      self.conv1 = nn.Conv2d(3, 32, 3, 1)
      self.conv2 = nn.Conv2d(32, 64, 3, 1)
      self.feature_size = feature_size 
      self.fc1 = nn.Linear(1600, self.feature_size)
      self.final_layer = nn.Linear(self.feature_size, 1)

    def forward(self, x):
      x = self.conv1(x) 
      x = F.relu(x) 
      x = self.conv2(x)
      x = F.relu(x) 
      x = F.max_pool2d(x, 2) 
      x = torch.flatten(x, 1) 
      x = self.fc1(x) 
      x = F.relu(x) 
      x = self.final_layer(x) 
      logits = x 
      
      return logits

  model = ConvNet(feature_size=128).cuda()

  # Define loss function helpers
  def mean_nll(logits, y):
    return nn.functional.binary_cross_entropy_with_logits(logits, y)

  def mean_accuracy(logits, y):
    preds = (logits > 0.).float()
    return ((preds - y).abs() < 1e-2).float().mean()

  def penalty(logits, y):
    scale = torch.tensor(1.).cuda().requires_grad_()
    loss = mean_nll(logits * scale, y)
    grad = autograd.grad(loss, [scale], create_graph=True)[0]
    return torch.sum(grad**2)

  # Train loop

  def pretty_print(*values):
    col_width = 13
    def format_val(v):
      if not isinstance(v, str):
        v = np.array2string(v, precision=5, floatmode='fixed')
      return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))

  optimizer = optim.Adam(model.parameters(), lr=flags.lr)

  pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test acc')

  for step in range(flags.steps):
    for env in envs:
      # ipy.embed()
      logits = model(env['images'])
      # Add a dimension for the logits
      logits = logits.unsqueeze(1) if len(logits.shape) == 1 else logits
      env['nll'] = mean_nll(logits, env['labels'])
      env['acc'] = mean_accuracy(logits, env['labels'])
      env['penalty'] = penalty(logits, env['labels'])
    
    train_nll = torch.stack([env['nll'] for env in envs[:-1]]).mean()
    train_acc = torch.stack([env['acc'] for env in envs[:-1]]).mean()
    train_penalty = torch.stack([env['penalty'] for env in envs[:-1]]).mean()

    weight_norm = torch.tensor(0.).cuda()
    for w in model.parameters():
      weight_norm += w.norm().pow(2)

    loss = train_nll.clone()
    loss += flags.l2_regularizer_weight * weight_norm
    penalty_weight = (flags.penalty_weight 
        if step >= flags.penalty_anneal_iters else 1.0)
    loss += penalty_weight * train_penalty
    if penalty_weight > 1.0:
      # Rescale the entire loss to keep gradients in a reasonable range
      loss /= penalty_weight

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    test_acc = envs[-1]['acc']
    if step % 100 == 0:
      pretty_print(
        np.int32(step),
        train_nll.detach().cpu().numpy(),
        train_acc.detach().cpu().numpy(),
        train_penalty.detach().cpu().numpy(),
        test_acc.detach().cpu().numpy()
      )

  final_train_accs.append(train_acc.detach().cpu().numpy())
  final_test_accs.append(test_acc.detach().cpu().numpy())
  print('Final train acc (mean/std across restarts so far):')
  print(np.mean(final_train_accs), np.std(final_train_accs))
  print('Final test acc (mean/std across restarts so far):')
  print(np.mean(final_test_accs), np.std(final_test_accs))
  
  # Save results
  np.savez('results_irm.npz',
           train_success_all_irm=100*np.array(final_train_accs),
           test_success_all_irm=100*np.array(final_test_accs))
