"""Generate thermal stress analysis data from heat analysis data."""
import argparse
import glob
import pathlib
import random

import femio
import numpy as np


RAW_DIRECTORY = pathlib.Path('data/raw')
POISSON_RATIO = .25
N_TIMESTEP = 11


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'input_directories',
        type=pathlib.Path,
        nargs='+',
        help='Input FrontISTR HEAT results directories')
    parser.add_argument(
        '--n-repetition',
        '-n',
        type=int,
        default=3,
        help='The number of repetition.')
    args = parser.parse_args()

    res_files = []
    for input_directory in args.input_directories:
        res_files = res_files + glob.glob(
            str(input_directory / '**/heat.res.0.100'), recursive=True)
    res_files = set(res_files)
    for res_file in res_files:
        thermal_generater = ThermalGenerator(pathlib.Path(res_file).parent)
        thermal_generater.generate(args.n_repetition)

    return


class ThermalGenerator:

    def __init__(
            self, input_data_directory, *,
            seed=None, thermal_expansion_scale=1.e-3):
        self.input_data_directory = input_data_directory
        self.seed = seed
        self.thermal_expansion_scale = thermal_expansion_scale
        self.fem_data = self._convert_fem_data()

        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)
        return

    def _convert_fem_data(self):
        print(f"Processing: {self.input_data_directory}")
        fem_data = femio.FEMData.read_directory(
            'fistr', self.input_data_directory, time_series=True)

        fem_data.settings = {
            'solution_type': 'STATIC',
            'output_res': 'NSTRAIN,ON\nNSTRESS,ON\n',
            'output_vis': 'NSTRAIN,ON\nNSTRESS,ON\n'}

        fem_data.constraints.reset()
        spring_ids = fem_data.elements.data[0, :3]
        fem_data.constraints['spring'] = femio.FEMAttribute(
            'SPRING', ids=spring_ids, data=np.ones((3, 3)) * 1e-4)

        fem_data.element_groups = {'ALL': fem_data.elements.ids}
        fem_data.sections.update_data(
            'MAT_ALL', {'TYPE': 'SOLID', 'EGRP': 'ALL'}, allow_overwrite=True)
        fem_data.materials.reset()

        if 'INITIAL_TEMPERATURE' not in fem_data.nodal_data:
            raise ValueError(
                f"No INITIAL_TEMPERATURE for: {self.input_data_directory}")
        if 'TEMPERATURE' not in fem_data.nodal_data:
            raise ValueError(
                f"No TEMPERATURE for: {self.input_data_directory}")
        temperatures = fem_data.nodal_data.get_attribute_data('TEMPERATURE')
        if len(temperatures) != N_TIMESTEP:
            raise ValueError(
                f"Not enough timestep for: {self.input_data_directory}\n"
                f"({N_TIMESTEP} expected but {len(temperatures)})")

        fem_data.nodal_data.update_data(
            fem_data.nodes.ids, {'CNT_TEMPERATURE': temperatures[-1]})
        return fem_data

    def generate(self, n_repetition):
        """Generate thermal stress analysis data. Thernal expansion coefficient
        tensor is generated at random.

        Parameters
        ----------
        n_repetition: int
            The number of repetition.
        """
        for i_repetition in range(n_repetition):
            self._generate(i_repetition)
        return

    def _generate(self, i_repetition):
        ltec = (np.random.rand(1, 6) * 2. - 1.) \
            * self.thermal_expansion_scale
        self.fem_data.materials.update_data(
            'MAT_ALL', {
                'Young_modulus': np.array([[1.]]),
                'Poisson_ratio': np.array([[POISSON_RATIO]]),
                'linear_thermal_expansion_coefficient_full': ltec},
            allow_overwrite=True)
        output_directory = self.input_data_directory \
            / f"thermal_rep{i_repetition}"

        self.fem_data.write('fistr', output_directory / 'thermal')
        return


if __name__ == '__main__':
    main()
