import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


import torch
import torchvision
import torchvision.transforms as transforms


import torch
import math
import warnings

from torch.nn.init import _calculate_fan_in_and_fan_out



def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    if mode == 'fan_in':
        denom = fan_in
    elif mode == 'fan_out':
        denom = fan_out
    elif mode == 'fan_avg':
        denom = (fan_in + fan_out) / 2

    variance = scale / denom

    if distribution == "truncated_normal":
        # constant is stddev of standard normal truncated to (-2, 2)
        trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
    elif distribution == "normal":
        tensor.normal_(std=math.sqrt(variance))
    elif distribution == "uniform":
        bound = math.sqrt(3 * variance)
        tensor.uniform_(-bound, bound)
    else:
        raise ValueError(f"invalid distribution {distribution}")


def lecun_normal_(tensor):
    variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
    
    
 

def rot_grid_to_idx(dim):
        n = dim**2
        idxs = torch.zeros((3*n**2))# -2*n-2*(n-2)))
        angles = torch.zeros((3*n**2))#-2*n-2*(n-2)))
        N = torch.arange(0,n).reshape(dim,dim)
        N2 = torch.arange(0,n**2).reshape(n,n)
        #print(N)
        #print(N2)
        ci = 0
        for i0 in range(0,dim):
            for j0 in range(0,dim):
                I = i0*dim+j0
                for i in range(0,dim):
                    for j in range(0,dim):
                        J = i*dim+j
                        #print("I",I,"J",J)
                        right90  = torch.tensor([[0,1],[-1,0]]).float() 
                        vec = right90@torch.tensor([i-i0,j-j0]).float()
                        #print(i0,i,j0,j)
                        if (i,j)==(i0,j0):
                            idxs[ci] = N2[N[i0,j0],N[i,j]]
                            idxs[ci+1] = N2[N[i0,j0],N[i,j]]
                            idxs[ci+2] = N2[N[i0,j0],N[i,j]]
                            angles[ci] =  1/3
                            angles[ci+1] = 1/3
                            angles[ci+2] = 1/3
                            ci+=3
                        else:

                            vecn =torch.round(vec/torch.norm(vec))

                            if torch.equal(vecn,torch.tensor([1,1]).float()):
                                    vecn =  torch.tensor([1,0])
                            elif torch.equal(vecn,torch.tensor([1,-1]).float()):
                                    vecn =  torch.tensor([0,-1])
                            elif torch.equal(vecn,torch.tensor([-1,1]).float()):
                                    vecn =  torch.tensor([0,1])
                            elif torch.equal(vecn,torch.tensor([-1,-1]).float()):
                                    vecn =  torch.tensor([-1,0])
                            #print("vecn",vecn)
                            vect = vecn + torch.tensor([i,j])
                            idx = vect.long()
                            #print(i0,j0,"---->",i,j)
                            try:
                                if -1 in torch.sign(idx):
                                    raise Exception('no negative indexing here')
                                tri_start =  N[i,j]
             
                                tri_middle =  N[idx[0],idx[1]]
                                #print(tri_middle)
                                tri_end =  N[i0,j0]
                                idxs[ci] = N2[tri_start,tri_middle]
                                idxs[ci+1] = N2[tri_middle,tri_end]
                                idxs[ci+2] = N2[N[i0,j0],N[i,j]]
                                vecf = (torch.tensor([i0,j0]) -  vect).float()
                                vec_i = torch.tensor([i-i0,j-j0]).float()

                                angle1 = torch.acos( (vec/torch.norm(vec))@( vecf/torch.norm( vecf)) )*1/math.pi #maximal 1
                                angle2 = torch.acos( (vec_i/torch.norm(vec_i))@( vecf/torch.norm( vecf)) )*1/math.pi
                               
                                angles[ci] =  angle1/2
                                angles[ci+1] = angle2/2
                                angles[ci+2] = 1 - angle1/2 -angle2/2
                                ci+=3
                            except:
                                idxs[ci] = N2[N[i0,j0],N[i,j]]
                                idxs[ci+1] = N2[N[i0,j0],N[i,j]]
                                idxs[ci+2] = N2[N[i0,j0],N[i,j]]
                                angles[ci] =  1/3
                                angles[ci+1] = 1/3
                                angles[ci+2] = 1/3
                                ci+=3

        #print(dist_Mat)       
        dist = torch.round(100000*angles)
        #print(dist)
        angles = torch.unique(angles.flatten(), dim=-1)
        unique = torch.unique(dist.flatten(), dim=-1)
        
        dimu= unique.shape[0]
        idxs_angles = unique.shape[0]*torch.ones(dist.shape)
        for i in range(dimu):
            mask = (dist == unique[i])
            idxs_angles[mask] = i
            
        return idxs.long(), idxs_angles.long(), angles, dimu


#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------






#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------


