import warnings

from dataset_5_layer.configs import MetamatDsConfig
warnings.filterwarnings('ignore')

import numpy as np
import datetime

global date
date = datetime.datetime.now()
import h5py


def resc_ang(in_theta):
  resc = in_theta / 45
  return resc


def resc_psi(in_psi):
  resc = in_psi / 90
  return resc


def resc_delt(in_delt):
  resc = in_delt / 90
  return resc


# load in the data from the generator file
# this follows the generator format :
# [angle,materials,thickness,rp,rs,tp,ts,psi,delta]
def readin_data(filename, nmat, nang, nlay, nwave):
  f = h5py.File(filename.absolute().as_posix(), 'r')
  arrd = np.array(f.get('data'))
  itl = 0
  ith = nang
  theta = (arrd[:, itl:ith])
  itl = ith
  ith += nmat
  ml1 = (arrd[:, itl:ith])
  ml2 = None
  if nlay>1:
    itl = ith
    ith += nmat
    ml2 = (arrd[:, itl:ith])
  ml3=None
  if nlay>2:
    itl = ith
    ith += nmat
    ml3 = (arrd[:, itl:ith])
  ml4=None
  if nlay>3:
    itl = ith
    ith += nmat
    ml4 = (arrd[:, itl:ith])
  ml5=None
  if nlay>4:
    itl = ith
    ith += nmat
    ml5 = (arrd[:, itl:ith])
  itl = ith
  ith += nlay
  th = (arrd[:, itl:ith])
  itl = ith
  ith += nwave * nang
  rp = (arrd[:, itl:ith])
  itl = ith
  ith += nwave * nang
  rs = (arrd[:, itl:ith])
  itl = ith
  ith += nwave * nang
  tp = (arrd[:, itl:ith])
  itl = ith
  ith += nwave * nang
  ts = (arrd[:, itl:ith])
  itl = ith
  ith += nwave * nang
  psi = (arrd[:, itl:ith])
  itl = ith
  delta = (arrd[:, itl:])
  f.close()

  # rescale the input data
  # usually just rescaling angles, thicknesses, psi and delta is OK
  psi = resc_psi(psi)
  delta = resc_delt(delta)
  # th = resc_th(th)
  return (ml1, ml2, ml3, ml4, ml5, th, theta, rp, rs, tp, ts, psi, delta)


# read in data from file with the readin_data fcn
# you need to update the num_mat, ect.. variables with your choices from the generator program
# num_mat = 5
# num_ang = 3
# num_lay = 5
# num_wave = 200
# file_name = cfg.h5_data_file
###CHANGE FOR DIFFERENT NUM OF LAYERS

# (ml1,ml2,ml3,ml4,ml5,th,ang,rp,rs,tp,ts,psi,delta) = readin_data(os.path.join('data',file_name),num_mat,num_ang,num_lay,num_wave)



def get_x_rt_data(cfg: MetamatDsConfig, invd_num_test: int, verbose=True):
  rp, rs, tp, ts, x = None, None, None, None, None

  if verbose:
    print('loading data: ', cfg.dataset_path)
  
  if cfg.num_lay == 5:
    (ml1, ml2, ml3, ml4, ml5, th, _, rp, rs, tp, ts, psi, delta) = readin_data(cfg.dataset_path,
                                                                         cfg.num_mat, cfg.num_ang, cfg.num_lay,
                                                                         cfg.num_wave)
    x = np.concatenate([ml1, ml2, ml3, ml4, ml5, th], axis=1)
  elif cfg.num_lay == 4:
    #ml5 is none
    (ml1, ml2, ml3, ml4, ml5, th, _, rp, rs, tp, ts, psi, delta) = readin_data(cfg.dataset_path,
                                                                         cfg.num_mat, cfg.num_ang, cfg.num_lay,
                                                                         cfg.num_wave)
    x = np.concatenate([ml1, ml2, ml3, ml4, th], axis=1)
  elif cfg.num_lay == 3:
    #ml5 ml4 are none
    (ml1, ml2, ml3, ml4, ml5, th, _, rp, rs, tp, ts, psi, delta) = readin_data(cfg.dataset_path,
                                                                         cfg.num_mat, cfg.num_ang, cfg.num_lay,
                                                                         cfg.num_wave)
    x = np.concatenate([ml1, ml2, ml3, th], axis=1)
  elif cfg.num_lay == 2:
    #ml5 ml4 are none
    (ml1, ml2, ml3, ml4, ml5, th, _, rp, rs, tp, ts, psi, delta) = readin_data(cfg.dataset_path,
                                                                               cfg.num_mat, cfg.num_ang, cfg.num_lay,
                                                                               cfg.num_wave)
    x = np.concatenate([ml1, ml2, th], axis=1)
  elif cfg.num_lay == 1:
    #ml5 ml4 are none
    (ml1, ml2, ml3, ml4, ml5, th, _, rp, rs, tp, ts, psi, delta) = readin_data(cfg.dataset_path,
                                                                               cfg.num_mat, cfg.num_ang, cfg.num_lay,
                                                                               cfg.num_wave)
    x = np.concatenate([ml1, th], axis=1)

  y_r_t = np.concatenate([rp, rs, tp, ts], axis=1)

  #Tests ones are for the inverse computation phase
  #shuffle the data

  # np.random.seed(None)
  # shuffle = np.random.permutation(np.arange(x.shape[0]))
  # x = x[shuffle, :]
  # y_r_t = y_r_t[shuffle, :]


  n_test = invd_num_test
  X_train = x[:-n_test, :]
  X_test = x[-n_test:, :]
  y_train = y_r_t[:-n_test, :]
  y_test = y_r_t[-n_test:, :]

  if verbose:
    print(" ...done!")

  return X_train, X_test, y_train, y_test