import torch
import numpy as np
import modules
from model.advection_model_mp import AdvectionModelMP
from model.advection_layer import AdvectionLayer
import networkx as nx
from itertools import count

def semi_lagrangian_mpnn_1d(
    initial_state, flow_field, dx, L, output_times, dt, control_inputs=None, 
    control_indices=None, interpolation='bilinear', progress=True, mixing_at_nodes=True, iterate=False, adaptive_steps=False
):
    N = int(np.ceil(L / dx)) # int(L / dx)
    n_edges = N - 1
    nsteps = flow_field.shape[1]

    edge_index = torch.stack((
        torch.arange(0, n_edges), torch.arange(1, N)
    ))
    edge_lengths = torch.ones(n_edges) * dx
    
    result, _, _, _, _ = semi_lagrangian_mpnn(
        initial_state, flow_field, edge_index, edge_lengths, output_times, dt, nsteps-1,
        control_inputs=control_inputs, control_indices=control_indices,
        interpolation=interpolation, max_msg_passing_rounds=5000, progress=progress,
        mixing_at_nodes=mixing_at_nodes, iterate=iterate, adaptive_steps=adaptive_steps
    )
    return result.numpy()

def semi_lagrangian_mpnn_2d(
    initial_state, flow_field, dx, dy, output_times, dt, control_inputs=None, 
    control_indices=None, interpolation='bilinear', progress=True, mixing_at_nodes=True, iterate=False, adaptive_steps=False,
    max_msg_passing_rounds=5000
):
    Ny, Nx = initial_state.shape[:2]
    nsteps = flow_field.shape[-1]

    initial_state = initial_state.reshape(Ny * Nx, -1)#[:,None].repeat(10, 1)
    flow_field = flow_field.reshape(-1, nsteps)

    G = nx.grid_2d_graph(Nx, Ny, create_using=nx.Graph)
    #G = Gu.to_directed()
    #G.remove_edges_from(list(G.edges))
    #G.add_edges_from(list(Gu.edges))

    edge_index = np.stack(nx.relabel_nodes(G, dict(zip(G.nodes, count()))).edges, 1)
    flow_field = flow_field[edge_index[0]] # - flow_field[edge_index[1]] # map to edges
    edge_lengths = np.where([ np.subtract(fromn, ton)[1] != 0 for fromn, ton in G.edges ], dy, dx)
    # edge_index = np.concatenate((edge_index, edge_index[::-1]), axis=1)
    # edge_lengths = np.concatenate((edge_lengths, edge_lengths))
    # flow_field = np.concatenate((flow_field, -flow_field), axis=0)
    
    result, _, _, _, _ = semi_lagrangian_mpnn(
        initial_state, flow_field, edge_index, edge_lengths, output_times, dt, nsteps-initial_state.shape[-1],
        control_inputs=control_inputs, control_indices=control_indices,
        interpolation=interpolation, max_msg_passing_rounds=max_msg_passing_rounds, progress=progress,
        mixing_at_nodes=mixing_at_nodes, iterate=iterate, adaptive_steps=adaptive_steps
    )
    return result.numpy(), flow_field


