import argparse
import copy
import os
import json
import tempfile
from typing import Dict, List, Set, Tuple
import torch
from torch import cuda, optim
from torch.optim import lr_scheduler
from torch.utils.tensorboard.writer import SummaryWriter

from importers.ea_ra_kgc import EaRaKgcData
from AlignKGC.data_loader import data_loader
from AlignKGC import losses, models, evaluate, utils
from AlignKGC.utils import get_filter, has_cuda
from AlignKGC.kb import kb


class AlignKgcBaseTrainer(object):
    """Base class for all trainers."""

    def __init__(self, **kwargs) -> None:
        self.argv = copy.deepcopy(kwargs)
        del self.argv["meta"]

        super(AlignKgcBaseTrainer, self).__init__()
        self.has_cuda = has_cuda()
        self.torchdev = torch.device("cuda:"+str(kwargs["cudadev"]))

        # Using kwargs[argname] means compulsory,
        # using kwargs.get(argname, defval) means optional with default.
        # As a policy reduce defaults to minimum here.
        # Use argparser to inject defaults.

        # Simple configs typically paths and scalar values
        self.meta : EaRaKgcData = kwargs["meta"]
        self.ea_percent : int = kwargs["ea_percent"]
        self.ra_percent : int = kwargs["ra_percent"]
        self.dataset_root : str = \
            self.meta.combined_ea_ra_path(self.ea_percent, self.ra_percent)
        self.resume_from_save: bool = "resume_from_save" in kwargs and \
            kwargs["resume_from_save"] != 0
        self.max_epochs : int = kwargs["max_epochs"]
        self.eval_only : bool = kwargs["eval_only"]

        self.model_base : str = \
            os.path.join(self.dataset_root, "model_" + self.__class__.__name__)
        if not os.path.isdir(self.model_base):
            print("creating", self.model_base)
            os.mkdir(self.model_base, mode=0o755)

        if self.eval_only:
            self.save_directory : str = None
            self.tflogs_dir : str = None
            self.tbwriter : SummaryWriter = None
        else:
            self.save_directory = tempfile.mkdtemp(dir=self.model_base)
            os.chmod(self.save_directory, 0o755)
            self.tflogs_dir = self.save_directory
            self.tbwriter = SummaryWriter(log_dir=self.save_directory)
            utils.duplicate_stdout(os.path.join(self.save_directory, "log.txt"))

        self.verbose : int = kwargs["verbose"]
        self.batch_size : int = kwargs["batch_size"]
        self.negative_sample_count : int = kwargs["negative_sample_count"]
        self.eval_batch_size : int = kwargs["eval_batch_size"]
        self.eval_every_x_mini_batches : int = \
            kwargs["eval_every_x_mini_batches"]
        self.gradient_clip : float = kwargs["gradient_clip"]
        self.regloss_coeff : float = kwargs["regloss_coeff"]
        self.ealoss_coeff : float = kwargs["ealoss_coeff"]
        self.raloss_coeff : float = kwargs["raloss_coeff"]
        if kwargs["hooks"]:
            self.hooks = json.loads(kwargs["hooks"])
        else:
            self.hooks = list()
        self.loss = getattr(losses, kwargs["loss"])()

        # Track validation perf
        self.best_mrr_on_valid : Dict = None

        # KGC train dev test folds and dependent members
        self.dltrain : data_loader = None
        self.dlvalid : data_loader = None
        self.dltestmap : Dict[str, Tuple[str, data_loader]] = None
        self.scoring_function = None
        self.optim : optim = None
        self.scheduler : lr_scheduler.ReduceLROnPlateau = None
        self.filtmap : Dict[str, List] = None
        self.eval_batch_size : int = None

        self.init_folds(**kwargs)

        # These may be best to init in subclasses.
        # self.ent_aligns : dict[int,int] = kwargs["ent_aligns"]
        # """ dict(g_ent_id, g_ent_id) """
        # self.rel_aligns = kwargs["rel_aligns"]
        # self.rel_implies = kwargs["rel_implies"]


    def init_folds(self, **kwargs):
        first_zero_val = (kwargs["oov_entity"] != None)

        ktrain = kb(os.path.join(self.dataset_root, 'train.txt'))
        if kwargs["oov_entity"]:
            if not "<OOV>" in ktrain.entity_map.keys():
                ktrain.entity_map["<OOV>"] = len(ktrain.entity_map)
                ktrain.nonoov_entity_count = ktrain.entity_map["<OOV>"]+1

        self.dltrain : data_loader = \
            data_loader(ktrain, self.has_cuda, loss=self.loss,
                        flag_add_reverse=kwargs["inverse"],
                        first_zero=first_zero_val)
        kvalid = kb(os.path.join(self.dataset_root,'valid.txt'),
                       em=ktrain.entity_map, rm=ktrain.relation_map,
                       add_unknowns=not kwargs["oov_entity"],
                       nonoov_entity_count=ktrain.nonoov_entity_count)
        self.dlvalid : data_loader = \
            data_loader(kvalid, self.has_cuda, loss=self.loss,
                        first_zero=first_zero_val)

        ktestmap = dict()
        """ key = langname below; val = (lang, kb) """
        for lang in self.meta.langs:
            for langname in ["test_" + lang + ".txt",
                            lang + "_f_test.txt",
                            lang + "_o_test.txt"]:
                langkpath = os.path.join(self.dataset_root, langname)
                print("loading", langname)
                langkb = kb(langkpath, em=ktrain.entity_map,
                            rm=ktrain.relation_map,
                            add_unknowns=not kwargs["oov_entity"],
                            nonoov_entity_count =
                            ktrain.nonoov_entity_count)
                ktestmap[langname] = (lang, langkb)
        self.dltestmap : Dict[str, Tuple[str, kb]] = dict()
        """ key = langname; val = (lang, dlkb) """
        for langname, (lang, langkb) in ktestmap.items():
            dlkb = data_loader(langkb, self.has_cuda, loss=self.loss,
                               first_zero=first_zero_val)
            self.dltestmap[langname] = (lang, dlkb)

        self.filtmap = dict()
        """ key = lang; val = filt """
        if kwargs["filter"]:
            for lang in self.meta.langs:
                filtname = "filters_" + lang + ".txt"
                filtpath = os.path.join(self.dataset_root, filtname)
                print("loading", filtname)
                langfilt = get_filter(filtpath, em=ktrain.entity_map,
                                      rm=ktrain.relation_map,
                                      add_unknowns=
                                      not kwargs["oov_entity"],
                                      nonoov_entity_count =
                                      ktrain.nonoov_entity_count)
                self.filtmap[lang] = langfilt

        model_arguments = json.loads(kwargs["model_arguments"])
        model_arguments['entity_count'] = len(ktrain.entity_map)
        if kwargs["regularizer"]:
            print("Using reg ", kwargs["regularizer"])
            model_arguments['reg'] = kwargs["regularizer"]
        if kwargs["inverse"]:
            model_arguments['relation_count'] = len(ktrain.relation_map)*2
            model_arguments['flag_add_reverse'] = kwargs["inverse"]
            model_arguments['flag_avg_scores'] = kwargs["avg_scores"]
        else:
            model_arguments['relation_count'] = len(ktrain.relation_map)
        model_arguments['batch_norm'] = kwargs["batch_norm"]
        print("model_arguments", model_arguments)

        self.scoring_function = getattr(models, kwargs["model"])(**model_arguments)
        if self.has_cuda:
            self.scoring_function = self.scoring_function.cuda()
        self.regularizer = self.scoring_function.regularizer
        try:
            self.flag_add_reverse = self.scoring_function.flag_add_reverse
        except:
            self.flag_add_reverse = 0

        self.optim = getattr(torch.optim, kwargs["optimizer"])\
            (self.scoring_function.parameters(), lr=kwargs["learning_rate"])
        self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim,
            'max', patience = 2, verbose=True)
        if not self.eval_batch_size:
            self.eval_batch_size = max(50, self.batch_size*2*
                self.negative_sample_count//len(ktrain.entity_map))

    def find_best_dev_kgc_model(self) -> str :
        """Search for model with best dev perf under a directory."""
        max_dev_perf = 0
        argmax_dev_path = None
        argmax_dev_perf = None
        for runname in os.listdir(self.model_base):
            runpath = os.path.join(self.model_base, runname)
            perfpath = os.path.join(runpath, "commit.json")
            if not os.path.isfile(perfpath):
                continue
            print("\tscan dev perf", runname)
            with open(perfpath, "rb") as perf_file:
                perfs = json.load(perf_file)
                if max_dev_perf < perfs["valid_score"]["m"]["mrr"]:
                    max_dev_perf = perfs["valid_score"]["m"]["mrr"]
                    argmax_dev_perf = perfs
                    argmax_dev_path = runpath
        if argmax_dev_perf:
            print("best dev perf", argmax_dev_path,
                  "macro_e2_mrr", argmax_dev_perf["valid_score"]["e2"]["mrr"],
                  "macro_m_mrr", argmax_dev_perf["valid_score"]["m"]["mrr"])
        return argmax_dev_path

    def do_eval_kgc_only(self):
        """Search for model with best dev perf under a directory, load that
        model, evaluate on test fold."""
        assert self.eval_only, "Flag eval_only not set"
        best_model_dir = self.find_best_dev_kgc_model()
        best_model_path = os.path.join(best_model_dir, "best_valid_model.pt")
        best_model = torch.load(best_model_path)
        self.scoring_function.load_state_dict(best_model["model_weights"])
        test_scores_per_lang = dict()
        sum_e2_mrr = 0.
        for langname, (langx, langkb) in self.dltestmap.items():
            test_score = evaluate.evaluate(langname, self.ranker,
                langkb.kb, self.eval_batch_size,
                verbose=self.verbose, hooks=self.hooks,
                filt = self.filtmap[langx])
            test_scores_per_lang[langname] = test_score
            sum_e2_mrr += test_score["e2"]["mrr"]
        print("macro e2 mrr", sum_e2_mrr / len(self.dltestmap))
        # TODO collect MR or H@k value for each query separately and store
        # TODO in a collection for significance tests
        # perfval = argmax_dev_perf["test_scores"][perfkey]["e2"]["mrr"]

    def start(self, steps: int=None, batch_count: int=None, mb_start: int=None):
        """Entry point for training."""
        assert not self.eval_only, "Cannot train in eval_only mode"
        if steps is None:
            steps = int(self.max_epochs * self.dltrain.kb.facts.shape[0] /
                        self.batch_size)
        print("steps=%d, eval_batch_size=%d" % (steps, self.eval_batch_size))
        if batch_count is None:
            batch_count = [self.eval_every_x_mini_batches//20, 20]
        print("batch_count", batch_count)
        if mb_start is None:
            if self.resume_from_save:
                mb_start = self.load_state(self.resume_from_save)
            else:
                mb_start = 0
        print("mb_start", mb_start)
        losses = []
        count = 0
        print("Starting training")
        for batchx in range(mb_start, steps):
            various_losses = self.step(batchx)
            losses.append(various_losses["total"])
            self.tbwriter.add_scalars("loss", various_losses, batchx)
            self.tbwriter.flush()
            if len(losses) >= batch_count[0]:
                count += 1
                losses = []
                if count == batch_count[1]:
                    self.scoring_function.eval()
                    valid_score = evaluate.evaluate("valid", self.ranker,
                        self.dlvalid.kb, self.eval_batch_size, 
                        verbose=self.verbose, hooks=self.hooks)
                    self.tbwriter.add_scalar("valid_e2_mrr",
                        valid_score["e2"]["mrr"], batchx)
                    self.tbwriter.flush()
                    test_scores_per_lang = dict()
                    for langname, (langx, langkb) in self.dltestmap.items():
                        test_score = evaluate.evaluate(langname, self.ranker,
                            langkb.kb, self.eval_batch_size, 
                            verbose=self.verbose, hooks=self.hooks,
                            filt = self.filtmap[langx])
                        test_scores_per_lang[langname] = test_score
                    self.scoring_function.train()
                    self.scheduler.step(valid_score['m']['mrr'])
                    #Scheduler to manage learning rate added
                    count = 0
                    self.save_state(batchx, valid_score, test_scores_per_lang)
        self.terminate()

    def terminate(self):
        commit_dict = copy.deepcopy(self.argv)
        commit_dict.update(self.best_mrr_on_valid)
        commit_path = os.path.join(self.save_directory, "commit.json")
        with open(commit_path, "w") as commit_file:
            json.dump(commit_dict, commit_file)
        print("wrote", commit_path)

    def save_state(self, mini_batches, valid_score, test_scores):
        state = dict()
        state['mini_batches'] = mini_batches
        state['epoch'] = mini_batches*self.batch_size/self.dltrain.kb.facts.shape[0]
        state['model_name'] = type(self.scoring_function).__name__
        state['model_weights'] = self.scoring_function.state_dict()
        state['optimizer_state'] = self.optim.state_dict()
        state['optimizer_name'] = type(self.optim).__name__
        state['entity_map'] = self.dltrain.kb.entity_map
        state['reverse_entity_map'] = self.dltrain.kb.reverse_entity_map
        state['relation_map'] = self.dltrain.kb.relation_map
        state['reverse_relation_map'] = self.dltrain.kb.reverse_relation_map
        state['nonoov_entity_count'] = self.dltrain.kb.nonoov_entity_count
        state["valid_score"] = valid_score
        state['test_scores'] = test_scores

        if not self.best_mrr_on_valid or \
                state['valid_score']['m']['mrr'] >= \
                    self.best_mrr_on_valid["valid_score"]["m"]["mrr"]:
            # print("_ARGV_", str(self.argv))
            # print("_BEST_MODEL_ {}".format(state["valid_score"]))
            best_name = os.path.join(self.save_directory, "best_valid_model.pt")
            self.best_mrr_on_valid = {
                "valid_score" : copy.deepcopy(valid_score),
                "test_scores" : copy.deepcopy(test_scores)
            }
            if(os.path.exists(best_name)):
                os.remove(best_name)
            torch.save(state, best_name)
            print("saved state to", best_name)

    def load_state(self, state_file):
        state = torch.load(state_file)
        if state['model_name'] != type(self.scoring_function).__name__:
            utils.colored_print('yellow', 'model name in saved file %s is different from the name of current model %s' %
                                (state['model_name'], type(self.scoring_function).__name__))
        self.scoring_function.load_state_dict(state['model_weights'])
        if state['optimizer_name'] != type(self.optim).__name__:
            utils.colored_print('yellow', ('optimizer name in saved file %s is different from the name of current '+
                                          'optimizer %s') %
                                (state['optimizer_name'], type(self.optim).__name__))
        self.optim.load_state_dict(state['optimizer_state'])
        return state['mini_batches']

    def get_lrid_to_soset_map(self):
        """Prepare map from lrid to SO-pairs(rel). INCOMPLETE."""
        lrid_to_soset : Dict[int, List[Tuple[int,int]]] = dict()
        for lang in self.meta.langs:
            lid = self.meta.lang_to_lid(lang)
            lang_kgc_train_path = os.path.join(self.meta.dir,
                                               "kgs/" + lang + "-train.tsv")
            for (sid, rid, oid) in self.meta.read_tsv_int_path(lang_kgc_train_path):
                lrid = self.meta.lang_rel_do_prefix(lid, rid)
                assert int == type(lrid)
                gsid = self.dltrain.kb.entity_map[str(sid)]
                assert int == type(gsid)
                goid = self.dltrain.kb.entity_map[str(oid)]
                assert int == type(goid)
                if lrid not in lrid_to_soset:
                    lrid_to_soset[lrid] = list()
                lrid_to_soset[lrid].append((gsid, goid))
        print("lrid_to_soset", len(lrid_to_soset))
        return lrid_to_soset

    def ra_loss_hard(self, meta: EaRaKgcData,
                     prid_to_sos: Dict[int, List[Tuple[int,int]]],
                     dojaccard: bool):
        """Replacement for get_rel_align_imply and rel_alignment_loss.
        INCOMPLETE."""
        ra_loss_ans : float = 0
        rids: List[int] = list()
        print("opening", meta.ra_path(self.ra_percent))
        for [rid] in meta.read_tsv_int_path(meta.ra_path(self.ra_percent)):
            rids.append(rid)  # no prefix
        num_tried, num_found = 0, 0
        for rid in rids:
            for lang1 in meta.langs:
                ll1 = meta.lang_to_lid(lang1)
                for lang2 in meta.langs:
                    if lang1 >= lang2:
                        continue
                    ll2 = meta.lang_to_lid(lang2)
                    ll1rid = meta.lang_rel_do_prefix(ll1, rid)
                    ll2rid = meta.lang_rel_do_prefix(ll2, rid)
                    num_tried += 1
                    if ll1rid in prid_to_sos and ll2rid in prid_to_sos:
                        num_found += 1
        print("rids", len(rids), "found", num_found, "of", num_tried)
        return ra_loss_ans


def get_base_argparser():
    """Prepares and returns an argparser with command line arguments shared
    across all AlignKGC variations."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--dbp5l", required=True,
                        help="/path/to/DBP-5L/")
    parser.add_argument("--eval_only", default=False, action='store_true')
    parser.add_argument('--loss', default="crossentropy_loss_AllNeg_subsample",
                        help="loss function name as in losses.py")
    parser.add_argument('--model', default="complex", required=False,
                        help="model name as in models.py")
    parser.add_argument('--model_arguments', default=
                        '{"embedding_dim":180, "batch_norm":1, "unit_reg":0}',
                        help="model arguments as in __init__ of "
                        "model (Excluding entity and relation count) "
                        "This is a json string", required=False)
    parser.add_argument('--optimizer', default='Adagrad')
    parser.add_argument('--learning_rate', type=float, default=0.8)
    parser.add_argument('--regularizer', default=2.0, type=float,
                        choices=[2.0, 3.0], help="regularizer norm")
    parser.add_argument('--regloss_coeff', type=float, default=0.02)
    parser.add_argument("--ealoss_coeff", type=float, default=50.0)
    parser.add_argument("--raloss_coeff", type=float, default=5.)
    parser.add_argument("--ea_percent", type=int, default=20)
    parser.add_argument("--ra_percent", type=int, default=20)
    parser.add_argument('--gradient_clip', type=float)
    parser.add_argument('--max_epochs', type=int, default=70)
    parser.add_argument('--batch_size', type=int, default=500)
    parser.add_argument('--eval_every_x_mini_batches', type=int, default=1000)
    parser.add_argument('--eval_batch_size', type=int, default=0)
    parser.add_argument('--negative_sample_count', type=int, default=2000)
    parser.add_argument('--resume_from_save', type=int, default=0)
    parser.add_argument('--oov_entity', type=int, default=1)
    parser.add_argument('-q', '--verbose', type=int, default=0)
    parser.add_argument('-z', '--debug', type=int, default=0)
    parser.add_argument('-k', '--hooks', default="[]")
    parser.add_argument('-bn', '--batch_norm', type=int, default=0)
    parser.add_argument('-msg', '--message', required=False)
    parser.add_argument('-f', '--filter', type=int, default=1)
    parser.add_argument('-inv', '--inverse', type=int, default=0)
    parser.add_argument('-avg', '--avg_scores', default=0)
    parser.add_argument("--seed", type=int, default=41,
                        help="seed for numpy and torch random numbers")
    parser.add_argument("--multiseed", type=int, default=1,
                        help="If >1, use seed to generate multiseed seeds "
                        "and run multiple times saving to different model dirs")
    parser.add_argument("--cudadev", type=int, default=0,
                        help="Ordinal number of cuda device 0/1/2 etc.")
    return parser
