from abc import ABC, abstractmethod
from collections import defaultdict

import jax
import numpy as np
from scipy.sparse import csc_matrix, lil_matrix, triu

from fast_marl import FastMARLEnv
from graphexs import separable_power_law
from utils import get_diophantine_solutions


class FastGraphexEnv(FastMARLEnv, ABC):
    """
    Models a mean-field MARL problem in discrete time.
    """

    def __init__(self, graphex, observation_space, action_space, time_steps, mu_0, dataset='graphex',
                 time_var: float = 10, graphex_cutoff: float = 500, **kwargs):
        self.graphex = graphex
        self.dataset = dataset
        self.graphex_cutoff = graphex_cutoff
        self.time_var = time_var

        super().__init__(observation_space, action_space, time_steps, mu_0, **kwargs)

        self.alphas = None
        self.adj_matrix = None
        self.degrees = None

        datasets = ["prosper-loans", "petster-friendships-dog", "soc-pokec-relationships", "livemocha",
                    "flickr-growth", "loc-brightkite_edges", "facebook-wosn-wall", "hyves", ]

        if self.dataset != 'graphex':
            if self.dataset in datasets:
                adjacency_list = defaultdict(set)
                with open(fr'real_network_data/{dataset}/out.{dataset}', 'r') as file:
                    lines = file.readlines()
            else:
                raise NotImplementedError

            # Process each line and populate the adjacency list
            for line in lines:

                # Skip comment/header lines
                if line.startswith('%'):
                    continue

                parts = line.strip().replace('\t', ' ').split(' ')  # \t for flickr, mathoverflow, soc-pokec datasets
                vertex1 = int(parts[0])
                vertex2 = int(parts[1])

                # Check if the edge has not been added before
                if vertex2 not in adjacency_list[vertex1]:
                    adjacency_list[vertex1].add(vertex2)
                    adjacency_list[vertex2].add(vertex1)  # Add the reverse edge for undirected graph

            # Determine the maximum vertex number
            max_vertex = max(adjacency_list.keys())

            # Convert the adjacency list to a CSC matrix
            rows = []
            cols = []
            for vertex, neighbors in adjacency_list.items():
                for neighbor in neighbors:
                    rows.append(vertex)
                    cols.append(neighbor)

            adj_matrix = csc_matrix((np.ones(len(rows)), (rows, cols)), shape=(max_vertex + 1, max_vertex + 1))

            print(fr"Loaded {max_vertex + 1} vertices", flush=True)

            def remove_zero_rows(x):
                # X is a scipy sparse matrix
                nonzero_row_ind, _ = x.nonzero()
                unique_nonzero_ind = np.unique(nonzero_row_ind)
                return x[unique_nonzero_ind]

            def remove_zero_columns(x):
                # X is a scipy sparse matrix
                _, nonzero_col_ind = x.nonzero()
                unique_nonzero_ind = np.unique(nonzero_col_ind)
                return x[:, unique_nonzero_ind]

            adj_matrix_clean = remove_zero_columns(remove_zero_rows(adj_matrix))

            print(fr"Without zero degrees: {adj_matrix_clean.shape[0]} vertices", flush=True)

            self.num_agents = adj_matrix_clean.shape[0]
            self.alphas = None
            self.adj_matrix = adj_matrix_clean
            self.degrees = adj_matrix_clean.sum(axis=1).A.flatten()

    def reset(self, seed=None):
        if seed is not None:
            np.random.seed(seed % (2**32))

        if self.dataset == 'graphex':
            # Simulation window parameters
            xMin = 0
            xMax = self.graphex_cutoff
            yMin = 0
            yMax = self.time_var
            xDelta = xMax - xMin
            yDelta = yMax - yMin  # rectangle dimensions
            areaTotal = xDelta * yDelta

            # Point process parameters
            lambda0 = 1  # intensity (ie mean density) of the Poisson process

            # Simulate a Poisson point process
            numbPoints = np.random.poisson(lambda0 * areaTotal)  # Poisson number of points
            xx = xDelta * np.random.uniform(0, 1, numbPoints) + xMin  # x coordinates of Poisson points
            # yy = yDelta * np.random.uniform(0, 1, numbPoints) + yMin  # y coordinates of Poisson points

            def random_draw(prob, size):
                rng_key = jax.random.PRNGKey(0 if seed is None else seed)
                random_vals = jax.random.uniform(rng_key, shape=size)
                return np.where(random_vals < prob, 1, 0)

            # Create an empty adjacency matrix
            num_points = xx.shape[0]
            adj_matrix = lil_matrix((num_points, num_points))

            # Process points in batches
            batch_size = 1000
            graphex_func = separable_power_law if self.graphex == "separable" else None
            # here we use arbitrary sigma due to separability, otherwise insert estimated sigmas

            for start in range(0, num_points, batch_size):
                end = min(start + batch_size, num_points)
                xx_batch = xx[start:end]
                x_batch = np.repeat(xx_batch[:, np.newaxis], num_points, axis=1)
                y_batch = np.repeat(xx[:, np.newaxis], end - start, axis=1).T
                edge_prob_batch = jax.vmap(graphex_func)(x_batch, y_batch)
                adj_matrix_batch = random_draw(edge_prob_batch, size=(end - start, num_points))
                adj_matrix[start:end, :] = adj_matrix_batch

            adj_matrix = adj_matrix.tocsr()
            # Make the adjacency matrix symmetric
            adj_matrix = triu(adj_matrix, k=1) + triu(adj_matrix, k=1).T

            degrees = adj_matrix.sum(axis=1).A.flatten()
            linked_nodes = degrees != 0
            adj_matrix = adj_matrix[linked_nodes, :][:, linked_nodes]

            adj_matrix = csc_matrix(adj_matrix)

            self.num_agents = np.count_nonzero(linked_nodes)
            self.alphas = xx[linked_nodes]
            self.adj_matrix = adj_matrix
            self.degrees = degrees[linked_nodes]

            print(fr"Generated {self.num_agents} vertices", flush=True)

        return super().reset()

    def get_P_k(self, t, k, prob_Gs):
        P = np.zeros((self.action_space.n, self.observation_space.n, self.observation_space.n))

        for n, g in enumerate(get_diophantine_solutions(k, self.observation_space.n)):
            prob_x_G = prob_Gs[n]
            P += prob_x_G * self.get_P_k_G(t, k, np.array(np.array(g)) / k)

        return P  # Return joint transition matrices over actions U on X for degree k and neighborhood probs

    def get_P_k_conditional(self, t, k, prob_Gs):
        P = np.zeros((self.action_space.n, self.observation_space.n, self.observation_space.n))

        for x in range(self.observation_space.n):
            for n, g in enumerate(get_diophantine_solutions(k, self.observation_space.n)):
                prob_x_G = prob_Gs[x][n]
                P[:, x, :] += prob_x_G * self.get_P_k_G(t, k, np.array(g) / k)[:, x, :]

        return P  # Return joint transition matrices over actions U on X for degree k and state-conditional neighborhood probs

    def get_R_k(self, t, k, prob_Gs):
        R = np.zeros((self.observation_space.n, self.action_space.n))

        for n, g in enumerate(get_diophantine_solutions(k, self.observation_space.n)):
            prob_x_G = prob_Gs[n]
            R += prob_x_G * self.get_R_k_G(t, k, np.array(g) / k)

        return R  # Return array X x U of expected rewards for a given degree and neighborhood Gs

    @abstractmethod
    def get_P_k_G(self, t, k, G):
        pass  # Return joint transition matrices over actions U on X for a given degree and neighborhood

    @abstractmethod
    def get_R_k_G(self, t, k, G):
        pass  # Return array X x U of expected rewards for a given degree and neighborhood

    @abstractmethod
    def get_P_high(self, t, mu):
        pass  # Return joint transition matrices over actions U on X for high degrees and given mf

    @abstractmethod
    def get_R_high(self, t, mu):
        pass  # Return array X x U of expected rewards for high degrees and given mf
