from epyt_flow.simulation import ScenarioSimulator
from epyt.epanet import ToolkitConstants
import numpy as np
import utils as functions
from wds_utils import InpSetup

def run_quality_simulation_spike(
        f_inp_in, duration, hydraulic_step, f_msx_in=None, quality_step=None, 
        pattern_step=None, sim_setup_fns=[], pattern=None, progress=True, injection_nodes=[],
        source_type=ToolkitConstants.EN_CONCEN, inp_setup_fns=[], **kwargs
    ):
    if quality_step is None:
        quality_step = hydraulic_step
    if pattern_step is None:
        pattern_step = hydraulic_step * 6

    with InpSetup(f_inp_in, inp_setup_fns) as inp_file:
        
        with ScenarioSimulator(f_inp_in=inp_file, f_msx_in=f_msx_in) as sim:
            # Set general parameters
            sim.set_general_parameters(
                simulation_duration=duration,
                hydraulic_time_step=hydraulic_step,
                reporting_time_step=hydraulic_step,
                quality_time_step=quality_step,
                flow_units_id=ToolkitConstants.EN_CMH,
                **kwargs
            )
            # Set initial concentration and simple (constant) reactions
            #zeroNodes = [0] * sim.epanet_api.getNodeCount()
            #sim.epanet_api.setNodeInitialQuality(zeroNodes)
            # sim.epanet_api.setLinkBulkReactionCoeff([-0.] * sim.epanet_api.getLinkCount())
            # sim.epanet_api.setLinkWallReactionCoeff([-.0] * sim.epanet_api.getLinkCount())
            #sim.epanet_api.setFlowUnitsCMH()
            # Add chlorine injection at the reservoir
            sim.enable_chemical_analysis()
            #sim.enable_chemical_analysis(chemical_name='CL2')
            # sim.enable_chemical_analysis(chemical_name='THM')
            sim.epanet_api.setTimePatternStep(pattern_step)

            if f_msx_in is not None:
                sim.epanet_api.setMSXTimeStep(quality_step)
            
            for fn in sim_setup_fns:
                fn(sim)
            
            if pattern is None:
                pattern_length = duration // pattern_step
                pattern = np.eye(pattern_length)[2] + np.eye(pattern_length)[3]*1e-10
                pattern += np.eye(pattern_length)[13] + np.eye(pattern_length)[14]*1e-10
                pattern += np.eye(pattern_length)[23] + np.eye(pattern_length)[24]*1e-10
                
            for reservoir_id in injection_nodes: # sim.sensor_config.nodes:# sim.epanet_api.getNodeReservoirNameID() + sim.epanet_api.getNodeTankNameID():
                if f_msx_in is not None:
                    # if reservoir_id != 'Injection':
                    #     continue
                    sim.add_species_injection_source(
                        'CL2', 
                        node_id=reservoir_id,
                        pattern=pattern,
                        source_type=source_type#ToolkitConstants.EN_CONCEN#EN_MASS
                    )
                else:
                    sim.add_quality_source(
                        node_id=reservoir_id,
                        pattern=pattern,
                        source_type=source_type#ToolkitConstants.EN_CONCEN#EN_MASS
                    )

            # Place quality and flow sensor everywhere
            sim.set_flow_sensors(sensor_locations=sim.sensor_config.links)
            sim.set_node_quality_sensors(sensor_locations=sim.sensor_config.nodes)
            sim.set_pressure_sensors(sensor_locations=sim.sensor_config.nodes)

            if f_msx_in is not None:
                sim.set_bulk_species_node_sensors(sensor_info={
                        "CL2": sim.sensor_config.nodes,
                #        "THM": sim.sensor_config.nodes
                })
            #sim.set_link_quality_sensors(sensor_locations=sim.sensor_config.links)

            # Run simulation and store results
            res = sim.run_simulation(verbose=progress)
            topo = sim.get_topology()
            
    return topo, res, pattern

