Example: Direct MDAO problem setup for plasma torch parameter space exploration.

Warning

Equations are AI generated, not ready for production.

This example demonstrates how to: 1. Set up the problem directly in Python (no STONE config) 2. Use RepetitivePulseBreakdownModel with:

  • Pulse duration effects (ionization avalanche kinetics: e + M → e + e + M+)

  • Memory effects from repeated pulsing

  • Field enhancement factors

  1. Explore feasible parameter space with power and breakdown constraints

  2. Visualize the solution space

  3. Run parameter sweeps in parallel using OpenMDAO’s DOEDriver

Optimization variables: - Frequency: 5 - 200 kHz - HV voltage (breakdown initiation): 5 - 80 kV - Gap distance: 6 - 30 mm

Constraints: - Total power deposited = 25 kW - HV voltage ≥ Breakdown voltage (pulse duration & frequency dependent) - Generator HV operating window: f * V^2 / (Z * t_pulse) <= P_max (Z=50Ω, t=20ns, P_max=1kW)

Fixed parameters: - Mass flow = 160 kg CH4/d - HV pulse duration = 1 µs (affects breakdown voltage) - Sustainer voltage = 500 V (independent of HV voltage) - Sustainer pulse duration = computed to match 25 kW target power

Outputs: - Energy per pulse (calculated) - Breakdown voltage (pulse duration & frequency-dependent) - Feasible parameter space visualization

Physics: - Short pulse durations require higher voltage for avalanche development - High pulse rates (many pulses/molecule) reduce breakdown voltage via memory effects - Transition from single-pulse to DC breakdown as frequency increases

Usage:

Standard execution (uses Python multiprocessing by default):

python example_vdef_mdao.py

MPI-based parallel execution (Linux/Unix only, requires mpi4py and petsc4py):

mpiexec -n $($env:NUMBER_OF_PROCESSORS - 1) python example_vdef_mdao.py

Note:

On Windows, parallel execution uses Python’s multiprocessing module by default. MPI with PETSc is not easily available on Windows.

import numpy as np
import openmdao.api as om

from paroto.mdao.group import TorchDesignGroup
from paroto.models.ablation import SimpleAblationModel
from paroto.models.breakdown_voltage import RepetitivePulseBreakdownModel
from paroto.models.gas_properties import GasHeatCapacityModel
from paroto.utils.io_utils import write_results_to_sqlite


def setup_problem_with_power_constraint():
    r"""Set up OpenMDAO problem with 25 kW power constraint.

    .. warning::
       Equations are AI generated, not ready for production.

    Uses RepetitivePulseBreakdownModel which accounts for:
    - Pulse duration effects (ionization avalanche time)
    - Memory effects from repeated pulsing
    - Field enhancement factors

    The problem includes the following key equations:

    Energy per pulse:

    .. math::

       E_{pulse} = \\frac{P_{gas}}{f_{pulse}}

    Power constraint error:

    .. math::

       \\epsilon_{power} = \\frac{P_{gas} - P_{target}}{P_{target}}

    where:
    - \\(E_{pulse}\\) is the energy deposited per pulse (J)
    - \\(P_{gas}\\) is the power transferred to gas (W)
    - \\(f_{pulse}\\) is the pulse frequency (Hz)
    - \\(P_{target}\\) = 25,000 W is the target power

    Returns
    -------
    prob : om.Problem
        Configured problem ready to run
    """
    # Create problem with torch design group
    prob = om.Problem()

    # Configure models
    breakdown_model = RepetitivePulseBreakdownModel(
        model_parameters={
            "A_paschen": 24.4,
            "B_paschen": 6.73,
            "field_enhancement_factor": 1.0,
            "ionization_time": 1e-7,
            "pulse_amplitude_factor": 4.0,
            "decay_constant": 1.0,
        }
    )
    ablation_model = SimpleAblationModel(
        model_parameters={"threshold_fraction": 0.8, "ablation_coefficient": 1.0e-5}
    )
    gas_heat_capacity = GasHeatCapacityModel(
        model_parameters={"cp_coeff_a": 1700.0, "cp_coeff_b": 9.08, "cp_coeff_c": -0.002}
    )

    # Create torch design group with models
    prob.model = TorchDesignGroup(
        max_current_density=1.0e6,
        breakdown_model=breakdown_model,
        ablation_model=ablation_model,
        gas_density=0.717,
        gas_viscosity=1.1e-5,
        gas_heat_capacity=gas_heat_capacity,
    )

    # Add required energy per pulse calculation (from target power and frequency)
    prob.model.add_subsystem(
        "required_energy",
        om.ExecComp(
            "required_energy_per_pulse = target_power / (pulse_frequency + 1e-10)",
            required_energy_per_pulse={"units": "J"},
            target_power={"units": "W", "val": 25000.0},
            pulse_frequency={"units": "Hz"},
        ),
        promotes=["pulse_frequency"],
    )

    # Add actual energy per pulse calculation
    prob.model.add_subsystem(
        "actual_energy",
        om.ExecComp(
            "energy_per_pulse = arc_voltage * arc_current * sustainer_pulse_duration",
            energy_per_pulse={"units": "J"},
            arc_voltage={"units": "V"},
            arc_current={"units": "A"},
            sustainer_pulse_duration={"units": "s"},
        ),
        promotes=["arc_voltage", "arc_current"],
    )

    # Use BalanceComp to iteratively solve for sustainer_pulse_duration
    # Balance: actual_energy = required_energy by adjusting sustainer_pulse_duration
    # Improved initial guess and scaling based on expected range:
    # At 5 kHz:   E_pulse ≈ 5 J,   with I≈4500A, V≈300V → t ≈ 3.7 µs
    # At 200 kHz: E_pulse ≈ 0.125 J, with I≈4500A, V≈300V → t ≈ 0.093 µs
    # Using 10 µs as middle-ground initial guess
    balance = prob.model.add_subsystem("balance", om.BalanceComp())
    balance.add_balance(
        "sustainer_pulse_duration",
        units="s",
        eq_units="J",
        lhs_name="energy_per_pulse",
        rhs_name="required_energy_per_pulse",
        val=10e-6,  # Better guess: 10 µs is middle of expected range
        lower=0.01e-6,  # 0.01 µs (10 ns) minimum - allow very short pulses at high freq
        upper=1e-3,  # 1 ms maximum
        ref=10e-6,  # Reference for normalization (typical mid-range value)
        ref0=1e-6,  # Lower reference bound
    )

    # Connect balance output to system
    prob.model.connect("balance.sustainer_pulse_duration", "sustainer_pulse_duration")
    prob.model.connect("balance.sustainer_pulse_duration", "actual_energy.sustainer_pulse_duration")

    # Connect energies to balance
    prob.model.connect("actual_energy.energy_per_pulse", "balance.energy_per_pulse")
    prob.model.connect(
        "required_energy.required_energy_per_pulse", "balance.required_energy_per_pulse"
    )

    # Add power check (should equal target power)
    prob.model.add_subsystem(
        "power_check",
        om.ExecComp(
            "power_error = (power_to_gas - target_power) / target_power",
            power_error={"val": 0.0},
            power_to_gas={"units": "W"},
            target_power={"units": "W", "val": 25000.0},
        ),
        promotes=["power_to_gas"],
    )

    # Set input defaults to resolve promotion ambiguities (after all subsystems added)
    prob.model.set_input_defaults("pulse_frequency", val=1000.0, units="Hz")
    prob.model.set_input_defaults("pulse_duration", val=1e-6, units="s")
    prob.model.set_input_defaults("hv_voltage", val=10000.0, units="V")
    prob.model.set_input_defaults("sustainer_voltage", val=500.0, units="V")
    prob.model.set_input_defaults("gap_distance", val=0.01, units="m")
    # sustainer_pulse_duration is set by the balance component

    # Add nonlinear solver to solve for sustainer_pulse_duration via balance
    # Use Newton solver with actual derivatives (now that declare_partials are fixed)
    prob.model.nonlinear_solver = om.NewtonSolver(maxiter=20, solve_subsystems=True)
    prob.model.nonlinear_solver.options["err_on_non_converge"] = False
    prob.model.nonlinear_solver.options["atol"] = 1e-6
    prob.model.nonlinear_solver.options["rtol"] = 1e-6
    prob.model.nonlinear_solver.linesearch = om.BoundsEnforceLS()
    prob.model.linear_solver = om.DirectSolver()

    return prob


