import cobra
import logging
import networkx as nx
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.utils.convert import *
from torch_geometric.utils import degree, one_hot
from networkx.algorithms.shortest_paths import *
from contrabass import CobraMetabolicModel
from enum import Enum
import csv
from os import listdir
from os.path import isfile, join
from torch_geometric.graphgym.config import cfg


logging.getLogger('cobra').setLevel(logging.CRITICAL)
cobra.core.model.configuration.solver = 'glpk'
CONST_EPSILON = 1e-08


class Direction(Enum):
    FORWARD = 0
    BACKWARD = 1
    REVERSIBLE = 2
    
    
def reaction_direction(reaction):
    if - CONST_EPSILON < reaction.upper_bound < CONST_EPSILON:
        upper = 0
    else:
        upper = reaction.upper_bound
    if - CONST_EPSILON < reaction.lower_bound < CONST_EPSILON:
        lower = 0
    else:
        lower = reaction.lower_bound

    if lower == 0 and upper == 0:
        # Note that reactions with upper and lower bound equal to 0 are considered FORWARD by default
        return Direction.FORWARD
    elif lower >= 0:
        return Direction.FORWARD
    elif upper > 0:
        return Direction.REVERSIBLE
    else:
        return Direction.BACKWARD

    
def is_dead_reaction(reaction):
    return abs(reaction.upper_bound) < CONST_EPSILON and abs(reaction.lower_bound) < CONST_EPSILON


def blocked_reactions(cobra_model):
    return cobra.flux_analysis.find_blocked_reactions(cobra_model)


def remove_blocked_reactions(cobra_model):
    br = blocked_reactions(cobra_model)
    cobra_model.remove_reactions(br)
    metabolites_to_remove = []
    for m in cobra_model.metabolites:
        if len(m.reactions) <= 1:
            metabolites_to_remove.append(m)
    cobra_model.remove_metabolites(metabolites_to_remove)
    return cobra_model


def _get_biomass_reaction(cobra_model):
    # NOTE: We dont expect to have more than one objective reaction
    return list(cobra.util.solver.linear_reaction_coefficients(cobra_model).keys())[0]
    

def get_largetst_connected_components(model):
    # Create a directed graph from the metabolic network
    G = nx.DiGraph()
    for reaction in model.reactions:
        for metabolite in reaction.metabolites:
            G.add_edge(metabolite.id, reaction.id)

    # Find the strongly connected components in the directed graph
    G = G.to_undirected()
    largest_cc = max(nx.connected_components(G), key=len)

    return [node for node in largest_cc]

            
def get_largest_connected_model(model):
    # Get the connected components in the metabolic network
    largest_component = get_largetst_connected_components(model)

    # Delete reactions and metabolites that are not part of the largest connected component
    reactions_to_remove = []
    metabolites_to_remove = []
    for reaction in model.reactions:
        if reaction.id not in largest_component:
            reactions_to_remove.append(reaction)

    for metabolite in model.metabolites:
        if metabolite.id not in largest_component:
            metabolites_to_remove.append(metabolite)
            
    model.remove_metabolites(metabolites_to_remove)
    model.remove_reactions(reactions_to_remove)
    return model


def open_reactions_flux(model):
    biomass = _get_biomass_reaction(model)
    count = 0
    for r in model.reactions:
        if r == biomass:
            continue
        elif r.lower_bound > 0.0:
            r.lower_bound = 0.0
            count += 1
        elif r.upper_bound < 0.0:
            r.upper_bound = 0.0
            count += 1            
    print(f"Opened {count} reactions")
    return model


def preprocess_model(model):
    # we first get the largest graph, to remove sparse nodes
    model = get_largest_connected_model(model)
    model = open_reactions_flux(model)
    model = remove_blocked_reactions(model)
    # after removing blocked reactions some smaller graphs might appear.
    # need to get the largest component again.
    model = get_largest_connected_model(model)
    return model


def edges_from_model(model, reactions_list, dictionary_reactions, dictionary_metabolites):
    source = []
    target = []
    edge_weight = []

    for r in reactions_list:
        # If a reaction is reversible we first add the reverse reaction edges
        if reaction_direction(r) == Direction.REVERSIBLE:   
            r_id = dictionary_reactions[r.reverse_id]
            for m in r.reactants:
                m_id = dictionary_metabolites[m.id]
                source.append(r_id)
                target.append(m_id)
                edge_weight.append([abs(r.get_coefficient(m))])
            for m in r.products:
                m_id = dictionary_metabolites[m.id]
                source.append(m_id)
                target.append(r_id)
                edge_weight.append([abs(r.get_coefficient(m))])

        if reaction_direction(r) == Direction.FORWARD or reaction_direction(r) == Direction.REVERSIBLE:  
            r_id = dictionary_reactions[r.id]
            for m in r.reactants:
                m_id = dictionary_metabolites[m.id]
                source.append(m_id)
                target.append(r_id)
                edge_weight.append([abs(r.get_coefficient(m))])
            for m in r.products:
                m_id = dictionary_metabolites[m.id]
                source.append(r_id)
                target.append(m_id)
                edge_weight.append([abs(r.get_coefficient(m))])
        else:
            r_id = dictionary_reactions[r.id]
            for m in r.products:
                m_id = dictionary_metabolites[m.id]
                source.append(m_id)
                target.append(r_id)
                edge_weight.append([abs(r.get_coefficient(m))])
            for m in r.reactants:
                m_id = dictionary_metabolites[m.id]
                source.append(r_id)
                target.append(m_id)
                edge_weight.append([abs(r.get_coefficient(m))])

    return source, target, edge_weight


