import torch
import copy
import dl_utils.utils as utils

class RotationMatrix(torch.nn.Module):
    def __init__(self,
            size,
            identity_init=False,
            bias=False,
            double_rot=False,
            mu=0,
            sigma=1,
            identity_rot=False,
            **kwargs):
        """
        size: int
            the height and width of the rotation matrix
        identity_init: bool
            if true, will initialize the rotation matrix to the identity
            matrix.
        bias: bool
            if true, will include a shifting term in the rotation matrix
        double_rot: bool
            if true, applies 2 rotation matrices sequentially as
            opposed to 1.
        mu: float or FloatTensor (size,)
            Used to center each feature dim of the activations.
        sigma: float or FloatTensor (size,)
            Used to scale each feature dim of the activations.
        identity_rot: bool
            if true, will always reset the rotation matrix to the
            identity. Used for debugging.
        """
        super().__init__()
        self.identity_rot = identity_rot
        self.identity_init = identity_init
        self.double_rot = double_rot

        if type(mu)==float or type(mu)==int:
            self.mu = mu
        else:
            self.register_buffer("mu", mu)
        if type(sigma)==float or type(sigma)==int:
            self.sigma = sigma
        else:
            self.register_buffer("sigma", sigma)

        lin = torch.nn.Linear(size, size, bias=False)
        if identity_init:
            lin.weight.data = torch.eye(
                self.size,dtype=lin.weight.data.dtype)

        # Shifting parameters
        if bias:
            self.bias = torch.nn.Parameter(
                torch.zeros(size,dtype=lin.weight.data.dtype))
        else:
            self.bias = 0
        if self.identity_rot:
            self.rot_module = lin
        else:
            # Orthogonal parameterization ensures that the weight is always
            # orthogonal
            self.rot_module = torch.nn.utils.parametrizations.orthogonal(lin)
            if self.double_rot:
              lin = torch.nn.Linear(size, size, bias=False)
              self.rot_module2=torch.nn.utils.parametrizations.orthogonal(lin)
              self.rot_forward = self.double_rot_forward
              self.rot_inv = self.double_rot_inv

    @property
    def weight(self):
        return self.rot_module.weight

    @property
    def size(self):
        return self.rot_module.weight.shape[0]

    @property
    def shape(self):
        return self.rot_module.weight.shape

    def rot_forward(self, h):
        h = (h-self.mu)/self.sigma
        return torch.matmul(h+self.bias, self.rot_module.weight)

    def rot_inv(self, h):
        h = torch.matmul(h, self.rot_module.weight.T)-self.bias
        h = h*self.sigma + self.mu
        return h

    def forward(self, h, inverse=False):
        if self.identity_rot:
            self.rot_module.weight.data = torch.eye(
              self.size,
              dtype=self.rot_module.weight.data.dtype,
              device=utils.device_fxn(self.rot_module.weight.get_device()),
            )
        if inverse: return self.rot_inv(h)
        return self.rot_forward(h)

    def double_rot_forward(self, h):
        h = (h-self.mu)/self.sigma
        return torch.matmul(
            torch.matmul(h+self.bias, self.rot_module.weight),
            self.rot_module2.weight)

    def double_rot_inv(self, h):
        h = torch.matmul(
            torch.matmul(h, self.rot_module2.weight.T),
            self.rot_module.weight.T)-self.bias
        h = h*self.sigma + self.mu
        return h

