Analysis script: Pairwise grid visualization of parameter space.

This script loads pre-computed parameter sweep results and creates a pairwise grid (corner plot) showing: - All 2D projections of parameter combinations - 1D marginal distributions (histograms) - Three-way classification: Feasible, Converged-Infeasible, Non-converged

The pairwise grid eliminates projection ambiguity by showing each dimension pair separately with proper marginal distributions.

Usage:

python analyze_vdef_pairwise.py

Requirements:

  • parameter_sweep_cases.sql (generated by example_vdef_mdao.py)

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec

from paroto.utils.io_utils import load_results_from_sqlite


def load_results(filename="parameter_sweep_cases.sql"):
    """Load parameter sweep results from SQLite database.

    Parameters
    ----------
    filename : str
        Path to SQLite database file

    Returns
    -------
    dict
        Results dictionary with arrays
    """
    print(f"Loading results from SQLite: {filename}")
    results = load_results_from_sqlite(filename)
    print(f"  Successfully loaded {len(results['frequency'])} evaluations")
    return results


def create_pairwise_grid(results):
    """Create pairwise grid visualization with marginal distributions.

    This visualization shows all 2D projections of the 3D parameter space
    along with 1D marginal distributions, eliminating projection ambiguity.

    Parameters
    ----------
    results : dict
        Results dictionary from parameter sweep
    """
    print("\n" + "=" * 70)
    print("  Creating Pairwise Grid Visualization")
    print("=" * 70)

    # Classification
    tolerance = 0.10
    converged_mask = np.array(results["converged"], dtype=bool)
    power_feasible = np.array(results["power_error"]) < tolerance
    breakdown_feasible = np.array(results["breakdown_margin"]) > 0
    constraints_met = power_feasible & breakdown_feasible

    feasible_mask = converged_mask & constraints_met
    converged_infeasible_mask = converged_mask & ~constraints_met
    non_converged_mask = ~converged_mask

    n_total = len(results["frequency"])
    n_feasible = np.sum(feasible_mask)
    n_converged_infeasible = np.sum(converged_infeasible_mask)
    n_non_converged = np.sum(non_converged_mask)

    print("\nData summary:")
    print(f"  Total points: {n_total}")
    print(f"  Feasible: {n_feasible} ({n_feasible / n_total * 100:.1f}%)")
    print(f"  Converged-Infeasible: {n_converged_infeasible}")
    print(f"  Non-converged: {n_non_converged}")

    # Prepare data arrays (convert units)
    freq_all = np.array(results["frequency"]) / 1000  # kHz
    volt_all = np.array(results["hv_voltage"]) / 1000  # kV
    gap_all = np.array(results["gap"]) * 1000  # mm
    energy_all = np.array(results["energy_per_pulse"])

    # Extract by category
    freq_feas = freq_all[feasible_mask]
    volt_feas = volt_all[feasible_mask]
    gap_feas = gap_all[feasible_mask]
    energy_feas = energy_all[feasible_mask]

    freq_conv_inf = freq_all[converged_infeasible_mask]
    volt_conv_inf = volt_all[converged_infeasible_mask]
    gap_conv_inf = gap_all[converged_infeasible_mask]
    energy_conv_inf = energy_all[converged_infeasible_mask]

    freq_non_conv = freq_all[non_converged_mask]
    volt_non_conv = volt_all[non_converged_mask]
    gap_non_conv = gap_all[non_converged_mask]

    # Define variables for grid
    variables = [
        ("Frequency", "kHz", freq_all, freq_feas, freq_conv_inf, freq_non_conv),
        ("HV Voltage", "kV", volt_all, volt_feas, volt_conv_inf, volt_non_conv),
        ("Gap", "mm", gap_all, gap_feas, gap_conv_inf, gap_non_conv),
        ("Energy/pulse", "J", energy_all, energy_feas, energy_conv_inf, None),
    ]

    n_vars = len(variables)

    # Create figure with grid
    fig = plt.figure(figsize=(16, 14))
    fig.suptitle(
        f"Pairwise Grid: Parameter Space Exploration\n"
        f"Feasible: {n_feasible}/{n_total} ({n_feasible / n_total * 100:.1f}%)",
        fontsize=16,
        fontweight="bold",
        y=0.995,
    )

    # Create grid specification
    gs = GridSpec(n_vars, n_vars, figure=fig, hspace=0.05, wspace=0.05)

    axes = []
    for i in range(n_vars):
        row_axes = []
        for j in range(n_vars):
            ax = fig.add_subplot(gs[i, j])
            row_axes.append(ax)
        axes.append(row_axes)

    # Plot each cell
    for i in range(n_vars):
        for j in range(n_vars):
            ax = axes[i][j]

            if i == j:
                # Diagonal: 1D marginal distribution (histogram)
                name, unit, all_data, feas_data, conv_inf_data, non_conv_data = variables[i]

                # Plot histograms for each category
                bins = 20
                if non_conv_data is not None and len(non_conv_data) > 0:
                    ax.hist(
                        non_conv_data,
                        bins=bins,
                        color="gray",
                        alpha=0.4,
                        edgecolor="dimgray",
                        linewidth=1,
                        label="Non-conv.",
                    )
                if len(conv_inf_data) > 0:
                    ax.hist(
                        conv_inf_data,
                        bins=bins,
                        color="orange",
                        alpha=0.5,
                        edgecolor="red",
                        linewidth=1,
                        label="Conv-Inf.",
                    )
                if len(feas_data) > 0:
                    ax.hist(
                        feas_data,
                        bins=bins,
                        color="lightblue",
                        alpha=0.7,
                        edgecolor="green",
                        linewidth=1.5,
                        label="Feasible",
                    )

                # Use nanmin/nanmax to handle NaN values from non-converged cases
                ax.set_xlim(np.nanmin(all_data), np.nanmax(all_data))
                ax.set_yticks([])

                if i == 0:
                    ax.legend(loc="upper right", fontsize=8, framealpha=0.9)

                # Label on diagonal
                ax.text(
                    0.5,
                    0.5,
                    f"{name}\n({unit})",
                    transform=ax.transAxes,
                    ha="center",
                    va="center",
                    fontsize=11,
                    fontweight="bold",
                    bbox=dict(boxstyle="round,pad=0.5", facecolor="white", alpha=0.8),
                )

            elif i > j:
                # Lower triangle: 2D scatter plots
                x_name, x_unit, x_all, x_feas, x_conv_inf, x_non_conv = variables[j]
                y_name, y_unit, y_all, y_feas, y_conv_inf, y_non_conv = variables[i]

                # Determine color mapping (use energy if available)
                if j < 3 and i < 3:
                    # Parameters: color by energy
                    color_data_feas = energy_feas
                    color_data_conv_inf = energy_conv_inf
                    cmap = "viridis"
                else:
                    # One axis is energy: use simpler coloring
                    color_data_feas = energy_feas if i == 3 else energy_feas
                    color_data_conv_inf = energy_conv_inf if i == 3 else energy_conv_inf
                    cmap = "viridis"

                # Plot non-converged
                if x_non_conv is not None and y_non_conv is not None and len(x_non_conv) > 0:
                    ax.scatter(
                        x_non_conv,
                        y_non_conv,
                        c="dimgray",
                        marker="x",
                        s=30,
                        alpha=0.5,
                        linewidth=1.5,
                    )

                # Plot converged-infeasible
                if len(x_conv_inf) > 0:
                    ax.scatter(
                        x_conv_inf,
                        y_conv_inf,
                        c=color_data_conv_inf,
                        cmap=cmap,
                        marker="^",
                        s=35,
                        alpha=0.6,
                        edgecolors="red",
                        linewidth=0.8,
                    )

                # Plot feasible
                if len(x_feas) > 0:
                    ax.scatter(
                        x_feas,
                        y_feas,
                        c=color_data_feas,
                        cmap=cmap,
                        marker="o",
                        s=40,
                        alpha=0.8,
                        edgecolors="green",
                        linewidth=0.8,
                    )

                # Use nanmin/nanmax to handle NaN values from non-converged cases
                ax.set_xlim(np.nanmin(x_all), np.nanmax(x_all))
                ax.set_ylim(np.nanmin(y_all), np.nanmax(y_all))
                ax.grid(True, alpha=0.2, linewidth=0.5)

            else:
                # Upper triangle: hide
                ax.axis("off")

            # Axis labels (only on edges)
            if i == n_vars - 1:
                # Bottom row
                ax.set_xlabel(f"{variables[j][0]} ({variables[j][1]})", fontsize=10)
            else:
                ax.set_xticklabels([])

            if j == 0 and i > 0:
                # Left column
                ax.set_ylabel(f"{variables[i][0]} ({variables[i][1]})", fontsize=10)
            else:
                if i != j:  # Don't remove for diagonal
                    ax.set_yticklabels([])

            # Tick parameters
            ax.tick_params(labelsize=8)

    plt.tight_layout(rect=[0, 0, 1, 0.99])

    # Save figure
    output_path = "pairwise_grid_visualization.png"
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    print(f"\nPairwise grid saved to: {output_path}")

    plt.show()

    print("\n" + "=" * 70)
    print("  Pairwise Grid Analysis Complete!")
    print("=" * 70)
    print("\nKey features:")
    print("  - Diagonal: 1D marginal distributions for each parameter")
    print("  - Lower triangle: All 2D parameter combinations")
    print("  - Color-coded by feasibility (Green=Feasible, Red=Infeasible, Gray=Non-conv)")
    print("  - Point color intensity shows Energy per pulse")
    print("\nBenefits:")
    print("  - No projection ambiguity - each 2D slice is shown separately")
    print("  - Marginal distributions reveal parameter preferences")
    print("  - Easy to spot correlations and constraint boundaries")


def main():
    """Execute main analysis workflow."""
    # Load pre-computed results from SQLite database
    results = load_results()

    # Create pairwise grid visualization
    create_pairwise_grid(results)


if __name__ == "__main__":
    try:
        main()
    except FileNotFoundError as e:
        print(f"\nError: {e}")
    except Exception as e:
        print(f"\nError: {e}")
        import traceback

        traceback.print_exc()