def evaluate_single_case(case_params):
    """Evaluate a single parameter case.

    Parameters
    ----------
    case_params : tuple
        (frequency, hv_voltage, gap_distance) parameter values

    Returns
    -------
    dict or None
        Results dictionary if successful, None if failed
    """
    freq, hv_volt, gap = case_params

    try:
        # Create a fresh problem for this case
        prob = setup_problem_with_power_constraint()
        prob.setup()

        # Set fixed parameters
        mass_flow = 160 / 86400  # kg/d to kg/s
        prob.set_val("pulse_duration", 1e-6)
        prob.set_val("sustainer_voltage", 300.0)
        prob.set_val("mass_flow", mass_flow)
        prob.set_val("gas_properties_pressure", 101325.0)
        prob.set_val("preheat_temperature", 300.0)
        prob.set_val("electrode_radius", 0.005)
        prob.set_val("graphite_sublimation_temp", 3900.0)
        prob.set_val("graphite_heat_capacity", 710.0)
        prob.set_val("torch_length", 0.1)
        prob.set_val("torch_diameter", 0.02)

        # Set design variables for this case
        prob.set_val("pulse_frequency", freq)
        prob.set_val("hv_voltage", hv_volt)
        prob.set_val("gap_distance", gap)

        # Run the model
        prob.run_model()

        # Check if nonlinear solver converged
        solver = prob.model.nonlinear_solver
        if hasattr(solver, "_iter_count") and hasattr(solver, "options"):
            if solver._iter_count >= solver.options["maxiter"]:
                # Solver reached max iterations without converging
                return None

        # Extract results and check for NaN/invalid values
        results = {
            "frequency": freq,
            "hv_voltage": hv_volt,
            "gap": gap,
            "energy_per_pulse": prob.get_val("actual_energy.energy_per_pulse")[0],
            "power_to_gas": prob.get_val("power_to_gas")[0],
            "power_error": abs(prob.get_val("power_check.power_error")[0]),
            "breakdown_voltage": prob.get_val("breakdown_voltage")[0],
            "breakdown_margin": prob.get_val("breakdown_margin")[0],
            "arc_current": prob.get_val("arc_current")[0],
            "sustainer_pulse_duration": prob.get_val("balance.sustainer_pulse_duration")[0],
            "operating_window_power": prob.get_val("operating_window_power")[0],
            "hv_operating_window_satisfied": prob.get_val("hv_operating_window_satisfied")[0],
        }

        # Check for NaN or infinite values (sign of convergence failure)
        for key, value in results.items():
            if key in ["frequency", "hv_voltage", "gap"]:
                continue  # Skip input parameters
            if not np.isfinite(value):
                return None

        return results
    except Exception:
        # Return None for failed cases (exceptions, NaN, etc.)
        return None