class Rotation_Symmetry_Break(nn.Module):
    """ PI Symmetry breaking Linear('ish) layer """
    def __init__(self, num_patches, num_heads):
        super().__init__()
        self.num_patches=num_patches
        

        
        #for weights
        idxs, idxs_angles, angles, dim_a = rot_grid_to_idx(int(math.sqrt(num_patches)))
        self.idxs = idxs
        self.angles =  nn.Parameter(torch.Tensor(num_heads,dim_a),requires_grad = True)  
        self.idxs_angles = idxs_angles      
        nn.init.kaiming_uniform_(self.angles, a=math.sqrt(5))
       # angles = torch.randn(self.angles.shape)
        #with torch.no_grad(): 
            #self.angles.data = 1000*angles#10000*angles.unsqueeze(0).repeat(num_heads,1)

    
        #self.proj = nn.Linear(3*self.num_patches**2,3*self.num_patches**2) 
    def forward(self, x):

        bs, heads, ps, ps= x.size()
        #ps =q.shape[-2]
 
        #indexing to maintain Rotation and Translation invariance, Break Mirror invariance
        y = x.flatten(-2)[:,:,self.idxs]
        #print("y.shape", y.shape)
    #    print("self.idxs",self.idxs[0:100])
     #   print("self.idxs_angles",self.idxs_angles[0:100])
        angles = torch.index_select(self.angles,1,self.idxs_angles.flatten().to(x.device))
        #print("angles", angles[0:100])
        #print("angles shape", angles.shape)
        #print("angles", angles[1:3,0:20])
        #angles =angles[:,torch.randperm(angles.shape[1])]
        #y = self.proj(y[:,:,torch.randperm(y.shape[2])])
        #print("angles shape", angles.shape)
        y = (y*angles.view(1,heads,3*self.num_patches**2)).reshape((bs,  heads, ps, ps,3)).sum(-1)
       # y = torch.cat([x[:,:,0:1,:], y],dim=-2)
        #y = torch.cat([ x[:,:,:,0:1],torch.cat([x[:,:,0:1,1:], y],dim=2)],dim=-1)
        #p1d = (1, 0, 1 ,0) # pad last and second to last dim by 1 
         
        return   y.reshape(x.shape)





#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------

    
def grid_to_idx(dim,features):
        n = dim**2
        dist_Mat = torch.zeros((n,n))

        for i0 in range(0,dim):
            for j0 in range(0,dim):
                I = i0*dim+j0
                for i in range(0,dim):
                    for j in range(0,dim):
                        J = i*dim+j
                        #print("I",I,"J",J)
                        distance = 10000*math.sqrt((i0-i)**2 + (j0-j)**2) *((dim+1) -(i-i0)) / ((dim+1) -(j-j0))
                        dist_Mat[I,J] =  distance
                        
        
        dist = torch.round(dist_Mat)
        unique = torch.unique(dist.flatten(), dim=-1)
        idxs = torch.zeros(dist.shape)
        dimu= unique.shape[0]
        for i in range(dimu):
            mask = (dist == unique[i])
            idxs[mask] = i
            
        idxs = idxs.unsqueeze(0).repeat(features,1,1)
        add = torch.arange(0,dimu*features,dimu).reshape(features,1).repeat(1,n**2).reshape(features,n,n)

        #print(add)
       # print(idxs)
        idxs += add   

        return idxs.long(), dimu
    
#grid_to_idx(2,3)
    

class Weighted_Symmetry_Break(nn.Module):
    """ PI Symmetry breaking Linear('ish) layer """
    def __init__(self, num_patches, num_heads):
        super().__init__()
        self.num_patches=num_patches
        self.num_heads =num_heads
        
        #for weights
        idxs, angles = mirror_grid_to_idx(int(math.sqrt(num_patches)))
        self.triangle_sum = trisum(num_patches, idxs, angles)
        #self.idxs = idxs
        #self.angles =  nn.Parameter(torch.Tensor(angles.shape),requires_grad = False)  
                                
        #with torch.no_grad(): 
          #  self.angles.data = angles

        #self.weights = nn.Parameter(torch.Tensor(2,self.num_heads))  
       # nn.init.uniform_(self.weights) # weight init
        
       # idxsT, dimT = grid_to_idx(int(math.sqrt(num_patches)),num_heads)
       # self.idxsT = idxsT

        
        #self.bias = nn.Parameter(torch.Tensor(num_heads,dimT))
       # nn.init.kaiming_uniform_(self.bias, a=math.sqrt(5)) # weight init         
        
        #self.bias2 = nn.Parameter(torch.Tensor(1,num_heads,1,num_patches+1,num_patches+1))
       # nn.init.kaiming_uniform_(self.bias2, a=math.sqrt(5)) # weight init        
         
    def forward(self, x):

        bs, heads, ps, heads = x.size()

        y = x.transpose(-2, -1).flatten(end_dim=-2)  
        y = self.triangle_sum(y[:,1:]).reshape(bs,heads,heads,ps-1).transpose(-2, -1)
        y = torch.cat([x[:,:,0:1,:], y],dim=-2)
        #p1d = (1, 0, 1 ,0) # pad last and second to last dim by 1 
        #y  = F.pad(y, p1d, "constant", 0) 
        
        #indexing to maintain only Translation invariance
       # bias  =  x[:,:,:,1:,1:]*self.bias.flatten()[self.idxsT].reshape((1,self.num_heads,1,ps-1,ps-1))
       # bias = torch.cat([x[:,:,:,:,0:1] ,torch.cat([x[:,:,:,0:1,1:], bias],dim=-2)],dim=-1)
        #bias  = F.pad(bias, p1d, "constant", 0) 
        


        #gate = self.weights.softmax(0)
        #y =   gate[3].view(1,heads,1,1,1)*x*self.bias2.view(1,heads,1,ps,ps)
        #y =  x*gate[0].view(1,heads,1,1,1) + y*gate[1].view(1,heads,1,1,1) #+gate[2].view(1,heads,1,1,1)*bias 
        #torch.cat([x.unsqueeze(-1),z.unsqueeze(-1)],dim=-1)@s                 
        return   y.reshape(x.shape)





