import numpy as np
import torch

DEVICE=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

f1w=None

def fun1(x,bias=None):
    #(xi-bi)^2
    # b,n,dim
    z=x if bias is None else x-bias.view(-1)
    sc=torch.sin(z)
    sc=f1w(sc)
    sc=torch.pow(sc,2).view(-1,x.shape[1])
    
    return sc


def fun2(x,bias=None):
    if not bias is None:
        bias=bias.view(-1)
        z=x-bias
    else:
        z=x
    sc=torch.sum(torch.abs(z),dim=2)
    return sc
    
def fun3(x,bias=None):
    if not bias is None:
        z=x-bias
    else:
        z=x
    
    z1=z[:,:,:-1]
    z2=z[:,:,1:]
    sc=torch.sum(torch.abs(z1+z2),dim=2)+torch.sum(torch.abs(z),dim=2)
    
    return sc

def fun4(x,b=None):  #checked
    if not b is None:
        z=x-b
    else:
        z=x
    sc=torch.sum(torch.pow(z,2),dim=2)
    return sc    


def fun5(x,b=None):  #checked
    if not b is None:
        z=x-b
    else:
        z=x
    z=torch.abs(z)
    sc=torch.max(z,dim=2)[0]
    return sc
 
def fun6(x,b=None): #checked
    if not b is None:
        z=x-b
    else:
        z=x
    x1=z[:,:,:-1]
    x2=z[:,:,1:]
    return torch.sum(100*torch.pow((torch.pow(x1,2)-x2),2)+torch.pow((x1-1),2),dim=2)



def fun7(x,b=None):     #checked
    if not b is None:
        z=x-b
    else:
        z=x
    sc=torch.sum(torch.pow(z,torch.tensor(2).to(DEVICE))-10*torch.cos(2*np.pi*(z))+10,dim=2)
    return sc


def fun8(x,b=None):  #checked
    if not b is None:
        z=x-b
    else:
        z=x
    i=torch.from_numpy(np.array([i+1 for i in range(x.shape[2])])).view(-1).to(DEVICE).view(1,1,x.shape[2])
    sc=torch.sum(torch.pow(z,2)/4000,dim=2)-torch.prod(torch.cos((z)/torch.sqrt(i)),dim=2)+1
    return sc




def fun9(x,b=None):   #checked
    if not b is None:
        z=x-b
    else:
        z=x
    sc=-20*torch.exp(-0.2*torch.sqrt((1/x.shape[2])*torch.sum(torch.pow(z,2),dim=2))
                     )-torch.exp((1/x.shape[2])*torch.sum(torch.cos(2*3.14*(z)),dim=2))+20+np.e
    return sc


if __name__=='__main__':
    a=torch.from_numpy(np.array([[1,2,3]]))
    b=torch.from_numpy(np.array([[-1,-2,-3]]))
    a=a.view(-1)
    b=b.view(-1)
    print(b.shape[0])
    sc=torch.sum(torch.pow(a-b,torch.tensor(2)))
    print(sc)