import numpy as np
import torch

def evaluate_one_hot(decoded_material, mat_lay=10, num_mat=7):
    values = decoded_material[:-mat_lay].reshape(-1, num_mat).square().sum(dim=1)
    mean = torch.mean(values).item()
    return mean


def neural_adjoint_search(simulator, initial_point, desired_y, lr=0.05, epochs=200, sloss=None):
    initial_point = initial_point.clone().detach()
    initial_point.requires_grad = True

    optimizer = torch.optim.Adam([initial_point], lr=lr)

    num_lay = 10
    num_mat = 7

    losses = []
    history = []
    for epoch in range(epochs):
        optimizer.zero_grad()
    
        # Apply softmax to one-hot part of the material, and sigmoid to thickness part of the material
        one_hot_part = torch.softmax(initial_point[:, :num_lay * num_mat].reshape(-1, num_lay, num_mat), dim=2)
        thickness_part = torch.sigmoid(initial_point[:, num_lay * num_mat:])

        x_hat = torch.cat([one_hot_part.reshape(-1, num_lay * num_mat), thickness_part], dim=1)

        one_hot_loss = ((one_hot_part ** 2).sum(dim=2).mean() - 1).square()
        multi_sim_loss = (simulator(x_hat) - desired_y).square().mean(dim=1)
        sim_loss = multi_sim_loss.mean()

        multi_sem_loss = None
        sem_loss = None
        if sloss != None:
            # Exclude thickness and reshape
            soft_decoded = 0.005 + 0.99 * x_hat
            soft_decoded_ = soft_decoded[:,:num_lay * num_mat].reshape(num_lay * num_mat, soft_decoded.shape[0])
            _, wmc = sloss(probabilities = soft_decoded_, output_wmc_per_sample = True)

            # The loss are now NUM_POINTS x 1
            multi_sem_loss = -torch.log(wmc)
            sem_loss = multi_sem_loss.mean()
    
        loss = sim_loss + one_hot_loss
        if sloss != None:
            loss += sem_loss

        loss.backward()
        optimizer.step()

        epoch_data = []
        idx_min = 0
        for idx in range(x_hat.shape[0]):
            decoded_point = np.copy(x_hat[idx].detach().cpu().numpy())
            srmse = multi_sim_loss[idx].item()
            sem_loss = multi_sem_loss[idx].item() if sloss != None else "None" # type: ignore
            onehot = evaluate_one_hot(x_hat[idx])

            epoch_data.append((idx, srmse, sem_loss, onehot, decoded_point))

            # If rec loss of this point is less than the min, then it's the new min
            if epoch_data[idx][1] < epoch_data[idx_min][1]:
                idx_min = idx

        history.append(epoch_data)
        losses.append(loss.item())

        #print(f"Epoch [{epoch + 1}/{epochs}] Sim loss: {sim_loss}, One hot loss: {one_hot_loss}")
        #sem_loss_str = f"{epoch_data[idx_min][2]:.4f}" if sloss != None else None
        #print(f"Epoch [{epoch + 1}/{epochs}]   Simulator Loss: {(epoch_data[idx_min][1]):.4f}, Sem Loss: {sem_loss_str}, One hot acc: {epoch_data[idx_min][3]}, Best point idx: {idx_min}")

    return history, losses