# Copyright © 2016-2023 Medical Image Analysis Laboratory, University Hospital
# Center and University of Lausanne (UNIL-CHUV), Switzerland
# This software is distributed under the open-source license Modified BSD.
"""Module for the assessment of the super-resolution reconstruction quality with a reference."""
from traits.api import *
from nipype.pipeline import engine as pe
from nipype.interfaces import utility as util
from nipype.interfaces.ants import RegistrationSynQuick, ApplyTransforms
import pymialsrtk.interfaces.postprocess as postprocess
import pymialsrtk.interfaces.preprocess as preprocess
import pymialsrtk.workflows.postproc_stage as postproc_stage
import logging
[docs]def create_srr_assessment_stage(
p_do_multi_parameters=False,
p_do_reconstruct_labels=False,
p_input_srtv_node=None,
p_verbose=False,
p_openmp_number_of_cores=1,
name="srr_assessment_stage",
):
"""Create an assessment workflow to compare a SR-reconstructed image and a reference target.
Parameters
----------
name : string
Name of workflow
(default: "sr_assessment_stage")
p_do_multi_parameters : boolean
whether multiple SR are to be assessed with different TV parameters
(default: `False`)
p_input_srtv_node : string
when p_do_multi_parameters is set, name of the sourcenode
from which metrics must be merged
p_openmp_number_of_cores : integer
number of threads possible for ants registration
(default : 1)
Inputs
--------
input_reference_image : pathlike object or string representing a file
Path to the ground truth image against which the SR will be evaluated.
input_reference_mask : pathlike object or string representing a file
Path to the mask of the ground truth image.
input_reference_labelmap : pathlike object or string representing a file
Path to the labelmap (tissue segmentation) of the ground truth image.
input_sr_image : pathlike object or string representing a file
Path to the SR reconstructed image.
input_sdi_image : pathlike object or string representing a file
Path to the SDI (interpolated image) used as input to the SR.
input_TV_parameters : dictionary
Dictionary of parameters that were used for the TV reconstruction.
Outputs
--------
outputnode.output_metrics : list of float
List of output metrics
Example
-------
>>> from pymialsrtk.pipelines.workflows import srr_assessment_stage as srr_assessment
>>> srr_eval = srr_assessment.create_srr_assessment_stage()
>>> srr_eval.inputs.input_reference_image = 'sub-01_desc-GT_T2w.nii.gz'
>>> srr_eval.inputs.input_reference_mask = 'sub-01_desc-GT_mask.nii.gz'
>>> srr_eval.inputs.input_reference_mask = 'sub-01_desc-GT_labels.nii.gz'
>>> srr_eval.inputs.input_sr_image = 'sub-01_id-1_rec-SR_T2w.nii.gz'
>>> srr_eval.inputs.input_sdi_image = 'sub-01_id-1_desc-SDI_T2w.nii.gz'
>>> srr_eval.inputs.input_TV_parameters = {
'in_loop': '10',
'in_deltat': '0.01',
'in_lambda': '0.75',
'in_bregman_loop': '3',
'in_iter': '50',
'in_step_scale': '1',
'in_gamma': '1',
'in_inner_thresh':
'1e-05',
'in_outer_thresh': '1e-06'
}
>>> srr_eval.run() # doctest: +SKIP
"""
srr_assessment_stage = pe.Workflow(name=name)
if not p_verbose:
# Removing verbose by removing the output altogether as we
# cannot control the verbosity level in the ANTs call
logging.getLogger(f"nipype.workflow.{name}").setLevel(0)
# Set up a node to define all inputs required for the
# preprocessing workflow
input_fields = [
"input_ref_image",
"input_ref_mask",
"input_ref_labelmap",
"input_sr_image",
"input_sdi_image",
"input_TV_parameters",
]
if p_do_reconstruct_labels:
input_fields += ["input_sr_labelmap"]
output_fields = ["output_metrics", "output_metrics_labels"]
inputnode = pe.Node(
interface=util.IdentityInterface(fields=input_fields), name="inputnode"
)
outputnode = pe.Node(
interface=util.IdentityInterface(fields=output_fields),
name="outputnode",
)
proc_reference = postproc_stage.create_postproc_stage(
p_ga=None,
p_do_anat_orientation=False,
p_do_reconstruct_labels=False,
p_verbose=p_verbose,
name="proc_reference",
)
crop_reference = pe.Node(
interface=preprocess.ReduceFieldOfView(), name="crop_reference"
)
registration_quick = pe.Node(
interface=RegistrationSynQuick(
num_threads=p_openmp_number_of_cores,
transform_type="r",
environ={"PATH": "/opt/conda/bin"},
terminal_output="file_stderr",
),
name="registration_quick",
)
apply_transform = pe.Node(
interface=ApplyTransforms(
num_threads=p_openmp_number_of_cores,
environ={"PATH": "/opt/conda/bin"},
terminal_output="file_stderr",
),
name="apply_transform",
)
if p_do_reconstruct_labels:
apply_transform_labels = pe.Node(
interface=ApplyTransforms(
num_threads=p_openmp_number_of_cores,
environ={"PATH": "/opt/conda/bin"},
),
name="apply_transform_labels",
terminal_output="file_stderr",
interpolation="NearestNeighbor",
)
mask_sr = pe.Node(interface=preprocess.MialsrtkMaskImage(), name="mask_sr")
sr_image_metrics = pe.Node(
postprocess.ImageMetrics(), name="sr_image_metrics"
)
if p_do_multi_parameters:
concat_sr_image_metrics = pe.JoinNode(
interface=postprocess.ConcatenateImageMetrics(),
joinfield=["input_metrics", "input_metrics_labels"],
joinsource=p_input_srtv_node,
name="concat_sr_image_metrics",
)
srr_assessment_stage.connect(
inputnode, "input_ref_image", proc_reference, "inputnode.input_image"
)
srr_assessment_stage.connect(
inputnode, "input_ref_mask", proc_reference, "inputnode.input_mask"
)
srr_assessment_stage.connect(
proc_reference,
"outputnode.output_image",
crop_reference,
"input_image",
)
srr_assessment_stage.connect(
proc_reference, "outputnode.output_mask", crop_reference, "input_mask"
)
srr_assessment_stage.connect(
inputnode, "input_ref_labelmap", crop_reference, "input_label"
)
srr_assessment_stage.connect(
inputnode, "input_sdi_image", registration_quick, "moving_image"
)
srr_assessment_stage.connect(
crop_reference, "output_image", registration_quick, "fixed_image"
)
srr_assessment_stage.connect(
inputnode, "input_sr_image", apply_transform, "input_image"
)
srr_assessment_stage.connect(
crop_reference, "output_image", apply_transform, "reference_image"
)
srr_assessment_stage.connect(
registration_quick, "out_matrix", apply_transform, "transforms"
)
if p_do_reconstruct_labels:
srr_assessment_stage.connect(
inputnode,
"input_sr_labelmap",
apply_transform_labels,
"input_image",
)
srr_assessment_stage.connect(
crop_reference,
"output_image",
apply_transform_labels,
"reference_image",
)
srr_assessment_stage.connect(
registration_quick,
"out_matrix",
apply_transform_labels,
"transforms",
)
srr_assessment_stage.connect(
apply_transform, "output_image", mask_sr, "in_file"
)
srr_assessment_stage.connect(
crop_reference, "output_mask", mask_sr, "in_mask"
)
srr_assessment_stage.connect(
mask_sr, "out_im_file", sr_image_metrics, "input_image"
)
srr_assessment_stage.connect(
crop_reference, "output_image", sr_image_metrics, "input_ref_image"
)
srr_assessment_stage.connect(
crop_reference, "output_label", sr_image_metrics, "input_ref_labelmap"
)
srr_assessment_stage.connect(
crop_reference, "output_mask", sr_image_metrics, "input_ref_mask"
)
srr_assessment_stage.connect(
inputnode,
"input_TV_parameters",
sr_image_metrics,
"input_TV_parameters",
)
if p_do_multi_parameters:
srr_assessment_stage.connect(
sr_image_metrics,
"output_metrics",
concat_sr_image_metrics,
"input_metrics",
)
srr_assessment_stage.connect(
sr_image_metrics,
"output_metrics_labels",
concat_sr_image_metrics,
"input_metrics_labels",
)
srr_assessment_stage.connect(
concat_sr_image_metrics, "output_csv", outputnode, "output_metrics"
)
srr_assessment_stage.connect(
concat_sr_image_metrics,
"output_csv_labels",
outputnode,
"output_metrics_labels",
)
else:
srr_assessment_stage.connect(
sr_image_metrics, "output_metrics", outputnode, "output_metrics"
)
srr_assessment_stage.connect(
sr_image_metrics,
"output_metrics_labels",
outputnode,
"output_metrics_labels",
)
return srr_assessment_stage