
import numpy as np

from ...fem_attribute import FEMAttribute
from ...fem_data import FEMData
from ...util import string_parser as st


class TIMONData(FEMData):
    """FEMEntity of TIMON version."""

    DICT_HOU_GLOBAL_NODE_DATA = {
        'names':
        [
            'pressure_start_shrinkage',
            'specific_volume_start_shrinkage',
            'average_temperature_start_shrinkage',
            'time_start_shrinkage',
            'shrinkage',
            'gradient_temperature_mold',
            'shrinkage_mold',
        ],
        'dimensions': [1, 1, 1, 1, 1, 1, 1],
        'n_line': 2
    }
    DICT_HOU_STEP_NODE_DATA = {
        'names':
        [
            'pressure',
            'specific_volume',
            'average_temperature',
            'max_temperature',
            'thickness_flow_layer',
        ],
        'dimensions': [1, 1, 1, 1, 1],
        'n_line': 2
    }
    DICT_HOU_STEP_SOLID_DATA = {
        'names':
        [
            'viscosity',
            'shear_velocity',
            'shear_stress',
            'flow_velocity',
        ],
        'dimensions': [1, 1, 1, 3],
        'n_line': 2
    }
    DICT_BOU_GLOBAL_SOLID_DATA = {
        'names':
        [
            'fiber_orientation_tensor',
            'fiber_orientation_vector',
            'fiber_velocity',
            'skin_fiber_orientation_vector',
        ],
        'dimensions': [6, 3, 3, 3],
        'n_line': 2
    }
    DICT_FOU_GLOBAL_NODE_DATA = {
        'names':
        [
            'flow_front_time',
            'flow_stop_time',
            'temp',
        ],
        'dimensions': [1, 1, 2],
        'n_line': 2
    }
    DICT_FOU_STEP_NODE_DATA = {
        'names':
        [
            'fou_pressure',
            'fou_average_velocity_flow_layer',
            'fou_center_temperature',
            'fou_thickness_average_temperature',
            'fou_flow_front_temperature'
        ],
        'dimensions': [1, 1, 1, 1, 1],
        'n_line': 2
    }
    DICT_FOU_STEP_SOLID_DATA = {
        'names':
        [
            'fou_thickness_flow_layer',
            'fou_viscosity',
            'fou_shear_speed',
            'fou_shear_stress',
            'fou_flow_velocity'
        ],
        'dimensions': [1, 1, 1, 1, 3],
        'n_line': 3
    }
    DICT_ROU2_GLOBAL_NODE_DATA = {
        'names':
        [
            'temperature_difference',
        ],
        'dimensions': [1],
        'n_line': 1
    }

    @classmethod
    def read_files(cls, file_names, read_mesh_only=False, time_series=False):
        """Initialize TIMONEntity object.

        Args:
            file_names: list of str
                File names.
            read_mesh_only: bool, optional [False]
                If true, read mesh (nodes and elements) and ignore
                material data, results and so on.
        """
        obj = cls()
        obj.file_names = file_names

        str_data = obj._read_str_data()
        print('Parsing data')
        obj._read_msh(str_data['msh'])
        obj.remove_useless_nodes()

        if not read_mesh_only:
            # Read hou file
            if 'hou' in str_data and str_data['hou'] is not None:
                obj._read_ou_file(
                    str_data['hou'],
                    dict_global_node_data=obj.DICT_HOU_GLOBAL_NODE_DATA,
                    dict_step_node_data=obj.DICT_HOU_STEP_NODE_DATA,
                    dict_step_solid_data=obj.DICT_HOU_STEP_SOLID_DATA)

            # Read bou file
            if 'bou' in str_data and str_data['bou'] is not None:
                obj._read_ou_file(
                    str_data['bou'],
                    dict_global_solid_data=obj.DICT_BOU_GLOBAL_SOLID_DATA)

            # Read fou file
            if 'fou' in str_data and str_data['fou'] is not None:
                obj._read_ou_file(
                    str_data['fou'],
                    dict_global_node_data=obj.DICT_FOU_GLOBAL_NODE_DATA,
                    dict_step_node_data=obj.DICT_FOU_STEP_NODE_DATA,
                    dict_step_solid_data=obj.DICT_FOU_STEP_SOLID_DATA)

            # Read rou2 file
            if 'rou2' in str_data and str_data['rou2'] is not None:
                obj._read_rou2(
                    str_data['rou2'],
                    dict_global_node_data=obj.DICT_ROU2_GLOBAL_NODE_DATA)

            # Read lbt file
            if 'lbt' in str_data and str_data['lbt'] is not None:
                obj._read_ideas_universal(str_data['lbt'], names=(
                    'inflow_gate', 'flow_length', 'flow_length_by_thickness'))

        obj.elements = obj.elements['tet']
        obj.remove_useless_nodes()
        obj.settings['solution_type'] = 'TIMON'

        return obj

    def _read_rou2(self, str_data, dict_global_node_data):
        n_line = str_data[[0]].to_values(data_type=int, to_rank1=True)[0]
        data = str_data.iloc[1:n_line + 1].to_values(r'\s+')[:, 1, None]
        self.nodal_data.update({
            dict_global_node_data['names'][0]:
            FEMAttribute(
                dict_global_node_data['names'][0],
                ids=self.nodes.ids, data=data)})

    def _read_str_data(self):
        """Read each file specified the extension.

        Returns:
            dict with keys = string of extention, values = StringSeries
            object.
        """
        str_data = {
            'msh': self._read_files(r'\.msh', strip=True),
            'hou': self._read_files(r'\.hou', strip=True, mandatory=False),
            'bou': self._read_files(r'\.bou', strip=True, mandatory=False),
            'fou': self._read_files(r'\.fou', strip=True, mandatory=False),
            'lbt': self._read_files(r'\.lbt', strip=True, mandatory=False),
            'rou2': self._read_files(r'\.rou2', strip=True, mandatory=False),
        }
        return str_data

    def _read_msh(self, string_series):
        self.nodes = self._read_nodes(string_series)
        self.elements = {
            'tet': self._read_element(string_series, 'SOLID', 4),
        }
        beam_data = self._read_element(string_series, 'BEAM', 2)
        if beam_data:
            self.elements.update({'beam': beam_data})
        return

    def _read_nodes(self, string_series):
        n_node = string_series.iloc[[1]].to_values(
            r'\s+', data_type=int)[0, 0]
        start = 2
        end = start + n_node
        return string_series.iloc[start:end].to_fem_attribute(
            'NODE', 0, slice(1, None), delimiter=r'\s+')

    def _read_element(self, string_series, keyword, n_node_per_element):
        candicate_element_header_index = string_series.indices_matches(
            fr"^\s*{keyword}\s*$")
        if len(candicate_element_header_index) != 1:
            raise ValueError(
                f"{len(candicate_element_header_index)} solid keywords found. "
                'The format may be wrong.')
        element_header_index = candicate_element_header_index[0]

        n_element = string_series.iloc[[element_header_index + 1]].to_values(
            r'\s+', data_type=int)[0, 0]
        if n_element == 0:
            return None
        start = element_header_index + 2
        end = start + n_element * 2
        element_ids = string_series.iloc[start:end:2].split_vertical_all(
            delimiter=r'\s+')[1].to_values(data_type=int, to_rank1=True)
        element_data = string_series.iloc[start + 1:end:2].to_values(
            delimiter=r'\s+', data_type=int)[:, :n_node_per_element]
        return FEMAttribute(
            'ELEMENT', ids=element_ids, data=element_data)

    def _read_ou_file(
            self, string_series, *,
            dict_global_node_data=None, dict_global_solid_data=None,
            dict_step_node_data=None, dict_step_solid_data=None):
        _n, _e = self._read_global_data(
            string_series,
            dict_global_node_data=dict_global_node_data,
            dict_global_solid_data=dict_global_solid_data)
        self.nodal_data.update(_n)
        self.elemental_data.update(_e)

        _n, _e = self._read_step_data(
            string_series, step=-1,
            dict_step_node_data=dict_step_node_data,
            dict_step_solid_data=dict_step_solid_data)
        self.nodal_data.update(_n)
        self.elemental_data.update(_e)

    def _read_global_data(
            self, string_series, *,
            dict_global_node_data=None, dict_global_solid_data=None):
        if dict_global_node_data is None and dict_global_solid_data is None:
            return {}, {}

        elif dict_global_node_data is not None \
                and dict_global_solid_data is None:
            candicate_step_indices = string_series.indices_matches(
                r"^\s*NODE\s")
            start = candicate_step_indices[0] + 2
            end = start + len(self.nodes.ids) * 2
            array = string_series[start:end:2].to_values(delimiter=r'\s+')
            nodal_data = self.parse_array_and_dict(
                array, self.nodes.ids, dict_global_node_data)
            return nodal_data, {}

        elif dict_global_node_data is None \
                and dict_global_solid_data is not None:
            start = 6
            end = start + len(self.elements['tet'].ids) * 4

            array = st.StringSeries.connect_all(
                [string_series[start+i:end:4] for i in range(3)],
                delimiter=' ').to_values(delimiter=r'\s+')
            elemental_data = self.parse_array_and_dict(
                array, self.elements['tet'].ids, dict_global_solid_data)
            return {}, elemental_data

    def _read_step_data(
            self, string_series, *,
            step=-1, dict_step_node_data=None, dict_step_solid_data=None):
        if dict_step_node_data is None and dict_step_solid_data is None:
            return {}, {}
        elif dict_step_node_data is None or dict_step_solid_data is None:
            raise NotImplementedError

        candicate_step_index = string_series.indices_matches(
            r"^\s*STEP\s*")
        step_index = candicate_step_index[step]
        if step == -1:
            end_step = len(string_series)
        else:
            end_step = candicate_step_index[step + 1]
        string_step = string_series[step_index:end_step]

        node_start = string_step.indices_matches(
            r"^\s*NODE\s*")[0] + 1
        length_node_data = len(self.nodes.ids) * dict_step_node_data['n_line']
        node_end = node_start + length_node_data
        nodal_array = string_step[node_start:node_end][
            np.arange(
                length_node_data, dtype=int)
            % dict_step_node_data['n_line'] != 0].to_values(
            delimiter=r'\s+')

        nodal_array = np.concatenate([
            string_step[
                node_start+i:node_end:dict_step_node_data['n_line']
            ].to_values(delimiter=r'\s+')
            for i
            in range(1, dict_step_node_data['n_line'])], axis=1)

        nodal_data = self.parse_array_and_dict(
            nodal_array, self.nodes.ids, dict_step_node_data)

        solid_start = string_step.indices_matches(
            r"^\s*SOLID\s*")[0] + 1
        length_element_data = len(self.elements['tet'].ids) \
            * dict_step_solid_data['n_line']
        solid_end = solid_start + length_element_data
        elemental_array = np.concatenate([
            string_step[
                solid_start+i:solid_end:dict_step_solid_data['n_line']
            ].to_values(delimiter=r'\s+')
            for i
            in range(1, dict_step_solid_data['n_line'])], axis=1)

        elemental_data = self.parse_array_and_dict(
            elemental_array, self.elements['tet'].ids, dict_step_solid_data)

        return nodal_data, elemental_data

    def parse_array_and_dict(self, array, ids, dict_):
        out_data = {}
        cum_dim = 0
        for name, dim in zip(dict_['names'], dict_['dimensions']):
            out_data.update({
                name:
                FEMAttribute(name, ids, array[:, cum_dim:cum_dim+dim])})
            cum_dim += dim
        return out_data
