import argparse
import time
import copy
import torch
import torch.nn
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import copy
import numpy as np


from lm import repackage_hidden, LM_LSTM
import reader
from utils import set_random_seed

parser = argparse.ArgumentParser(description='Simplest LSTM-based language model in PyTorch')
parser.add_argument('--data', type=str, default='/path/to/data',
                    help='location of the data corpus')
parser.add_argument('--hidden_size', type=int, default=1500,
                    help='size of word embeddings')
parser.add_argument('--num_steps', type=int, default=35,
                    help='number of LSTM steps')
parser.add_argument('--num_layers', type=int, default=2,
                    help='number of LSTM layers')
parser.add_argument('--batch_size', type=int, default=20,
                    help='batch size')
parser.add_argument('--num_epochs', type=int, default=200,
                    help='number of epochs')
parser.add_argument('--dp_keep_prob', type=float, default=0.35,
                    help='dropout *keep* probability')
parser.add_argument('--initial_lr', type=float, default=1.0,
                    help='initial learning rate')
parser.add_argument('--save', type=str,  default='/path/to/models/lm_model_splitsgd.pt',
                    help='path to save the final model')
args = parser.parse_args()

criterion = nn.CrossEntropyLoss()


def run_test_epoch(model, data):
    """Runs the model on the given data."""
    model.batch_size = 1
    model.eval()
    hidden = model.init_hidden()
    costs = 0.0
    iters = 0
    for step, (x, y) in enumerate(reader.ptb_iterator(data, model.batch_size, model.num_steps)):
        inputs = Variable(torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
        targets = Variable(torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        outputs, hidden = model(inputs, hidden)

        tt = torch.squeeze(targets.view(-1, model.batch_size * model.num_steps))
        loss = criterion(outputs.view(-1, model.vocab_size), tt)
        costs += loss.data.item() * model.num_steps
        iters += model.num_steps
    model.batch_size = args.batch_size

    return np.exp(costs / iters)


def split_training(data, test_data, init_lr, mom=0.9):
    B = 40
    # t1 = 1
    t1 = 9
    q = 0.25
    # q = 0.2  # now is the PROPORTION of negative dot products
    gamma = 0.5
    w = 4
    # diagnose_ratio = 0.5
    print('q', q)
    lr = init_lr

    net_splitsgd = LM_LSTM(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
                           vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
    net_splitsgd.cuda()
    optimizer_splitsgd = optim.SGD(net_splitsgd.parameters(), lr=lr, momentum=mom)

    batches = list(enumerate(reader.ptb_iterator(data, net_splitsgd.batch_size, net_splitsgd.num_steps)))
    # original_batches = copy.deepcopy(batches)
    # random_batches = copy.deepcopy(batches)
    print('batches', len(batches))
    # l = int((len(batches) * diagnose_ratio / 2) // w)
    l = int((len(batches) / 2) // w)
    print('l', l)
    real_epoch = 0
    training_loss_splitsgd = []
    train_perplexity_splitsgd = []
    test_perplexity_splitsgd = []
    # epoch_num_list = []
    # cur_epoch_num = 0
    for b in range(B):
        for t in range(t1):
            running_loss = 0
            costs = 0.0
            iters = 0
            hidden = net_splitsgd.init_hidden()
            net_splitsgd.train()
            for step, (x, y) in batches:
                inputs = Variable(torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                targets = Variable(torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                hidden = repackage_hidden(hidden)
                optimizer_splitsgd.zero_grad()
                outputs, hidden = net_splitsgd(inputs, hidden)
                tt = torch.squeeze(targets.view(-1, net_splitsgd.batch_size * net_splitsgd.num_steps))
                loss = criterion(outputs.view(-1, net_splitsgd.vocab_size), tt)
                loss.backward()
                optimizer_splitsgd.step()

                running_loss += loss.item()
                costs += loss.data.item() * net_splitsgd.num_steps
                iters += net_splitsgd.num_steps
            real_epoch += 1
            train_perplexity = np.exp(costs / iters)
            test_perplexity = run_test_epoch(net_splitsgd, test_data)
            training_loss_splitsgd.append(running_loss)
            train_perplexity_splitsgd.append(train_perplexity)
            test_perplexity_splitsgd.append(test_perplexity)
            # cur_epoch_num += 1
            # epoch_num_list.append(cur_epoch_num)
            # print('cur_epoch_num', cur_epoch_num)
            print('epoch: %d, training loss: %f, train perplexity: %f, test perplexity: %f, learning rate: %f' %
                  (real_epoch, running_loss, train_perplexity, test_perplexity, lr))

        net1 = LM_LSTM(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
                       vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
        net1.cuda()
        net1.load_state_dict(net_splitsgd.state_dict())
        optimizer1 = torch.optim.SGD(net1.parameters(), lr=lr, momentum=mom)
        optimizer1.load_state_dict(optimizer_splitsgd.state_dict())

        net2 = LM_LSTM(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
                       vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
        net2.cuda()
        net2.load_state_dict(net_splitsgd.state_dict())
        optimizer2 = torch.optim.SGD(net2.parameters(), lr=lr, momentum=mom)
        optimizer2.load_state_dict(optimizer_splitsgd.state_dict())

        # Copy the two net so we can get back the parameters
        init_params1 = copy.deepcopy(net1)
        init_params2 = copy.deepcopy(net2)

        hidden1 = net1.init_hidden()
        hidden2 = net2.init_hidden()

        dot_prod = []
        net1.train()
        net2.train()
        # np.random.shuffle(random_batches)
        # batches = copy.deepcopy(random_batches)
        for i in range(w):
            for j in range(l):
                step, (x, y) = batches[i * l + j]
                inputs = Variable(torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                targets = Variable(torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                hidden1 = repackage_hidden(hidden1)
                optimizer1.zero_grad()
                outputs, hidden1 = net1(inputs, hidden1)
                tt = torch.squeeze(targets.view(-1, net_splitsgd.batch_size * net_splitsgd.num_steps))
                loss = criterion(outputs.view(-1, net_splitsgd.vocab_size), tt)
                loss.backward()
                optimizer1.step()
            for j in range(l):
                step, (x, y) = batches[w * l + i * l + j]
                inputs = Variable(torch.from_numpy(x.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                targets = Variable(torch.from_numpy(y.astype(np.int64)).transpose(0, 1).contiguous()).cuda()
                hidden2 = repackage_hidden(hidden2)
                optimizer2.zero_grad()
                outputs, hidden2 = net2(inputs, hidden2)
                tt = torch.squeeze(targets.view(-1, net_splitsgd.batch_size * net_splitsgd.num_steps))
                loss = criterion(outputs.view(-1, net_splitsgd.vocab_size), tt)
                loss.backward()
                optimizer2.step()

            fin_params1 = net1.state_dict()
            fin_params2 = net2.state_dict()

            for param_tensor in dict(net1.named_parameters()).keys():
                p1 = fin_params1[param_tensor] - init_params1.state_dict()[param_tensor]
                p2 = fin_params2[param_tensor] - init_params2.state_dict()[param_tensor]
                dot_prod.append(torch.sum(p1 * p2))

            init_params1 = copy.deepcopy(net1)
            init_params2 = copy.deepcopy(net2)

        training_loss_splitsgd.append(training_loss_splitsgd[-1])
        train_perplexity_splitsgd.append(train_perplexity_splitsgd[-1])

        stationarity = sum([dot_prod[i].item() < 0 for i in range(len(dot_prod))]) >= q * len(dot_prod)
        if stationarity:
            lr = lr * gamma
            # t1 = int(t1 / gamma)
            # t1 = 1

        net_splitsgd = LM_LSTM(embedding_dim=args.hidden_size, num_steps=args.num_steps, batch_size=args.batch_size,
                               vocab_size=vocab_size, num_layers=args.num_layers, dp_keep_prob=args.dp_keep_prob)
        net_splitsgd.cuda()
        beta = 0.5
        params1 = net1.state_dict()
        params2 = net2.state_dict()
        for name1 in params1.keys():
            if name1 in params2.keys():
                params2[name1].data.copy_(beta * params1[name1].data + (1 - beta) * params2[name1].data)

        net_splitsgd.load_state_dict(params2)
        optimizer_splitsgd = torch.optim.SGD(net_splitsgd.parameters(), lr=lr, momentum=mom)
        real_epoch += 1

        test_perplexity = run_test_epoch(net_splitsgd, test_data)
        test_perplexity_splitsgd.append(test_perplexity)
        # cur_epoch_num += diagnose_ratio
        # epoch_num_list.append(cur_epoch_num)
        # batches = copy.deepcopy(original_batches)
        # print('cur_epoch_num', cur_epoch_num)
        print('D -> epoch: %d, test accuracy: %f, stationarity: %s, negative dot products: %i out of %i, '
              'learning rate: %f' % (real_epoch, test_perplexity, bool(stationarity),
                                     sum([dot_prod[i].item() < 0 for i in range(len(dot_prod))]), len(dot_prod), lr))
    return net_splitsgd


if __name__ == "__main__":
    set_random_seed(666)
    raw_data = reader.ptb_raw_data(data_path=args.data)
    train_data, valid_data, test_data, word_to_id, id_2_word = raw_data
    vocab_size = len(word_to_id)
    print('Vocabulary size: {}'.format(vocab_size))
    lr = args.initial_lr
    print("########## Training ##########################")
    net_splitsgd = split_training(train_data, test_data,  args.initial_lr, mom=0.9)
    print("########## Testing ##########################")
    print('Test Perplexity: {:8.2f}'.format(run_test_epoch(net_splitsgd, test_data)))
    with open(args.save, 'wb') as f:
        torch.save(net_splitsgd, f)
    print("########## Done! ##########################")
