# -*- coding: utf-8 -*-
# !/usr/bin/python

import sys
import time
import torch
sys.path.append("..")
import random
from baselines.basic_trainer import BasicTrainer
from baselines.utils import RANDOM


class EWCTrainer(BasicTrainer):
    def __init__(self, args, model_save_path):
        super(EWCTrainer, self).__init__(args, model_save_path)

        # CL component
        self.past_task_id = -1
        self.observed_task_ids = []
        self.memory_data = {}  # stores exemplars class by class

        self.fisher = {}
        self.optpar = {}

    def train(self):

        for i in range(self.args.task_num):
            best_result = {"acc": 0.0, "epoch": 0}
            examples = self.task_controller.task_list[i]["train"]

            n_epochs = self.args.epoch
            epoch_eval = self.args.epoch_eval

            patience = 0

            if i != self.past_task_id:
                if len(self.observed_task_ids) > 0:
                    self.model.train()
                    start_time = time.time()
                    self.optimizer.zero_grad()
                    past_task_id = self.observed_task_ids[self.past_task_id]
                    replay_examples = self.memory_data[past_task_id]
                    assert past_task_id != i

                    replay_report_loss = self.train_one_epoch(replay_examples, optimize_step=False)
                    print("\nReplay Task {}, replay_report_loss {}, Time {}".format(past_task_id, replay_report_loss, time.time() - start_time))

                    self.optpar[self.past_task_id] = {}
                    self.fisher[self.past_task_id] = {}

                    for n, p in self.model.named_parameters():
                        if p.grad is None or "plm_model" in n:
                            continue
                        pd = p.data.clone()
                        pg = p.grad.data.clone().pow(2)
                        self.optpar[self.past_task_id][n] = pd
                        self.fisher[self.past_task_id][n] = pg
                    torch.cuda.empty_cache()
                self.observed_task_ids.append(i)
                self.past_task_id = i

            self.memory_data[i] = []

            for epoch in range(n_epochs):
                self.model.train()
                epoch_begin = time.time()
                random.shuffle(examples)
                st = 0
                cnt = 0
                report_loss, example_num = 0.0, 0
                self.optimizer.zero_grad()

                while st < len(examples):
                    ed = st + self.args.batch_size if st + self.args.batch_size < len(examples) else len(examples)

                    report_loss, example_num, loss = self.train_one_batch(examples[st:ed], report_loss, example_num)

                    torch.cuda.empty_cache()
                    loss.backward()

                    if (cnt + 1) % self.args.accumulation_step == 0 or ed == len(examples):
                        if len(self.observed_task_ids) > 1:
                            reg_loss = 0
                            for _task_id in range(len(self.observed_task_ids) - 1):
                                for n, p in self.model.named_parameters():
                                    if p.grad is None or "plm_model" in n:
                                        continue
                                    l = self.fisher[_task_id][n]
                                    l = l * (p - self.optpar[_task_id][n]).pow(2)
                                    reg_loss += l.sum()
                            torch.cuda.empty_cache()
                            reg_loss = self.args.ewc_reg * reg_loss
                            reg_loss.backward()

                        if self.args.clip_grad > 0.:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad)
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    st = ed
                    cnt += 1

                print("\nTask {}, Epoch Train {}, Loss {}, Time {}".format(i, epoch, report_loss / example_num, time.time() - epoch_begin))

                if epoch < epoch_eval:
                    continue

                start_time = time.time()
                dev_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(self.task_controller.task_list[i]["dev"])
                print('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f\n' % (epoch, time.time() - start_time, dev_acc))

                if dev_acc >= best_result['acc']:
                    best_result['acc'], best_result['epoch'] = dev_acc, epoch
                    self.save(self.model, name="model.bin")
                    patience = 0
                else:
                    patience += 1

                if patience > self.args.max_patience:
                    break

            self.memory_data[i] = []
            sampled_examples = RANDOM(examples=examples,
                                      memory_size=self.args.memory_size)
            self.memory_data[i].extend(sampled_examples)

            self.load(self.model)
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(self.task_controller.task_list[i]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))

            self.first_acc_list[i] = test_acc
            self.eval_task_stream(i, test_acc)

        return self.avg_acc_list, self.whole_acc_list, self.bwt_list, self.fwt_list
