import numpy as np
from tqdm import tqdm
from dataset_5_layer.configs import MetamatDsConfig
from copy import deepcopy
from dataset_5_layer.data_utils import BB_metals as bb
from dataset_5_layer.data_utils import TMM_numba as tmm
from dataset_5_layer.data_utils import dielectric_materials as di

import scipy.stats as st

#Thickness have to be expressed in nm
def srmse_evaluate(x_pred, x_actual, cfg: MetamatDsConfig):
    wave = cfg.wave # np.linspace(450, 950, 200) * 1E-9
    ag = bb.nk_material('Ag', wave)
    au = bb.nk_material('Au', wave)
    ni = bb.nk_material('Ni', wave)
    al2o3 = di.nk_material('al2o3', wave)
    tio2 = di.nk_material('tio2', wave)
    ito = di.nk_material('ito', wave)
    gl = di.nk_Cauchy_Urbach(wave, 1.55,
                             0.005)  # glass model for substrate (based on 2-parameter cauchy fit of slides in lab)
    void = np.ones(wave.size)  # vacuum

    materials = np.array([ag, al2o3, ito, ni, tio2])

    ang = np.array([25., 45., 65.])  # angles to calculate for each system [deg from normal]
    n_super = void  # material to be used for superstrate
    n_subst = gl  # material to be used for substrate

    rmse_samples = []
    rmse_waves = []

    ytrue = []
    ypred = []

    # print('\nComputing srmse...')
    for x, y in zip(x_pred, x_actual):
    #for x, y in zip(x_pred, x_actual):
        pred = generate_fcn(wave, n_subst, n_super, materials, x, ang, cfg)
        true = generate_fcn(wave, n_subst, n_super, materials, y, ang, cfg)

        se = (pred-true)**2

        se_waves = np.reshape(se,(4,3,200))
        rmse_singlewaves = np.sqrt(np.mean(se_waves, axis=2))
        # print(rmse_singlewaves.shape)
        rmse_singlewaves = np.reshape(rmse_singlewaves,(1,-1))
        rmse_waves.append(rmse_singlewaves)

        #se shape is (4,600) and rmse computed separately (rp,rs,tp,ts)
        rmse_m = np.sqrt(np.mean(se,1))
        rmse_m = np.mean(rmse_m)
        rmse_samples.append(rmse_m)
        ytrue.append(np.reshape(true,(1,-1)))
        ypred.append(np.reshape(pred,(1,-1)))


    rmse_waves = np.squeeze(np.array(rmse_waves))
    rmse_samples = np.array(rmse_samples)

    mean = np.mean(rmse_samples)
    std = np.std(rmse_samples)
    ci = st.t.interval(0.95, df=len(rmse_samples)-1, loc=mean, scale=std)

    ytrue = np.squeeze(np.array(ytrue))
    ypred = np.squeeze(np.array(ypred))

    return mean, std, ci, rmse_samples, tuple((ytrue,ypred)), rmse_waves


def generate_fcn(wave, n_subst, n_super, materials, x, ang, cfg: MetamatDsConfig):
    l = x[cfg.th_index:]
    n = np.zeros((l.size, wave.size), dtype=complex)
    rp = np.zeros(wave.size * ang.size)
    rs = np.zeros(wave.size * ang.size)
    tp = np.zeros(wave.size * ang.size)
    ts = np.zeros(wave.size * ang.size)

    new_m = np.reshape(x[:cfg.th_index], (-1, cfg.num_mat))
    matset = np.argmax(new_m, axis=1)
    for ly, mat in zip(range(l.size), matset):
        n[ly, :] = materials[mat, :]

    for j in range(0, ang.size):
        for i in range(0, wave.size):
            rp[i + wave.size * j] = tmm.reflect_amp(1, ang[j], wave[i], n[:, i], l, n_super[i], n_subst[i])
            rs[i + wave.size * j] = tmm.reflect_amp(0, ang[j], wave[i], n[:, i], l, n_super[i], n_subst[i])
            tp[i + wave.size * j] = tmm.trans_amp(1, ang[j], wave[i], n[:, i], l, n_super[i], n_subst[i])
            ts[i + wave.size * j] = tmm.trans_amp(0, ang[j], wave[i], n[:, i], l, n_super[i], n_subst[i])

    res = np.array([rp, rs, tp, ts])

    return res
