from model import CNN
from butterflenet import ButterfLeNet1D
import utils

from torch.autograd import Variable
import torch
import torch.optim as optim
import torch.nn as nn

from sklearn.utils import shuffle
from gensim.models.keyedvectors import KeyedVectors
import numpy as np
import argparse
import copy


def train(data, params, sgdr=False, offline=False):
    if params["MODEL"] != "rand":
        # load word2vec
        print("loading word2vec...")
        word_vectors = KeyedVectors.load_word2vec_format("GoogleNews-vectors-negative300.bin", binary=True)

        wv_matrix = []
        for i in range(len(data["vocab"])):
            word = data["idx_to_word"][i]
            if word in word_vectors.vocab:
                wv_matrix.append(word_vectors.word_vec(word))
            else:
                wv_matrix.append(np.random.uniform(-0.01, 0.01, 300).astype("float32"))

        # one for UNK and one for zero padding
        wv_matrix.append(np.random.uniform(-0.01, 0.01, 300).astype("float32"))
        wv_matrix.append(np.zeros(300).astype("float32"))
        wv_matrix = np.array(wv_matrix)
        params["WV_MATRIX"] = wv_matrix

    sgdr_steps = 1
    if sgdr:
        sgdr_steps = 3

    if offline:
        sgdr_steps += 1

    for sgdr_iter in range(sgdr_steps):
        is_offline = (offline and (sgdr_iter == sgdr_steps - 1))

        if is_offline:
            print("OFFLINE EVALUATION")
            newmodel = CNN(**params).cuda(params["GPU"])
            newmodel.load_params(model, requires_grad=False)
            model = newmodel

        elif sgdr_iter > 0:
            print("SGDR RESTART")
            newmodel = CNN(**params).cuda(params["GPU"])
            newmodel.load_params(model, requires_grad=True)
            model = newmodel

        else:
            model = CNN(**params).cuda(params["GPU"])

        parameters = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adadelta(parameters, params["LEARNING_RATE"])
        criterion = nn.CrossEntropyLoss()

        pre_dev_acc = 0
        max_dev_acc = 0
        max_test_acc = 0
        max_dev_acc_offline = 0
        max_test_acc_offline = 0
        p_improvement = 0
        patience = 3
        for e in range(params["EPOCH"]):
            data["train_x"], data["train_y"] = shuffle(data["train_x"], data["train_y"])

            for i in range(0, len(data["train_x"]), params["BATCH_SIZE"]):
                batch_range = min(params["BATCH_SIZE"], len(data["train_x"]) - i)

                batch_x = [[data["word_to_idx"][w] for w in sent] +
                        [params["VOCAB_SIZE"] + 1] * (params["MAX_SENT_LEN"] - len(sent))
                        for sent in data["train_x"][i:i + batch_range]]
                batch_y = [data["classes"].index(c) for c in data["train_y"][i:i + batch_range]]

                batch_x = Variable(torch.LongTensor(batch_x)).cuda(params["GPU"])
                batch_y = Variable(torch.LongTensor(batch_y)).cuda(params["GPU"])

                optimizer.zero_grad()
                model.train()
                pred = model(batch_x)
                loss = criterion(pred, batch_y)
                loss.backward()
                nn.utils.clip_grad_norm(parameters, 
                    max_norm=params["NORM_LIMIT"])
                optimizer.step()

            if is_offline:
                dev_acc_offline = test(data, model, params, mode="dev")
                test_acc_offline = test(data, model, params)
                print("epoch:", e + 1, "/ dev_acc:", dev_acc_offline, "/ test_acc:", test_acc_offline)

                if params["EARLY_STOPPING"] and dev_acc_offline <= pre_dev_acc:
                    p_improvement += 1
                    if p_improvement >= patience:
                        print("early stopping by dev_acc!")
                        break
                else:
                    p_improvement = 0
                    pre_dev_acc = dev_acc_offline

                if dev_acc_offline > max_dev_acc_offline:
                    max_dev_acc_offline = dev_acc_offline
                    max_test_acc_offline = test_acc_offline
                    best_model = copy.deepcopy(model)
            else:
                dev_acc = test(data, model, params, mode="dev")
                test_acc = test(data, model, params)
                print("epoch:", e + 1, "/ dev_acc:", dev_acc, "/ test_acc:", test_acc)

                if params["EARLY_STOPPING"] and dev_acc <= pre_dev_acc:
                    p_improvement += 1
                    if p_improvement >= patience:
                        print("early stopping by dev_acc!")
                        break
                else:
                    p_improvement = 0
                    pre_dev_acc = dev_acc

                if dev_acc > max_dev_acc:
                    max_dev_acc = dev_acc
                    max_test_acc = test_acc
                    best_model = copy.deepcopy(model)

        print("max dev acc:", max_dev_acc, "test acc:", max_test_acc)
        print("offline test acc:", max_test_acc_offline)

    return best_model