def run_quality_simulation(inp_file, duration, hydraulic_step, quality_step, msx_file=None):
    pattern_length = duration // hydraulic_step
    N = 3
    seed = 42
    pattern = functions.create_wavy_pattern(pattern_length, N, hydraulic_step, seed)

    sim_kwargs = {
        'duration' : duration,
        'hydraulic_step' : hydraulic_step,
        'quality_step' : quality_step,
        'pattern' : pattern,
        #'f_msx_in' : 'networks/ltown.msx',
        #'sim_setup_fns' : [partial(functions.set_sim_demands, base=0.5, pattern=np.array([1.]), nodelist=list(map(lambda x: x[0], filter(lambda n: 'seg' not in n[0], topology.nodes()))))]
    }
    if msx_file is not None:
        sim_kwargs['f_msx_in'] = msx_file

    topo, res, inj_pattern = run_quality_simulation_spike(inp_file, **sim_kwargs)
    data = res.get_data_nodes_quality() # [time, nodes]
    edge_flows = res.get_data_flows() # [time, nodes]
    
    sources_at = topo.get_all_reservoirs()
    sources = [ n in sources_at for n in topo.get_all_nodes() ]
    reservoir_ids = functions.get_node_index(topo, topo.get_all_reservoirs())

    lengths = functions.get_edge_attribute(topo, 'length')
    roughness = functions.get_edge_attribute(topo, 'roughness_coeff')
    diameters = functions.get_edge_attribute(topo, 'diameter')
    flow_velocity = functions.flow_to_velocity(topo, edge_flows)
    boundary_index = functions.get_node_index(topo, sources_at)

    return {
        'topo' : topo,
        'data' : data.T,
        'source_nodes' : sources_at,
        'source_mask' : sources,
        'reservoir_ids' : reservoir_ids,
        'lengths' : lengths,
        'diameters' : diameters,
        'roughness' : roughness,
        'flows' : edge_flows.T / 60. / 60.,
        'flow_velocity' : flow_velocity.T,
        'roughness' : roughness,
        'boundary_index' : boundary_index,
        'n_edges' : edge_flows.shape[1]
    }

def inp_to_graph_data(
        inp_file, injection_pattern, injection_nodes, n_seconds, dt, quality_dt=None, 
        pattern_dt=None, f_msx_in='networks/ltown.msx', sim_setup_fns=[], 
        inp_setup_fns=[], out_file=None, **kwargs
    ):

    sim_kwargs = {
        'duration' : n_seconds,
        'hydraulic_step' : dt,
        'quality_step' : quality_dt,
        'pattern' : injection_pattern,
        'pattern_step' : pattern_dt,
        'f_msx_in' : f_msx_in,
        'sim_setup_fns' : sim_setup_fns,
        'inp_setup_fns' : inp_setup_fns,
        'injection_nodes' : injection_nodes
    }
    
    sim_kwargs.update(kwargs)
    
    topo, res, inj_pattern = run_quality_simulation_spike(inp_file, **sim_kwargs)
    if f_msx_in is not None:
        data = res.get_data_bulk_species_node_concentration() # [time, nodes]
    else:
        data = res.get_data_nodes_quality()
    edge_flows = res.get_data_flows() # [time, nodes]
    
    edge_index = functions.make_edge_index(topo)
    #tensor_data = torch.tensor(data[:].T).float()
    true_flows = functions.flow_to_velocity(topo, edge_flows).T
    edge_index_t = edge_index
    lengths = functions.get_edge_attribute(topo, 'length')
    diameter = functions.get_edge_attribute(topo, 'diameter')/10/100
    boundary_index = functions.get_node_index(topo, injection_nodes)
    
    if out_file is not None:
        res.save_to_file(out_file)
        
    return {
        'edge_index' : edge_index_t,
        'epanet_result' : data[:].T,
        'flow_field' : true_flows,
        'edge_lengths' : lengths,
        'edge_diameter' : diameter,
        'boundary_index' : boundary_index,
        'boundary_values' : data[:, boundary_index].T,
        'topology' : topo,
        'res' : res
    }