import torch
import torch.nn as nn

import numpy as np

import argparse
import os

import pickle

from timit_loader import TIMIT

import torchdiffeq
#from torchdiffeq import odeint_adjoint as odeint
from torchdiffeq import odeint as odeint

parser = argparse.ArgumentParser(description='Sequence Modeling - TIMIT')

parser.add_argument('--cpu',
                    action='store_true',
                    default=False,
                    help='use CPU')
parser.add_argument('--epochs',
                    type=int,
                    default=1200,
                    help='number of epochs to train (default: 1200)')
parser.add_argument('--optimizer',
                    type=str,
                    default='Adam',
                    choices=['Adam', 'RMSprop', 'RMSAdam'],
                    help='optimizer: choices Adam and RMSprop')
parser.add_argument('--wd',
                    type=float,
                    default=0.0,
                    help='weight decay value (default: 0.0)')
parser.add_argument('--momentum',
                    type=float,
                    default=0.0,
                    help='momentum value (default: 0.0)')
parser.add_argument('--lr',
                    type=float,
                    default=1e-3,
                    help='learning rate (default: 0.001)')
parser.add_argument('--lr_decay',
                    type=float,
                    default=0.5,
                    help='learning rate decay value (default: 0.1)')
parser.add_argument('--clip',
                    type=float,
                    default=15,
                    help='gradient clip, -1 means no clip (default: 15)')
parser.add_argument('--hidden_size', type=int, default=224)
parser.add_argument('--hidden_layers', type=int, default=1)
parser.add_argument('--batch_size',
                    type=int,
                    default=128,
                    help='batch size (default: 128)')
parser.add_argument('--logsdir',
                    type=str,
                    default='logs/timit_task/',
                    help='folder for storing logs')
parser.add_argument('--seed', type=int, default=5544)
parser.add_argument('--opt',
                    type=str,
                    default='Adam',
                    help='SGD, Adam, RMSprop, Momentum')

# Ours params
parser.add_argument('--eps', type=float, default=1)
parser.add_argument('--alpha', type=float, default=1.0)
parser.add_argument('--beta', type=float, default=0.7)
parser.add_argument('--eta', type=float, default=0.1)
parser.add_argument('--gamma', type=float, default=0.0)
parser.add_argument('--gamma_W', type=float, default=0.0001)
parser.add_argument('--init', type=float, default=1)
parser.add_argument('--method', type=str, default='midpoint')

args = parser.parse_args()
print(args)

# Get GPU/CPU as device
if args.cpu is not True and torch.cuda.get_device_capability(
        torch.device('cuda'))[0] >= 3.5 and torch.cuda.device_count() > 0:
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
args.device = device


# Fixed network parameters.
m = 129
k = args.hidden_size

# Set the seed of PRNG manually for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)


def masked_loss(lossfunc, logits, y, lens):
    """ Computes the loss of the first `lens` items in the batches """
    mask = torch.zeros_like(logits, dtype=torch.bool)
    for i, l in enumerate(lens):
        mask[i, :l, :] = 1
    logits_masked = torch.masked_select(logits, mask)
    y_masked = torch.masked_select(y, mask)
    return lossfunc(logits_masked, y_masked)


def gaussian_init_(n_units, std=1):
    sampler = torch.distributions.Normal(torch.Tensor([0]),
                                         torch.Tensor([std / n_units]))
    A_init = sampler.sample((n_units, n_units))[..., 0]
    return A_init


def get_device():
    """Get a gpu if available."""
    if torch.cuda.device_count() > 0:
        device = torch.device('cuda')
        print("Connected to a GPU")
    else:
        print("Using the CPU")
        device = torch.device('cpu')
    return device


def which_device(model):
    return next(model.parameters()).device


class LipschitzRNN_ODE(nn.Module):
    """The derivative of the continuous-time RNN, to plug into an integrator."""

    def __init__(self, n_units, beta, gamma, gamma_W, alpha, init_std):
        super(LipschitzRNN_ODE, self).__init__()
        self.device = get_device()

        self.gamma = gamma
        self.gamma_W = gamma_W

        self.beta = beta
        self.alpha = alpha

        self.tanh = nn.Tanh()

        self.z = torch.zeros(n_units)
        self.C = nn.Parameter(gaussian_init_(n_units, std=init_std))
        self.B = nn.Parameter(gaussian_init_(n_units, std=init_std))
        self.I = torch.eye(n_units).to(self.device)
        self.i = 0

    def forward(self, t, h):
        """dh/dt as a function of time and h(t)."""
        if self.i == 0:
            self.A = self.beta * (self.B - self.B.transpose(1, 0)) + (
                1 - self.beta) * (self.B +
                                  self.B.transpose(1, 0)) - self.gamma * self.I
            self.W = self.beta * (self.C - self.C.transpose(1, 0)) + (
                1 - self.beta) * (
                    self.C + self.C.transpose(1, 0)) - self.gamma_W * self.I

        return self.alpha * torch.matmul(
            h, self.A) + self.tanh(torch.matmul(h, self.W) + self.z)