def semi_lagrangian_mpnn(
    initial_state, flow_field, edge_index, edge_lengths, output_times, dt, nsteps, model=None,
    edge_capacities=1., control_inputs=None, control_indices=None, interpolation='bilinear',
    max_msg_passing_rounds=300, progress=True, mixing_at_nodes=True, iterate=False, adaptive_steps=False,
    device='cpu'
):
    model_inputs = prepare_model_inputs(
        initial_state, flow_field, edge_index, edge_lengths, output_times, dt, 
        nsteps, edge_capacities, control_inputs, control_indices, iterate, device
    )
    
    if model is None:
        # Create Solver object
        advection_op = modules.AdvectionModuleGridSampleDynamic(interpolation_mode=interpolation)
        mask_op = modules.MaskingModuleSigmoid()
        mask_op.mask_temp = 0.00001

        # model = AdvectionModelMP(
        #     mask_op, advection_op, max_msg_passing_rounds=max_msg_passing_rounds,
        #     mixing_at_nodes=mixing_at_nodes
        # )
        layer = AdvectionLayer(advection_op, mixing_at_nodes)
        model = AdvectionModelMP(
            layer, mask_op, max_msg_passing_rounds=max_msg_passing_rounds,
            progress=progress, adaptive_steps=adaptive_steps
        )
    
    model.to(device)
    #model.set_flow_field(flow_field_graph * edge_capacities, delay_steps)
    #model.set_sl_times(selfloop_mask)

    pred, edge_passes, _, agg_time, aggs_all = model(**model_inputs)
    #     initial_state, edge_index, nsteps, dt, n_steps=n_steps,
    #     boundary_index=boundary_index, x_boundary=boundary_input
    # )
    
    pred[:,:model_inputs['x'].shape[1]] = model_inputs['x']
    
    return pred, edge_passes, _, agg_time, aggs_all


def prepare_model_inputs(
    initial_state, flow_field, edge_index, edge_lengths, output_times, dt, 
    nsteps, edge_capacities=1., control_inputs=None, control_indices=None, 
    iterate=False, device='cpu'
):
    _, n_edges = edge_index.shape
    flow_field_graph = flow_field[:n_edges]

    traversal_times, selfloop_mask, xs_map = modules.compute_backward_transit_times_fast(
        edge_lengths, flow_field_graph, dt
    )
    
    traversal_times = np.nan_to_num(traversal_times, nan=nsteps*dt)
    flow_field_graph = torch.as_tensor(flow_field_graph, dtype=torch.get_default_dtype()).to(device)
    # traversal_times, selfloop_mask = modules.compute_backward_transit_times(
    #     torch.as_tensor(edge_lengths), flow_field_graph, dt
    # ) 

    edge_capacities = torch.as_tensor(edge_capacities, dtype=torch.get_default_dtype()).to(device)
    initial_state = torch.as_tensor(initial_state, dtype=torch.get_default_dtype()).to(device)
    traversal_times = torch.as_tensor(traversal_times, dtype=torch.get_default_dtype()).to(device)
    xs_map = torch.as_tensor(xs_map, dtype=torch.get_default_dtype()).to(device)
    selfloop_mask = torch.as_tensor(selfloop_mask, dtype=torch.get_default_dtype()).to(device)
    edge_index = torch.as_tensor(edge_index).to(device)

    if control_indices is not None:
        boundary_index = torch.tensor(control_indices)
        boundary_input = torch.tensor(control_inputs, dtype=torch.get_default_dtype()).to(device)
        if boundary_input.ndim < 2:
            boundary_input = boundary_input.unsqueeze(0)
        if boundary_input.ndim < 3:
            boundary_input = boundary_input.unsqueeze(-1)
    else:
        boundary_index = boundary_input = None
    
    if edge_capacities.ndim < 2 and edge_capacities.ndim > 0:
        edge_capacities = edge_capacities.unsqueeze(1)
    if initial_state.ndim < 2:
        initial_state = initial_state.unsqueeze(-1)
    if initial_state.ndim < 3:
        initial_state = initial_state.unsqueeze(-1)

    delay_steps = traversal_times.clamp(max=max(output_times)*dt)
    delay_steps = delay_steps / dt

    if iterate:
        n_steps = nsteps
        nsteps = 1
    else:
        n_steps = 1
    
    return {
        'x' : initial_state,
        'edge_index' : edge_index,
        'flows' : flow_field_graph * edge_capacities,
        'delay_steps' : delay_steps,
        'xs_map' : xs_map,
        'sl_mask' : selfloop_mask,
        'Tau' : nsteps,
        'n_steps' : n_steps,
        'boundary_values' : boundary_input,
        'boundary_index' : boundary_index,
    }