import logging

import math
import re
import os
import numpy as np
from collections import deque


import torch
from torch import optim
import torch.nn.functional as F

from data_process.icews.loader import DataLoader
from model.dynamic_fewshot import DyFewShot

from tensorboardX import SummaryWriter

from code_utils.config import PATHS
import json
from datetime import datetime




class Trainer():
    def __init__(self, args):
        self.data_loader = DataLoader(args)
        device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")
        ent_emb, rel_emb = self.data_loader.load_pretrained_emb(name=args.emb_name,
                                                                graph_mode=args.graph_mode,
                                                                dataset_name=args.dataset)
        args.ent_num = len(self.data_loader.symb2id['ent2id'])
        args.rel_num = len(self.data_loader.symb2id['rel2id'])
        self.args = args
        self.model = DyFewShot(args, ent_emb=ent_emb, rel_emb=rel_emb, device=device)

        self.parameters = filter(lambda p: p.requires_grad, self.model.parameters())

        if self.args.checkpoint > 0:
            lr = args.lr * 0.01
        else:
            lr = args.lr

        self.optim = optim.Adam(self.parameters, lr=lr,
                                weight_decay=args.weight_decay)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optim, milestones=[8000], gamma=0.1)

        self.log_every = args.log
        self.margin = args.margin
        self.eval_every = args.eval
        self.max_epochs = args.epochs
        self.epoch_nums = 0

        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        self.writer = SummaryWriter(logdir='logs/' + current_time + '_' + self.create_model_name(0)[:-4])
        # self.writer = SummaryWriter(PATHS.LOG_DIR, filename_suffix=self.create_model_name())

    def create_model_name(self, epoch_nums=None):
        # import IPython; IPython.embed()
        if epoch_nums is None:
            epoch_nums = self.epoch_nums
        model_name = 'lr_%.4f_wd_%.2f_dp_%.2f_tune_%s_shot_%d_b_%d_et_%s_ed_%d_meta_%s_seq_%s_reg_%s%.2f_margin_%d_epoch_%d_r%d.pth' % (
                                                                                self.args.lr,
                                                                                self.args.weight_decay,
                                                                                self.args.enc_dropout,
                                                                                self.args.finetune,
                                                                                   self.args.shots,
                                                                                   self.args.batch_size,
                                                                                   self.args.emb_name,
                                                                                   self.args.emb_dim,
                                                                                   self.args.meta_type,
                                                                                   self.args.seq_encoder,
                                                                                   self.args.norm,
                                                                                   self.args.alpha,
                                                                                   self.args.margin,
                                                                                            epoch_nums, self.args.run_num)


        return model_name

    def save(self, model_name=None, path=None):
        if not path:
            path = self.data_loader.hist_path + PATHS.SAVE_PATH
            if not os.path.exists(path):
                os.makedirs(path)
            print(path)
        if model_name is None:
            model_name = self.create_model_name()
        else:
            name = self.create_model_name()
            model_name = re.sub(r"(?is)_epoch_[\d]+", "_%s"%model_name, name)

        torch.save(self.model.state_dict(), path + model_name)

    def load(self, epoch_nums=None, model_name=None, path=None):
        if not path:
            path = self.data_loader.hist_path + PATHS.SAVE_PATH
            print(path)
        if model_name is None:
            model_name = self.create_model_name(epoch_nums)
        else:
            name = self.create_model_name()
            model_name = re.sub(r"(?is)_epoch_[\d]+", "_%s"%model_name, name)

        print(path, model_name)
        state_dict = torch.load(path + model_name)


        self.model.load_state_dict(state_dict)
        for parameter in self.model.parameters():
            print(parameter)
        import IPython;
        IPython.embed()


    def save_results(self, results):
        path = self.data_loader.hist_path + PATHS.SAVE_PATH
        name = self.create_model_name()
        model_name = re.sub(r"(?is)_epoch_[\d]+", "_results", name)[:-4]
        print(path + model_name + '.json')
        if self.args.checkpoint > 0:
            with open(path + model_name + '.json', 'w') as fp:
                prev = json.load(results, fp)
            prev.extend(results)
            import IPython; IPython.embed()
            results = prev

        with open(path + model_name + '.json', 'w') as fp:
            json.dump(results, fp)


    def loss_regularizer(self):
        reg = None
        for W in self.model.parameters():
            if reg is None:
                if self.args.norm == 'l1':
                    reg = W.norm(1)
                else:
                    reg = W.norm(2)
            else:
                if self.args.norm == 'l1':
                    reg = reg + W.norm(1)
                else:
                    reg = reg + W.norm(2)

        return reg

    def train(self):
        logging.info('START TRAINING...')

        best_mrr = 0.0

        losses = deque([], self.eval_every)
        margins = deque([], self.log_every)
        results = []
        if self.args.checkpoint > 0:
            model = self.load(self.args.checkpoint)
            self.epoch_nums += self.args.checkpoint

        for sample in self.data_loader.load('train'):
            self.optim.zero_grad()
            dists = self.model.loss(sample)

            n = math.floor(dists.shape[0] / 2)
            query_scores, false_scores = dists[:n], dists[n:]
            if self.args.meta_type == 'protonet':
                margin_ = query_scores - false_scores
                loss = F.relu(self.margin + margin_).mean()
            elif 'match' in self.args.meta_type:
                margin_ = query_scores - false_scores
                loss = F.relu(self.margin - margin_).mean()

            loss += self.args.alpha * self.loss_regularizer()

            margins.append(margin_.mean().item())
            losses.append(loss.item())

            loss.backward()
            self.optim.step()

            if self.epoch_nums != 0 and self.epoch_nums % 5 == 0:
                print('epoc %d: loss %f' % (self.epoch_nums, np.mean(losses)))

            if self.epoch_nums != 0 and self.epoch_nums % self.eval_every == 0:

                print('epoc %d: loss %f' % (self.epoch_nums, np.mean(losses)))
                hits10, hits5, hits3, hits1, mrr, _ = self.eval('val')
                print("hist1: %f, hit3: %f, hit10: %f, map: %f"%(hits1, hits3, hits10, mrr))
                self.writer.add_scalar('HITS10', hits10, self.epoch_nums)
                self.writer.add_scalar('HITS5', hits5, self.epoch_nums)
                self.writer.add_scalar('HITS1', hits1, self.epoch_nums)
                self.writer.add_scalar('MAP', mrr, self.epoch_nums)
                results.append([hits10, hits5, hits1, mrr])
                self.save()

                if mrr > best_mrr:
                    self.save(model_name='model_bestMrr')
                    best_mrr = hits10

            if self.epoch_nums % self.log_every == 0:
                # self.save()
                # logging.info('AVG. BATCH_LOSS: {.2f} AT STEP {}'.format(np.mean(losses), self.batch_nums))
                self.writer.add_scalar('Avg_batch_loss', np.mean(losses), self.epoch_nums)

            self.epoch_nums += 1
            self.scheduler.step()
            if self.epoch_nums > self.max_epochs:
                self.save()
                break
        self.save_results(results)

    def eval(self, mode):
        self.model.eval()

        hits10 = []
        hits5 = []
        hits3 = []
        hits1 = []
        mrr = []
        ranks = []

        for sample in self.data_loader.load(mode):

            scores = self.model.loss(sample)
            scores = scores.detach().numpy()
            ind = 0
            sort = list(np.argsort(scores, axis=0))

            # permute = np.random.permutation(len(scores))
            # ind = np.argwhere(permute == 0)[0][0]
            # sort = list(np.argsort(scores[permute], axis=0))

            rank = sort.index(ind) + 1
            ranks.append(rank)
            if rank <= 10:
                hits10.append(1.0)
                # hits10_.append(1.0)
            else:
                hits10.append(0.0)
                # hits10_.append(0.0)
            if rank <= 5:
                hits5.append(1.0)
                # hits5_.append(1.0)
            else:
                hits5.append(0.0)
                # hits5_.append(0.0)
            if rank <= 3:
                hits3.append(1.0)
                # hits1_.append(1.0)
            else:
                hits3.append(0.0)
                # hits1_.append(0.0)

            if rank <= 1:
                hits1.append(1.0)
                # hits1_.append(1.0)
            else:
                hits1.append(0.0)
                # hits1_.append(0.0)

            mrr.append(1.0 / rank)
            # mrr_.append(1.0 / rank)

            # logging.critical(
            #     '{} Hits10:{:.3f}, Hits5:{:.3f}, Hits1:{:.3f} MRR:{:.3f}'.format(query_, np.mean(hits10_), np.mean(hits5_),
            #                                                                      np.mean(hits1_), np.mean(mrr_)))
            # logging.info('Number of candidates: {}, number of text examples {}'.format(len(candidates), len(hits10_)))
            # print query_ + ':'
            # print 'HITS10: ', np.mean(hits10_)
            # print 'HITS5: ', np.mean(hits5_)
            # print 'HITS1: ', np.mean(hits1_)
            # print 'MAP: ', np.mean(mrr_)

        logging.critical('HITS10: {:.3f}'.format(np.mean(hits10)))
        logging.critical('HITS5: {:.3f}'.format(np.mean(hits5)))
        logging.critical('HITS3: {:.3f}'.format(np.mean(hits3)))
        logging.critical('HITS1: {:.3f}'.format(np.mean(hits1)))
        logging.critical('MAP: {:.3f}'.format(np.mean(mrr)))

        self.model.train()

        return np.mean(hits10), np.mean(hits5), np.mean(hits3), np.mean(hits1), np.mean(mrr), ranks

    def test_(self, test_setting, path=None):
        epoch_num = test_setting['epoch_nums']
        prefix = test_setting['prefix']
        if prefix is None:
            self.load(epoch_num, path=path)
        else:
            self.load(model_name=prefix, path=path)
        logging.info('Pre-trained model loaded')

        val_res = self.eval(mode='val')
        test_res = self.eval(mode='test')
        return test_res, val_res
        # return val_res




