import time
import numpy as np
import torch
from envs.util import ErdosRenyiGraphGenerator
import networkx as nx
import matplotlib.pyplot as plt
import copy


class ColorSystemGenerator:
    @staticmethod
    def get(max_steps, node_num, graph_generator, test=False):
        return ColorSystem(max_steps,
                           node_num,
                           graph_generator,
                           test)


class ColorSystem:
    class action_space():
        def __init__(self, num_actions):
            self.n_node = num_actions
            self.node_actions = torch.arange(self.n_node)

        def sample(self):
            node = np.random.choice(self.node_actions, 1)
            return node

    class observation_space():
        def __init__(self, n, n_observables):
            self.shape = [n, n_observables]

    def __init__(self,
                 max_steps,
                 node_num,
                 graph_generator=None,
                 test=False):
        self.test = test
        self.max_steps = max_steps
        self.current_step = 0

        self.node_num = node_num
        self.init_color_num = 1
        self.color_num = self.init_color_num

        if graph_generator is not None:
            self.graph = graph_generator
        else:
            self.graph = ErdosRenyiGraphGenerator(node_num=20, p=0.15)

        self.action_space = self.action_space(self.node_num)
        self.observation_space = self.observation_space(6, self.node_num)  
        
    def reset(self):
        self.current_step = 0
        self.color_num = self.init_color_num
        self.accessible_colors = [[i for i in range(0, self.color_num)] for _ in range(self.node_num)]
        if self.test:
            self.matrix = self.graph.get()
        else:
            self.matrix = self.graph.getNewGraph()
        self.matrix = torch.as_tensor(self.matrix)
        self.update_adj = torch.as_tensor(np.copy(self.matrix))
        self.state = self._reset_state()
        return (self.state, self.matrix)

    def show_Color(self):
        print("state[0]", self.state[0])
        for i in range(self.node_num):
            print("accessible_colors[{}]:{}".format(i, self.accessible_colors[i]))

    def _reset_state(self):
        state = torch.zeros((self.observation_space.shape[0], self.node_num), dtype=torch.float32)

        # state[0, :] 顶点颜色
        state[0] = -1  # 初始颜色为-1表示未着色

        # state[1, :] 顶点原始度数（固定）
        degrees = torch.sum(self.matrix, dim=1)
        state[1] = degrees

        # state[2, :] 顶点可用颜色数
        state[2] = self.color_num

        # state[3, :] 饱和度（初始为0）
        state[3] = 0

        # state[4, :] 度中心性（固定）
        max_degree = torch.max(degrees) if torch.max(degrees) != 0 else 1
        state[4] = degrees / max_degree

        # state[5, :] 相邻颜色数（初始为0）
        state[5] = 0
        return state

    def step(self, node):
        done = False
        penalty = False
        self.current_step += 1

        if self.current_step > self.max_steps:
            print("Return done")
            raise NotImplementedError

        new_state = torch.as_tensor(np.copy(self.state))

        if self.state[0, node] != -1:
            reward = -self.color_num
            if self.current_step == self.max_steps:
                done = True
            return np.vstack((self.state, self.matrix)), reward / self.node_num, done

        if len(self.accessible_colors[node]) > 0:
            new_state[0, node] = self.accessible_colors[node][0]
            for i in np.where(self.matrix[node] != 0)[0]:
                if self.accessible_colors[node][0] in self.accessible_colors[i]:
                    self.accessible_colors[i].remove(self.accessible_colors[node][0])
                    new_state[2, i] = len(self.accessible_colors[i])
        else:
            new_state[0, node] = self.color_num
            self.color_num += 1
            penalty = True
            for i in np.where(self.matrix[node] == 0)[0]:
                self.accessible_colors[i].append(self.color_num - 1)
                new_state[2, i] += 1

            # ==== 动态更新所有节点的属性 ====
            for n in range(self.node_num):
                # 使用torch.where获取邻居索引
                neighbors = torch.where(self.matrix[n] != 0)[0]

                # 计算饱和度（已着色邻居数量）
                colored_neighbors = neighbors[self.state[0, neighbors] != -1]
                new_state[3, n] = len(colored_neighbors)

                # 计算相邻颜色数
                neighbor_colors = self.state[0, neighbors]
                valid_colors = neighbor_colors[neighbor_colors != -1]
                new_state[5, n] = len(torch.unique(valid_colors)) if len(valid_colors) > 0 else 0

        # 更新度中心性（度数/最大度数）
        current_degrees = self.update_adj.sum(axis=1)
        max_degree = current_degrees.max() if current_degrees.max() != 0 else 1
        new_state[4] = current_degrees / max_degree

        # 计算reward
        if penalty:
            reward = (self.init_color_num - self.color_num)
        else:
            reward = 1

        # 更新邻接矩阵（移除已着色节点的边）
        self.update_adj[node] = 0
        self.update_adj[:, node] = 0
        new_state[1] = torch.as_tensor(self.update_adj.sum(axis=1))

        # 检查终止条件
        if self.current_step == self.max_steps or (new_state[0] != -1).all():
            done = True

        self.state = new_state

        return np.vstack((self.state, self.matrix)), reward / self.node_num, done

    def get_allowed_action_states(self):
        return -1