#!/usr/bin/env python3

from torchtext.datasets import *
import utils
from utils import *
# download WikiText2
WikiText2.download('./corpus')
WikiText103.download('./corpus')
import nltk
nltk.download('punkt')

from models import *

import argparse
parser = argparse.ArgumentParser(description='Train / load a word embedding model.')

# note : shell arguments use hyphens, not underscores

parser.add_argument('mode',
                    nargs='+',
                    choices = ["cache","train","resume","val","analogy","sim","classify","craw","cluster","examples","test","vocab","full_eval","to_csv","redundancy"],
                    help="Specify the mode to run."+
                    "You can specify multiple modes by including the keywords.")

parser.add_argument('--comment', type=str,
                    help="This string is simply added to the hyperparameter dict. "+
                    "It makes it possible to distinguish two runs with the same input command -- For example, "+
                    "it is useful to compare the old and new implementations: "
                    "It starts writing to a new log file without overwriting the existing log file or the model file.")

parser.add_argument('--shuffle',    action="store_true",
                    help="if enabled, shuffle the training dataset in each epoch. For a large dataset this could be slow")
parser.add_argument('--subsampling',  action="store_true",
                    help="if enabled, training dataset is sub-sampled for frequently occuring words")
parser.add_argument('--subsampling-threshold',  type=float, default=0.00001,
                    help="value t in Mikolov et.al. 'Distributed Representations of Words and Phrases and their Compositionality'. "+
                    "Smaller value prunes more aggressively.")
parser.add_argument('--epochs',     type=int,   default=5)
parser.add_argument('--start-epoch',     type=int,   default=0,
                    help="treat as if the training starts from this epoch (useful for resuming)")
parser.add_argument('--batch-size', type=int,   default=1000)
parser.add_argument('--lr',         type=float, default=0.01)
parser.add_argument('--context-before',  type=int,   default=2)
parser.add_argument('--context-after',   type=int,   default=2)
parser.add_argument('--embedding',  type=int,   default=50)
parser.add_argument('--initialization',  choices=["gaussian","logistic"], default="gaussian")
parser.add_argument('--negative-sample',  type=int,   default=5)
parser.add_argument('--linear-dim', type=int, default=128,
                    help="Specify the width of the additional linear layer. Only for CBOW2 and NNLM model.")
parser.add_argument('--loss',  type=str,   default="logsigmoiddot",
                    choices = utils.losses.keys(),
                    help="Loss/distance used for comparing the embedding space in the subclass of LastLayerBackwardMixin.")

parser.add_argument('--threads',  type=int,   default=4,
                    help="The number of threads to use while loading and caching the datasets.")

parser.add_argument('--path',       type=str, default="results/")

# extract the class objects
import models
import inspect
model_list = [ name for name, obj in vars(models).items() if (inspect.isclass(obj) and issubclass(obj, Trainer) and (obj is not Trainer)) ]
parser.add_argument('--model',      type=str, default="SkipGram", choices=model_list)
parser.add_argument('--optimizer',  type=str, default="RAdam")

parser.add_argument('--min-occurrence', type=int, default=5,
                    help="threshold for a word to be included in a vocaburary. "+
                    "If a word does not appear more than this threshold, "+
                    "it is treated as <unk>.")
parser.add_argument('--train',  type=str, default='./corpus/Pride-and-Prejudice.txt',
                    help="specify the path for the training text dataset")
parser.add_argument('--valid',   type=str, default='',
                    help="specify the path for the validation text dataset")
parser.add_argument('--test',   type=str, default='',
                    help="specify the path for the test text dataset")
parser.add_argument('--force-reload', action="store_true")

# GS-related
parser.add_argument('--annealing-schedule', type=str, default='exponential')
parser.add_argument('--annealing-max',      type=float, default=2.0, help="initial temperature")
parser.add_argument('--annealing-min',      type=float, default=0.5, help="final temperature")
parser.add_argument('--annealing-start',    type=float, default=0,   help="annealing begins in this epoch.")
parser.add_argument('--annealing-end',      type=float, default=5,   help="annealing ends in this epoch. This does not have to match the total epochs (could be longer or shorter)")
parser.add_argument('--straight-through', action="store_true",
                    help="If present, use the straight-through estimator")
parser.add_argument('--beta', type=float, default=1.0,
                    help="Multiplier for GumbelSoftmax / BinConcrete variational loss")
parser.add_argument('--initial-state', default="zeros",
                    choices = ["random","zeros","ones","half"],
                    help="Initial state in sequential BTL models")
parser.add_argument('--affine', action="store_true", default=False,
                    help="Wheter to use the affine transformation in batchnorm")
parser.add_argument('--noise', action="store_true", default=False,
                    help="Use the gumbel noise in BinConcrete if present.")

print("parsing args...")
args = parser.parse_args()

hyper = vars(args).copy()

postprocess_hyper(hyper)


modelclass = eval(hyper["model"])


print("loading data...")
train = modelclass.datasetclass(input=hyper["train"],**hyper)
valid = modelclass.datasetclass(input=hyper["valid"],vocab=train.vocab,**hyper)

