import utils as functions
import numpy as np
import simulate
import torch
import torch_geometric
import os
import pickle
import modules
import wds_utils

def create_random_demand_profiles(avg_demand_min, avg_demand_max, n_nodes, n_times):
    '''
    Make demand pattern for a WDS with average demand across nodes (ADN)
    ADN >= avg_demand_min
    ADN <= avg_demand_max
    Within the number of timesteps (n_times), both of these bounds will be 
    reached at least one time.
    '''
    z = np.random.normal(0, 0.04, size=(n_nodes, 100 + n_times))
    dp = np.abs((np.random.normal(size=(1, 100 + n_times)) + z).cumsum(1))[:,100:]
    #dp = dp / dp.sum(0) * avg_demand_max
    dp = (dp - dp.min()) / (dp.max() - dp.min())
    dp = np.abs(dp )#.clip(0)
    dp = dp * (avg_demand_max - avg_demand_min) + avg_demand_min
    return dp

class RandomHydraulicsDataset:

    def __init__(self, inp_file, n_seconds, cT, Tau, dt, quality_dt, pattern_dt,
                 diameter_bounds, length_bounds, average_demand_bounds,
                 pattern, sources_at, data_path='./dataset', equal_demand_at_nodes=False,
                 **sim_kwargs):
        '''
        inp_file: The inp file used for the simulation
        n_seconds: The number of seconds that the simulation runs 
            (n_seconds // Tau + 1 is the number of steps returned)
        cT: (s. Tau)
        Tau: cT corresponds to our (known) history and Tau to the number of
            time steps that we want to predict. cT and Tau parameters have no 
            effect on the data, they only effect the size of the x and y 
            output variables. So x will have shape [num_nodes, cT], 
            y will have shape [num_nodes, cT+Tau].
            (Note: Thus the whole simulation output is stored in y.
            cT and Tau are a little bit uneccessary to have, we should think 
            about removing them and only keeping n_seconds))
        dt: This is the hydraulics time step that is also the reporting time step.
        quality_dt: Quality time step used by EPANET
        pattern_dt: Pattern time step used by EPANET
        diameter_bounds: Used as distributional bounds to randomize pipe diameters
        length_bounds: Used as distributional bounds to randomize pipe lengths
        average_demand_bounds: Used as distributional bounds to randomize demands
        pattern: The injection pattern (potentially also allow to randomize this)
        sources_at: The chlorine injection nodes
        data_path: The directory where data will be stored.
        sim_kwargs: Other simulation setup parameters.
        '''
        self.topology = functions.read_inp(inp_file)
        self.diameter_bounds = diameter_bounds 
        self.length_bounds = length_bounds
        self.average_demand_bounds = average_demand_bounds
        self.inp_file = inp_file
        self.n_seconds = n_seconds
        self.dt = dt
        self.quality_dt = quality_dt
        self.pattern_dt = pattern_dt
        self.sim_kwargs = sim_kwargs
        self.pattern = pattern
        self.sources_at = sources_at
        self.cT = cT
        self.Tau = Tau
        self.wds_name = os.path.basename(self.inp_file)[:-4] # remove .inp
        self.data_path = os.path.join(data_path, self.wds_name)
        self.diameter_order = self.topology.query_link_attribute('diameter').values.argsort()
        self.equal_demand_at_nodes = equal_demand_at_nodes
        if self.equal_demand_at_nodes:
            self.data_path = os.path.join(self.data_path, 'equal_demands')
        os.makedirs(self.data_path, exist_ok=True)
        self.initialize_variable_maxima()
    
    def initialize_variable_maxima(self):
        # convert from mm to m
        self.min_diameter, self.max_diameter = np.divide(self.diameter_bounds, 1000)
        # convert from L to m^3
        min_demand, max_demand = np.divide(self.average_demand_bounds, 1000)
        self.max_edge_capacity = (self.max_diameter / 2.)**2*np.pi
        self.min_edge_capacity = (self.min_diameter / 2.)**2*np.pi
        self.max_edge_length = self.length_bounds[1]
        self.max_flow = max_demand * self.topology.num_nodes
        self.max_vel = self.max_flow / self.min_edge_capacity
        self.max_delay_steps = (self.cT + self.Tau) * 2
        self.max_travel_time = self.max_delay_steps * self.dt

    @property
    def n_junctions(self): 
        return self.topology.num_junctions
    
    @property
    def n_nodes(self): 
        return self.topology.num_nodes
    
    @property
    def n_links(self): 
        return self.topology.num_links
    
    @property
    def pattern_length(self):
        return int(self.n_seconds // self.pattern_dt)
    
    def _generate_next(self):
        dp = create_random_demand_profiles(
            *self.average_demand_bounds, self.n_junctions, self.pattern_length
        )
        if self.equal_demand_at_nodes:
            dp = dp[:1].repeat(self.n_junctions, axis=0)
        lengths = np.random.uniform(*self.length_bounds, size=self.n_links)
        diameters = np.random.uniform(*self.diameter_bounds, size=self.n_links)
        # To generate more plausible networks, keep larger pipes large 
        # and smaller pipes small (large at reservoir, small at leaf nodes)
        ordered_dias = np.zeros_like(diameters)
        ordered_dias[self.diameter_order] = sorted(diameters)
        return dp, lengths, ordered_dias
    
    def get_sim_fn(self, sources_at=None):
        dp, lengths, diameters = self._generate_next()
        def _set_params(sim):
            functions.set_sim_demands(
                sim, base=[1.0] * self.n_junctions, pattern=dp, 
                nodelist=self.topology.junction_name_list
            ) 
            # TODO: (Fix) This overwrites the injection demands!
            #    If the injection mode is SETPOINT and injections are made 
            #    at a reservoir, then this works, otherwise
            # if sources_at is not None:
            #     functions.set_sim_demands(
            #     sim, base=[-1.0] * self.n_junctions, pattern=dp.sum(0), 
            #     nodelist=sources_at
            # )
            return
        return [_set_params]
    
    def get_inp_setup_fn(self):
        '''
        So far pipe lengths and diameters cannot be changes when EPANET already
        loaded the inp file. Here we create a function to modify the `in_inp`
        file and save it to the `out_inp` file.
        '''
        _, lengths, diameters = self._generate_next()

        def _set_params(in_inp, out_inp):
            wds_utils.set_pipe_lengths(in_inp, lengths, out_inp)
            wds_utils.set_pipe_diameters(out_inp, diameters, out_inp)
            
        return [_set_params]

    def add_normalization_values(self, sample):
        sample['max_diameter'] = self.max_diameter
        sample['max_edge_capacity'] = self.max_edge_capacity
        sample['max_edge_length'] = self.max_edge_length
        sample['max_vel'] = self.max_vel
        sample['max_delay_steps'] = self.max_delay_steps 
        sample['max_flow'] = self.max_flow
        return sample
    
    def get(self, idx):
        sample_path = os.path.join(self.data_path, f'{idx}.pkl')
        
        # Load a pickle file with the data if it has been simulated before.
        if os.path.isfile(sample_path):
            with open(sample_path, 'rb') as f:
                return self.add_normalization_values(pickle.load(f))
        
        sim_setup_fns = self.get_sim_fn(self.sources_at)
        sim_setup_fns = sim_setup_fns + self.sim_kwargs.get('sim_setup_fns', [])
        inp_setup_fns = self.get_inp_setup_fn() + self.sim_kwargs.get('inp_setup_fns', [])
        graph_data = simulate.inp_to_graph_data(
            self.inp_file, self.pattern, self.sources_at, self.n_seconds, self.dt,
            progress=False, sim_setup_fns=sim_setup_fns, inp_setup_fns=inp_setup_fns,
            **self.sim_kwargs
        )
        crossection = graph_data['edge_diameter']**2 * np.pi #* 60 * 60

        _, n_edges = graph_data['edge_index'].shape
        flow_field_graph = graph_data.pop('flow_field')[:n_edges] #/ 60 / 60

        traversal_times, selfloop_mask, xs_map = modules.compute_backward_transit_times_fast(
            graph_data['edge_lengths'], flow_field_graph, self.dt
        )

        graph_data.pop('topology')
        graph_data.pop('res')
        
        sample = dict(
            **graph_data,
            flow_field_graph=flow_field_graph,
            edge_capacity=crossection,
            traversal_times=traversal_times,
            selfloop_mask=selfloop_mask,
            xs_map=xs_map,
            Tau=self.Tau,
            cT=self.cT,
            dt=self.dt,
        )

        # Save data as pickle file
        with open(sample_path, 'wb') as f:
            pickle.dump(sample, f)
            
        return self.add_normalization_values(sample)
    
class TorchRandomHydraulicsDataset(torch_geometric.data.Dataset):

    def __init__(self, size, hydraulics_ds, device='cpu', idx_offset=0, **sim_kwargs):
        super(TorchRandomHydraulicsDataset, self).__init__()
        self.hydraulics_ds = hydraulics_ds
        self.device = device
        self.data_path = './dataset'
        self.size = size
        self.idx_offset = idx_offset
        os.makedirs(self.data_path, exist_ok=True)
    
    def len(self):
        return self.size

    @property
    def dt(self):
        return self.hydraulics_ds.dt
    @property
    def cT(self):
        return self.hydraulics_ds.cT
    @property
    def Tau(self):
        return self.hydraulics_ds.Tau
    
    def get(self, idx):
        idx = idx + self.idx_offset
        sample = self.hydraulics_ds.get(idx)
        edge_capacities = sample['edge_capacity']

        if edge_capacities.ndim < 2 and edge_capacities.ndim > 0:
            edge_capacities = np.expand_dims(edge_capacities, axis=1)
        if sample['boundary_values'].ndim < 2:
            sample['boundary_values'] = np.expand_dims(sample['boundary_values'], axis=0)
            
        flows = sample['flow_field_graph'] * edge_capacities
        
        x = sample['epanet_result'][:, :self.cT]
        y = sample['epanet_result'][:, :self.cT+self.Tau]
        
        traversal_times = np.nan_to_num(sample['traversal_times'], nan=self.Tau*self.dt)
        xs_map = sample['xs_map']
        delay_steps = traversal_times / self.dt
        
        delay_steps = delay_steps.clip(max=self.cT+self.Tau)
        return torch_geometric.data.Data(
            x=torch.as_tensor(x, dtype=torch.get_default_dtype()), 
            y=torch.as_tensor(y, dtype=torch.get_default_dtype()), 
            edge_index=torch.as_tensor(sample['edge_index']),
            edge_lengths=torch.as_tensor(sample['edge_lengths'], dtype=torch.get_default_dtype()),
            edge_diameter=torch.as_tensor(sample['edge_diameter'], dtype=torch.get_default_dtype()),
            edge_capacity=torch.as_tensor(sample['edge_capacity'], dtype=torch.get_default_dtype()),
            traversal_times=torch.as_tensor(traversal_times, dtype=torch.get_default_dtype()),
            xs_map=torch.as_tensor(xs_map, dtype=torch.get_default_dtype()),
            flows=torch.as_tensor(flows, dtype=torch.get_default_dtype()),
            sl_mask=torch.as_tensor(sample['selfloop_mask'], dtype=torch.get_default_dtype()),
            boundary_values=torch.as_tensor(sample['boundary_values'], dtype=torch.get_default_dtype()),
            boundary_index=torch.as_tensor(sample['boundary_index']),
            delay_steps=torch.as_tensor(delay_steps, dtype=torch.get_default_dtype()),
            flow_field=torch.as_tensor(sample['flow_field_graph'], dtype=torch.get_default_dtype()),
            
            max_diameter=torch.as_tensor(sample['max_diameter'], dtype=torch.get_default_dtype()),
            max_edge_capacity=torch.as_tensor(sample['max_edge_capacity'], dtype=torch.get_default_dtype()),
            max_edge_length=torch.as_tensor(sample['max_edge_length'], dtype=torch.get_default_dtype()),
            max_vel=torch.as_tensor(sample['max_vel'], dtype=torch.get_default_dtype()),
            max_delay_steps=torch.as_tensor(sample['max_delay_steps'], dtype=torch.get_default_dtype()),
            max_flow=torch.as_tensor(sample['max_flow'], dtype=torch.get_default_dtype()),
        ).to(self.device)