Analyze and visualize the VedEf problem graph structure.

This script introspects the problem graph to show: - Nodes organized by hierarchy level - Node types (parameters, models, constraints) - Connections between levels - Graph statistics

import importlib.util
from pathlib import Path

from paroto.viz import extract_hierarchy

# Import setup_vedef_problem from the setup file
spec = importlib.util.spec_from_file_location(
    "problem_module", Path(__file__).parent / "1_setup_vedef_problem.py"
)
problem_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(problem_module)
setup_vedef_problem = problem_module.setup_vedef_problem


def print_section(title, char="="):
    """Print a section header."""
    print()
    print(char * 80)
    print(title)
    print(char * 80)


def print_nodes_by_level(graph):
    """Print nodes organized by hierarchy level."""
    print_section("NODES BY HIERARCHY LEVEL")

    nodes_by_level = graph.get_nodes_by_level()

    for level in sorted(nodes_by_level.keys()):
        nodes = nodes_by_level[level]
        print(f"\n{'-' * 80}")
        print(f"LEVEL {level} ({len(nodes)} nodes)")
        print(f"{'-' * 80}")

        # Organize by type
        params = []
        models = []
        constraints = []
        others = []

        for node_id in nodes:
            node_info = graph.nodes[node_id]
            node_type = node_info["type"]
            label = node_info["label"]

            if "param" in node_type:
                params.append((node_id, label, node_type))
            elif node_type == "constraint":
                constraints.append((node_id, label, node_type))
            elif "output" in node_type:
                models.append((node_id, label, node_type))
            else:
                others.append((node_id, label, node_type))

        # Print by category
        if params:
            print(f"\n  [PARAMETERS] ({len(params)}):")
            for _node_id, label, node_type in sorted(params, key=lambda x: x[1]):
                print(f"     - {label:30s} [{node_type}]")

        if models:
            print(f"\n  [MODELS/OUTPUTS] ({len(models)}):")
            for node_id, label, _node_type in sorted(models, key=lambda x: x[1]):
                subsystem = graph.nodes[node_id]["subsystem"]
                print(f"     - {label:30s} (from: {subsystem})")

        if constraints:
            print(f"\n  [CONSTRAINTS] ({len(constraints)}):")
            for _node_id, label, _node_type in sorted(constraints, key=lambda x: x[1]):
                print(f"     - {label:30s}")

        if others:
            print(f"\n  [OTHER] ({len(others)}):")
            for _node_id, label, node_type in sorted(others, key=lambda x: x[1]):
                print(f"     - {label:30s} [{node_type}]")


def print_connections_by_level(graph):
    """Print connections organized by source/target levels."""
    print_section("CONNECTIONS BY LEVEL")

    connections_by_levels = {}

    for source, target in graph.edges:
        source_level = graph.nodes[source]["level"]
        target_level = graph.nodes[target]["level"]

        if source_level < 0 or target_level < 0:
            continue  # Skip unconnected

        key = (source_level, target_level)
        if key not in connections_by_levels:
            connections_by_levels[key] = []

        source_label = graph.nodes[source]["label"]
        target_label = graph.nodes[target]["label"]
        connections_by_levels[key].append((source_label, target_label))

    for (src_lvl, tgt_lvl), connections in sorted(connections_by_levels.items()):
        print(f"\n{'-' * 80}")
        print(f"Level {src_lvl} -> Level {tgt_lvl} ({len(connections)} connections)")
        print(f"{'-' * 80}")

        for source_label, target_label in sorted(connections)[:15]:  # Show first 15
            print(f"  {source_label:25s} -> {target_label}")

        if len(connections) > 15:
            print(f"  ... and {len(connections) - 15} more connections")


def print_graph_statistics(graph):
    """Print overall graph statistics."""
    print_section("GRAPH STATISTICS")

    total_nodes = len(graph.nodes)
    total_edges = len(graph.edges)

    # Count by type
    nodes_by_type = graph.get_nodes_by_type()

    # Count by level
    nodes_by_level = graph.get_nodes_by_level()
    num_levels = len(nodes_by_level)

    # Count connected vs unconnected
    connected = sum(1 for n in graph.nodes.values() if n["level"] >= 0)
    unconnected = total_nodes - connected

    print(f"\n  Total Nodes:       {total_nodes:>6d}")
    print(f"  Total Edges:       {total_edges:>6d}")
    print(f"  Hierarchy Levels:  {num_levels:>6d}")
    print(f"  Connected Nodes:   {connected:>6d}")
    print(f"  Unconnected Nodes: {unconnected:>6d}")

    print("\n  Nodes by Type:")
    for node_type, nodes in sorted(nodes_by_type.items()):
        print(f"    {node_type:20s}: {len(nodes):>4d}")

    print("\n  Nodes per Level:")
    for level in sorted(nodes_by_level.keys()):
        nodes = nodes_by_level[level]
        print(f"    Level {level}: {len(nodes):>4d} nodes")


def print_primary_parameter_propagation(graph):
    """Show how each primary parameter propagates through levels."""
    print_section("PRIMARY PARAMETER PROPAGATION", char="-")

    primary_params = graph.primary_params

    for param in sorted(primary_params):
        if param not in graph.nodes:
            continue

        param_label = graph.nodes[param]["label"]
        print(f"\n  [{param_label}]:")

        # Find all nodes this parameter connects to (directly and indirectly)
        visited = set()
        queue = [param]
        levels_reached = {}

        while queue:
            current = queue.pop(0)
            if current in visited:
                continue
            visited.add(current)

            current_level = graph.nodes[current]["level"]
            current_label = graph.nodes[current]["label"]

            if current_level not in levels_reached:
                levels_reached[current_level] = []
            if current != param:  # Don't include the parameter itself
                levels_reached[current_level].append(current_label)

            # Add neighbors
            for source, target in graph.edges:
                if source == current and target not in visited:
                    queue.append(target)

        # Print propagation
        for level in sorted(levels_reached.keys()):
            if levels_reached[level]:
                nodes = levels_reached[level][:5]  # Show first 5
                print(f"     Level {level}: {', '.join(nodes)}", end="")
                if len(levels_reached[level]) > 5:
                    print(f" ... (+{len(levels_reached[level]) - 5} more)")
                else:
                    print()


def main():
    """Run VedEf problem graph analysis."""
    print("=" * 80)
    print("VedEf Problem Graph Analysis")
    print("=" * 80)

    # Setup problem
    print("\n>> Setting up VedEf problem...")
    prob = setup_vedef_problem()
    print(">> Problem configured")

    # Extract hierarchy
    print(">> Extracting parameter hierarchy...")
    primary_params = [
        "G_UMAX_OUT",  # Generator voltage (High voltage)
        "G_e",  # Interelectrode gap distance
        "TP_D_OUT",  # Torch outlet diameter
        "G_F",  # PRF - Pulse Frequency
        "G_Ep",  # Energy per pulse
    ]
    graph = extract_hierarchy(prob, primary_params)
    print(">> Hierarchy extracted")

    # Print analyses
    print_graph_statistics(graph)
    print_nodes_by_level(graph)
    print_connections_by_level(graph)
    print_primary_parameter_propagation(graph)

    print_section("ANALYSIS COMPLETE", char="=")
    print("\n>> Graph structure successfully analyzed!")
    print()


if __name__ == "__main__":
    main()