import torch
import networkx as nx
from qiskit import QuantumCircuit
from PIL import Image

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure

import io

def figToPIL(fig:Figure)->Image:
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight')
    plt.close(fig)
    buf.seek(0)
    img = Image.open(buf)
    return img

def plot_graph_cut(G:torch.Tensor, basis:int, ax:Axes|None=None, 
                   title:str|None=None)->Figure|None:
    '''
    Plots the graph specified by the adjacency matrix G with the vertices 
    partioned according to the basis.
    For a graph of n vertices, the basis is interpetted as an n-bit number where
    each bit index corresponds to a vertex in the graph. Bits marked 1 
    will render pink vertices, and those marked zero will render light blue 
    vertices. The edges between the two partitions will be darker than the edges
    between vertices in the same partition.

    It is assumed that the graph is undirected, unweighted and there are no 
    edges starting and ending at the same vertex.

    Arguments
    ---------
    G: torch.Tensor
        The adjacency matrix of the graph to render.
    basis: int
        The basis to partition the graph vertices with
    ax: Axes, optional
        The Axes to plot the graph in. If None, the graph is plotted in a new 
        figure which is returned.
    title: str, optional
        The plot title
    
    Returns
    -------
    None if ax is not None
    Figure if ax is None
    '''
    n = len(G)  # Number of vertices
    graph = nx.Graph()
    
    # Add nodes with colors based on the basis bits
    node_colors = []
    for i in range(n):
        color = 'pink' if (basis >> i) & 1 else 'lightblue'
        graph.add_node(i, color=color)
        node_colors.append(color)
    
    # Add edges with appropriate transparency
    edges = []
    edge_colors = []
    edge_alphas = []
    for i in range(n):
        for j in range(i+1, n):  # Upper triangular part
            if G[i, j] == 1:
                color_i = node_colors[i]
                color_j = node_colors[j]
                alpha = 1.0 if color_i != color_j else 0.2
                graph.add_edge(i, j)
                edges.append((i, j))
                edge_colors.append("black")
                edge_alphas.append(alpha)
    
    # Draw the graph
    pos = nx.circular_layout(graph)
    if ax is None:
        fig,ax = plt.subplots(1,1,figsize=(4,4))
    else:
        fig = None
    
    # Draw nodes
    nx.draw_networkx_nodes(graph, pos, node_color=node_colors, linewidths=1.5,ax=ax)
    
    # Draw edges with different transparency levels
    for (edge, alpha) in zip(edges, edge_alphas):
        nx.draw_networkx_edges(graph, pos, edgelist=[edge], edge_color='black', alpha=alpha,ax=ax)
    
    # Draw labels
    nx.draw_networkx_labels(graph, pos, font_size=12, font_color='black',ax=ax)
    
    ax.axis('off')

    if title is not None:
        ax.set_title(title)
    return fig

def make_max_cut_plot(G:torch.Tensor, bases:torch.Tensor, 
                      probs:torch.Tensor|None=None,
                      title:str="Graph Cuts",fig_scale:int=3)->Image:
    ncols = len(bases)//2
    figsize=(fig_scale*ncols, fig_scale*2)
    fig,axs = plt.subplots(2,ncols,figsize=figsize)
    for i,basis in enumerate(bases):
        ax = axs.flatten()[i]
        plot_graph_cut(G, basis, ax, f'Basis State: {basis:4d}'+(
            f'\nProbability: {probs[i]:0.4f}' if probs is not None else ''
        ))
    fig.suptitle(title)
    fig.tight_layout()
    return figToPIL(fig)
    # buf = io.BytesIO()
    # fig.savefig(buf, format='png', bbox_inches='tight')
    # plt.close(fig)
    # buf.seek(0)
    # img = Image.open(buf)
    # return img

def quantum_circuit_plot(gate_list:torch.Tensor,
                         angles:torch.Tensor,
                         file:str|None=None,
                         plot_barriers:bool=True,
                         plot_id:bool=False
                         )->Image:
    num_layers,num_qubits = angles.shape

    qc = QuantumCircuit(num_qubits)
    for layer in range(num_layers):
        for qbit in range(num_qubits):
            G = gate_list[layer,qbit].item()
            if G == 0:
                if plot_id:
                    qc.id(qbit)
            elif G == 1:
                qc.rx(angles[layer,qbit].item(), qbit)
            elif G == 2:
                qc.ry(angles[layer,qbit].item(), qbit)
            elif G == 3:
                qc.rz(angles[layer,qbit].item(), qbit)
            else:
                cbit = G-4 if G-4 < qbit else G-3
                qc.cx(cbit, qbit)
        if layer != num_layers-1:
            qc.barrier()
    return figToPIL(qc.draw('mpl', filename=file,plot_barriers=plot_barriers,
                            justify='none'))

def quantum_super_circuit_plot(gate_list:torch.Tensor,
                               angles:torch.Tensor,
                               super_circuit_structure:torch.Tensor,
                               total_qubits:int,
                               file:str|None=None,
                               plot_barriers:bool=True,
                               plot_id:bool=False,
                               )->Image:
    num_subcircuits,num_layers,num_support = angles.shape
    qc = QuantumCircuit(total_qubits)
    for i in range(num_subcircuits):
        for layer in range(num_layers):
            for t_id in range(num_support):
                G = gate_list[layer,t_id].item()
                target_qbit = super_circuit_structure[i,t_id].item()
                if G == 0:
                    if plot_id:
                        qc.id(target_qbit)
                elif G == 1:
                    qc.rx(angles[i,layer,t_id].item(), target_qbit)
                elif G == 2:
                    qc.ry(angles[i,layer,t_id].item(), target_qbit)
                elif G == 3:
                    qc.rz(angles[i,layer,t_id].item(), target_qbit)
                else:
                    c_id = G-4 if G-4 < t_id else G-3
                    cbit = super_circuit_structure[i, c_id]
                    qc.cx(cbit, target_qbit)
        qc.barrier()
    return figToPIL(qc.draw('mpl', filename=file, plot_barriers=plot_barriers,
                            justify='none'))