#-------------Vit implementation ----------------
from functools import partial
import torch
import torch.nn as nn
#from algorithms.layers import to_2tuple, trunc_normal_, lecun_normal_

from itertools import repeat
import collections.abc


# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)


def grid_avg_dist_to_idx(dim,features):
        n = dim**2
        dist_Mat = torch.zeros((n,n))

        for i0 in range(0,dim):
            for j0 in range(0,dim):
                I = i0*dim+j0
                for i in range(0,dim):
                    for j in range(0,dim):
                        J = i*dim+j
                        #print("I",I,"J",J)
                        distance = math.sqrt((i0-i)**2 + (j0-j)**2)
                        dist_Mat[I,J] =  distance
                        
        
        avg_dist = torch.round(torch.sum(dist_Mat,dim=-1))
        unique = torch.unique(avg_dist, dim=-1)
        idxs = torch.zeros(avg_dist.shape)
        dim = unique.shape[0]
        for i in range(dim):
            mask = (avg_dist == unique[i])
            idxs[mask] = i
        idxs = idxs.unsqueeze(-1).repeat(1,features)
        add = torch.arange(0,dim*features,dim).repeat(n).reshape(n,features)
        #print(add)
       # print(idxs)
        idxs += add   
        return idxs.long(), dim
    
#print(grid_avg_dist_to_idx(5,3))

    
      
    
import torch.nn.functional as F

class Sym_Break_Linear(nn.Module):
    """ PI Symmetry breaking Linear('ish) layer """
    def __init__(self, num_patches, in_features):
        super().__init__()
        self.num_patches=num_patches
        self.idim =  int(math.sqrt(num_patches)) 
        self.in_features = in_features
        #for weights
        self.kernel_size = self.idim + 1
        self.padding = self.idim //2
        idxsW, dimW = grid_dist_to_idx(dim = self.kernel_size) #int(math.sqrt(num_patches)))
        self.idxsW = idxsW
        self.dimW = dimW
        #for bias

        self.weights = nn.Parameter(torch.Tensor(in_features,dimW))  
      #  self.bias = nn.Parameter(torch.Tensor(dim,in_features))
        
        # initialize weights and biases
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
        self.proj = nn.Linear(num_patches-1,num_patches-1)

    def forward(self, x):
        #if self.mult is None:
        bs, np, h,hd = x.size()
        W =  torch.index_select(self.weights , 1, self.idxsW.flatten().to(x.device)).reshape((self.in_features,1,self.kernel_size,self.kernel_size))
        y = F.conv2d(x[:,1:].transpose(1,3).reshape(bs*hd,h,self.idim, self.idim), W, padding=self.padding, groups=h)
        #y = torch.einsum('hnj,bhjk->bhnk',W,x[:,1:].transpose(1,2).reshape( bs, h , np-1, hd))
        y = self.proj(y.reshape(bs, hd,h,np-1)).transpose(1,3)
  
        return   torch.cat([x[:,0:1].permute(0,2,1,3), y],dim=2) 
    

    import torch




 #-----------------------------------------------------------------------------
#--------------------------------------------------------------------------   

def grid_dist_to_idx_mirror(dim,feat=1):
        n = dim**2
        dist_Mat = torch.zeros((n,n))

        for i0 in range(0,dim):
            for j0 in range(0,dim):
                I = i0*dim+j0
                for i in range(0,dim):
                    for j in range(0,dim):
                        J = i*dim+j
                        #print("I",I,"J",J)
                        #distance = math.sqrt(100+100*(i0-i)**2 + np.sign(np.array((j0-j)))*(j0-j)**2)
                        distance = math.sqrt((i0-i)**2 + (j0-j)**2)
                        #sign = np.sign(np.array((i0-i)))
                        #if sign == 0:
                            #sign = 50
                        dist_Mat[I,J] =  distance
                        
        #print(dist_Mat)       
        dist = torch.round(100000*dist_Mat)
        #print(dist)
        unique = torch.unique(dist.flatten(), dim=-1)
    

        dimu= unique.shape[0]
        idxs = unique.shape[0]*torch.ones(dist.shape)
        for i in range(dimu):
            mask = (dist == unique[i])
            idxs[mask] = i
        
        return idxs.long(), dimu
    
 
            
        
#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------

#-----------------------------------------------------------------------------
#----------------------------------------------------------------------------- 

