# file: prism/utils/model_analysis.py
from pathlib import Path

import torch
from torchinfo import summary
from torchview import draw_graph


def save_model_summary(model, input_data, output_dir, filename, depth=5):
    try:
        try:
            device = next(model.parameters()).device
        except StopIteration:
            device = input_data.device if isinstance(input_data, torch.Tensor) else 'cpu'

        model_summary = summary(
            model=model,
            input_data=input_data,
            device=device,
            verbose=0,
            depth=depth,
            row_settings=["hide_recursive_layers"]
        )
        full_summary = f"{model_summary}\n\n{model.__repr__()}"

        output_path = Path(output_dir) / f"{filename}.txt"
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(output_path, 'w') as f:
            f.write(full_summary)

    except Exception as e:
        print(f"  [Warning] Could not generate summary for {filename}: {e}")


def save_model_graph(model, input_data, output_dir, filename, graph_depth=5):
    try:
        try:
            device = next(model.parameters()).device
        except StopIteration:
            device = input_data.device if isinstance(input_data, torch.Tensor) else 'cpu'

        model_graph = draw_graph(
            model=model,
            input_data=input_data,
            depth=graph_depth,
            device=device,
            expand_nested=True,
            save_graph=False
        )

        output_path = Path(output_dir) / filename
        output_path.parent.mkdir(parents=True, exist_ok=True)
        model_graph.visual_graph.render(output_path.with_suffix(''), format='svg', cleanup=True)

    except Exception as e:
        print(f"  [Warning] Could not generate graph for {filename}: {e}")