def test(data, model, params, mode="test"):
    model.eval()

    correct = 0.0
    total = 0.0

    if mode == "dev":
        x, y = data["dev_x"], data["dev_y"]
    elif mode == "test":
        x, y = data["test_x"], data["test_y"]

    for i in range(0, len(x), params["BATCH_SIZE"]):
        batch_range = min(params["BATCH_SIZE"], len(x) - i)
        batch_x = [
            [data["word_to_idx"][w] if w in data["vocab"] else params["VOCAB_SIZE"] for w in sent] + [params["VOCAB_SIZE"] + 1] * (params["MAX_SENT_LEN"] - len(sent))
                    for sent in x[i:i + batch_range]]

        batch_x = Variable(torch.LongTensor(batch_x)).cuda(params["GPU"])
        batch_y = [data["classes"].index(c) for c in y[i:i + batch_range]]

        pred = np.argmax(model(batch_x).cpu().data.numpy(), axis=1)
        total += len(pred)
        correct += sum([1 if p == batch_y else 0 for p, batch_y in zip(pred, batch_y)])

    acc = correct / total
    return acc


def main():
    parser = argparse.ArgumentParser(description="-----[CNN-classifier]-----")
    parser.add_argument("--mode", default="train", help="train: train (with test) a model / test: test saved models")
    parser.add_argument("--model", default="rand", help="available models: rand, static, non-static, multichannel")
    parser.add_argument("--dataset", default="TREC", help="available datasets: MR, TREC, SST1, SST2, SUBJ, CR, MPQA")
    parser.add_argument("--save_model", default=False, action='store_true', help="whether saving model or not")
    parser.add_argument("--early_stopping", default=False, action='store_true', help="whether to apply early stopping")
    parser.add_argument("--epoch", default=100, type=int, help="number of max epoch")
    parser.add_argument("--learning_rate", default=1.0, type=float, help="learning rate")
    parser.add_argument("--gpu", default=-1, type=int, help="the number of gpu to be used")
    parser.add_argument("--kop", default=False, action='store_true', help="whether to use K-operation or not")
    parser.add_argument("--tied", default=False, action='store_true', help="whether to tie K-operation or not")
    parser.add_argument("--warm_start", default=False, action='store_true', help="whether to warm start K-operation as conv or not")
    parser.add_argument("--fc", default=False, action='store_true', help="whether to use fully connected or not")
    parser.add_argument("--sgdr", default=False, action='store_true', help="whether to use SGDR or not")
    parser.add_argument("--offline", default=False, action='store_true', help="whether or not to perform offline eval at the end of training")

    options = parser.parse_args()
    data = getattr(utils, f"read_{options.dataset}")()

    data["vocab"] = sorted(list(set([w for sent in data["train_x"] + data["dev_x"] + data["test_x"] for w in sent])))
    data["classes"] = sorted(list(set(data["train_y"])))
    data["word_to_idx"] = {w: i for i, w in enumerate(data["vocab"])}
    data["idx_to_word"] = {i: w for i, w in enumerate(data["vocab"])}

    params = {
        "MODEL": options.model,
        "DATASET": options.dataset,
        "SAVE_MODEL": options.save_model,
        "EARLY_STOPPING": options.early_stopping,
        "EPOCH": options.epoch,
        "LEARNING_RATE": options.learning_rate,
        "KOP": options.kop,
        "TIED": options.tied,
        "WARM_START": options.warm_start,
        "FC": options.fc,
        "MAX_SENT_LEN": max([len(sent) for sent in data["train_x"] + data["dev_x"] + data["test_x"]]),
        "BATCH_SIZE": 50,
        "WORD_DIM": 300,
        "VOCAB_SIZE": len(data["vocab"]),
        "CLASS_SIZE": len(data["classes"]),
        "FILTERS": [3, 4, 5],
        "FILTER_NUM": [100, 100, 100],
        "DROPOUT_PROB": 0.5,
        "NORM_LIMIT": 1,
        "GPU": options.gpu,
        "SGDR": options.sgdr
    }

    print("=" * 20 + "INFORMATION" + "=" * 20)
    print("MODEL:", params["MODEL"])
    print("DATASET:", params["DATASET"])
    print("VOCAB_SIZE:", params["VOCAB_SIZE"])
    print("EPOCH:", params["EPOCH"])
    print("LEARNING_RATE:", params["LEARNING_RATE"])
    print("EARLY_STOPPING:", params["EARLY_STOPPING"])
    print("SAVE_MODEL:", params["SAVE_MODEL"])
    print("SGDR:", params["SGDR"])
    print("=" * 20 + "INFORMATION" + "=" * 20)

    if options.mode == "train":
        print("=" * 20 + "TRAINING STARTED" + "=" * 20)
        model = train(data, params, sgdr=options.sgdr, offline=options.offline)
        if params["SAVE_MODEL"]:
            utils.save_model(model, params)
        print("=" * 20 + "TRAINING FINISHED" + "=" * 20)
    else:
        model = utils.load_model(params).cuda(params["GPU"])

        test_acc = test(data, model, params)
        print("test acc:", test_acc)


if __name__ == "__main__":
    main()