class Sym_Break_Had(nn.Module):
    """ PI Symmetry breaking Linear('ish) layer """
    def __init__(self, num_patches, feat, interaction_length=False):
        super().__init__()
        self.num_patches=num_patches
        self.feat =feat


        #for weights
        idxsW, dimW = grid_dist_to_idx_mirror(int(math.sqrt(num_patches)), feat =feat)


        self.idxsW = idxsW
        self.dimW = dimW

        self.weights = nn.Parameter(torch.Tensor(feat,dimW))  
       # self.zeros = nn.Parameter(torch.zeros((1,1)),requires_grad = False)  
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init

    def forward(self, x):
        #if self.mult is None:
        W=   torch.index_select(self.weights , 1,  self.idxsW.flatten().to(x.device)) 
        #W = self.weights.reshape((self.head_dim*self.dimW))
       #print(W.shape)
        B,H,N,N2 = x.size()

        y = W.view(1,self.feat,N-1,N-1)*x[:,:,1:,1:]
  
        return  torch.cat([ x[:,:,:,0:1],torch.cat([x[:,:,0:1,1:], y],dim=2)],dim=-1)


class Sym_Break_Bias(nn.Module):
    """ PI Symmetry breaking Linear('ish) layer """
    def __init__(self, num_patches, feat, interaction_length=False):
        super().__init__()
        self.num_patches=num_patches
        self.feat =feat


        #for weights
        idxsW, dimW = grid_dist_to_idx_mirror(int(math.sqrt(num_patches)), feat =feat)


        self.idxsW = idxsW
        self.dimW = dimW

        self.weights = nn.Parameter(torch.Tensor(feat,dimW))  
     #   self.zeros = nn.Parameter(torch.zeros((1,1)),requires_grad = False)  
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init

    def forward(self, x):
        #if self.mult is None:
        W =  torch.index_select(self.weights , 1,  self.idxsW.flatten().to(x.device)) 
        #W = self.weights.reshape((self.head_dim*self.dimW))
       #print(W.shape)
        B,H,N,N2 = x.size()

        y = W.view(1,self.feat,N-1,N-1) + x[:,:,1:,1:]
  
        return  torch.cat([ x[:,:,:,0:1],torch.cat([x[:,:,0:1,1:], y],dim=2)],dim=-1) 
    

#-----------------------------------------------------------------------------
#-----------------------------------------------------------------------------


class Mlp(nn.Module):
	def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
		super().__init__()
		out_features = out_features or in_features
		hidden_features = hidden_features or in_features
		self.fc1 = nn.Linear(in_features, hidden_features)
		self.act = act_layer()
		self.fc2 = nn.Linear(hidden_features, out_features)

	def forward(self, x):
		x = self.fc1(x)
		x = self.act(x)
		x = self.fc2(x)
		return x

def safe_sqrt(x):
    return torch.where(x>=0.0,torch.pow(x,1/2),torch.zeros(x.shape).to(x.device))

    


def grid_dist_to_idx(dim = 5):
        
        dist_Mat = torch.zeros((dim,dim))

        #only works correcttly if dim is odd
        i0,j0 = (dim -1) //2, (dim -1) //2
        for i in range(0,dim):
            for j in range(0,dim):
                        #print("I",I,"J",J)
                        distance = math.sqrt((i0-i)**2 + (j0-j)**2)
                        dist_Mat[i,j] =  distance
                        
        
        dist = torch.round(1000000*dist_Mat)
        unique = torch.unique(dist.flatten(), dim=-1)
        idxs = torch.zeros(dist.shape)
        dimu= unique.shape[0]
        for i in range(dimu):
            mask = (dist == unique[i])
            idxs[mask] = i
        
            
        return idxs.long(), dimu        
    
import torch.nn.functional as F

class Sym_Break_Linear_base(nn.Module):
    """ PI Symmetry breaking Linear('ish) layer """
    def __init__(self, num_patches, in_features):
        super().__init__()
        self.num_patches=num_patches
        self.idim =  int(math.sqrt(num_patches)) 
        self.in_features = in_features
        #for weights
        self.kernel_size = self.idim + 1
        self.padding = self.idim //2
        idxsW, dimW = grid_dist_to_idx(dim = self.kernel_size) #int(math.sqrt(num_patches)))
        self.idxsW = idxsW
        self.dimW = dimW
        #for bias

        self.weights = nn.Parameter(torch.Tensor(in_features,dimW))  
      #  self.bias = nn.Parameter(torch.Tensor(dim,in_features))
        
        # initialize weights and biases
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
        

    def forward(self, x):
        #if self.mult is None:
        bs, np, c = x.size()
        W =  torch.index_select(self.weights , 1, self.idxsW.flatten().to(x.device)).reshape((self.in_features,1,self.kernel_size,self.kernel_size))
        y = F.conv2d(x[:,1:].transpose(1,2).reshape(bs,c,self.idim, self.idim).transpose(2,3), W, padding=self.padding, groups=c).transpose(2,3)
        #y = torch.einsum('hnj,bhjk->bhnk',W,x[:,1:].transpose(1,2).reshape( bs, h , np-1, hd))
        y = y.reshape(bs, c,np-1).transpose(1,2)
  
        return   torch.cat([x[:,0:1], y],dim=1) 
    


