import torch
import numpy as np
import torch.optim as optim
import math

def drop_wave(x):
    sum_sq = torch.sum(x ** 2, dim=1)
    numerator = 1 + torch.cos(12 * torch.sqrt(sum_sq))
    denominator = 0.5 * sum_sq + 2
    return  1 - (numerator / denominator)


def ackley(x, a=20, b=0.2, c=2 * math.pi):
    D = x.shape[1]
    sum_sq_term = torch.sum(x ** 2, dim=1) / D
    cos_term = torch.sum(torch.cos(c * x), dim=1) / D
    result = -a * torch.exp(-b * torch.sqrt(sum_sq_term)) - torch.exp(cos_term) + a + math.e
    return result


def rosenbrock(x, a=1, b=100):
    x1 = x[:, :-1]
    x2 = x[:, 1:]
    result = torch.sum(b * (x2 - x1 ** 2) ** 2 + (a - x1) ** 2, dim=1)
    return result


def griewank(x):
    D = x.shape[1]
    sum_term = torch.sum(x ** 2, dim=1) / 4000
    prod_term = torch.prod(torch.cos(x / torch.sqrt(torch.arange(1, D + 1, dtype=torch.float32, device=x.device))), dim=1)
    result = sum_term - prod_term + 1
    return result


def alpine_1(x):
    return torch.sum(torch.abs(x * torch.sin(x) + 0.1 * x), dim=1)

def rastrigin(x): 
    n = x.shape[1]
    return 10 * n + torch.sum(x**2 - 10 * torch.cos(2 * torch.pi * x), dim=1)

def xin_she_yang_1(x): 
    u = torch.sum(torch.abs(x), dim=-1)
    
    v = torch.exp(-torch.sum(torch.sin(x**2), dim=-1))
    
    return u * v

def salomon(x):
    term = torch.sqrt(torch.sum(x**2, dim=1))
    return 1 - torch.cos(2 * torch.pi * term) + 0.1 * term


def schaffer_f7(x):
    term1 = torch.sqrt(x[:, :-1]**2 + x[:, 1:]**2)  
    wraparound_distance = torch.sqrt(x[:, -1]**2 + x[:, 0]**2)  
    term1 = torch.cat((term1, wraparound_distance.unsqueeze(1)), dim=1)
    return torch.sum(term1**0.5 * (1 + torch.sin(50 * term1**0.2)**2), dim=1)/x.shape[1]


def expanded_schaffer_f6(x): 
    term1 = torch.sqrt(x[:, :-1]**2 + x[:, 1:]**2)
    term1 = torch.cat([term1, torch.sqrt(x[:,-1]**2 + x[:,0]**2).reshape(-1,1)],dim=1)
    return torch.sum(0.5 + (torch.sin(term1)**2 - 0.5) / (1 + 0.001 * term1**2)**2, dim=1)

def xin_she_yang_3(x):
    bracket_1 = torch.exp(-torch.sum((x/15)**10, dim=1)) - 2*torch.exp(-torch.sum(x**2,dim=1))
    ans =  bracket_1 * torch.prod(torch.cos(x)**2, dim=1)
    return (1 + ans)

def xin_she_yang_5(x):
    ans = (torch.sum(torch.sin(x)**2, dim=1) - torch.exp(torch.sum(x**2,dim=1)))*(torch.exp(-torch.sum(torch.sin(x.abs()**0.5)**2,dim=1)))
    return (1 + ans)

def happy_cat(x, alpha=0.25): #GM AT -1
    n = x.shape[1]
    norm_term = torch.sum(x**2, dim=1) - n
    return 0.5 + (norm_term**2)**alpha + (1 / n) * (0.5*torch.sum(x**2,dim=1)+torch.sum(x, dim=1))

def hg_bat(x, alpha=0.5): #GM AT -1
    n = x.shape[1]
    norm_term = torch.abs(torch.sum(x**2, dim=1)**2 - torch.sum(x,dim=1)**2)
    return 0.5 + (norm_term)**alpha + (1 / n) * (0.5*torch.sum(x**2,dim=1)+torch.sum(x, dim=1))

def schwefel(x): #GM AT 420.98....
    n = x.shape[1]
    return 418.9829 * n - torch.sum(x * torch.sin(torch.sqrt(torch.abs(x))), dim=1)

min_values={}


functions = [drop_wave, ackley, rosenbrock, griewank, alpine_1,rastrigin,xin_she_yang_1, salomon, schaffer_f7, expanded_schaffer_f6, xin_she_yang_3, xin_she_yang_5, happy_cat, hg_bat, schwefel]


aberrant_global_min = {'happy_cat':-1.0, 'hg_bat':-1.0, 'schwefel':420.98}


steps = 10000
bound = 100

lr_list = [x * 0.05 for x in range(1, 101)]


for function in functions:
    if function.__name__ in aberrant_global_min:
        global_min_value = aberrant_global_min[function.__name__]
    else:
        global_min_value = 0.0
    for lr in lr_list:
        random_point = np.random.uniform(-bound, bound, 10000000)
        x = torch.from_numpy(random_point.reshape(-1, 1000)).cuda()
        x.requires_grad = True
        optimizer = optim.Adam([x], lr=lr)
        min_loss_value = float('inf')
        best_input = None  # To store the input with the lowest loss
        best_input_norm = None  # To store the norm of the input with the lowest loss
        
        min_value = float('inf')

        for iteration in range(steps):
            optimizer.zero_grad()
            loss = function(x)  # (B,) vector
            loss_sum = loss.sum()  # Scalar to backpropagate
            loss_sum.backward()
            optimizer.step()
            
            min_value = min(min_value, (x - global_min_value).norm(dim=1).min())
            if min_value < 5:
                break

            # Get the minimum loss value and its corresponding index
            current_min_loss, min_loss_index = torch.min(loss, dim=0)
            
            if current_min_loss.item() < min_loss_value:
                min_loss_value = current_min_loss.item()
                best_input = x[min_loss_index].clone().detach()  # Store the best input with the lowest loss
                best_input_norm = best_input.norm().item()  # Calculate the norm of the best input


        min_values[(function.__name__, bound, lr)] = (min_value, min_loss_value, best_input_norm)
        min_values[(function.__name__, bound, lr, 'final point')] = x.cpu().data.numpy()
        print((function.__name__, bound, lr), (min_value, min_loss_value, best_input_norm))
       
        if min_value < 5:
            break
        
        
    

