import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from abc import ABC, abstractmethod
import random

# ER
class ErdosRenyiGraphGenerator:
    def __init__(self, garph=None, node_num=20, p=0.15):
        self.node_num = node_num
        self.p = p
        self.graph = garph

    def get(self):
        adj = nx.to_numpy_array(self.graph)
        return adj

    def getNewGraph(self):
        self.graph = nx.erdos_renyi_graph(self.node_num, self.p)
        adj = nx.to_numpy_array(self.graph)
        return adj

    def caculate(self):
        pass

    def draw(self, colors):
        color_map = []
        for i in colors:
            if i == 0:
                color_map.append('red')
            elif i == 1:
                color_map.append('green')
            elif i == 2:
                color_map.append('blue')
            elif i == 3:
                color_map.append('beige')
            elif i == 4:
                color_map.append('yellow')
            elif i == 5:
                color_map.append('pink')
            elif i == 6:
                color_map.append('purple')
            elif i == 7:
                color_map.append('brown')
            elif i == 8:
                color_map.append('gray')
        pos = nx.spring_layout(self.graph)
        nx.draw(self.graph, node_color=color_map, with_labels=True, pos=pos)
        plt.show()

class SetGraphGenerator():
    def __init__(self, matrices, ordered=False):
        self.graphs = matrices
        self.ordered = ordered
        if self.ordered:
            self.i = 0

    def get(self):
        if self.ordered:
            m = self.graphs[self.i]
            self.i = (self.i + 1)%len(self.graphs)
        else:
            m = random.sample(self.graphs, k=1)[0]

    def getNewGraph(self):
        if self.ordered:
            m = self.graphs[self.i]
            self.i = (self.i + 1) % len(self.graphs)
        else:
            m = random.sample(self.graphs, k=1)[0]
        return  m

    def draw(self, colors):
        color_map = []
        for i in colors:
            if i == 0:
                color_map.append('red')
            elif i == 1:
                color_map.append('blue')
            elif i == 2:
                color_map.append('green')
            elif i == 3:
                color_map.append('yellow')
            elif i == 4:
                color_map.append('pink')
            elif i == 5:
                color_map.append('purple')
            elif i == 6:
                color_map.append('red')
            elif i == 7:
                color_map.append('brown')
            elif i == 8:
                color_map.append('gray')
            elif i == 9:
                color_map.append('beige')
        pos = nx.spring_layout(self.graph)
        plt.show()

# BA
class BarabasiAlbertGraphGenerator:
    def __init__(self, garph=None, node_num=20, m=4):
        self.node_num = node_num
        self.m = m
        self.graph = garph

    def get(self):
        adj = nx.to_numpy_array(self.graph)
        return adj

    def getNewGraph(self):
        self.graph = nx.barabasi_albert_graph(self.node_num, self.m)
        adj = nx.to_numpy_array(self.graph)
        return adj

    def caculate(self):
        pass

    def draw(self, colors):
        color_map = []
        for i in colors:
            if i == 1:
                color_map.append('red')
            elif i == 2:
                color_map.append('blue')
            elif i == 3:
                color_map.append('green')
            elif i == 4:
                color_map.append('yellow')
            elif i == 5:
                color_map.append('pink')
            elif i == 6:
                color_map.append('purple')
            elif i == 7:
                color_map.append('brown')
            elif i == 8:
                color_map.append('gray')
            elif i == 9:
                color_map.append('beige')
        pos = nx.spring_layout(self.graph)
        plt.show()


class SetBarabasiAlbertGraphGenerator():
    def __init__(self, matrices, ordered=False):
        self.graphs = matrices
        self.ordered = ordered
        if self.ordered:
            self.i = 0

    def get(self):
        if self.ordered:
            m = self.graphs[self.i]
            self.i = (self.i + 1)%len(self.graphs)
        else:
            m = random.sample(self.graphs, k=1)[0]

    def getNewGraph(self):
        if self.ordered:
            m = self.graphs[self.i]
            self.i = (self.i + 1) % len(self.graphs)
        else:
            m = random.sample(self.graphs, k=1)[0]
        return  m

    def draw(self, colors):
        color_map = []
        for i in colors:
            if i == 0:
                color_map.append('red')
            elif i == 1:
                color_map.append('blue')
            elif i == 2:
                color_map.append('green')
            elif i == 3:
                color_map.append('yellow')
            elif i == 4:
                color_map.append('pink')
            elif i == 5:
                color_map.append('purple')
            elif i == 6:
                color_map.append('red')
            elif i == 7:
                color_map.append('brown')
            elif i == 8:
                color_map.append('gray')
            elif i == 9:
                color_map.append('beige')
        pos = nx.spring_layout(self.graph)
        plt.show()

# WS
class WattsStrogatzGraphGenerator:
    def __init__(self, garph=None, node_num=20, k=4, p=0.7):
        self.node_num = node_num
        self.k = k
        self.p = p
        self.graph = garph

    def get(self):
        adj = nx.to_numpy_array(self.graph)
        return adj

    def getNewGraph(self):
        self.graph = nx.watts_strogatz_graph(self.node_num, self.k, self.p)
        adj = nx.to_numpy_array(self.graph)
        return adj


    def caculate(self):
        pass

    def draw(self, colors):
        color_map = []
        for i in colors:
            if i == 1:
                color_map.append('red')
            elif i == 2:
                color_map.append('blue')
            elif i == 3:
                color_map.append('green')
            elif i == 4:
                color_map.append('yellow')
            elif i == 5:
                color_map.append('pink')
            elif i == 6:
                color_map.append('purple')
            elif i == 7:
                color_map.append('brown')
            elif i == 8:
                color_map.append('gray')
            elif i == 9:
                color_map.append('beige')
        pos = nx.spring_layout(self.graph)
        plt.show()


class SetWattsStrogatzGraphGenerator():
    def __init__(self, matrices, ordered=False):
        self.graphs = matrices
        self.ordered = ordered
        if self.ordered:
            self.i = 0

    def get(self):
        if self.ordered:
            m = self.graphs[self.i]
            self.i = (self.i + 1)%len(self.graphs)
        else:
            m = random.sample(self.graphs, k=1)[0]

    def getNewGraph(self):
        if self.ordered:
            m = self.graphs[self.i]
            self.i = (self.i + 1) % len(self.graphs)
        else:
            m = random.sample(self.graphs, k=1)[0]
        return  m

    def draw(self, colors):
        color_map = []
        for i in colors:
            if i == 0:
                color_map.append('red')
            elif i == 1:
                color_map.append('blue')
            elif i == 2:
                color_map.append('green')
            elif i == 3:
                color_map.append('yellow')
            elif i == 4:
                color_map.append('pink')
            elif i == 5:
                color_map.append('purple')
            elif i == 6:
                color_map.append('red')
            elif i == 7:
                color_map.append('brown')
            elif i == 8:
                color_map.append('gray')
            elif i == 9:
                color_map.append('beige')
        pos = nx.spring_layout(self.graph)
        plt.show()