import importlib
from typing import Literal
import torch
import numpy as np
from data_utils.real_simulator_tmm import srmse_evaluate
from data_utils.utils import unscaler
import commons.semantic_loss as semloss
importlib.reload(semloss)


def evaluate_one_hot(decoded_material, mat_lay):
    if len(decoded_material.shape) < 2:
        decoded_material = decoded_material.unsqueeze(0)

    values = decoded_material[:, :-mat_lay].reshape(-1,5).square().sum(dim=1)
    mean = torch.mean(values).item()
    return mean

def evaluate_single_srmse(decoded_material, desired_material, metamat_config):
    if len(decoded_material.shape) < 2:
        decoded_material = decoded_material.unsqueeze(0)
    
    if len(desired_material.shape) < 2:
        desired_material = desired_material.unsqueeze(0)
    
    decoded_material = decoded_material.cpu().detach().numpy()
    desired_material = desired_material.cpu().detach().numpy()
    
    mean, std, ci, rmse_samples, ytrue_ypred, rmse_waves = srmse_evaluate(
        np.array(decoded_material),
        np.array(desired_material),
        metamat_config
    )

    return mean


def search_point_multi(vae, vae_simulator, initial_points, desired_y, desired_x, metamat_config, epochs=100, learning_rate=0.001, sloss=None,):
    mat_layers = metamat_config.num_lay
    n_mat = 5
    n_lay = 5

    desired_x = torch.tensor(desired_x).to(vae.device)
    # Scale desired point from (-1,1) to nanometer range for metrics
    desired_x_scaled = desired_x.clone().to(vae.device)
    unscaler(desired_x_scaled, mat_layers)

    initial_points = torch.tensor(initial_points).to(vae.device)
    bias = torch.zeros((initial_points.shape[0], initial_points.shape[1]), device=vae.device, requires_grad=True)
    optimizer = torch.optim.Adam([initial_points, bias], lr=learning_rate)

    history = []
    for epoch in range(epochs):
        optimizer.zero_grad()

        points = initial_points + bias
        decoded_points = vae.decode(points)
        soft_decoded = 0.0001 + 0.9998 * decoded_points

        # Exclude thickness and reshape
        soft_decoded_ = soft_decoded[:,:n_lay * n_mat].reshape(n_lay * n_mat, soft_decoded.shape[0])

        multi_rec_loss = (vae_simulator(points) - desired_y).square().sum(dim=1)
        rec_loss = multi_rec_loss.mean()
        loss = rec_loss
        
        multi_sem_loss = None
        if sloss != None:
            _, wmc = sloss(probabilities = soft_decoded_, output_wmc_per_sample = True)

            # The loss are now NUM_POINTS x 1
            multi_sem_loss = -torch.log(wmc)
            sem_loss = multi_sem_loss.mean()
            loss = loss + sem_loss

        loss.backward(retain_graph=False)

        # Scale material in nm range
        unscaler(decoded_points, mat_layers)

        epoch_data = []
        idx_min = 0
        for idx in range(decoded_points.shape[0]):
            latent_point = np.copy(points[idx].detach().cpu().numpy())
            decoded_point = np.copy(decoded_points[idx].detach().cpu().numpy())
            rec_loss = multi_rec_loss[idx].item()  #evaluate_single_srmse(decoded_points[idx], desired_material=desired_x_scaled, metamat_config=metamat_config) 
            sem_loss = multi_sem_loss[idx].item() if sloss != None else "None" # type: ignore
            onehot = evaluate_one_hot(decoded_points[idx], mat_lay=mat_layers)

            epoch_data.append((idx, rec_loss, sem_loss, onehot, latent_point, decoded_point))

            # If rec loss of this point is less than the min, then it's the new min
            if epoch_data[idx][1] < epoch_data[idx_min][1]:
                idx_min = idx

        optimizer.step()
        history.append(epoch_data)
        
        #sem_loss_str = f"{epoch_data[idx_min][2]:.4f}" if sloss != None else None
        #print(f"Epoch [{epoch + 1}/{epochs}]   Rec Loss: {(np.sqrt(epoch_data[idx_min][1] / 2400)):.5f}, Sem Loss: {sem_loss_str}, One hot acc: {epoch_data[idx_min][3]}, Best point idx: {idx_min}")

        #real_srmse = evaluate_single_srmse(decoded_points[idx_min], desired_material=desired_x_scaled, metamat_config=metamat_config)
        #print(f"Real srmse: {real_srmse}")

    return history



