import math

from approaches.tag.tag_utils import *


class TAG:
    """
    Implementation of our proposed TAG optimizer
    """
    def __init__(self, model, args, num_tasks, optim='rms', lr=None, b=5):
        """
        Gets all the necessary arguments for initialization
        :param model: Current model
        :param args: All arguments for experiment configuration
        :param num_tasks: Total number of tasks
        :param optim: Base optimizers to be used: {'rms':TAG-RMSProp,  'adagrad':TAG-Adagrad, 'adam': TAG-Adam}
        :param lr: Learning rate (eta)
        :param b: Hyperparameter for regulating alpha - high b value implies more focus on preventing forgetting
        """
        self.optim = optim
        self.args = args
        self.iters = 0
        self.model = model
        self.b = b
        self.weight_decay = 0.0
        if self.optim == 'adam':
            self.beta1, self.beta2 = 0.9, 0.999
        else:
            self.beta1, self.beta2 = 0.9, 0.99
        self.lr = lr
        self.alpha_add_ = {}
        self.v, self.v_t = {}, {}
        self.m, self.m_t = {}, {}
        self.m_t_norms = {}
        for task in range(num_tasks):
            self.v_t[task] = {}
            self.m_t[task] = {}
            self.m_t_norms[task] = {}
            self.alpha_add_[task] = {}
            for (name, param) in model.named_parameters():
                if task == 0:
                    self.v[name] = torch.zeros_like(param).to(args.device)
                    self.m[name] = torch.zeros_like(param).to(args.device)
                self.alpha_add_[task][name] = np.array([1])
                self.v_t[task][name] = torch.zeros_like(param).to(args.device)
                self.m_t[task][name] = torch.zeros_like(param).to(args.device)
                self.m_t_norms[task][name] = torch.zeros_like(param).to(args.device)

    def zero_grad(self):
        return self.model.zero_grad()

    def update_all(self, task_id):
        """
        Normalize the current task-based first moments (that will remain fixed)
        """
        for name, v in self.model.named_parameters():
            self.m_t_norms[task_id][name] = self.m_t[task_id][name].reshape(-1) / torch.norm(self.m_t[task_id][name])

    def update_naive(self, param_name, param_grad):
        """
        Use the naive-optimizer update
        :param param_name: Parameter identity
        :param param_grad: Gradient associated with the given parameter
        :return: New update to the given parameter
        """
        if self.optim == 'rms':
            self.v[param_name] = self.beta2 * self.v[param_name] + (1 - self.beta2) * param_grad ** 2
        else:
            self.v[param_name] += param_grad ** 2
        denom = torch.sqrt(self.v[param_name]) + 1e-8
        return - (self.lr * param_grad / denom)

    def update_tag(self, param_name, param_grad, task_id):
        """
        Update Task-based accumulated gradients, calculate alpha and return the new updates
        :param param_name: Parameter identity
        :param param_grad: Gradient associated with the given parameter
        :param task_id: Current task identity
        :return: New update to the given parameter
        """
        bias_corr1, bias_corr2 = 1, 1
        new_v = None
        DEVICE = self.args.device

        # Update task-based first moment
        # self.m_t[task_id][param_name] = self.beta1 * self.m_t[task_id][param_name] + (1 - self.beta1) * param_grad
        self.m_t[task_id][param_name] = (self.beta1 * self.m_t[task_id][param_name].to(DEVICE) + (1 - self.beta1) * param_grad).cpu()

        # Change numerator based on the optimizer
        if self.optim == 'adam':
            bias_corr1, bias_corr2 = 1 - self.beta1 ** (self.iters + 1), 1 - self.beta2 ** (self.iters + 1)
            numer = self.m_t[task_id][param_name] / bias_corr1
        else:
            numer = param_grad

        # Update task-based second moments based on the optimizer
        if self.optim == 'rms' or self.optim == 'adam':
            # self.v_t[task_id][param_name] = self.beta2 * self.v_t[task_id][param_name] + (1 - self.beta2) * param_grad ** 2
            self.v_t[task_id][param_name] = (self.beta2 * self.v_t[task_id][param_name].to(DEVICE) + (1 - self.beta2) * param_grad ** 2).cpu()
        else:
            # self.v_t[task_id][param_name] = self.v_t[task_id][param_name] + param_grad ** 2
            self.v_t[task_id][param_name] = (self.v_t[task_id][param_name].to(DEVICE) + param_grad ** 2).cpu()

        # Get new alphas by computing correlation using task-based first moments
        if task_id > 0:
            alpha_add = []
            for t in range(task_id):
                corr = torch.dot(self.m_t[task_id][param_name].reshape(-1) / torch.norm(self.m_t[task_id][param_name]),
                                 self.m_t_norms[t][param_name])
                alpha_add += [(-corr).cpu().numpy()]
            alpha_add += [-1.]
            alpha_add = torch.from_numpy(np.array(alpha_add)).to(DEVICE)
            alpha_add_ = torch.exp(self.b * alpha_add).float()
        else:
            alpha_add_ = torch.from_numpy(np.array([1.0] * (task_id + 1))).to(DEVICE)
        self.alpha_add_[task_id][param_name] = alpha_add_.cpu().numpy()

        # Concatenate all task-based second moments
        for t in range(task_id):
            new_v = self.v_t[t][param_name].unsqueeze(0).to(DEVICE) \
                if t == 0 \
                else torch.cat((new_v, self.v_t[t][param_name].unsqueeze(0).to(DEVICE)), dim=0)
        new_v = self.v_t[task_id][param_name].unsqueeze(0).to(DEVICE) \
            if new_v is None \
            else torch.cat((new_v, self.v_t[task_id][param_name].unsqueeze(0).to(DEVICE)), dim=0)

        # Compute inner product of alphas and task-based second moments using torch.einsum() function.
        # eq takes care of varying the dimensions of parameter variable with each layer.
        eq = {1: 'n,nh->h', 2: 'n,nhw->hw', 3: 'n,nhwc->hwc', 4: 'n,nhwvd->hwvd', 5: 'n,nhwzxc->hwzxc'}[len(param_grad.shape)]
        denom = (torch.sqrt(torch.einsum(eq, alpha_add_.float(), new_v)) / math.sqrt(bias_corr2)) + 1e-8

        return - (self.lr * numer / denom)

    def step(self, model, task_id, step):
        """
        Perform update over the parameters
        :param model: Current model
        :param task_id: Current task id (t)
        :param step: Current Step (n)
        :return:
        """
        self.iters = step
        state_dict = model.state_dict()
        for i, (name, param) in enumerate(state_dict.items()):
            if name.split('.')[-1] in ['running_mean', 'num_batches_tracked', 'running_var']:
                continue
            for n, v in model.named_parameters():
                if n == name:
                    break
            if v.grad is None:
                continue
            update = self.update_tag(name, v.grad, task_id)
            state_dict[name].data.copy_(param + update.reshape(param.shape))
        return state_dict


def store_alpha(tag_optimizer, task_id, iter, alpha_mean=None):
    """
    Collects alpha values for given task (t) and current step (n)
    :param tag_optimizer: Object of the class tag_opt()
    :param task_id: Current task identity
    :param iter: Current step in the epoch
    :return: alpha_mean: Dictionary with previous task ids as keys
    """
    for tau in tag_optimizer.alpha_add_[task_id]:
        alphas = tag_optimizer.alpha_add_[task_id][tau]
        if iter == 0:
            alpha_mean[tau] = alphas
        else:
            alpha_mean[tau] = (alpha_mean[tau] * iter + alphas) / (iter + 1)
    return alpha_mean