class Attention_base(nn.Module):
	def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, use_soft = True):
		super().__init__()
		self.num_heads = num_heads
		head_dim = dim // num_heads
		self.scale = qk_scale or head_dim ** -0.5

		self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
		self.qkv_SymBreak = Sym_Break_Linear_base(num_patches,dim * 2)

		#self.rot_SymBreak = Rotation_Symmetry_Break(num_patches,num_heads) # uncomment for SiT and SiT*
		#self.rot_SymBreak2 = Rotation_Symmetry_Break(num_patches,num_heads) # uncomment for SiT*

		self.proj = nn.Linear(dim, dim)

	def forward(self, x):
		B, N, C = x.shape
		#print("B, N, C", B, N, C)
		qkv = self.qkv(x)
		#qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
		#v = qkv[:,:,2*C:].reshape(B, N, self.num_heads, C // self.num_heads).transpose(1,2) 
		v = x.reshape(B, N, self.num_heads, C // self.num_heads).transpose(1,2)         
		qkv = self.qkv_SymBreak(qkv).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
		#q, k , v= qkv[0], qkv[1], qkv[2]
		q, k = qkv[0], qkv[1]
		attn = (q @ k.transpose(-1,-2)) * self.scale  
		attn = attn + attn.transpose(-2, -1)
		#attn =self.Sym_Break_Had(attn)
		#attn = self.rot_SymBreak(attn) # uncomment for SiT  and SiT*
		#attn = self.rot_SymBreak2(nn.GELU()(attn))- attn # uncomment for SiT*
		attn = attn.softmax(dim=-1)
		x = (attn @ v ).transpose(1, 2).reshape(B, N, C)

		x = self.proj(x)
		#print("x out" , x.shape)
		return x



class Attention_Basic(nn.Module):
	def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, use_soft = True):
		super().__init__()
		self.num_heads = num_heads
		head_dim = dim // num_heads
		self.scale = qk_scale or head_dim ** -0.5


		#self.rot_SymBreak = Rotation_Symmetry_Break(num_patches,num_heads) # uncomment for SiT and SiT*
		#self.rot_SymBreak2 = Rotation_Symmetry_Break(num_patches,num_heads) # uncomment for SiT*

		#self.use_soft = use_soft
		self.proj = nn.Linear(dim, dim)

	def forward(self, q,k,v):
		B, N, C = v.shape

		q = q.reshape(B, N, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)
		k = k.reshape(B, N, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)
		v = v.reshape(B, N, self.num_heads, C // self.num_heads).permute( 0, 2, 1, 3)      
		attn = (q @ k.transpose(-1,-2)) * self.scale  
		attn = attn + attn.transpose(-2, -1)
		#attn = self.rot_SymBreak(attn) # uncomment for SiT  and SiT*
		#attn = self.rot_SymBreak2(nn.GELU()(attn)) - attn  # uncomment for SiT*
		attn = attn.softmax(dim=-1)
		x = (attn @ v ).transpose(1, 2).reshape(B, N, C)

		x = self.proj(x)
		#print("x out" , x.shape)
		return x
    

    
class Block_Basic(nn.Module):
	def __init__(self, dim, num_heads, num_patches, mlp_ratio=1., qkv_bias=False, qk_scale=None,
				 act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_soft =True):
		super().__init__()
		self.norm1 = norm_layer(dim)
		self.norm2 = norm_layer(dim)
		self.act= act_layer()
		self.attn = Attention_Basic(dim, num_patches = num_patches, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, use_soft = use_soft)


	def forward(self, x):
		v = x[0]
		x = x[1]
		x = x +self.attn(self.norm1(x),self.norm2(v))

		return x       
    
    
    
class Block_base(nn.Module):
	def __init__(self, dim, num_heads, num_patches, mlp_ratio=1., qkv_bias=False, qk_scale=None,
				 act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_soft =True):
		super().__init__()
		self.norm1 = norm_layer(dim)
		self.act= act_layer()
		self.attn = Attention_base(dim, num_patches = num_patches, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, use_soft = use_soft)

	def forward(self, x):
		x = x +self.attn(self.norm1(x))
		return x
    
    
    


class VisionTransformer(nn.Module):
	def __init__(self, num_patches, in_chans=3, embed_dim=64, depth=4,
				 num_heads=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, use_soft = True):
		super().__init__()
		self.embed_dim = embed_dim
		norm_layer = partial(nn.LayerNorm, eps=1e-6)


		self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
		#self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
		self.proj = nn.Linear(embed_dim , 16)
		self.proj2 = nn.Linear(16 , 1)
		self.projf = nn.Linear(embed_dim , embed_dim//2)   
		self.projf2 = nn.Linear(embed_dim , embed_dim//2)   
		self.blocks = nn.Sequential(*[
			Block_base(
				dim=embed_dim, num_heads=num_heads,  num_patches= num_patches, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
				norm_layer=norm_layer, act_layer=nn.GELU, use_soft = use_soft)
			for i in range(depth)])
		self.norm = norm_layer(embed_dim)

		# Weight init
		#trunc_normal_(self.pos_embed, std=.02)
		trunc_normal_(self.cls_token, std=.02)
		self.apply(_init_vit_weights)

	@torch.jit.ignore
	def no_weight_decay(self):
		return {'pos_embed', 'cls_token'}
 
	def forward(self, x):
		#x = x.permute(0,3,1,2)      
		#x = self.patch_embed(x)
		cls_token = self.cls_token.expand(x.size(0), -1, -1)
		x = torch.cat((cls_token, x), dim=1)
		#x = x + self.pos_embed
		x = self.blocks(x)
		x = self.norm(x)
		y = self.proj2(nn.GELU()(self.proj(x[:,1:])))
		y = y.reshape(x[:, 0].shape)
		y = torch.cat([ self.projf(nn.GELU()(y )),self.projf2(nn.GELU()(x[:, 0]))],dim=-1)
		return y #

    



        
class Siet(nn.Module):
    def __init__(self, img_size=64, action_dim=15, patch_size=8,  patch_size_local=12, in_chans=9, embed_dim=64, depth=2,
                num_heads=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, use_soft = True):
        super().__init__()
        self.action_dim =action_dim
        self.patch_size = patch_size
        self.patch_size_l= patch_size_local
        self.embed_dim = embed_dim
       # self.embed_dim_local = embed_dim // 2
        self.num_patches =( img_size //patch_size)**2

        self.graph_block = Sym_Break_Linear_Block(img_size=img_size, patch_size = patch_size,in_features =embed_dim,num_patches =self.num_patches, num_heads=num_heads)
      
        self.vit_global = VisionTransformer(num_patches=self.num_patches ,in_chans=embed_dim, embed_dim=embed_dim, depth=depth,
             num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale)
            
       # self.mlp = Mlp(in_chans, hidden_features=embed_dim, out_features=embed_dim )
        
       # self.mlp2 = Mlp(embed_dim, hidden_features=embed_dim, out_features=action_dim )
        self.proj = nn.Linear(in_chans, self.embed_dim)
        #self.proj2 = nn.Linear( (self.num_patches_local+1)*self.num_head_local ,embed_dim)
        self.patch_idxs = self.create_patching_idxs(dim = img_size, ps = patch_size_local, ps2 =patch_size) #.flatten() 
        
    
    def create_patching_idxs(self,dim = 32, ps= 3*4, ps2 = 4):
        pad_dim = ps2 #int((ps-1)/2) 
        image_dim_pad = (2*pad_dim  + dim)
        arr = torch.tensor(list(range(image_dim_pad**2))).reshape((image_dim_pad,image_dim_pad)).long()
        idxs = torch.zeros((dim//ps2,dim//ps2,ps,ps)).long()
        for i in range(dim//ps2):
            for j in range(dim//ps2):
                    idxs[i,j,:,:] =  arr[i:i +ps ,j:j+ps]

        return idxs.flatten()
        
    def get_patches_mini(self,x):
        bs, h,w, c = x.shape
        padding = int((self.patch_size_l)//3) # padding = 3*8//2 =3*4 = 4
        x = torch.cat([x[:,::2,::2],x[:,1::2,1::2],x[:,1::2,::2],x[:,::2,1::2]],dim=-1)
        y = F.pad(x.permute(0,3,1,2), (padding, padding, padding, padding), mode='constant', value=0)
        y = torch.index_select(y.flatten(-2), 2, self.patch_idxs.to(x.device))
        #print(y.shape)
        y =y.reshape(bs,4*c,self.patch_size,self.patch_size,self.patch_size_l,self.patch_size_l)
        
        return y.permute(0,2,3,4,5,1).reshape((-1,  self.patch_size_l*self.patch_size_l,4*c)) #
        
    def get_patches_flat(self, x, ps):
        bs,  h, w,c = x.size()    
        patches = x.unfold(1, ps,   ps).permute(0,1,4, 2, 3)
        patches = patches.unfold(3,  ps,  ps).permute(0, 1, 3, 2,5,4)  
        return patches.reshape((-1, ps**2,c)) #

    
    def forward(self, x):
        x = x.permute(0,2,3,1)
        bs, h,w,c = x.shape
        x = self.proj(x)
        x = x.permute(0,3,1,2)

        x =  self.graph_block(x).reshape(bs, self.num_patches,self.embed_dim )

        x = self.vit_global(x)
        x =x.flatten(1)
    
        return x
    
    
    
    

class Rotation_Symmetry_Break2(nn.Module):
    """ PI Symmetry breaking Linear('ish) layer """
    def __init__(self, num_patches, num_heads):
        super().__init__()
        self.num_patches=num_patches
        
        
        #for weights
        idxs, idxs_angles, angles, dim_a = rot_grid_to_idx(int(math.sqrt(num_patches)))
        self.idxs = idxs
        self.angles =  nn.Parameter(torch.Tensor(num_heads,dim_a),requires_grad = True)  
        self.idxs_angles = idxs_angles      
        nn.init.kaiming_uniform_(self.angles, a=math.sqrt(5))
 
    def forward(self, x):

        bs, heads, ps, ps= x.size()

 
        #indexing to maintain Rotation and Translation invariance, Break Mirror invariance
        y = x[:,:,1:,1:].flatten(-2)[:,:,self.idxs]

        angles = torch.index_select(self.angles,1,self.idxs_angles.flatten().to(x.device))
        y = (y*angles.view(1,heads,3*self.num_patches**2)).reshape((bs,  heads, ps-1, ps-1,3)).sum(-1)
        y = torch.cat([ x[:,:,:,0:1],torch.cat([x[:,:,0:1,1:], y],dim=2)],dim=-1)

         
        return   y.reshape(x.shape)
    
 
  

    
class Sym_Break_Linear_Block(nn.Module):
    """ PI Symmetry breaking Linear('ish) layer """
    def __init__(self, img_size, patch_size, in_features,num_patches, num_heads):
        super().__init__()
        self.idim = img_size
        self.patch_size = patch_size
        self.in_features = in_features
        self.num_patches = num_patches
        #for weights
        self.kernel_size = self.patch_size + 1
        self.kernel_size0 = self.patch_size//2 +1
        self.padding0 = self.patch_size//4
        self.padding = (self.kernel_size -1) //2
        #self.padding2 = (self.patch_size) //4
        idxsW0, dimW0 = grid_dist_to_idx(dim = self.kernel_size0) #int(math.sqrt(num_patches)))
        self.idxsW0 = idxsW0
        self.dimW0 = dimW0
        idxsW, dimW = grid_dist_to_idx(dim = self.kernel_size) #int(math.sqrt(num_patches)))
        self.idxsW = idxsW
        self.dimW = dimW
        #for weights
        self.weights = nn.Parameter(torch.Tensor(in_features,dimW))  
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
        self.weights0 = nn.Parameter(torch.Tensor(in_features,dimW0))  
        nn.init.kaiming_uniform_(self.weights0, a=math.sqrt(5)) # weight init
        self.weights1 = nn.Parameter(torch.Tensor(in_features,dimW))  
        nn.init.kaiming_uniform_(self.weights1, a=math.sqrt(5)) # weight init
        
        
        self.weightsa = nn.Parameter(torch.Tensor(in_features,dimW))  
        nn.init.kaiming_uniform_(self.weightsa, a=math.sqrt(5)) # weight init
        self.weights1a = nn.Parameter(torch.Tensor(in_features,dimW))  
        nn.init.kaiming_uniform_(self.weights1a, a=math.sqrt(5))
        #self.weights01 = nn.Parameter(torch.Tensor(in_features,dimW0))  
       # nn.init.kaiming_uniform_(self.weights01, a=math.sqrt(5)) # weight init
        
        idxsW2, dimW2 = grid_dist_to_idx2(dim = self.patch_size//2) #int(math.sqrt(num_patches)))
        self.idxsW2 = idxsW2
        self.dimW2 = dimW2
        #for weights
        self.weights2 = nn.Parameter(torch.Tensor(in_features,dimW2))  
        nn.init.kaiming_uniform_(self.weights2, a=math.sqrt(5)) # weight init
        
        
        self.proj = nn.Linear(in_features,in_features,bias =False)
        #self.proj1 = nn.Linear(in_features,in_features)
     #   self.conv1 = nn.Conv2d(in_features,in_features,1)
        self.projk = nn.Linear(in_features,in_features,bias =False)
        self.projq = nn.Linear(in_features,in_features,bias =False)
        self.norm = nn.LayerNorm(3*in_features)
        
        self.projk2 = nn.Linear(in_features,in_features,bias =False)
        self.projq2 = nn.Linear(in_features,in_features,bias =False)
        self.norm2 = nn.LayerNorm(3*in_features)
        #self.norm2 = nn.LayerNorm(in_features)
        
        self.attn = Attention_Basic(in_features, num_patches = num_patches, num_heads=num_heads, qkv_bias=False)
        self.attn2 = Attention_Basic(in_features, num_patches = num_patches, num_heads=num_heads, qkv_bias=False)
        
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm_f = norm_layer(in_features)
        self.projfinv = nn.Linear(in_features , in_features//2 )
        self.projf = nn.Linear(in_features , 16)
        self.projf2 = nn.Linear(16 , 2)
        

        
    def get_patches_flat(self, x, ps):
        bs,  h, w,c = x.size()    
        patches = x.unfold(1, ps,   ps).permute(0,1,4, 2, 3)
        patches = patches.unfold(3,  ps,  ps).permute(0, 1, 3, 2,5,4)  
        return patches.reshape((-1, ps**2,c)) #
    
    def reconstruct_image(self,patches, ps=4, img_dim = 32):
        xy_dim = img_dim // ps
        
        bs,num_p, num_pn, f = patches.size()    
        patches = patches.reshape(bs, num_p ,ps,ps,f)
        img = patches.reshape((bs,xy_dim,xy_dim,ps, ps,f)).permute(0,1,3,2,4,5)
        img = img.reshape((bs,img_dim, img_dim,f))
        return img
        
    def forward(self, x):
        #if self.mult is None:
        bs, f ,idim, idim= x.size()
        W0 =  torch.index_select(self.weights0 , 1, self.idxsW0.flatten().to(x.device)).reshape((self.in_features,1,self.kernel_size0,self.kernel_size0))
        #W01 =  torch.index_select(self.weights01 , 1, self.idxsW0.flatten().to(x.device)).reshape((self.in_features,1,self.kernel_size0,self.kernel_size0))
        W =  torch.index_select(self.weights , 1, self.idxsW.flatten().to(x.device)).reshape((self.in_features,1,self.kernel_size,self.kernel_size))
        W1 =  torch.index_select(self.weights1 , 1, self.idxsW.flatten().to(x.device)).reshape((self.in_features,1,self.kernel_size,self.kernel_size))
        
        Wa =  torch.index_select(self.weightsa , 1, self.idxsW.flatten().to(x.device)).reshape((self.in_features,1,self.kernel_size,self.kernel_size))
        W1a =  torch.index_select(self.weights1a , 1, self.idxsW.flatten().to(x.device)).reshape((self.in_features,1,self.kernel_size,self.kernel_size))

        
        W2 =  torch.index_select(self.weights2 , 1, self.idxsW2.flatten().to(x.device)).reshape((1,self.in_features,(self.patch_size//2)**2)).permute(0,2,1)

    #   print("x.size()",x.size())
        y = F.conv2d(x.reshape(bs, f,self.idim, self.idim), W0, padding=self.padding0, groups=f)
        y = nn.ReLU()(self.proj(y.permute(0,2,3,1))).permute(0,3,1,2)
        #y = F.conv2d(y.reshape(bs, f,self.idim, self.idim), W01, padding=self.padding0, groups=f)
        #x = nn.ReLU()(self.conv1(x) + y ) #self.proj1(x.permute(0,2,3,1)).permute(0,3,1,2))
        x= nn.MaxPool2d((2,2), stride=(2,2))(y)
   #     print(x.shape)
        #---first Sit'ish layer ---> reduced attention window size,
        q = F.conv2d(x.reshape(bs, f,self.idim//2, self.idim//2), W, padding=self.padding, groups=f)
        q = self.projq(q.permute(0,2,3,1)) #  y.permute(0,2,3,1) #
        k = F.conv2d(x.reshape(bs, f,self.idim//2, self.idim//2), W1, padding=self.padding, groups=f)
        k = self.projk(k.permute(0,2,3,1)) #  y.permute(0,2,3,1) #
        qkv = self.norm(self.get_patches_flat(torch.cat([q,k,x.permute(0,2,3,1),],dim=-1),ps=self.patch_size//2))
      #  print("qkv.shape", qkv.shape)
        q, k, v = qkv[:,:,:f], qkv[:,:,f:2*f], qkv[:,:,2*f:]
        x = v + self.attn(q,k,v)
        x = self.reconstruct_image(x.reshape(bs,((self.idim//2)//(self.patch_size//2))**2, (self.patch_size//2)**2,f) ,ps=self.patch_size//2,img_dim=self.idim//2)
        x = x.permute(0,3,1,2)
        #---second Sit'ish layer ---> reduced attention window size, only conv spans over more batches
        q = F.conv2d(x.reshape(bs, f,self.idim//2, self.idim//2), Wa, padding=self.padding, groups=f)
        q = self.projq2(q.permute(0,2,3,1)) #  y.permute(0,2,3,1) #
        k = F.conv2d(x.reshape(bs, f,self.idim//2, self.idim//2), W1a, padding=self.padding, groups=f)
        k = self.projk2(k.permute(0,2,3,1)) #  y.permute(0,2,3,1) #
        qkv = self.norm2(self.get_patches_flat(torch.cat([q,k,x.permute(0,2,3,1),],dim=-1),ps=self.patch_size//2))
      #  print("qkv.shape", qkv.shape)
        q, k, v = qkv[:,:,:f], qkv[:,:,f:2*f], qkv[:,:,2*f:]
        x = v + self.attn(q,k,v)


        x_inv = self.projfinv((nn.GELU()(x)*W2).sum(1)) # alternatie to using token embeding, jsut multiply with weights of symmetries of graph
        
        x = (nn.GELU()(self.norm_f(x)))
        x = self.projf2(nn.GELU()(self.projf(x)))
        x = x.reshape(bs,self.num_patches, f//2)
        x = torch.cat([x,x_inv.reshape(bs,self.num_patches, f//2)],dim=-1)

       
        return   x 
    

    
def grid_dist_to_idx2(dim = 4):
        
        dist_Mat = torch.zeros((dim,dim))

        #only works correcttly if dim is odd
        i0,j0 = (dim-1) /2, (dim-1) /2
        for i in range(0,dim):
            for j in range(0,dim):
                        #print("I",I,"J",J)
                        distance = math.sqrt((i0-i)**2 + (j0-j)**2)
                        dist_Mat[i,j] =  distance
                        
        
        dist = torch.round(1000000*dist_Mat)
        unique = torch.unique(dist.flatten(), dim=-1)
        idxs = torch.zeros(dist.shape)
        dimu= unique.shape[0]
        for i in range(dimu):
            mask = (dist == unique[i])
            idxs[mask] = i
        
            
        return idxs.long(), dimu         
    
