# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from abc import ABC, abstractmethod
from email.headerregistry import HeaderRegistry
from random import weibullvariate
from re import S
from stat import FILE_ATTRIBUTE_NOT_CONTENT_INDEXED
from turtle import forward
from typing import Tuple


from utils.euclidean import givens_rotations, givens_rotations_reverse
from utils.quaternion import quaternion_rotation 

import torch
from torch import nn


class Regularizer(nn.Module, ABC):
    @abstractmethod
    def forward(self, factors: Tuple[torch.Tensor]):
        pass

class F2(Regularizer):
    def __init__(self, weight: float):
        super(F2, self).__init__()
        self.weight = weight

    def forward(self, factors):
        norm = 0
        for f in factors:
            norm += self.weight * torch.sum(f ** 2)
        return norm / factors[0].shape[0]

class N3(Regularizer):
    def __init__(self, weight: float):
        super(N3, self).__init__()
        self.weight = weight

    def forward(self, factors):
        """Regularized complex embeddings https://arxiv.org/pdf/1806.07297.pdf"""
        norm = 0
        for f in factors:
            norm += self.weight * torch.sum(
                torch.abs(f) ** 3
            )
        return norm / factors[0].shape[0]

class DURA_RESCAL(Regularizer):
    def __init__(self, weight: float):
        super().__init__()
        self.weight = weight
    
    def forward(self, factors: Tuple[torch.Tensor]):
        norm = 0
        h, r, t = factors
        norm += torch.sum(t**2 + h**2)
        norm += torch.sum(torch.bmm(r.transpose(1, 2), h.unsqueeze(-1)) ** 2 + torch.bmm(r, t.unsqueeze(-1)) ** 2)
        return self.weight * norm / h.shape[0]

class DURA_QuatE(Regularizer):
    def __init__(self, weight):
        super().__init__()
        self.weight = weight
    
    def forward(self, factors: Tuple[torch.Tensor]):
        norm = 0
        h, r, t = factors
        hr = quaternion_rotation(r, h, right=True)
        rt = quaternion_rotation(r, t, right=True, transpose=True)

        norm += torch.sum(hr ** 2 + rt ** 2 + h ** 2 + t ** 2)

        return self.weight * norm / h.shape[0]

class DURA_W(Regularizer):
    def __init__(self, weight: float):
        super().__init__()
        self.weight = weight
    
    def forward(self, factors: Tuple[torch.Tensor]):
        norm = 0
        h, r, t = factors
        norm += 0.5 *torch.sum(t**2 + h**2)
        norm += 1.5 * torch.sum(h**2 * r**2 +  t**2 *  r**2)
        return self.weight * norm / h.shape[0]

class DURA_UniBi_2(Regularizer):
    def __init__(self, weight: float):
        super().__init__()
        self.weight = weight
    
    def forward(self, factors):
        norm = 0
        h, Rot_u, Rel_s, Rot_v, t = factors # they have been processed
        uh = givens_rotations(Rot_u, h)
        suh = Rel_s * uh

        vt = givens_rotations(Rot_v, t, transpose=True)
        svt = Rel_s * vt

        norm += torch.sum(
            suh ** 2 + svt ** 2 + h ** 2 + t ** 2
        )

        return self.weight * norm / h.shape[0]
    
class DURA_UniBi_3(Regularizer):
    def __init__(self, weight: float):
        super().__init__()
        self.weight = weight

    def forward(self, factors):
        norm = 0
        h, Rot_u, Rel_s, Rot_v, t = factors # they have been processed
        uh = quaternion_rotation(Rot_u, h, right=True)
        suh = Rel_s * uh

        vt = quaternion_rotation(Rot_v, t, right=True, transpose=True)
        svt = Rel_s * vt

        norm += torch.sum(
            suh ** 2 + svt ** 2 + h ** 2 + t ** 2
        )

        return self.weight * norm / h.shape[0]
