import plotly.graph_objects as go
from typing import Dict, Any, List
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import networkx as nx
from io import BytesIO
import base64

class BaseVisualizer:
    def create_visualizations(self, outputs: Dict[str, Any], metrics: Dict[str, float], state: Any) -> Dict[str, go.Figure]:
        raise NotImplementedError

class StarVisualizer(BaseVisualizer):
    def __init__(self, path_lengths: List[int]):
        self.val_frames = {'temp': []}
        self.val_acc_frames = {'temp': [], 'ntp': [], 'probs': []}
        self.val_graph_frames = []
        self.graphs_logged = False
        self.path_lengths = path_lengths
    
    def create_visualizations(self, outputs: Dict[str, Any], metrics: Dict[str, float], state: Any, dataset_samples: Dict[str, List[Dict[str, Any]]] = None) -> Dict[str, go.Figure]:
        figures = {
            'ntp_probs': self._create_ntp_probs(metrics, state),
        }
        return figures
    
    def _create_ntp_probs(self, metrics: Dict[str, float], state: Any) -> Dict[str, go.Figure]:
        figures = {}
        for i in range(len(self.path_lengths)):
            if f"eval_{i}/ntp_probs_0" not in metrics.keys():
                continue
            position_probs = [metrics[f"eval_{i}/ntp_probs_{j}"] for j in range(self.path_lengths[i])]
            figures[f"eval_{i}"] = {
                'probs': self._create_single_path_accuracies(position_probs, state, f'eval_{i}', 'probs')
            }
        return figures
    
    def _create_single_path_accuracies(self, position_accuracies: List[float], state: int, split: str, response_type: str) -> go.Figure:
        bar = go.Bar(
            x=list(range(len(position_accuracies))), 
            y=position_accuracies,
            text=[f'{acc:.2%}' for acc in position_accuracies],
            textposition='auto',
        )
        layout = go.Layout(
            title='Node Accuracy by Position',
            yaxis=dict(range=[0, 1], tickformat='.0%'),
            xaxis_title="Position in Path",
            yaxis_title="Accuracy",
            showlegend=False
        )
        frame_name = state
        frame = go.Frame(data=[bar], name=frame_name)
        if split == 'train':
            self.train_acc_frames[response_type].append(frame)
            all_frames = self.train_acc_frames[response_type]
        else:
            self.val_acc_frames[response_type].append(frame)
            all_frames = self.val_acc_frames[response_type]
        fig = go.Figure(data=[all_frames[-1].data[0]], frames=all_frames, layout=layout)
        return fig

def get_visualizer(dataset_type: str) -> BaseVisualizer:
    visualizers = {
        'star': StarVisualizer,
    }
    if dataset_type not in visualizers:
        raise ValueError(f"Unknown dataset type: {dataset_type}")
    return visualizers[dataset_type]()