import time

import numpy as np
import torch
from torch import Tensor
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

import utils
from approaches.agem import agem_utils
from approaches.agem.buffer import Buffer


class Appr(object):
    def __init__(self, model: nn.Module, device: str,
                 epochs_max: int, patience_max: int,
                 lr: float, lr_min: float, lr_factor: float,
                 buffer_size: int, buffer_percent: float):
        super().__init__()

        self.device = device
        self.model = model.to(device)
        self.nepochs = epochs_max
        self.lr_patience = patience_max
        self.lr = lr
        self.lr_min = lr_min
        self.lr_factor = lr_factor

        self.ce = nn.CrossEntropyLoss()
        self.clipgrad = 1000

        # buffer
        self.buffer_size = buffer_size
        self.buffer_percent = buffer_percent
        self.buffer = Buffer(self.buffer_size, self.device)
        self.grad_dims = []
        for param in self.model.parameters():
            self.grad_dims.append(param.data.numel())
        # endfor
        self.grad_xy = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
        self.grad_er = torch.Tensor(np.sum(self.grad_dims)).to(self.device)
    # enddef

    def _get_optimizer(self, lr: float):
        return torch.optim.SGD(self.model.parameters(), lr=lr)
    # enddef

    def project(self, gxy: torch.Tensor, ger: torch.Tensor) -> torch.Tensor:
        corr = torch.dot(gxy, ger) / torch.dot(ger, ger)
        return gxy - corr * ger
    # enddef

    def store_grad(self, params, grads, grad_dims):
        """
            This stores parameter gradients of past tasks.
            pp: parameters
            grads: gradients
            grad_dims: list with number of parameters per layers
        """
        # store the gradients
        grads.fill_(0.0)
        count = 0
        for param in params():
            if param.grad is not None:
                begin = 0 if count == 0 else sum(grad_dims[:count])
                end = np.sum(grad_dims[:count + 1])
                grads[begin: end].copy_(param.grad.data.view(-1))
            # endif

            count += 1
        # endfor
    # enddef

    def overwrite_grad(self, params, newgrad, grad_dims):
        """
            This is used to overwrite the gradients with a new gradient
            vector, whenever violations occur.
            pp: parameters
            newgrad: corrected gradient
            grad_dims: list storing number of parameters at each layer
        """
        count = 0
        for param in params():
            if param.grad is not None:
                begin = 0 if count == 0 else sum(grad_dims[:count])
                end = sum(grad_dims[:count + 1])
                this_grad = newgrad[begin: end].contiguous().view(
                    param.grad.data.size())
                param.grad.data.copy_(this_grad)
            # endif

            count += 1
        # endfor
    # enddef

    def train(self, t: int, dl_train: DataLoader, dl_val: DataLoader):
        best_loss = np.inf
        best_model = utils.get_model(self.model)
        lr = self.lr
        # patience = self.lr_patience
        patience = 0

        self.optimizer = self._get_optimizer(lr)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                         mode='min',
                                                         factor=1.0 / self.lr_factor,
                                                         patience=max(self.lr_patience - 1, 0),
                                                         min_lr=self.lr_min,
                                                         verbose=True,
                                                         )

        # Loop epochs
        for e in range(self.nepochs):
            # Train
            clock0 = time.time()
            # iter_bar = tqdm(dl_train, desc='Train Iter (loss=X.XXX)')
            self.train_epoch(t, dl_train)
            clock1 = time.time()
            train_loss, train_acc = self.eval(t, dl_train)
            clock2 = time.time()
            # print('time: ',float((clock1-clock0)*30*25))

            print('| Epoch {:3d}, time={:5.1f}ms/{:5.1f}ms | Train: loss={:.3f}, acc={:5.1f}% |' \
                  .format(e + 1,
                          1000 * dl_train.batch_size * (clock1 - clock0) / len(dl_train),
                          1000 * dl_train.batch_size * (clock2 - clock1) / len(dl_train),
                          train_loss, 100 * train_acc), end='')
            # Valid
            valid_loss, valid_acc = self.eval(t, dl_val)
            print(' Valid: loss={:.3f}, acc={:5.1f}% |'.format(valid_loss, 100 * valid_acc), end='')
            # Adapt lr
            if e == 0 or valid_loss < best_loss:
                best_loss = valid_loss
                best_model = agem_utils.get_model(self.model)
                # patience = self.lr_patience
                patience = 0
                print(' *', end='')
            else:
                if utils.get_current_lr(self.optimizer) <= self.lr_min:
                    patience += 1
                else:
                    patience = 0
                # endif

                '''
                patience -= 1                
                if patience <= 0:
                    lr /= self.lr_factor
                    print(' lr={:.1e}'.format(lr), end='')
                    if lr < self.lr_min:
                        print()
                        break
                    # endif
                    patience = self.lr_patience
                    self.optimizer = self._get_optimizer(lr)
                # endif
                '''
            # endif

            if patience >= self.lr_patience:
                break
            # endif

            scheduler.step(valid_loss)
            print()
        # endfor

        # Restore best
        agem_utils.set_model_(self.model, best_model)

        # Update old

        # add data to the buffer
        # print('len(train): ', len(train_data))
        # samples_per_task = int(len(train_data) * self.buffer_percent)
        samples_per_task = int(len(dl_train.dataset) * self.buffer_percent)
        print('samples_per_task: ', samples_per_task)

        # loader = DataLoader(train_data, batch_size=samples_per_task)
        loader = DataLoader(dl_train.dataset, batch_size=samples_per_task)

        # input_ids, segment_ids, input_mask, targets, _ = next(iter(loader))
        input, targets = next(iter(loader))

        self.buffer.add_data(
            input=input.to(self.device),
            labels=targets.to(self.device),
            task_labels=torch.ones(samples_per_task,
                                   dtype=torch.long).to(self.device) * (t)
            )

        return
    # enddef

    def train_epoch(self, t: int, dl_train: DataLoader):
        self.model.train()

        for _, batch in enumerate(dl_train):
            batch = [bat.to(self.device) if bat is not None else None for bat in batch]
            # input_ids, segment_ids, input_mask, targets, _ = batch
            x, targets = batch

            # now compute the grad on the current data
            self.optimizer.zero_grad()
            output, _ = self.model(t, x)

            loss = self.ce(output, targets)
            loss.backward()  # backward first

            # Forward current model
            if not self.buffer.is_empty():
                self.store_grad(self.model.parameters, self.grad_xy, self.grad_dims)
                buf_inputs, buf_labels, buf_task_labels = self.buffer.get_data(self.buffer_size)
                self.model.zero_grad()
                # buf_inputs = buf_inputs.long
                # buf_labels = buf_labels.long()

                outputs, _ = self.model.forward(t, buf_inputs)
                cur_output = outputs

                penalty = self.ce(cur_output, buf_labels)
                penalty.backward()
                self.store_grad(self.model.parameters, self.grad_er, self.grad_dims)

                dot_prod = torch.dot(self.grad_xy, self.grad_er)
                if dot_prod.item() < 0:
                    g_tilde = self.project(gxy=self.grad_xy, ger=self.grad_er)
                    self.overwrite_grad(self.model.parameters, g_tilde, self.grad_dims)
                else:
                    self.overwrite_grad(self.model.parameters, self.grad_xy, self.grad_dims)
                # endif
            # endif

            # dl_train.set_description('Train Iter (loss=%5.3f)' % loss.item())
            torch.nn.utils.clip_grad_norm(self.model.parameters(), self.clipgrad)

            # Backward
            self.optimizer.step()
        # endfor

        return
    # enddef

    def eval(self, t: int, dl: DataLoader):
        total_loss = 0
        total_acc = 0
        total_num = 0

        self.model.eval()
        target_list = []
        pred_list = []

        with torch.no_grad():
            for _, batch in enumerate(dl):
                batch = [bat.to(self.device) if bat is not None else None for bat in batch]
                # input_ids, segment_ids, input_mask, targets, _ = batch
                x, targets = batch
                # real_b = input_ids.size(0)
                real_b = x.size(0)

                # Forward
                # output_dict = self.model.forward(input_ids, segment_ids, input_mask)
                # outputs = output_dict['y']
                # output = outputs[t]
                output, _ = self.model(t, x)  # type: Tensor

                loss = self.ce(output, targets)  # type: Tensor
                _, pred = output.max(1)
                hits = (pred == targets).float()

                target_list.append(targets)
                pred_list.append(pred)

                # Log
                total_loss += loss.data.cpu().numpy().item() * real_b
                total_acc += hits.sum().data.cpu().numpy().item()
                total_num += real_b
            # endfor
        # endwith

        return total_loss / total_num, total_acc / total_num
    # enddef