print("building model...")
model = modelclass(hyper, train)

print("running task...")

try:
    if "cache" in args.mode:
        # Dummy task.
        # data is already cached in the datasetclass initialization above; nothing to do.
        pass

    if "train" in args.mode:
        print("task: training...")
        training_timer(model.local("time-memory.json"),
                       lambda: model.loop(train, valid), hyper["epochs"], 
                       model_memory=compute_model_memory(model))
        model.save()

    if "resume" in args.mode:
        print("task: resume...")
        model.load()
        training_timer(model.local("time-memory.json"),
                       lambda: model.loop(train, valid), hyper["epochs"], 
                       model_memory=compute_model_memory(model))
        model.save()

    if "val" in args.mode or "full_eval" in args.mode:
        print("task: val...")
        model.load()
        model.eval()

        from functools import partial
        statistics = {"valid":{}}

        with torch.no_grad():
            for name,dic in statistics.items():
                data = eval(name)   # value of train or val
                print("evaluating {} loss".format(name))
                dic["loss"] = model.evaluate(data, model.loss)
                print("evaluating {} top-1,5 and 10 accuracy".format(name))
                results = model.evaluate(data, lambda *args: model.accuracy(10,*args,reduction="sum"))
                dic["acc1"],dic["acc5"],dic["acc10"] = results[0], results[4], results[9]
        
        append_save(statistics, model.local("performance.json"))

    if "analogy" in args.mode or "full_eval" in args.mode:
        print("task: analogy...")
        model.load()
        model.eval()
        import analogy

        with torch.no_grad():
            add_analogy_statistics = analogy.evaluate_analogy(model, valid, multiplicative=False)
        
        append_save({"analogy":add_analogy_statistics}, model.local("performance.json"))
    
    if "sim" in args.mode or "full_eval" in args.mode:
        print("task: word similarity...")
        model.load()
        model.eval()
        WORD_SIMILARITY_DATASETS = ["bruni_men", "radinsky_mturk", "luong_rare", "sim999",
                                    "ws353_relatedness", "ws353_similarity", "ws353",]
        similarity_statistics = {}
        import similarity
        with torch.no_grad():
            for dataset_name in WORD_SIMILARITY_DATASETS:
                dataset = similarity.similarity_dataset(dataset_name)
                similarity_statistics[dataset_name] = similarity.evaluate_similarity(dataset, model, valid, 
                                                                        normalize=True, ignore_oov=True)
                print(f"dataset {dataset_name} correlation: {similarity_statistics[dataset_name]}")

        append_save({"word_similarity":similarity_statistics}, model.local("performance.json"))

    if "classify" in args.mode or "full_eval" in args.mode:
        print("task: word similarity...")
        if "classification_datasets" not in os.listdir():
            import subprocess
            subprocess.call(['./download_classification_dataset.sh'])
        model.load()
        model.eval()
        CLASSIFICATION_DATASETS = [
            "twenty_ng_sci",
            "twenty_ng_reli",
            "twenty_ng_sport",
            "twenty_ng_comp",
            "movie_sentiment"
        ]
        classification_statistics = {}
        import classification
        with torch.no_grad():
            for dataset_name in CLASSIFICATION_DATASETS:
                result = classification.eval_classification_dataset(dataset_name, valid, model, "test" in args.mode)
                classification_statistics[dataset_name] = result
                print(f"dataset {dataset_name} classification accuracy: {classification_statistics[dataset_name]}")
        
        append_save({"text_classification":classification_statistics}, model.local("performance.json"))
    
    # craw for word craw, as we plot the t-SNE of word embedding in 2D space along sentence
    if "craw" in args.mode or "full_eval" in args.mode:
        print("task: word craw plot...")
        model.load()
        model.eval()
        import word_craw
        def rep(adj,noun,repeat=1):
            return [[*[adj for j in range(repeat)], noun],[adj]]
                
        phrases = [
            # [["red","red","apple"], ["red"]]
            rep("regular","habit",8)+rep("routine","habit",8),
            # rep("red","apple"),
            # rep("white","bear"),
            rep("free","gift"),
            rep("poisonous","venom")+rep("venomous","poison"),
            rep("armed","gunman")+rep("armed","gunmen"),
            rep("unexpected","surprise")+rep("surprising","surprise"),
            rep("male","king")+rep("female","queen"),
            rep("german","volkswagen")+rep("italian","ferrari"),
            [["adult","male","cattle"], ["ox"]],
            [["italian","luxury","sports","car","manufacturer"],["ferrari"]], # https://en.wikipedia.org/wiki/Ferrari
            [["long","thin","solid","cylindrical","pasta"],["spaghetti"]],    # https://en.wikipedia.org/wiki/Spaghetti
            [["dark","irish","dry","stout"],["guiness"]],                   # https://en.wikipedia.org/wiki/Guinness
            [["quadrupedal","ruminant","mammal"], ["sheep"]],             # https://en.wikipedia.org/wiki/Sheep
            [["southernmos"," continent"],["antarctica"]],              # https://en.wikipedia.org/wiki/Antarctica
            [["edible","red","berry"],["tomato"]],                        # https://en.wikipedia.org/wiki/Tomato
            [["denim","pants"],["jeans"]],
        ]
        # post_compositional_phrase = [
        #     ["meet together", "meet"],
        #     ["speak aloud"],
        #     ["fetch back"]
        # ]
        with torch.no_grad():
            # word_craw.plot_phrases(phrases, model, valid, 50, transform_model="tsne")
            word_craw.plot_phrases(phrases, model, valid, 100, transform_model="pca")

    if "cluster" in args.mode or "full_eval" in args.mode:
        print("task: word craw plot...")
        model.load()
        model.eval()
        import word_cluster
        with torch.no_grad():
            # word_cluster.plot_analogy_cluster(model, valid, 100, [], transform_model="tsne")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['capital-common-countries'], transform_model="tsne")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['capital-world'], transform_model="tsne")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['city-in-state'], transform_model="tsne")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['family'], transform_model="tsne")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['capital-common-countries'], transform_model="tsne")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['gram8-plural'], transform_model="tsne")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['capital-common-countries','city-in-state','family'], transform_model="tsne")
            
            # word_cluster.plot_analogy_cluster(model, valid, 100, [], transform_model="pca")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['capital-common-countries'], transform_model="pca")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['capital-world'], transform_model="pca")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['city-in-state'], transform_model="pca")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['family'], transform_model="pca")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['capital-common-countries'], transform_model="pca")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['gram8-plural'], transform_model="pca")
            word_cluster.plot_analogy_cluster(model, valid, 0, ['capital-common-countries','city-in-state','family'], transform_model="pca")

    if "examples" in args.mode or "full_eval" in args.mode:
        model.load()
        model.eval()
        with torch.no_grad():
            # print some results
            stat = model.print_knn(
                ["car", "motorcycle", "bicycle",
                 "king", "queen", "prince", "princess",
                 "sea","lake","river","pond", "island","mountain","hill","valley","forest","woods",
                 "apple","grape","orange","muscat",
                 "potato","carrot","onion",
                 "garlic","pepper","cumin","oregano","jarapeno",
                 "wine","sake","coke","pepsi","water","gingerale",
                 "meat","steak","hamburger","salad","sushi","grill",
                 "spaghetti","noodle","ramen",
                 "run","flee","escape","jump","dance","wave",
                 "speak","yell","murmur","shout",
                 "discuss","argue","claim","prove",
                 "give","provide","produce",
                 "write","read",
                 "lamborghini", "ferrari", "maserati", "fiat", "renault", "bmw", "mercedes", "audi", "toyota", "honda", "mazda", "nissan", "subaru", "ford", "gm", "chevrolet",
                 "suzuki", "kawasaki", "ducati", "yamaha",],
                5, valid)
            append_save(stat, model.local("knn_example.json"))

            stat = model.print_analogy(
                ["man",   "red",    "run",      "japan",  "japan", "japan"],
                ["king",  "apple",  "running",  "tokyo",  "tokyo", "tokyo"],
                ["woman", "yellow", "swim",     "france", "italy", "russia"],
                ["queen", "banana", "swimming", "paris",  "rome",  "moscow"],
                5, valid)
            append_save(stat, model.local("analogy_example.json"))

        # import dataset
        # analogy = dataset.analogy_dataset()
        # import random
        # for _ in range(5):
        #     print_analogy(model,*random.choice(analogy),5,valid)

    if "test" in args.mode:
        print("task: test...")
        print("really, only use this method when you need to evaluate it for the paper")
        print("do not use it for tuning!")
        test  = modelclass.datasetclass(input=hyper["test"], vocab=train.vocab,**hyper)
        model.load()
        model.evaluate(test)

    if "to_csv" in args.mode:
        model.load()
        model.eval()
        import to_csv
        to_csv.to_csv(model,valid,topk_include=1000)
        to_csv.to_csv(model,valid,topk_include=2000)
        to_csv.to_csv(model,valid,topk_exclude=1000, topk_include=2000)
        to_csv.to_csv(model,valid,topk_include=4000)
        to_csv.to_csv(model,valid,topk_exclude=2000, topk_include=4000)
        to_csv.to_csv(model,valid,topk_include=8000)
        to_csv.to_csv(model,valid,topk_exclude=4000, topk_include=8000)
        to_csv.to_csv(model,valid,topk_include=16000)
        to_csv.to_csv(model,valid,topk_exclude=8000, topk_include=16000)
        to_csv.to_csv(model,valid,topk_include=32000)
        to_csv.to_csv(model,valid,topk_exclude=16000,topk_include=32000)
        to_csv.to_csv(model,valid,topk_exclude=24000,topk_include=32000)
        to_csv.to_csv(model,valid)

    if "redundancy" in args.mode:
        model.load()
        model.eval()
        import redundancy
        redundancy.redundancy(model,valid,32000)

    if "vocab" in args.mode:
        print("task: vocab...")
        print("vocabulary:",train.vocab_size)
        for key in train.vocab:
            print(key)
            import sys
            sys.exit()
except Exception:
    import stacktrace
    print_all_gpu_variables()
    stacktrace.format()
