# -*- coding: utf-8 -*-
# !/usr/bin/python

import sys
import time
import torch
import torch.nn as nn
sys.path.append("..")
import random
from baselines.basic_trainer import BasicTrainer


class HATTrainer(BasicTrainer):
    def __init__(self, args, model_save_path):
        super(HATTrainer, self).__init__(args, model_save_path)

        self.lamb = 0.75
        self.smax = 400

        self.mask_pre = None
        self.mask_back = None

        self.model.mask_module_list = nn.Parameter(torch.rand(args.task_num, args.hidden_size))

    def train(self):
        for i in range(self.args.task_num):
            best_result = {"acc": 0.0, "epoch": 0}
            examples = self.task_controller.task_list[i]["train"]
            n_epochs = self.args.epoch
            epoch_eval = self.args.epoch_eval

            patience = 0

            for epoch in range(n_epochs):
                self.model.train()
                epoch_begin = time.time()
                random.shuffle(examples)
                st = 0
                itera = 0
                report_loss, example_num = 0.0, 0

                self.optimizer.zero_grad()
                while st < len(examples):
                    s = (self.smax - 1 / self.smax) * (itera + 1) / int(len(examples) / self.args.batch_size) + 1 / self.smax

                    # training on the batch of current task
                    ed = st + self.args.batch_size if st + self.args.batch_size < len(examples) else len(examples)
                    report_loss, example_num, loss = self.train_one_batch(examples[st:ed], report_loss, example_num)

                    loss.backward()

                    if (itera + 1) % self.args.accumulation_step == 0 or ed == len(examples):
                        if i > 0:
                            for n, p in self.model.named_parameters():
                                if n in self.mask_back and p.grad is not None:
                                    p.grad.data *= self.mask_back[n]

                        # Compensate embedding gradients
                        for n, p in self.model.named_parameters():
                            if n.startswith('mask') and p.grad is not None:
                                # print (n)
                                # print (p)
                                num = torch.cosh(torch.clamp(s * p.data, -50, 50)) + 1
                                den = torch.cosh(p.data) + 1
                                p.grad.data *= self.smax / s * num / den

                        if self.args.clip_grad > 0.:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad)
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    # some records
                    st = ed
                    itera += 1

                print("\nTask {}, Epoch Train {}, Loss {}, Time {}".format(i, epoch, report_loss / example_num, time.time() - epoch_begin))

                if epoch < epoch_eval:
                    continue

                start_time = time.time()
                dev_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                    self.task_controller.task_list[i]["dev"])
                print('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f\n' % (epoch, time.time() - start_time, dev_acc))

                if dev_acc >= best_result['acc']:
                    best_result['acc'], best_result['epoch'] = dev_acc, epoch
                    self.save(self.model, name="model.bin")
                    patience = 0
                else:
                    patience += 1

                if patience > self.args.max_patience:
                    break

            mask_ass = torch.sigmoid(self.smax * self.model.mask_module_list[i]).detach()

            if i == 0:
                self.mask_pre = mask_ass
            else:
                self.mask_pre = torch.max(self.mask_pre, mask_ass)

            self.mask_back = {}
            for n, _ in self.model.named_parameters():
                if n.startswith('decoder_lstm'):
                    if n == 'decoder_lstm.weight_ih':
                        vals = self.mask_pre.data.view(1, -1).expand_as(self.model.decoder_lstm.weight_ih)
                        self.mask_back[n] = 1 - vals
                    elif n == 'decoder_lstm.weight_hh':
                        vals = self.mask_pre.data.view(1, -1).expand_as(self.model.decoder_lstm.weight_hh)
                        self.mask_back[n] = 1 - vals

            self.load(self.model)
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                self.task_controller.task_list[i]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))

            self.first_acc_list[i] = test_acc
            self.eval_task_stream(i, test_acc)

        return self.avg_acc_list, self.whole_acc_list, self.bwt_list, self.fwt_list

    def train_one_batch(self, examples, report_loss, example_num):
        score = self.model.forward(examples)
        loss_sketch = -score[0]
        loss_lf = -score[1]

        _loss = torch.sum(loss_sketch).data.item() + torch.sum(loss_lf).data.item()
        #
        loss_sketch = torch.mean(loss_sketch)
        loss_lf = torch.mean(loss_lf)

        loss = loss_lf + loss_sketch

        report_loss += _loss
        example_num += len(examples)
        return report_loss, example_num, loss