import plotly.graph_objects as go
import matplotlib.pyplot as plt
from utils import *
import plotly.io as pio


class PaintClass:
    def __init__(self, g, graph, communities, paint_paras):
        self.g = g
        self.graph = graph
        self.communities = communities
        self.line_width = paint_paras['line_width']
        self.node_size_others = paint_paras['node_size_others']
        self.text_size = paint_paras['text_size']
        self.highlighted_node_YN = paint_paras['highlighted_node_YN']
        self.Allnode_id_YN = paint_paras['Allnode_id_YN']
        self.legend_YN = paint_paras['legend_YN']
        self.style = paint_paras['paint_style']
        self.figure_size = paint_paras['figure_size']
        if graph in ['polbooks','dolphins', 'lesmis','karate']: #lesmis inter 13 intra 2
            self.inter_ratio = 13
            self.intra_ratio = 2
        else:
            self.distance = paint_paras['distance']
            self.top_ratio =  paint_paras['top_ratio']
            self.dis_threshold =  paint_paras['dis_threshold']


    def painting_GD_new(self, g_new):
        position = graph_positions_old(g_new, self.style)
        merged_membership = self.communities.membership
        #
        self.g.vs['community'] = merged_membership

        # Check and set 'name' attribute if it's not present
        if 'name' not in self.g.vs.attributes():
            self.g.vs['name'] = [str(i) for i in range(self.g.vcount())]  # Set names as indices

        # Generate colors dynamically based on the number of communities
        num_communities = len(set(merged_membership))

        style = 'tab20'
        # Select a color map
        cmap = plt.get_cmap(style)
        if num_communities == 2:
            # Choose the first and the second color of the three-color case
            fixed_indices = [0, 1]
        elif num_communities == 3:
            # Choose three colors
            fixed_indices = [0, 1, 2]
        else:
            # General case: evenly spread indices
            fixed_indices = [i for i in range(num_communities)]
        colors = [cmap(i / max(3, num_communities)) for i in fixed_indices]
        colors = [f'rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {a})' for r, g, b, a in colors]
        self.g.vs['color'] = [colors[community] for community in self.g.vs['community']]

        if self.Allnode_id_YN:
            self.g.vs['shape'] = ['circle']
            self.g.vs['size'] = [0 if g_new.degree(i) == 0 else self.node_size_others for i in range(self.g.vcount())]
            self.g.vs['text'] = [str(self.g.vs[i].index) if g_new.degree(i) != 0 else '' for i in
                                 range(self.g.vcount())]
        else:
            self.g.vs['shape'] = ['circle']
            self.g.vs['size'] = [self.node_size_others] * self.g.vcount()
            self.g.vs['text'] = [''] * self.g.vcount()



        # Create edge trace
        edge_trace = go.Scatter(
            x=position[1][0], y=position[1][1],
            line=dict(width=self.line_width, color='#888'),
            hoverinfo='none',
            mode='lines',
            showlegend=False)

        # Create node trace
        node_trace = go.Scatter(
            x=position[0][0], y=position[0][1],
            mode='markers+text',
            hoverinfo='text',
            marker=dict(
                showscale=False,
                color=self.g.vs['color'],
                size=self.g.vs['size'],
                symbol=self.g.vs['shape'],
                line=dict(  # Conditionally set the line based on highlight status
                    color=['#888'],
                )
            ),
            text=self.g.vs['text'],
            textfont=dict(
                size=[self.text_size for i in range(self.g.vcount())],
                color=['black' for i in range(self.g.vcount())],
                family='Arial Black'  # Using Arial Black for bold text appearance
            ),
            textposition="middle center",
            showlegend=False
        )

        # Create the figure and add traces
        fig = go.Figure(data=[edge_trace, node_trace], layout=go.Layout(
            showlegend=True,
            hovermode='closest',
            margin=dict(b=5, l=5, r=5, t=5),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            plot_bgcolor='white',
            paper_bgcolor='white',
            legend=dict(x=0.05, y=0.05)
        ))


        # fig.show()
        base_filename = f"{self.graph}_graph_plot"
        filename = generate_incremental_filename(f"Figure/GD/{self.graph}/", base_filename)

        # Save the figure with the generated unique filename
        pio.write_image(fig, filename, format='png', width=self.figure_size[0], height=self.figure_size[1])