def generate_pyg_data(model):
    reactions_list = list(model.reactions)
    metabolites_list = list(model.metabolites)
    dictionary_reactions = {}
    dictionary_metabolites = {}

    count = 0 # id assigned to each reaction
    for i in metabolites_list:
        dictionary_metabolites[i.id] = count
        count += 1

    for i in reactions_list:
        dictionary_reactions[i.id] = count
        count += 1
        # reverse reaction
        if reaction_direction(i) == Direction.REVERSIBLE:  
            dictionary_reactions[i.reverse_id] = count
            count += 1

    reactions_mask = [] # mask to separate reactions nodes from metabolites nodes
    features = [] # features of each node
    y = [] # label of each node
    index_aggr = [] # index to aggregate reversible reactions
    dictionary_aggr_rxn = {} # similar to dictionary_reactions buth wil reversible reactions collapsed

    '''
    Features of the nodes:
    - 0: is objective reaction (biomass) 
    - 1: is exchange reaction (0=no, 1=yes))
    - 2: is reaction node (0=no, 1=yes))
    - 3: is metabolite node (0=no, 1=yes))
    - 4: is a reversible reaction (0=no, 1=yes)) 
    '''
    BIOMASS_INDEX = 0
    EXCHANGE_INDEX = 1

    i = 0
    for m in metabolites_list:
        features.append([0, 0, 0, 1, 0])
        y.append([0])
        reactions_mask.append(False)
        index_aggr.append(0)
        i += 1
    for r in reactions_list:
        dictionary_aggr_rxn[r.id] = i
        if reaction_direction(r) == Direction.REVERSIBLE:   
            features.append([0, 0, 1, 0, 1])
            features.append([0, 0, 1, 0, 1])
            index_aggr.append(dictionary_reactions[r.reverse_id])
            index_aggr.append(dictionary_reactions[r.id])
            i += 1
            y.append([0])
            y.append([0])
            reactions_mask.append(True)
            reactions_mask.append(True)
        else:
            features.append([0, 0, 1, 0, 0])
            #index_aggr.append(i)
            i += 1
            y.append([0])
            reactions_mask.append(True)
            index_aggr.append(0)
        
        
    # the length of aggregated reactions must match the output vector
    assert (len(reactions_list) + len(metabolites_list)) == i
            
    # Set biomass flag
    biomass = _get_biomass_reaction(model)
    features[dictionary_reactions[biomass.id]][BIOMASS_INDEX] = 1
    # TODO: what if biomass reversible?

    # Set exchange flag
    for r in model.exchanges:
        r_id = dictionary_reactions[r.id]
        features[r_id][EXCHANGE_INDEX] = 1
        if reaction_direction(r) == Direction.REVERSIBLE:
            r_id = dictionary_reactions[r.reverse_id]
            features[r_id][EXCHANGE_INDEX] = 1

    source, target, edge_weight = edges_from_model(
        model,
        reactions_list, 
        dictionary_reactions, 
        dictionary_metabolites
    )
    edge_index = torch.tensor([source, target], dtype=torch.long)
    edge_weight = torch.tensor(edge_weight, dtype=torch.float)

    data = Data(
        x = torch.tensor(features, dtype=torch.float), 
        y = torch.tensor(y, dtype=torch.uint8),
        edge_index = edge_index, 
        edge_attr = edge_weight,
        edge_weight = edge_weight,
        reactions_mask=torch.tensor(reactions_mask).type(torch.bool).squeeze(),
        index_aggr=torch.tensor(index_aggr).squeeze()
    )
    return data, dictionary_reactions, dictionary_metabolites, dictionary_aggr_rxn

    
def add_labels(data, bass_model, dictionary_reactions):
    bass_model.compute_essential_reactions()
    er = bass_model.essential_reactions()
    er_ids = [e.id for e in er]
    for r in er:
        r_id = dictionary_reactions[r.id]
        data.y[r_id] = 1.0
        if reaction_direction(r) == Direction.REVERSIBLE:
            r_id = dictionary_reactions[r.reverse_id]
            data.y[r_id] = 1.0

    return data, er_ids


def generate_connection_vector(data, nxGraph):
    H = nxGraph
    obj_id = (data.x[:, 0] == 1).nonzero(as_tuple=True)[0].item()
    con = []
    for n in range(data.num_nodes):
        con.append(1 if has_path(H, obj_id, n) else 0)

    con = torch.tensor(con, dtype=torch.float)
    con = con.view(-1, 1) if con.dim() == 1 else x
    return con


def compute_degree(data, in_degree=True):
    idx, x = data.edge_index[1 if in_degree else 0], data.x
    deg = degree(idx, data.num_nodes, dtype=torch.float)
    deg = deg.view(-1, 1) if deg.dim() == 1 else x
    return deg


def add_network_features(data):
    con = generate_connection_vector(data, to_networkx(data))
    conr = generate_connection_vector(data, to_networkx(data).reverse()) # reversed graph
    deg = compute_degree(data, in_degree=True) 
    dego = compute_degree(data, in_degree=False)
    deg = (deg == 1).long()
    dego = (dego == 1).long()

    aux = torch.cat([deg, dego], dim=-1)
    aux = torch.cat([aux, con], dim=-1)
    aux = torch.cat([aux, conr], dim=-1)
    data.x = torch.cat([data.x, aux], dim=-1)

    return data


def generate_model_data(MODEL, BIOMASS):
    bass_model = CobraMetabolicModel(MODEL)
    model = bass_model.model()
    model.objective = BIOMASS
    data, dictionary_reactions, dictionary_metabolites = generate_pyg_data(model)
    data, er_ids = add_labels(data, bass_model, dictionary_reactions)
    data = add_network_features(data)