class RelaxedRotationMatrix(RotationMatrix):
    """
    This module is similar to the RotationMatrix, it will however relax
    the orthonormal constraint on the rotation matrix, constraining it to
    only be an invertible matrix. This is done by using a diagonal
    matrix with non-zero values before the rotation matrix.
    """
    def __init__(self, eps=1e-10, rot_first=False, *args, **kwargs):
        """
        eps: float
            a small value to ensure no division by 0
        rot_first: bool
            debugging tool. if true, will apply the orthogonal rotation
            matrix before the scaling matrix. if true, should make this
            module equivalent to the base class.
        """
        super().__init__(*args, **kwargs)
        self.diag = torch.nn.Parameter(torch.ones(self.size).float())
        self.eps = eps
        self.rot_first = rot_first

    def diag_forward(self, h):
        diag = self.diag+self.eps*torch.sign(self.diag)
        return torch.matmul(h,torch.diag(diag))
        #return torch.matmul(h,torch.diag(self.diag+self.eps))

    def diag_inv(self, h):
        diag = self.diag+self.eps*torch.sign(self.diag)
        return torch.matmul(h,torch.diag(1/diag))
        #return torch.matmul(h,torch.diag(1/(self.diag+self.eps)))

    def rot_first_forward(self, h, inverse=False):
        if inverse: return self.rot_inv(self.diag_inv(h))-self.bias
        return self.diag_forward(self.rot_forward(h+self.bias))

    def scale_first_forward(self, h, inverse=False):
        if inverse: return self.diag_inv(self.rot_inv(h))-self.bias
        return self.rot_forward(self.diag_forward(h+self.bias))

    def forward(self, h, inverse=False):
        if self.rot_first:
            return self.rot_first_forward(h, inverse=inverse)
        else:
            return self.scale_first_forward(h, inverse=inverse)

    def unit_forward(self, h, inverse=False):
        if inverse: return self.rot_inv(h)-self.bias
        return self.rot_forward(h+self.bias)

class Rotation(torch.nn.Module):
    """
    This class abstracts away the user's need to specify the classes
    RotationMatrix vs RelaxedRotationMatrix, making it a single
    argument to this class.
    """
    def __init__(self,
            size,
            identity_init=False,
            bias=False,
            relaxed=False,
            double_rot=False,
            eps=1e-8,
            **kwargs):
        """
        size: int
            the height and width of the rotation matrix
        identity_init: bool
            if true, will initialize the rotation matrix to the identity
            matrix.
        bias: bool
            if true, will include a shifting term in the rotation matrix
        double_rot: bool
            if true, applies 2 rotation matrices sequentially as
            opposed to 1.
        """
        super().__init__()
        self.size = size
        self.identity_init = identity_init
        self.bias = bias
        self.relaxed = relaxed
        self.eps = eps
        self.double_rot = double_rot

        rot_kwargs = {
                "size": self.size,
                "identity_init": self.identity_init,
                "bias": self.bias,
                "eps": self.eps,
                "double_rot": self.double_rot,
        }
        if self.relaxed:
            self.rot_mtx = RelaxedRotationMatrix( **rot_kwargs )
        else:
            self.rot_mtx = RotationMatrix( **rot_kwargs )

    def forward(self, *args, **kwargs):
        return self.rot_mtx(*args, **kwargs)

class RecurrentRotation(Rotation):
    """
    This module performs a series of rotations swapping part of the
    result out with new input. This module is different than the normal
    rotation as it expects a sequence of inputs.
    """
    def __init__(self, *args, **kwargs):
        """
        size: int
            the height and width of the rotation matrix
        identity_init: bool
            if true, will initialize the rotation matrix to the identity
            matrix.
        bias: bool
            if true, will include a shifting term in the rotation matrix
        """
        if len(args)>0:
            kwargs["size"] = 2*kwargs.get("size",args[0])
        else:
            kwargs["size"] = 2*kwargs["size"]
        super().__init__(*args, **kwargs)

    def forward(self, x, nloops, inverse=False, rs=None, **kwargs):
        """
        Assumes that x is a sequence and performs n recurrent loops on
        the vectors in x in the forward pass. It then inverts these
        recurrent rotations in the backwards pass.

        Args:
            x: torch tensor (B,L,D)
            nloops: torch tensor (B,)
                number of recurrent loops. If None, defaults to a vector
                of entries of L.
            inverse: bool
                if true, inverts the recurrent loops.
            kwargs:
                if inverting, must argue the inverse function's key
                word arguments. see `invert_forward` for details.
        Returns:
            if not inverse:
                rs: list of tensors [(B,D), ...]
                    the resulting rotated vectors existing in the input
                    portion of the recurrent state.
                hs: list of tensors [(B,D), ...]
                    the resulting rotated vectors existing in the state
                    portion of the recurrent state.
                final_h: tensor (B,D)
                    the state portion of the recurrent state at the
                    argued step index for that entry in the batch. Use
                    this vector to perform causal interventions.
        """
        if inverse:
            return self.invert_forward(h=x,rs=rs,nloops=nloops,**kwargs)

        B,L,D = x.shape
        if nloops is None: nloops = torch.zeros_like(x) + x.shape[-2]
        rs = []
        hs = []
        h = torch.zeros_like(x[:,0])
        final_h = torch.zeros_like(h)
        for i in range(L):
            v = torch.cat([x[:,i],h],dim=-1)
            v = self.rot_mtx.rot_forward(v)
            r,h = v[:,:D],v[:,D:]
            rs.append(r)
            hs.append(h)
            idx = nloops==i
            final_h[idx] = h[idx]
            # Break early if we've computed everything
            #if torch.sum((nloops<i).long())==B: break
        return rs, hs, final_h

    def invert_forward(self, h, rs, nloops,  **kwargs):
        """
        This function inverts the forward pass. It will only perform
        nloops inversion loops on each entry in the batch.

        Args:
            h: torch tensor (B,D)
                the final_h from the forward pass. Likely intervened
                upon at this point.
            rs: list of torch tensors [(B,D), ...]
                the input portion of the recurrent state from the forward
                pass.
            nloops: torch tensor (B,)
        Returns:
            xs: torch tensor (B,len(rs),D)
        """
        B,D = h.shape
        og_h = h
        h = torch.zeros_like(h)
        xs = []
        for i in reversed(range(len(rs))):
            idx = nloops==i
            h[idx] = og_h[idx]
            v = torch.cat([rs[i],h], dim=-1)
            v = self.rot_mtx.rot_inv(v)
            x, h = v[:,:D],v[:,D:]
            xs.append(x)
        xs = torch.stack(list(reversed(xs)),dim=1)
        return xs, h

