import argparse
import time
import torch
import torch.nn
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import copy


from lm import repackage_hidden, LM_LSTM
import reader
import numpy as np

parser = argparse.ArgumentParser(description='Simplest LSTM-based language model in PyTorch')
parser.add_argument('--data', type=str, default='/path/to/data',
                    help='location of the data corpus')
parser.add_argument('--hidden_size', type=int, default=1500,
                    help='size of word embeddings')
parser.add_argument('--num_steps', type=int, default=35,
                    help='number of LSTM steps')
parser.add_argument('--num_layers', type=int, default=2,
                    help='number of LSTM layers')
parser.add_argument('--batch_size', type=int, default=20,
                    help='batch size')
parser.add_argument('--num_epochs', type=int, default=200,
                    help='number of epochs')
parser.add_argument('--dp_keep_prob', type=float, default=0.35,
                    help='dropout *keep* probability')
parser.add_argument('--initial_lr', type=float, default=1.0,
                    help='initial learning rate')
parser.add_argument('--save', type=str,  default='/path/to/models/lm_model_splitsgd.pt',
                    help='path to save the final model')
args = parser.parse_args()

criterion = nn.CrossEntropyLoss()


def run_test_epoch(model, data):
    """Runs the model on the given data."""
    model.eval()
    hidden = model.init_hidden()
    costs = 0.0
    iters = 0
    for step, (x, y) in enumerate(reader.ptb_iterator(data, model.batch_size, model.num_steps)):
        inputs = Variable(torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
        targets = Variable(torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        outputs, hidden = model(inputs, hidden)

        tt = torch.squeeze(targets.view(-1, model.batch_size * model.num_steps))
        loss = criterion(outputs.view(-1, model.vocab_size), tt)
        costs += loss.data.item() * model.num_steps
        iters += model.num_steps

    return np.exp(costs / iters)


def split_training(data, test_data, init_lr):
    B = 40
    # t1 = 1
    t1 = 9
    q = 0.25
    # q = 0.2  # now is the PROPORTION of negative dot products
    gamma = 0.5
    w = 4
    lr = init_lr

    net_splitsgd = LM_LSTM(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
                           vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
    net_splitsgd.cuda()
    optimizer_splitsgd = optim.Adam(net_splitsgd.parameters(), lr=lr)

    batches = list(enumerate(reader.ptb_iterator(data, net_splitsgd.batch_size, net_splitsgd.num_steps)))
    l = int((len(batches) / 2) // w)
    print('l', l)
    real_epoch = 0
    training_loss_splitsgd = []
    train_perplexity_splitsgd = []
    test_perplexity_splitsgd = []
    for b in range(B):
        for t in range(t1):
            running_loss = 0
            costs = 0.0
            iters = 0
            hidden = net_splitsgd.init_hidden()
            net_splitsgd.train()
            for step, (x, y) in batches:
                inputs = Variable(torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                targets = Variable(torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                hidden = repackage_hidden(hidden)
                optimizer_splitsgd.zero_grad()
                outputs, hidden = net_splitsgd(inputs, hidden)
                tt = torch.squeeze(targets.view(-1, net_splitsgd.batch_size * net_splitsgd.num_steps))
                loss = criterion(outputs.view(-1, net_splitsgd.vocab_size), tt)
                loss.backward()
                optimizer_splitsgd.step()

                running_loss += loss.item()
                costs += loss.data.item() * net_splitsgd.num_steps
                iters += net_splitsgd.num_steps
            real_epoch += 1
            train_perplexity = np.exp(costs / iters)
            test_perplexity = run_test_epoch(net_splitsgd, test_data)
            training_loss_splitsgd.append(running_loss)
            train_perplexity_splitsgd.append(train_perplexity)
            test_perplexity_splitsgd.append(test_perplexity)
            print('epoch: %d, training loss: %f, train perplexity: %f, test perplexity: %f, learning rate: %f' %
                  (real_epoch, running_loss, train_perplexity, test_perplexity, lr))

        net1 = LM_LSTM(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
                       vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
        net1.cuda()
        net1.load_state_dict(net_splitsgd.state_dict())
        optimizer1 = torch.optim.Adam(net1.parameters(), lr=lr)
        optimizer1.load_state_dict(optimizer_splitsgd.state_dict())

        net2 = LM_LSTM(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
                       vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
        net2.cuda()
        net2.load_state_dict(net_splitsgd.state_dict())
        optimizer2 = torch.optim.Adam(net2.parameters(), lr=lr)
        optimizer2.load_state_dict(optimizer_splitsgd.state_dict())

        # Copy the two net so we can get back the parameters
        init_params1 = copy.deepcopy(net1)
        init_params2 = copy.deepcopy(net2)

        hidden1 = net1.init_hidden()
        hidden2 = net2.init_hidden()

        dot_prod = []
        net1.train()
        net2.train()
        for i in range(w):
            for j in range(l):
                step, (x, y) = batches[i * l + j]
                inputs = Variable(torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                targets = Variable(torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                hidden1 = repackage_hidden(hidden1)
                optimizer1.zero_grad()
                outputs, hidden1 = net1(inputs, hidden1)
                tt = torch.squeeze(targets.view(-1, net_splitsgd.batch_size * net_splitsgd.num_steps))
                loss = criterion(outputs.view(-1, net_splitsgd.vocab_size), tt)
                loss.backward()
                optimizer1.step()
            for j in range(l):
                step, (x, y) = batches[w * l + i * l + j]
                inputs = Variable(torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                targets = Variable(torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                hidden2 = repackage_hidden(hidden2)
                optimizer2.zero_grad()
                outputs, hidden2 = net2(inputs, hidden2)
                tt = torch.squeeze(targets.view(-1, net_splitsgd.batch_size * net_splitsgd.num_steps))
                loss = criterion(outputs.view(-1, net_splitsgd.vocab_size), tt)
                loss.backward()
                optimizer2.step()

            fin_params1 = net1.state_dict()
            fin_params2 = net2.state_dict()

            for param_tensor in dict(net1.named_parameters()).keys():
                p1 = fin_params1[param_tensor] - init_params1.state_dict()[param_tensor]
                p2 = fin_params2[param_tensor] - init_params2.state_dict()[param_tensor]
                dot_prod.append(torch.sum(p1 * p2))

            init_params1 = copy.deepcopy(net1)
            init_params2 = copy.deepcopy(net2)

        training_loss_splitsgd.append(training_loss_splitsgd[-1])
        train_perplexity_splitsgd.append(train_perplexity_splitsgd[-1])

        stationarity = sum([dot_prod[i].item() < 0 for i in range(len(dot_prod))]) >= q * len(dot_prod)
        if stationarity:
            lr = lr * gamma

        net_splitsgd = LM_LSTM(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
                               vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
        net_splitsgd.cuda()
        beta = 0.5
        params1 = net1.state_dict()
        params2 = net2.state_dict()
        for name1 in params1.keys():
            if name1 in params2.keys():
                params2[name1].data.copy_(beta * params1[name1].data + (1 - beta) * params2[name1].data)

        net_splitsgd.load_state_dict(params2)
        optimizer_splitsgd = torch.optim.Adam(net_splitsgd.parameters(), lr=lr)
        real_epoch += 1

        test_perplexity = run_test_epoch(net_splitsgd, test_data)
        test_perplexity_splitsgd.append(test_perplexity)

        print('D -> epoch: %d, test accuracy: %f, stationarity: %s, negative dot products: %i out of %i, '
              'learning rate: %f' % (real_epoch, test_perplexity, bool(stationarity),
                                     sum([dot_prod[i].item() < 0 for i in range(len(dot_prod))]), len(dot_prod), lr))
    return net_splitsgd


if __name__ == "__main__":
    raw_data = reader.ptb_raw_data(data_path=args.data)
    train_data, valid_data, test_data, word_to_id, id_2_word = raw_data
    vocab_size = len(word_to_id)
    print('Vocabulary size: {}'.format(vocab_size))
    lr = args.initial_lr
    print("########## Training ##########################")
    net_splitsgd = split_training(train_data, test_data,  args.initial_lr)
    print("########## Testing ##########################")
    print('Test Perplexity: {:8.2f}'.format(run_test_epoch(net_splitsgd, test_data)))
    with open(args.save, 'wb') as f:
        torch.save(net_splitsgd, f)
    print("########## Done! ##########################")
