import torch
import numpy as np
from data_utils.real_simulator_tmm import srmse_evaluate
from data_utils.utils import unscaler

def evaluate_one_hot(decoded_material, mat_lay):
    if len(decoded_material.shape) < 2:
        decoded_material = decoded_material.unsqueeze(0)

    values = decoded_material[:, :-mat_lay].reshape(-1,5).square().sum(dim=1)
    mean = torch.mean(values).item()
    return mean

def evaluate_single_srmse(decoded_material, desired_material, metamat_config):
    if len(decoded_material.shape) < 2:
        decoded_material = decoded_material.unsqueeze(0)
    
    if len(desired_material.shape) < 2:
        desired_material = desired_material.unsqueeze(0)
    
    decoded_material = decoded_material.cpu().detach().numpy()
    desired_material = desired_material.cpu().detach().numpy()
    
    mean, std, ci, rmse_samples, ytrue_ypred, rmse_waves = srmse_evaluate(
        np.array(decoded_material),
        np.array(desired_material),
        metamat_config
    )

    return mean

def neural_adjoint_search(simulator, initial_point, desired_x, desired_y, metamat_config, lr=0.05, epochs=200, sloss=None):
    initial_point = initial_point.clone().detach()
    initial_point.requires_grad = True

    optimizer = torch.optim.Adam([initial_point], lr=lr)
    desired_x_nmrange = desired_x.clone().detach()
    unscaler(desired_x_nmrange, n_layer=5)

    num_lay = 5
    num_mat = 5

    losses = []
    history = []
    for epoch in range(epochs):
        optimizer.zero_grad()
    
        # Apply softmax to one-hot part of the material, and sigmoid to thickness part of the material
        one_hot_part = torch.softmax(initial_point[:, :num_lay * num_mat].reshape(-1, num_lay, num_mat), dim=2)
        thickness_part = torch.tanh(initial_point[:, num_lay * num_mat:])

        #thickness_nm = ((thickness_part + 1) * 30e-9) * 1e9
        #thick_relu_loss = torch.relu(10 - thickness_nm).sum(dim=1).mean()

        # Reconstruct the material
        x_hat = torch.cat([one_hot_part.reshape(-1, num_lay * num_mat), thickness_part], dim=1)

        one_hot_loss = ((one_hot_part ** 2).sum(dim=2).mean() - 1).square()
        multi_sim_loss = (simulator(x_hat) - desired_y).square().mean(dim=1)
        sim_loss = multi_sim_loss.mean()

        multi_sem_loss = None
        sem_loss = None
        if sloss != None:
            # Exclude thickness and reshape
            soft_decoded = 0.0001 + 0.9998 * x_hat
            soft_decoded_ = soft_decoded[:,:num_lay * num_mat].reshape(num_lay * num_mat, soft_decoded.shape[0])
            _, wmc = sloss(probabilities = soft_decoded_, output_wmc_per_sample = True)

            # The loss are now NUM_POINTS x 1
            multi_sem_loss = -torch.log(wmc)
            sem_loss = multi_sem_loss.mean()
    
        loss = sim_loss + one_hot_loss # + 0.1*thick_relu_loss
        if sloss != None:
            loss += sem_loss

        loss.backward()
        optimizer.step()

        epoch_data = []
        idx_min = 0
        unscaler(x_hat, n_layer=5)
        for idx in range(x_hat.shape[0]):
            decoded_point = np.copy(x_hat[idx].detach().cpu().numpy())
            smse = multi_sim_loss[idx].item()
            sem_loss = multi_sem_loss[idx].item() if sloss != None else "None" # type: ignore
            onehot = evaluate_one_hot(x_hat[idx], mat_lay = 5)

            epoch_data.append((idx, smse, sem_loss, onehot, decoded_point))

            # If rec loss of this point is less than the min, then it's the new min
            if epoch_data[idx][1] < epoch_data[idx_min][1]:
                idx_min = idx

        history.append(epoch_data)
        losses.append(loss.item())

        #print(f"Epoch [{epoch + 1}/{epochs}] Sim loss: {sim_loss}, One hot loss: {one_hot_loss}")
        #sem_loss_str = f"{epoch_data[idx_min][2]:.4f}" if sloss != None else None
        #print(f"Epoch [{epoch + 1}/{epochs}]   SMSE: {(epoch_data[idx_min][1]):.5f}, One hot acc: {epoch_data[idx_min][3]}, Best point idx: {idx_min}")

    #srmse = evaluate_single_srmse(torch.tensor(history[-1][idx_min][4]), desired_x_nmrange, metamat_config)
    #print(f"Real Srmse of best point: {srmse}")
    return history, losses