class Model(nn.Module):
    def __init__(self,
                 m,
                 k,
                 init_std=6,
                 alpha=0.1,
                 beta=0.85,
                 eta=0.1,
                 gamma=0.1,
                 gamma_W=0.0001,
                 eps=0.001,
                 method='midpoint'):
        super(Model, self).__init__()

        self.m = m
        self.k = k

        self.tanh = nn.Tanh()

        self.gamma = gamma
        self.gamma_W = gamma_W

        self.beta = beta
        self.alpha = alpha
        self.eps = eps
        self.eta = eta

        self.method = method

        self.func = LipschitzRNN_ODE(k, beta, gamma, gamma_W, alpha, init_std)

        self.U = nn.Linear(m, k)
        self.V = nn.Linear(k, m)

        self.loss_func = nn.MSELoss()

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_normal_(self.V.weight.data, nonlinearity="relu")
        nn.init.constant_(self.V.bias.data, 0)

    def forward(self, inputs):

        h = torch.zeros(inputs.shape[0], self.k).to('cuda')
        outputs = []

        i = 0
        T = 490
        #print(self.eps)
        for input in torch.unbind(inputs, dim=1):

            self.func.z = self.U(input)
            self.func.i = i
            h = odeint(self.func,
                       h,
                       torch.tensor([0, self.eps]).float(),
                       method=self.method)[-1, :, :]
            i += 1

            outputs.append(self.V(h))

        return torch.stack(outputs, dim=1)

    def loss(self, logits, y, len_batch):
        return masked_loss(self.loss_func, logits, y, len_batch)


def load_timit_data():
    kwargs = {'num_workers': 1, 'pin_memory': True}
    train_loader = torch.utils.data.DataLoader(TIMIT('data/',
                                                     mode="train"),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    # Load test and val in one big batch
    test_loader = torch.utils.data.DataLoader(TIMIT('data/', mode="test"),
                                              batch_size=400,
                                              shuffle=True,
                                              **kwargs)
    val_loader = torch.utils.data.DataLoader(TIMIT('data/', mode="val"),
                                             batch_size=192,
                                             shuffle=True,
                                             **kwargs)

    return train_loader, val_loader, test_loader


def main():

    train_loader, val_loader, test_loader = load_timit_data()

    # Model and optimizers
    model = Model(m,
                  k,
                  init_std=args.init,
                  alpha=args.alpha,
                  beta=args.beta,
                  eta=args.eta,
                  gamma=args.gamma,
                  gamma_W=args.gamma_W,
                  eps=args.eps,
                  method=args.method).to(device)
    model.train()

    print('**** Setup ****')
    n_params = sum(p.numel() for p in model.parameters())
    print('Total params: %.2fk ; %.2fM' %
          (n_params * 10**-3, n_params * 10**-6))
    print('************')

    if args.opt == 'RMSprop':
        opt = torch.optim.RMSprop(model.parameters(),
                                  lr=args.lr,
                                  alpha=0.99,
                                  momentum=0.0)
    elif args.opt == 'Adam':
        opt = torch.optim.Adam(model.parameters(), lr=args.lr)

    best_test = 1e7
    best_validation = 1e7
    val_history = []
    test_history = []
    for ep in range(1, args.epochs + 1):

        processed = 0
        step = 1

        for batch_idx, (batch_x, batch_y,
                        len_batch) in enumerate(train_loader):
            batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(
                device), len_batch.to(device)

            opt.zero_grad()

            logits = model(batch_x)
            mse_loss = model.loss(logits, batch_y, len_batch)

            loss = mse_loss

            loss.backward()

            if args.clip > 0:
                nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            opt.step()

            processed += len(batch_x)
            step += 1

        print("Epoch {}, LR {:.5f} \tLoss: {:.2f} ".format(
            ep, opt.param_groups[0]['lr'], loss))

        model.eval()
        with torch.no_grad():
            # There's just one batch for test and validation
            for batch_x, batch_y, len_batch in test_loader:
                batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(
                    device), len_batch.to(device)
                logits = model(batch_x)
                loss_test = model.loss(logits, batch_y, len_batch)

            for batch_x, batch_y, len_batch in val_loader:
                batch_x, batch_y, len_batch = batch_x.to(device), batch_y.to(
                    device), len_batch.to(device)
                logits = model(batch_x)
                loss_val = model.loss(logits, batch_y, len_batch)

            if loss_val < best_validation:
                best_validation = loss_val.item()
                best_test = loss_test.item()
            # else:
            # 	# scheduler
            # 	sched.step()
            # 	if orth_opt:
            # 		orth_sched.step()

        print()
        print("Val:  Loss: {:.2f}\tBest: {:.2f}".format(
            loss_val, best_validation))
        print("Test: Loss: {:.2f}\tBest: {:.2f}".format(loss_test, best_test))
        print()
        val_history.append(loss_val.item())
        test_history.append(loss_test.item())
        model.train()

    stats = (best_validation, best_test)

    data = {'val': val_history, 'test': test_history}
    f = open(
        './logs/timit_task/' + 'LipschitzRNN' + '_beta' + str(args.beta) +
        '_gamma' + str(args.gamma) + '_eps' + str(args.eps) + '_init' +
        str(args.init) + '_seed' + str(args.seed) + '_loss.pkl', "wb")
    pickle.dump(data, f)


if __name__ == "__main__":
    main()