def run_parameter_sweep(use_multiprocessing=True, n_procs=None):
    r"""Run parameter space exploration using parallel DOEDriver.

    .. warning::
       Equations are AI generated, not ready for production.

    Evaluates the model over a grid of parameters using OpenMDAO's DOEDriver
    with parallel execution support:

    .. math::

       \\mathcal{P} = \\{(f, V_{HV}, d) : f \\in [5, 200]\\text{ kHz},
                      V_{HV} \\in [5, 80]\\text{ kV},
                      d \\in [6, 30]\\text{ mm}\\}

    where:
    - :math:`f` is the pulse frequency
    - :math:`V_{HV}` is the HV voltage (breakdown initiation)
    - :math:`d` is the gap distance

    Feasibility criterion:

    .. math::

       |\\epsilon_{power}| < 0.10

    Parameters
    ----------
    use_multiprocessing : bool, optional
        Use Python multiprocessing for parallel execution (default: True).
        If False, uses MPI (requires mpiexec and petsc4py).
    n_procs : int, optional
        Number of processors to use. If None, uses all available cores minus 1.

    Returns
    -------
    results : dict
        Dictionary containing sweep results
    """
    import os

    print("\n" + "=" * 70)
    print("  VDEF MDAO Example - Parameter Space Exploration")
    if use_multiprocessing:
        print("  (Parallel execution via multiprocessing)")
    else:
        print("  (Parallel execution via MPI)")
    print("=" * 70)

    # Convert mass flow rate
    mass_flow = 160 / 86400  # kg/d to kg/s

    # Define parameter ranges (adjusted for 25 kW target)
    frequency_range = np.linspace(5000, 200000, 10)  # Hz (5-200 kHz)
    hv_voltage_range = np.linspace(5000, 80000, 15)  # V (5-80 kV HV voltage)
    gap_range = np.linspace(0.006, 0.030, 8)  # m (6-30 mm)

    # Fixed parameters
    sustainer_voltage_fixed = 300.0  # V (constant, independent of HV voltage)
    hv_pulse_duration = 1e-6  # 1 µs (HV pulse duration)

    print("\nParameter ranges:")
    print(
        f"  Frequency: {frequency_range.min() / 1000:.0f} - {frequency_range.max() / 1000:.0f} kHz"
    )
    print(
        f"  HV Voltage: {hv_voltage_range.min() / 1000:.0f} - "
        f"{hv_voltage_range.max() / 1000:.0f} kV"
    )
    print(f"  Gap:       {gap_range.min() * 1000:.1f} - {gap_range.max() * 1000:.1f} mm")
    print(f"  Sustainer Voltage: {sustainer_voltage_fixed:.0f} V (fixed)")
    print(f"  HV Pulse Duration: {hv_pulse_duration * 1e6:.1f} µs (fixed)")
    print("  Target power: 25.0 kW")
    print(f"  Mass flow: {mass_flow:.6f} kg/s (fixed)")

    # Create problem with DOEDriver
    prob = setup_problem_with_power_constraint()

    # Add design variables for DOE
    prob.model.add_design_var(
        "pulse_frequency", lower=frequency_range.min(), upper=frequency_range.max()
    )
    prob.model.add_design_var(
        "hv_voltage", lower=hv_voltage_range.min(), upper=hv_voltage_range.max()
    )
    prob.model.add_design_var("gap_distance", lower=gap_range.min(), upper=gap_range.max())

    # Add objectives to record (not actual optimization objectives)
    prob.model.add_objective("power_to_gas")

    # Determine number of processors
    # Use half of available cores to avoid memory exhaustion (each worker loads full OpenMDAO model)
    if n_procs is None:
        total_cores = os.cpu_count() if os.cpu_count() is not None else 1
        n_procs = max(1, total_cores // 2)  # Use half of available cores

    if use_multiprocessing:
        print(f"\nUsing {n_procs} processors for parallel execution via multiprocessing")
    else:
        print("\nUsing MPI for parallel execution")

    # Create DOE driver with full factorial design
    prob.driver = om.DOEDriver(
        om.FullFactorialGenerator(
            levels={
                "pulse_frequency": len(frequency_range),
                "hv_voltage": len(hv_voltage_range),
                "gap_distance": len(gap_range),
            }
        )
    )

    # Only enable OpenMDAO's parallel execution for MPI mode
    if not use_multiprocessing:
        prob.driver.options["run_parallel"] = True
        prob.driver.options["procs_per_model"] = 1

    # Create persistent SQLite database for case recording
    recorder_file = "parameter_sweep_cases.sql"

    # Add recorder to capture all cases
    recorder = om.SqliteRecorder(recorder_file)
    prob.driver.add_recorder(recorder)
    prob.driver.recording_options["includes"] = [
        "pulse_frequency",
        "hv_voltage",
        "gap_distance",
        "power_to_gas",
        "power_check.power_error",
        "actual_energy.energy_per_pulse",
        "breakdown_voltage",
        "breakdown_margin",
        "arc_current",
        "balance.sustainer_pulse_duration",
        "operating_window_power",
        "hv_operating_window_satisfied",
    ]
    prob.driver.recording_options["record_objectives"] = True
    prob.driver.recording_options["record_constraints"] = False
    prob.driver.recording_options["record_desvars"] = True

    # Setup and set fixed parameters
    prob.setup()

    prob.set_val("pulse_duration", hv_pulse_duration)
    prob.set_val("sustainer_voltage", sustainer_voltage_fixed)
    prob.set_val("mass_flow", mass_flow)
    prob.set_val("gas_properties_pressure", 101325.0)
    prob.set_val("preheat_temperature", 300.0)
    prob.set_val("electrode_radius", 0.005)
    prob.set_val("graphite_sublimation_temp", 3900.0)
    prob.set_val("graphite_heat_capacity", 710.0)
    prob.set_val("torch_length", 0.1)
    prob.set_val("torch_diameter", 0.02)

    total_evals = len(frequency_range) * len(hv_voltage_range) * len(gap_range)
    print("\nRunning parallel parameter sweep...")
    print(f"  Total evaluations: {total_evals}")

    # Initialize results dictionary with convergence tracking
    results = {
        "frequency": [],
        "hv_voltage": [],
        "gap": [],
        "energy_per_pulse": [],
        "power_to_gas": [],
        "power_error": [],
        "breakdown_voltage": [],
        "breakdown_margin": [],
        "arc_current": [],
        "sustainer_pulse_duration": [],
        "operating_window_power": [],
        "hv_operating_window_satisfied": [],
        "converged": [],  # Track convergence status
    }

    successful_count = 0
    failed_count = 0

    if use_multiprocessing:
        # Use multiprocessing for Windows
        from concurrent.futures import ProcessPoolExecutor
        from itertools import product

        # Generate all parameter combinations
        param_cases = list(product(frequency_range, hv_voltage_range, gap_range))

        print(f"  Using ProcessPoolExecutor with {n_procs} workers")

        # Run cases in parallel
        with ProcessPoolExecutor(max_workers=n_procs) as executor:
            case_results = list(executor.map(evaluate_single_case, param_cases))

        # Process results - store all cases including failed ones
        for i, result in enumerate(case_results):
            freq, hv_volt, gap = param_cases[i]
            if result is not None:
                # Converged case
                for key, value in result.items():
                    results[key].append(value)
                results["converged"].append(True)
                successful_count += 1
            else:
                # Non-converged case - store parameters but mark as failed
                results["frequency"].append(freq)
                results["hv_voltage"].append(hv_volt)
                results["gap"].append(gap)
                results["energy_per_pulse"].append(np.nan)
                results["power_to_gas"].append(np.nan)
                results["power_error"].append(np.nan)
                results["breakdown_voltage"].append(np.nan)
                results["breakdown_margin"].append(np.nan)
                results["arc_current"].append(np.nan)
                results["sustainer_pulse_duration"].append(np.nan)
                results["operating_window_power"].append(np.nan)
                results["hv_operating_window_satisfied"].append(np.nan)
                results["converged"].append(False)
                failed_count += 1
    else:
        # Use MPI-based DOE driver (Linux/Unix with PETSc)
        print(f"  Recording to: {recorder_file}")
        prob.run_driver()

        print("\nExtracting results from case recorder...")

        # Extract results from recorder
        cr = om.CaseReader(recorder_file)
        cases = cr.list_cases("driver")

        for case_id in cases:
            try:
                case = cr.get_case(case_id)

                # Extract design variables
                freq = case.get_val("pulse_frequency")[0]
                hv_volt = case.get_val("hv_voltage")[0]
                gap = case.get_val("gap_distance")[0]

                # Extract outputs
                power_to_gas = case.get_val("power_to_gas")[0]
                power_error = case.get_val("power_check.power_error")[0]
                energy_per_pulse = case.get_val("actual_energy.energy_per_pulse")[0]
                breakdown_voltage_val = case.get_val("breakdown_voltage")[0]
                breakdown_margin_val = case.get_val("breakdown_margin")[0]
                arc_current = case.get_val("arc_current")[0]
                sustainer_pulse_duration = case.get_val("balance.sustainer_pulse_duration")[0]
                operating_window_power = case.get_val("operating_window_power")[0]
                hv_operating_window_satisfied = case.get_val("hv_operating_window_satisfied")[0]

                # Store results
                results["frequency"].append(freq)
                results["hv_voltage"].append(hv_volt)
                results["gap"].append(gap)
                results["energy_per_pulse"].append(energy_per_pulse)
                results["power_to_gas"].append(power_to_gas)
                results["power_error"].append(abs(power_error))
                results["breakdown_voltage"].append(breakdown_voltage_val)
                results["breakdown_margin"].append(breakdown_margin_val)
                results["arc_current"].append(arc_current)
                results["sustainer_pulse_duration"].append(sustainer_pulse_duration)
                results["operating_window_power"].append(operating_window_power)
                results["hv_operating_window_satisfied"].append(hv_operating_window_satisfied)
                results["converged"].append(True)

                successful_count += 1

            except Exception as e:
                failed_count += 1
                if failed_count <= 5:
                    print(f"  Warning: Failed to extract case {case_id}: {str(e)}")

    print(f"\nCompleted {successful_count + failed_count} evaluations")
    print(f"  Successful: {successful_count}")
    print(f"  Failed: {failed_count}")

    # Convert to numpy arrays
    for key in results:
        results[key] = np.array(results[key])

    # For multiprocessing mode (Windows), manually write results to SQLite
    # This is necessary because ProcessPoolExecutor bypasses OpenMDAO's recorder
    if use_multiprocessing:
        print(f"\nWriting results to SQLite database: {recorder_file}")
        write_results_to_sqlite(results, recorder_file)

    print(f"Results saved to SQLite database: {recorder_file}")

    return results


def visualize_parameter_space(results):
    """Visualize the feasible parameter space.

    .. warning::
       Equations are AI generated, not ready for production.

    Parameters
    ----------
    results : dict
        Results from parameter sweep
    """
    print("\n" + "-" * 70)
    print("  Generating visualizations...")
    print("-" * 70)

    # Create three-way classification: feasible, converged-infeasible, non-converged
    tolerance = 0.10
    converged_mask = np.array(results["converged"], dtype=bool)

    # For converged points, check feasibility
    power_feasible = np.array(results["power_error"]) < tolerance
    breakdown_feasible = np.array(results["breakdown_margin"]) > 0
    hv_feasible = np.array(results["hv_operating_window_satisfied"]) >= 1.0
    constraints_met = power_feasible & breakdown_feasible & hv_feasible

    # Three categories
    feasible_mask = converged_mask & constraints_met
    converged_infeasible_mask = converged_mask & ~constraints_met
    non_converged_mask = ~converged_mask

    n_total = len(results["frequency"])
    n_feasible = np.sum(feasible_mask)
    n_converged_infeasible = np.sum(converged_infeasible_mask)
    n_non_converged = np.sum(non_converged_mask)
    n_converged = n_feasible + n_converged_infeasible

    print("\nConvergence and Feasibility analysis:")
    print(f"  Total points evaluated: {n_total}")
    print(f"  Converged: {n_converged} ({n_converged / n_total * 100:.1f}%)")
    print(f"    - Feasible (meets constraints): {n_feasible} ({n_feasible / n_total * 100:.1f}%)")
    print(f"    - Infeasible (violates constraints): {n_converged_infeasible}")
    print(f"  Non-converged: {n_non_converged} ({n_non_converged / n_total * 100:.1f}%)")

    # Detailed constraint violation analysis (for converged points)
    if n_converged > 0:
        print("\n  Constraint violation breakdown (converged points only):")
        n_power_violated = np.sum(converged_mask & ~power_feasible)
        n_breakdown_violated = np.sum(converged_mask & ~breakdown_feasible)
        n_hv_violated = np.sum(converged_mask & ~hv_feasible)
        print(f"    - Power constraint violated: {n_power_violated}")
        print(f"    - Breakdown constraint violated: {n_breakdown_violated}")
        print(f"    - HV operating window violated: {n_hv_violated}")

    if n_feasible == 0:
        print("\n  WARNING: No feasible points found.")

        # Check if we have any results at all
        if n_total == 0:
            print("\n  ERROR: No successful evaluations. Check parameter ranges and model setup.")
            print("\n" + "=" * 70)
            return

        print("\n" + "-" * 70)
        print("  Closest Points Analysis")
        print("-" * 70)

        # Analyze the distribution of power errors
        power_errors = results["power_error"]
        power_actual = results["power_to_gas"] / 1000  # kW
        target_power = 25.0  # kW

        if len(power_errors) > 0:
            print("\nPower error statistics:")
            print(f"  Min error:    {power_errors.min() * 100:.1f}%")
            print(f"  Median error: {np.median(power_errors) * 100:.1f}%")
            print(f"  Max error:    {power_errors.max() * 100:.1f}%")
            print("\nActual power range:")
            print(f"  Min power: {power_actual.min():.2f} kW (target: {target_power} kW)")
            print(f"  Max power: {power_actual.max():.2f} kW (target: {target_power} kW)")
        else:
            print("\n  No power error data available - all evaluations failed.")
            print("\n" + "=" * 70)
            return

        # Find the N closest points
        n_closest = min(10, n_total)
        closest_indices = np.argsort(power_errors)[:n_closest]

        print(f"\n{n_closest} Closest points to target power:")
        print("-" * 125)
        print(
            f"{'#':<3} {'Freq (kHz)':<12} {'HV (kV)':<10} {'Gap (mm)':<10} "
            f"{'Power (kW)':<12} {'Error (%)':<10} {'V_bd (V)':<10} {'Margin (V)':<12}"
        )
        print("-" * 125)

        for i, idx in enumerate(closest_indices, 1):
            freq = results["frequency"][idx] / 1000  # Convert to kHz
            volt = results["hv_voltage"][idx] / 1000  # Convert to kV
            gap = results["gap"][idx] * 1000  # Convert to mm
            power = results["power_to_gas"][idx] / 1000  # Convert to kW
            error = results["power_error"][idx] * 100  # Convert to %
            v_bd = results["breakdown_voltage"][idx]
            margin = results["breakdown_margin"][idx]

            print(
                f"{i:<3} {freq:<12.1f} {volt:<10.1f} {gap:<10.2f} "
                f"{power:<12.2f} {error:<10.1f} {v_bd:<10.0f} {margin:<12.0f}"
            )

        print("-" * 125)

        # Provide diagnostic suggestions
        print("\nDiagnostic suggestions:")
        avg_power = power_actual.mean()
        if avg_power < target_power * 0.5:
            print(
                "  -> Power is too low. Consider: increasing voltage, pulse duration, "
                "decreasing frequency, or increasing gap"
            )
        elif avg_power > target_power * 1.5:
            print(
                "  -> Power is too high. Consider: decreasing voltage, pulse duration, "
                "increasing frequency, or decreasing gap"
            )
        else:
            print("  -> Power range overlaps target. Consider: refining grid resolution")

        best_idx = closest_indices[0]
        best_error = power_errors[best_idx] * 100
        if best_error < 20:
            print(f"  -> Best point is within {best_error:.1f}% - close to feasible!")
            print("    Try refining the search around the closest points shown above.")

        print("\n" + "=" * 70)
        return

    # Convert arrays for convenience
    freq_array = np.array(results["frequency"])
    volt_array = np.array(results["hv_voltage"])
    gap_array = np.array(results["gap"])
    energy_array = np.array(results["energy_per_pulse"])
    current_array = np.array(results["arc_current"])

    # Extract feasible points
    freq_feas = freq_array[feasible_mask] / 1000  # Convert to kHz
    volt_feas = volt_array[feasible_mask] / 1000  # Convert to kV
    gap_feas = gap_array[feasible_mask] * 1000  # Convert to mm
    energy_feas = energy_array[feasible_mask]
    current_feas = current_array[feasible_mask]

    # Extract converged but infeasible points
    freq_conv_infeas = freq_array[converged_infeasible_mask] / 1000
    volt_conv_infeas = volt_array[converged_infeasible_mask] / 1000
    gap_conv_infeas = gap_array[converged_infeasible_mask] * 1000
    energy_conv_infeas = energy_array[converged_infeasible_mask]
    current_conv_infeas = current_array[converged_infeasible_mask]

    # Extract non-converged points
    freq_non_conv = freq_array[non_converged_mask] / 1000
    volt_non_conv = volt_array[non_converged_mask] / 1000
    gap_non_conv = gap_array[non_converged_mask] * 1000
    # Non-converged points have NaN for outputs, so we don't use them for coloring

    print("\nFeasible parameter ranges:")
    print(f"  Frequency: {freq_feas.min():.1f} - {freq_feas.max():.1f} kHz")
    print(f"  HV Voltage: {volt_feas.min():.1f} - {volt_feas.max():.1f} kV")
    print(f"  Gap:       {gap_feas.min():.1f} - {gap_feas.max():.1f} mm")
    print(f"  Energy/pulse: {energy_feas.min():.1f} - {energy_feas.max():.1f} J")
    print(f"  Arc current:  {current_feas.min():.1f} - {current_feas.max():.1f} A")

    # Import matplotlib only in main process for visualization (avoid memory issues in workers)
    import matplotlib.pyplot as plt

    # Create figure with subplots
    fig = plt.figure(figsize=(16, 10))
    fig.suptitle(
        f"Plasma Torch Parameter Space (25 kW Constraint) - "
        f"Converged: {n_converged}/{n_total}, Feasible: {n_feasible}/{n_total} "
        f"({n_feasible / n_total * 100:.1f}%)",
        fontsize=16,
        fontweight="bold",
    )

    # Plot 1: Frequency vs Voltage, colored by Energy per pulse
    ax1 = fig.add_subplot(2, 3, 1)
    # Plot non-converged points (crosses) - gray, no color mapping
    if len(freq_non_conv) > 0:
        ax1.scatter(
            freq_non_conv,
            volt_non_conv,
            c="dimgray",
            marker="x",
            s=100,
            alpha=0.7,
            linewidth=2.5,
            label=f"Non-converged ({len(freq_non_conv)})",
        )
    # Plot converged but infeasible points (triangles with red edge)
    if len(freq_conv_infeas) > 0:
        ax1.scatter(
            freq_conv_infeas,
            volt_conv_infeas,
            c=energy_conv_infeas,
            cmap="viridis",
            marker="^",
            s=80,
            alpha=0.8,
            edgecolors="red",
            linewidth=1.2,
            label=f"Conv. Infeasible ({len(freq_conv_infeas)})",
        )
    # Plot feasible points (circles with green edge)
    scatter1 = ax1.scatter(
        freq_feas,
        volt_feas,
        c=energy_feas,
        cmap="viridis",
        marker="o",
        s=80,
        alpha=0.9,
        edgecolors="green",
        linewidth=1.2,
        label=f"Feasible ({len(freq_feas)})",
    )
    ax1.set_xlabel("Frequency (kHz)", fontsize=11)
    ax1.set_ylabel("HV Voltage (kV)", fontsize=11)
    ax1.set_title("Frequency vs HV Voltage", fontweight="bold")
    ax1.grid(True, alpha=0.3)
    ax1.legend(loc="best", fontsize=8)
    cbar1 = plt.colorbar(scatter1, ax=ax1)
    cbar1.set_label("Energy/pulse (J)", fontsize=10)

    # Plot 2: Frequency vs Gap, colored by Energy per pulse
    ax2 = fig.add_subplot(2, 3, 2)
    # Plot non-converged points
    if len(freq_non_conv) > 0:
        ax2.scatter(
            freq_non_conv,
            gap_non_conv,
            c="dimgray",
            marker="x",
            s=100,
            alpha=0.7,
            linewidth=2.5,
            label="Non-converged",
        )
    # Plot converged but infeasible points
    if len(freq_conv_infeas) > 0:
        ax2.scatter(
            freq_conv_infeas,
            gap_conv_infeas,
            c=energy_conv_infeas,
            cmap="viridis",
            marker="^",
            s=80,
            alpha=0.8,
            edgecolors="red",
            linewidth=1.2,
            label="Conv. Infeasible",
        )
    # Plot feasible points
    scatter2 = ax2.scatter(
        freq_feas,
        gap_feas,
        c=energy_feas,
        cmap="viridis",
        marker="o",
        s=80,
        alpha=0.9,
        edgecolors="green",
        linewidth=1.2,
        label="Feasible",
    )
    ax2.set_xlabel("Frequency (kHz)", fontsize=11)
    ax2.set_ylabel("Gap Distance (mm)", fontsize=11)
    ax2.set_title("Frequency vs Gap", fontweight="bold")
    ax2.grid(True, alpha=0.3)
    ax2.legend(loc="best", fontsize=8)
    cbar2 = plt.colorbar(scatter2, ax=ax2)
    cbar2.set_label("Energy/pulse (J)", fontsize=10)

    # Plot 3: Voltage vs Gap, colored by Energy per pulse
    ax3 = fig.add_subplot(2, 3, 3)
    # Plot non-converged points
    if len(volt_non_conv) > 0:
        ax3.scatter(
            volt_non_conv,
            gap_non_conv,
            c="dimgray",
            marker="x",
            s=100,
            alpha=0.7,
            linewidth=2.5,
            label="Non-converged",
        )
    # Plot converged but infeasible points
    if len(volt_conv_infeas) > 0:
        ax3.scatter(
            volt_conv_infeas,
            gap_conv_infeas,
            c=energy_conv_infeas,
            cmap="viridis",
            marker="^",
            s=80,
            alpha=0.8,
            edgecolors="red",
            linewidth=1.2,
            label="Conv. Infeasible",
        )
    # Plot feasible points
    scatter3 = ax3.scatter(
        volt_feas,
        gap_feas,
        c=energy_feas,
        cmap="viridis",
        marker="o",
        s=80,
        alpha=0.9,
        edgecolors="green",
        linewidth=1.2,
        label="Feasible",
    )
    ax3.set_xlabel("HV Voltage (kV)", fontsize=11)
    ax3.set_ylabel("Gap Distance (mm)", fontsize=11)
    ax3.set_title("HV Voltage vs Gap", fontweight="bold")
    ax3.grid(True, alpha=0.3)
    ax3.legend(loc="best", fontsize=8)
    cbar3 = plt.colorbar(scatter3, ax=ax3)
    cbar3.set_label("Energy/pulse (J)", fontsize=10)

    # Plot 4: Frequency vs Voltage, colored by Arc Current
    ax4 = fig.add_subplot(2, 3, 4)
    # Plot non-converged points
    if len(freq_non_conv) > 0:
        ax4.scatter(
            freq_non_conv,
            volt_non_conv,
            c="dimgray",
            marker="x",
            s=100,
            alpha=0.7,
            linewidth=2.5,
            label="Non-converged",
        )
    # Plot converged but infeasible points
    if len(freq_conv_infeas) > 0:
        ax4.scatter(
            freq_conv_infeas,
            volt_conv_infeas,
            c=current_conv_infeas,
            cmap="plasma",
            marker="^",
            s=80,
            alpha=0.8,
            edgecolors="red",
            linewidth=1.2,
            label="Conv. Infeasible",
        )
    # Plot feasible points
    scatter4 = ax4.scatter(
        freq_feas,
        volt_feas,
        c=current_feas,
        cmap="plasma",
        marker="o",
        s=80,
        alpha=0.9,
        edgecolors="green",
        linewidth=1.2,
        label="Feasible",
    )
    ax4.set_xlabel("Frequency (kHz)", fontsize=11)
    ax4.set_ylabel("HV Voltage (kV)", fontsize=11)
    ax4.set_title("Frequency vs HV Voltage (Arc Current)", fontweight="bold")
    ax4.grid(True, alpha=0.3)
    ax4.legend(loc="best", fontsize=8)
    cbar4 = plt.colorbar(scatter4, ax=ax4)
    cbar4.set_label("Arc Current (A)", fontsize=10)

    # Plot 5: Energy per pulse distribution - show all three categories
    ax5 = fig.add_subplot(2, 3, 5)
    # Plot converged infeasible distribution
    if len(energy_conv_infeas) > 0:
        ax5.hist(
            energy_conv_infeas,
            bins=30,
            color="orange",
            alpha=0.5,
            edgecolor="darkred",
            label=f"Conv. Infeasible ({len(energy_conv_infeas)})",
            linewidth=1.5,
        )
    # Plot feasible distribution
    if len(energy_feas) > 0:
        ax5.hist(
            energy_feas,
            bins=30,
            color="steelblue",
            alpha=0.7,
            edgecolor="black",
            label=f"Feasible ({len(energy_feas)})",
            linewidth=1.5,
        )
    ax5.set_xlabel("Energy per Pulse (J)", fontsize=11)
    ax5.set_ylabel("Frequency Count", fontsize=11)
    ax5.set_title("Energy per Pulse Distribution", fontweight="bold")
    ax5.grid(True, alpha=0.3, axis="y")
    # Add mean line for feasible points
    if len(energy_feas) > 0:
        ax5.axvline(
            energy_feas.mean(),
            color="darkblue",
            linestyle="--",
            linewidth=2,
            label=f"Feasible Mean: {energy_feas.mean():.1f} J",
        )
    ax5.legend(fontsize=8)

    # Plot 6: 3D scatter (Frequency, Voltage, Gap)
    ax6 = fig.add_subplot(2, 3, 6, projection="3d")
    # Plot non-converged points
    if len(freq_non_conv) > 0:
        ax6.scatter(
            freq_non_conv,
            volt_non_conv,
            gap_non_conv,
            c="dimgray",
            marker="x",
            s=80,
            alpha=0.7,
            linewidth=2.5,
            label="Non-converged",
        )
    # Plot converged but infeasible points
    if len(freq_conv_infeas) > 0:
        ax6.scatter(
            freq_conv_infeas,
            volt_conv_infeas,
            gap_conv_infeas,
            c=energy_conv_infeas,
            cmap="viridis",
            marker="^",
            s=50,
            alpha=0.8,
            edgecolors="red",
            linewidth=1.2,
            label="Conv. Infeasible",
        )
    # Plot feasible points
    scatter6 = ax6.scatter(
        freq_feas,
        volt_feas,
        gap_feas,
        c=energy_feas,
        cmap="viridis",
        marker="o",
        s=50,
        alpha=0.9,
        edgecolors="green",
        linewidth=1.2,
        label="Feasible",
    )
    ax6.set_xlabel("Frequency (kHz)", fontsize=10)
    ax6.set_ylabel("HV Voltage (kV)", fontsize=10)
    ax6.set_zlabel("Gap (mm)", fontsize=10)
    ax6.set_title("3D Parameter Space", fontweight="bold")
    ax6.legend(loc="best", fontsize=8)
    cbar6 = plt.colorbar(scatter6, ax=ax6, shrink=0.7)
    cbar6.set_label("Energy/pulse (J)", fontsize=9)

    plt.tight_layout()

    # Save figure
    output_path = "parameter_space_visualization.png"
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    print(f"\nVisualization saved to: {output_path}")

    # Show plot
    plt.show()

    print("\n" + "=" * 70)
    print("  Analysis complete!")
    print("=" * 70)


def generate_n2_diagram():
    """Generate N2 diagram for the MDAO problem.

    Creates an interactive HTML N2 diagram showing the connections
    between all subsystems and components in the torch design problem.
    The diagram is configured to show detailed node information by default.
    """
    print("\n" + "=" * 70)
    print("  Generating N2 Diagram")
    print("=" * 70)

    # Set up the problem
    prob = setup_problem_with_power_constraint()
    prob.setup()

    # Run the model once to populate values for better visualization
    print("\n  Running model to populate values...")
    prob.set_val("pulse_duration", 1e-6)  # s (1 µs HV pulse)
    prob.set_val("sustainer_voltage", 500.0)  # V
    prob.set_val("mass_flow", 160 / 86400)  # kg/s (converted from 160 kg/d)
    prob.set_val("gas_properties_pressure", 101325.0)  # Pa (atmospheric)
    prob.set_val("preheat_temperature", 300.0)  # K
    prob.set_val("electrode_radius", 0.005)  # m (5 mm)
    prob.set_val("graphite_sublimation_temp", 3900.0)  # K
    prob.set_val("graphite_heat_capacity", 710.0)  # J/(kg·K)
    prob.set_val("torch_length", 0.1)  # m (100 mm)
    prob.set_val("torch_diameter", 0.02)  # m (20 mm)
    prob.set_val("pulse_frequency", 50000.0)  # Hz (50 kHz)
    prob.set_val("hv_voltage", 20000.0)  # V (20 kV)
    prob.set_val("gap_distance", 0.01)  # m (10 mm)

    prob.run_model()

    # Generate N2 diagram with detailed information
    # The N2 diagram will show:
    # - Component hierarchy and connections
    # - Variable values and units
    # - Solver information
    om.n2(
        prob,
        outfile="n2_diagram.html",
        show_browser=True,
        title="Plasma Torch MDAO System - Detailed N2 Diagram",
    )

    print("\nN2 diagram generated successfully!")
    print("  Output file: n2_diagram.html")
    print("  The diagram should open in your default web browser.")
    print("\nInteractive features:")
    print("  - Click on any component to see detailed information")
    print("  - Hover over connections to see variable names and values")
    print("  - Use the toolbar to toggle between different views")
    print("  - Click the info icon in toolbar for help")
    print("\n" + "=" * 70)


def main(n_procs=None):
    """Execute parameter sweep using parallel DOEDriver."""
    # Run parameter sweep
    results = run_parameter_sweep(n_procs=n_procs)

    # Visualize results
    visualize_parameter_space(results)

    print("\nNext steps:")
    print("  1. Examine the feasible parameter space visualizations")
    print("  2. Select optimal operating point based on constraints")
    print("  3. Consider adding arc mobility model for refined analysis")
    print("  4. Validate with experimental data if available")

    print("\nParallel execution notes:")
    print("  - Windows: Uses Python multiprocessing automatically (current default)")
    print("  - Linux/Unix: Can use MPI with 'mpiexec -n <num_procs> python example_vdef_mdao.py'")
    print("  - DOEDriver automatically distributes cases across available processors")
    print()


if __name__ == "__main__":
    import sys

    try:
        # Check for command-line argument to generate N2 diagram only
        if len(sys.argv) > 1 and sys.argv[1] == "--n2":
            generate_n2_diagram()
        else:
            main(n_procs=4)
    except Exception as e:
        print(f"\nError: {e}")
        import traceback

        traceback.print_exc()