# Standard imports
import argparse
import gc
import os
from pathlib import Path
from tqdm import tqdm
from sys import exit
import wandb
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.utils
from torch.nn.utils import weight_norm

# Our imports
import utils

# =======
# Data
# =======

def get_batch_simple(num_samples=128, sample_len=20, one_hot='False', memory_len=10):

    X = np.zeros((num_samples, sample_len + memory_len - 1))
    data = np.random.randint(low = 1, high = 9, size = (num_samples, memory_len))
    X[:, :memory_len] = data
    # X[:, -(memory_len + 1)] = 9 # removing indicator 

    Y = np.zeros((num_samples, sample_len + memory_len - 1))
    Y[:, -memory_len:] = data

    if utils.str_to_bool(one_hot) == True:
        x_out = F.one_hot(torch.tensor(X, dtype=torch.int64).permute(1,0), 10).float()
        # x_out[memory_len:, ...] = 0
        return x_out, torch.tensor(Y).int().permute(1,0).unsqueeze(-1)
    else:
        return torch.tensor(X).float().permute(1,0).unsqueeze(-1), torch.tensor(Y).int().permute(1, 0).unsqueeze(-1)

def copy_selective(sample_len=20, memory_len=10, garbage_len=5):
    data = np.random.randint(low = 1, high = 9, size = memory_len)
    positions = np.sort(np.array(np.random.choice(sample_len-garbage_len-1, memory_len, replace=False)))

    X = np.zeros(sample_len + memory_len + garbage_len)
    X[positions] = data
    # indicator for model to begin outputting memorized tokens
    # X[-(memory_len + 1)] = 9

    Y = np.zeros(sample_len + memory_len + garbage_len)
    Y[-memory_len:] = data

    return torch.tensor(X).float().unsqueeze(-1), torch.tensor(Y).long().unsqueeze(-1)

def batch_copy_selective(num_samples=128, sample_len=20, memory_len=10, func=copy_selective):
  batch_x, batch_y = [], []
  for i in range(num_samples):
    x, y = func(sample_len, memory_len)
    batch_x.append(x)
    batch_y.append(y)
  return torch.stack(batch_x, dim=1), torch.stack(batch_y, dim=1)