class FixedMask(torch.nn.Module):
    def __init__(self, size, n_units=1):
        """
        size: int
            the number of the hidden state vector
        n_units: int
            the number of units to swap
        """
        super().__init__()
        self.size = size
        self.n_units = n_units
        self.temperature = None
        mask = torch.zeros(self.size).float()
        mask[:self.n_units] = 1
        self.register_buffer("mask", mask)
        
    def get_boundary_mask(self):
        return self.mask
        
    def forward(self, base, source):
        """
        base: torch tensor (B,H)
            the main vector that will receive new neurons for
            causal interchange
        source: torch tensor (B,H)
            the vector that will give neurons to create a
            causal interchange in the other sequence
            
        Returns:
            base: torch tensor (B,H)
                the vector that received new neurons for
                a causal interchange
        """
        base = self.mask*source + (1-self.mask)*base
        return base

class BoundlessMask(torch.nn.Module):
    def __init__(self,
                 size,
                 temperature=0.01,
                 full_boundary=False,
                 split_start=True,
                ):
        """
        size: int
            the size of the hidden state vector
        temperature: float
        full_boundary: bool
            if true, will create an individual parameter for each
            neuron in the swap mask. Ideally, you will want to anneal the
            temperature progressively over the course of training
            and add an L1 loss term on the mask to the overall loss
            term. 
        split_start: bool
            if true, will start the boundaries so that half of the swap
            mask is all zeros and half is ones. Otherwise starts all ones
        """
        super().__init__()
        self.size = size
        self.temperature = temperature
        self.split_start = split_start
        self.register_buffer("indices", torch.arange(self.size).float())
        if full_boundary:
            self.boundaries = torch.nn.Parameter(torch.ones(size))
            if self.split_start: self.boundaries.data[:size//2] *= -1
            self.get_boundary_mask = self.full_boundary_mask
        else:
            self.boundaries = torch.nn.Parameter(torch.FloatTensor([-1,size+1]))
            if self.split_start:
                self.boundaries.data[0] = (size+1)//2
            self.get_boundary_mask = self.edges_boundary_mask

    def full_boundary_mask(self):
        return torch.sigmoid(self.boundaries/self.temperature)
                                 
    def edges_boundary_mask(self):
        boundary_x, boundary_y = self.boundaries
        return (torch.sigmoid((self.indices - boundary_x) / self.temperature) * \
            torch.sigmoid((boundary_y - self.indices) / self.temperature))**2
        
    def forward(self, base, source):
        """
        base: torch tensor (B,H)
            the main vector that will receive new neurons for
            causal interchange
        source: torch tensor (B,H)
            the vector that will give neurons to create a
            causal interchange in the other sequence
            
        Returns:
            base: torch tensor (B,H)
                the vector that received new neurons for
                a causal interchange
        """
        mask = self.get_boundary_mask()
        base = mask*source + (1-mask)*base
        return base

class CausalInterchange(torch.nn.Module):
    def __init__(self,
            size,
            temperature=0.01,
            fixed=None,
            full_boundary=False,
            double_rot=False,
            mu=0,
            sigma=1,
            relaxed=False,
            rot_first=False,
            sep_rot=False,
            rot_bias=False,
            identity_init=False,
            identity_rot=False,
            *args, **kwargs):
        """
        Args:
            size: int
                the size of the distributed vectors
            temperature: float
                used in creating the boundary mask. lower means more
                defined boundaries
            fixed: int or None
                if integer is argued, uses fixed swap mask. Otherwise uses
                learned mask.
            full_boundary: bool
                if true, will use an individual parameter for each
                index in the boundary mask. otherwise uses two
                parameters denoting the edges.
            double_rot: bool
                if true, applies 2 rotation matrices sequentially as
                opposed to 1.
            mu: float or FloatTensor (size,)
                Used to center each feature dim of the activations.
            sigma: float or FloatTensor (size,)
                Used to scale each feature dim of the activations.
            rot_first: bool
                only applies if relaxed is true. if true, will apply the
                orthogonal rotation matrix before the scaling matrix.
                This should not have any effect on the end result.
            relaxed: bool
                if true, will relax the orthonormal constraint on the
                rotation matrix to only be an invertible matrix.
            sep_rot: bool
                if true, will use a separate rotation matrix for the
                source activations
            identity_init: bool
                if true, will initialize the rotation matrix to the identity
                matrix.
            identity_rot: bool
                if true, will always reset the rotation matrix to the
                identity. Used for debugging.
            rot_bias: bool
                if true, will include a shifting term in the rotation
                matrix
        """
        super().__init__()
        self.size = size
        self.double_rot = double_rot
        self.relaxed = relaxed
        self.rot_first = rot_first
        self.sep_rot = sep_rot
        self.identity_init = identity_init
        self.rot_bias = rot_bias
        self.identity_rot = identity_rot
        self.mu = mu
        self.sigma = sigma

        self.rot_mtx, self.source_mtx = self.get_rot_mtxs()

        if fixed:
            self.swap_module = FixedMask(size, n_units=fixed)
        else:
            self.swap_module = BoundlessMask(
                size, temperature, full_boundary=full_boundary,
            )

    def get_rot_mtxs(self):
        kwargs = {
            "size": self.size,
            "identity_init": self.identity_init,
            "bias": self.rot_bias,
            "rot_first": self.rot_first,
            "double_rot": self.double_rot,
            "mu": self.mu,
            "sigma": self.sigma,
            "identity_rot": self.identity_rot,
        }
        if isinstance(self, RecurRotCausalInterchange):
            rot_mtx = RecurrentRotation(**kwargs)
        elif self.relaxed:
            rot_mtx = RelaxedRotationMatrix(**kwargs)
        else: rot_mtx = RotationMatrix(**kwargs)

        if self.sep_rot:
            source_mtx = copy.deepcopy(rot_mtx)
        else:
            source_mtx = rot_mtx
        return rot_mtx, source_mtx

    def forward(self, base, source):
        """
        base: torch tensor (B,H)
            the vector that will receive new neurons
        source: torch tensor (B,H)
            the vector that will give neurons
            
        Returns:
            base: torch tensor (B,H)
                the vector that will receive new neurons
        """
        rot_base_h = self.rot_mtx(base)
        rot_intr_h = self.source_mtx(source)

        mask = self.swap_module.get_boundary_mask()
        masked_base = (1-mask)*rot_base_h
        masked_intr = mask*rot_intr_h
        rot_swapped = masked_base+masked_intr

        new_h = self.rot_mtx(rot_swapped, inverse=True)
        return new_h

    def unit_forward(self, base, source):
        """
        Use this function to see if the relaxed rotation performance is
        the same with unit axes.

        Args:
            base: torch tensor (B,H)
                the vector that will receive new neurons
            source: torch tensor (B,H)
                the vector that will give neurons
        Returns:
            base: torch tensor (B,H)
                the vector that will receive new neurons
        """
        rot_base_h = self.rot_mtx.unit_forward(base)
        rot_intr_h = self.source_mtx.unit_forward(source)

        mask = self.swap_module.get_boundary_mask()
        masked_base = (1-mask)*rot_base_h
        masked_intr = mask*rot_intr_h
        rot_swapped = masked_base+masked_intr

        new_h = self.rot_mtx.unit_forward(rot_swapped, inverse=True)
        return new_h

class RecurRotCausalInterchange(CausalInterchange):
    def __init__(self, *args, **kwargs):
        """
        Similar to the CausalInterchange class, but accepts full sequence
        arguments.

        Args:
            See CausalInterchange for details
        """
        super().__init__(*args, **kwargs)
        kwargs["size"] = 2*kwargs["size"]
        self.rot_mtx, self.source_mtx = self.get_rot_mtxs()

    def forward(self, base, source, base_idxs, source_idxs):
        """
        base: torch tensor (B,S,H)
            the sequence that will be intervened upon
        source: torch tensor (B,S,H)
            the sequence that will be harvested to intervene upon the
            base
        base_idxs: torch long tensor (B,)
            the index in the base sequence to intervene upon
        source_idxs: torch long tensor (B,)
            the index in the source sequence to intervene upon
            
        Returns:
            base: torch tensor (B,H)
                the intervened sequence
        """
        base_rs, _, rot_base_h = self.rot_mtx(base, nloops=base_idxs)
        _,_,rot_source_h = self.source_mtx(source,nloops=source_idxs)

        mask = self.swap_module.get_boundary_mask()
        masked_base = (1-mask)*rot_base_h
        masked_intr = mask*rot_source_h
        rot_swapped = masked_base+masked_intr

        intrv, h = self.rot_mtx(
            rot_swapped, rs=base_rs, nloops=base_idxs, inverse=True)
        mask = utils.get_mask_past_idx(base.shape[:-1], base_idxs).long()
        mask = mask.unsqueeze(-1)
        intrv = (1-mask)*intrv + mask*base
        return intrv, h

def load_alignment(path, ret_config=False):
    """
    This is a helper function to load saved DAS checkpoint.
    
    Args:
        path: str
        ret_config: bool
            if true, will return the configuration dict alongside the
            das module.
    Returns:
        das_modu: CausalInterchange module
            can access the rotation module using `das_modu.rot_mtx`
    """
    checkpt = torch.load(path)
    das_modu = CausalInterchange(**checkpt["config"])
    try:
        das_modu.load_state_dict(checkpt["state_dict"])
    except:
        print("Failed to load alignment state dict, attempting fix")
        for name,p in das_modu.named_parameters():
            print(name)
        sd = das_modu.state_dict()
        for k in sd.keys():
            if k.split(".")[0]=="rot_mtx":
                kk = k.replace("rot_mtx", "source_mtx")
                checkpt["state_dict"][kk] = sd[k]
        das_modu.load_state_dict(checkpt["state_dict"])
        print("Fix succeeded!")

    if ret_config:
        conf = {"config": checkpt["config"]}
        if "meta_config" in checkpt:
            conf["meta_config"] = checkpt["meta_config"]
        return das_modu, conf
    return das_modu

if __name__=="__main__":
    size = 5
    seq_len = 10
    identity_init = False
    bias = False
    relaxed = False
    rand_bias = False
    rand_diag = True
    rot_first = True

    bsize = 3
    x = torch.rand(bsize,size)
    nloops = torch.randint(1,seq_len, (bsize,)).long()

    #def forward(self, x, nloops, inverse=False, rs=None, **kwargs):
    #def invert_forward(self, h, rs, nloops,  **kwargs):

    rot = Rotation(
        size=size,
        identity_init=identity_init,
        bias=bias,
        relaxed=relaxed,
        double_rot=True,
        rot_first=rot_first)

    if rand_bias and bias:
        rot.bias.data = torch.rand_like(rot.bias.data)

    with torch.no_grad():
        rot_x = rot(x)
        fx = rot(rot_x, inverse=True)
    print("og:", x[0])
    print("trans:", fx[0])
    print("MSE:", torch.mean(torch.abs(x-fx)).item())

