import numpy as np
import torch
from torch_scatter import scatter_min
from tqdm.auto import tqdm

DEBUG = True

class AdvectionModelMP(torch.nn.Module):

    def __init__(
            self, mp_layer, mask_op, max_msg_passing_rounds=25, constant_history=False, 
            progress=True, adaptive_steps=False, **kwargs
        ):
        super().__init__(**kwargs)
        # The maximum number of message passing rounds. More message-passing
        # rounds are required in areas of low flow to achieve high accuracy.
        self.max_msg_passing_rounds = max_msg_passing_rounds
        self.mask_op = mask_op
        self.advection_mpnn = mp_layer
        self.constant_history = constant_history
        self.progress = progress
        self.adaptive_steps = adaptive_steps
        self.debug_info = {}

    def aggregated_time(self, n_nodes, Tau, edge_mask, edge_index, delay_steps, agg_time=None):
        snd, rec = edge_index[:, edge_mask]
        delay_steps = delay_steps[edge_mask]
        n_edges = len(edge_mask)
        n_time = delay_steps.shape[1]
        
        if agg_time is None:
            time = delay_steps
            agg_time = torch.zeros((n_edges, n_time)).to(delay_steps.device)
            agg_time[edge_mask] = time.abs()
            # initialize the maximum delay time
            node_time = torch.zeros((n_nodes, n_time)).to(delay_steps.device) + Tau + 1
            # for all nodes with an inoming edge, update the node time with
            # that edge value, as this will be the amout of temporal information transmitted
            # take the minimum to make sure that node remains active for all infomration 
            # to pass through. Note that at all root nodes (e.g. reservoirs), the
            # maximum delay time will remain, as all the information already passed through it 
            # independent of delay time.
            node_time_pos, _ = scatter_min( time + (time <= 0) * 1e10, rec, dim=0, out=node_time)
            node_time, _     = scatter_min(-time + (time > 0) * 1e10, snd, dim=0, out=node_time_pos)
        else:
            agg_time, node_time = agg_time
            time = delay_steps

            # update the amout of data that each node has transmitted in this step
            new_node_time = torch.zeros_like(node_time) + Tau + 1
            node_time_pos, _ = scatter_min(node_time[snd] + (time >= 0) * 1e10, rec, dim=0, out=new_node_time)
            node_time, _ = scatter_min(node_time[rec] + (time < 0) * 1e10, snd, dim=0, out=node_time_pos)
            agg_time[edge_mask] += node_time[snd] * (time >= 0)
            agg_time[edge_mask] += node_time[rec] * (time < 0)

        return (agg_time, node_time)
    
    def create_initial_edge_mask(self, edge_index):
        return torch.ones(edge_index.shape[1], dtype=bool).to(edge_index.device)
        
    def set_flow_field(self, flows, delay_steps):
        self.flows = flows
        self.delay_steps = -delay_steps

    def set_step_flow_params(self, step):
        self._step_flows = self.flows[:,:step]
        self._step_delay_steps = self.delay_steps[:,:step]
        self._step_sl_mask = self.sl_mask[:,:step]
        self._step_xs_map = None
        if self.xs_map is not None:
            self._step_xs_map = self.xs_map[:,:step]
    
    def set_sl_mask(self, sl_mask):
        self.sl_mask = sl_mask
    
    def forward(
            self, x, edge_index, flows, delay_steps, sl_mask, Tau, n_steps=1, boundary_values=None, boundary_index=None, **kwargs
        ):
        xT = x.clone() # save all results
        average_steps = 0
        n_nodes, cT = x.shape[:2]
        edge_passes = 0
        edges_active = []
        agg_times = []
        self.node_time_prev = None

        self.flows = flows
        self.delay_steps = -delay_steps
        self.sl_mask = sl_mask
        self.xs_map = kwargs.pop('xs_map')

        step = 0
        output_steps = cT + Tau * n_steps
        
        while x.shape[1] < output_steps:
        # for step in range(n_steps):
            if self.adaptive_steps:
                Tau = self.delay_steps.abs().min().int()
            self.set_step_flow_params(cT + Tau)
            time_active = self.create_initial_edge_mask(edge_index)
            
            agg_time, masking = None, None
            
            agg1 = torch.nn.functional.pad(x.clone(), (0,0,0,Tau,0,0))
            aggs = agg1.clone()

            if boundary_values is not None:
                assert boundary_index is not None
                start_idx = 0 if not self.constant_history else step*Tau
                agg1[boundary_index] = boundary_values[:,start_idx:agg1.shape[1]]
                
            total_ts = cT + Tau
            sigm_mask = torch.nn.functional.sigmoid(
                (
                    torch.linspace(-1, 1, total_ts).to(agg1.device) 
                    - cT/(total_ts) + Tau/(total_ts)
                ) * 1e12
            ).unsqueeze(-1)
            
            aggs_all = [agg1.detach()]
            
            if self.progress:
                progress = tqdm(total=100)

            for i in range(self.max_msg_passing_rounds):

                if not time_active.any():
                    if self.progress:
                        progress.set_description(f'Total iterations: {i}')
                    break

                flows = self._step_flows
                warp_map = self._step_delay_steps
                sl_mask = self._step_sl_mask
                xs_map = self._step_xs_map
                
                agg1 = self.advection_mpnn(agg1, edge_index, flows, warp_map, sl_mask, time_active, xs_map=xs_map, **kwargs)
                # Compute how much time passes for each edge (for efficiency reasons)
                agg_time = self.aggregated_time(n_nodes, Tau, time_active, edge_index, warp_map, agg_time)

                # Set zeros at injection node, this is a dirichlet boundary condition and we now the function value
                # Note: This is not universally true, if we inject mass then the masses should mix at 
                # the boundary condition too (TODO: Make a parameter for this)
                if boundary_index is not None:
                    agg1[boundary_index] *= 0
                    
                agg1 = agg1 * sigm_mask

                if i >= 0:
                    time_active = torch.logical_and(time_active, (agg_time[0].abs() <= (Tau+1)).any(1))

                if self.progress:
                    progress.n = np.round((agg_time[0].abs() >= Tau).float().mean().item() * 100., decimals=2)
                    progress.refresh()
                
                agg_times.append(agg_time)
                aggs_all.append(agg1.detach())
                aggs += agg1
                
                #TODO: (fix shape issue) edge_passes += 1# (agg_time[0] <= Tau).int()
            
            if self.progress:
                progress.close()   
            average_steps += i
            x_new = aggs
            xT = x_new
            
            if boundary_values is not None:
               xT[boundary_index] = boundary_values[:,start_idx:xT.shape[1]]
            if self.constant_history:
                x = xT[:,(step+1)*Tau:]
            else:
                x = xT
                cT = cT + Tau
            step += 1
            
        return xT, edge_passes, edges_active, agg_times, aggs_all
    
