import torch
import numpy as np
import src.assets.simulator.multi_layer_model as real_simulator
import src.utils.data as utils

def argmax_material(tensor, num_lay=10, num_mat=7):
    if len(tensor.shape) == 1:
        tensor = tensor.unsqueeze(0)

    onehot_part = tensor[:, :num_lay*num_mat].reshape(-1, num_lay, num_mat)
    argmax = torch.argmax(onehot_part, dim=2)
    new_onehot = torch.nn.functional.one_hot(argmax, num_classes=num_mat).reshape(-1, num_lay*num_mat)

    return torch.cat([new_onehot, tensor[:, num_lay*num_mat:]], dim=1)

def evaluate_one_hot(decoded_material, mat_lay=10, num_mat=7):
    values = decoded_material[:-mat_lay].reshape(-1, num_mat).square().sum(dim=1)
    mean = torch.mean(values).item()
    return mean

def search_point_multi(vae, vae_simulator, initial_point, desired_y, desired_x, epochs=100, learning_rate=0.001, sloss = None):
    #desired_spectra = simulate_material(desired_x[0])

    initial_point = initial_point.clone().detach()
    initial_point.requires_grad = True

    bias = torch.zeros((initial_point.shape[0], initial_point.shape[1]), device=vae.device, requires_grad=True)
    optimizer = torch.optim.Adam([initial_point, bias], lr=learning_rate)

    n_lay = 10
    n_mat = 7

    history = []
    for epoch in range(epochs):
        optimizer.zero_grad()

        points = initial_point + bias
        decoded_points = vae.decode(points)
        soft_decoded = 0.005 + 0.99 * decoded_points

        # Exclude thickness and reshape
        soft_decoded_ = soft_decoded[:,:n_lay * n_mat].reshape(n_lay * n_mat, soft_decoded.shape[0])

        # simulator simulate from latent_space
        multi_rec_loss = (vae_simulator(points) - desired_y).square().sum(dim=1)
        rec_loss = multi_rec_loss.mean()
        #### one_hot_constr = 1 - decoded_points[:, :-n_lay].reshape(-1, n_mat).square().sum(dim=1).reshape(-1, n_lay).sum(dim=1)

        loss = rec_loss
        
        multi_sem_loss = None
        if sloss != None:
            _, wmc = sloss(probabilities = soft_decoded_, output_wmc_per_sample = True)

            # The loss are now NUM_POINTS x 1
            multi_sem_loss = -torch.log(wmc)
            sem_loss = multi_sem_loss.mean()
            loss = loss + sem_loss

        loss.backward(retain_graph=False)

        epoch_data = []
        idx_min = 0
        for idx in range(decoded_points.shape[0]):
            latent_point = np.copy(points[idx].detach().cpu().numpy())
            decoded_point = np.copy(decoded_points[idx].detach().cpu().numpy())
            rec_loss = multi_rec_loss[idx].item()
            sem_loss = multi_sem_loss[idx].item() if sloss != None else "None" # type: ignore
            onehot = evaluate_one_hot(decoded_points[idx])

            epoch_data.append((idx, rec_loss, sem_loss, onehot, latent_point, decoded_point))

            # If rec loss of this point is less than the min, then it's the new min
            if epoch_data[idx][1] < epoch_data[idx_min][1]:
                idx_min = idx

        optimizer.step()
        history.append(epoch_data)
        
        #sem_loss_str = f"{epoch_data[idx_min][2]:.4f}" if sloss != None else None
        #print(f"Epoch [{epoch + 1}/{epochs}]   Rec Loss: {epoch_data[idx_min][1]:.4f}, Sem Loss: {sem_loss_str}, One hot acc: {epoch_data[idx_min][3]}, Best point idx: {idx_min}")
    
    #decoded_spectra = simulate_material(best_point_decoded.detach()[0])   # type: ignore
    #rmse = (decoded_spectra - desired_y).square().mean()

    return history


