Source code for pymialsrtk.workflows.recon_stage

# Copyright © 2016-2023 Medical Image Analysis Laboratory, University Hospital
# Center and University of Lausanne (UNIL-CHUV), Switzerland
# This software is distributed under the open-source license Modified BSD.

"""Module for the reconstruction stage of the super-resolution reconstruction pipeline."""

from traits.api import *

from nipype.interfaces import utility as util
from nipype.pipeline import engine as pe

import pymialsrtk.workflows.recon_labelmap_stage as recon_labelmap_stage

import pymialsrtk.interfaces.reconstruction as reconstruction
import pymialsrtk.interfaces.postprocess as postprocess
import pymialsrtk.interfaces.utils as utils


[docs]def create_recon_stage( p_paramTV, p_use_manual_masks, p_do_multi_parameters=False, p_do_nlm_denoising=False, p_do_reconstruct_labels=False, p_do_refine_hr_mask=False, p_skip_svr=False, p_sub_ses="", p_verbose=False, name="recon_stage", ): """Create a super-resolution reconstruction workflow. Parameters ---------- p_paramTV : dictionary Dictionary of TV parameters p_use_manual_masks : boolean Whether masks were done manually. p_do_nlm_denoising : boolean Whether to proceed to non-local mean denoising. (default: `False`) p_do_multi_parameters : boolean Perform super-resolution reconstruction with a set of multiple parameters. (default: `False`) p_do_reconstruct_labels : boolean Whether we are also reconstruction label maps. (default: `False`) p_do_refine_hr_mask : boolean Whether to do high-resolution mask refinement. (default: `False`) p_skip_svr : boolean Whether slice-to-volume registration (SVR) should be skipped. (default: `False`) p_sub_ses : string String describing subject-session information (default: '') p_verbose : boolean Whether verbosity should be enabled (default: `False`) name : string Name of workflow (default: "recon_stage") Inputs ---------- input_images : list of items which are a pathlike object or string representing a file Input T2w images input_images_nlm : list of items which are a pathlike object or string representing a file Input T2w images, required if `p_do_nlm_denoising=True` input_masks : list of items which are a pathlike object or string representing a file Input mask images stacks_order : list of integer Order of stacks in the reconstruction Outputs ---------- output_sr : pathlike object or string representing a file SR reconstructed image output_sdi : pathlike object or string representing a file SDI image output_hr_mask : pathlike object or string representing a file SRR mask output_tranforms : list of items which are a pathlike object or string representing a file Estimated transformation parameters outputnode.output_json_path : pathlike object or string representing a file Path to the json sidecar of the SR reconstruction outputnode.output_sr_png : pathlike object or string representing a file Path to the PNG of the SR reconstruction outputnode.output_TV_parameters : dictionary Parameters used for TV reconstruction Example ------- >>> from pymialsrtk.pipelines.workflows import recon_stage as rec >>> recon_stage = rec.create_preproc_stage( p_paramTV, p_use_manual_masks, p_do_nlm_denoising=False) >>> recon_stage.inputs.inputnode.input_images = ['sub-01_run-1_T2w.nii.gz', 'sub-01_run-2_T2w.nii.gz'] >>> recon_stage.inputs.inputnode.input_masks = ['sub-01_run-1_T2w_mask.nii.gz', 'sub-01_run-2_T2w_mask.nii.gz'] >>> recon_stage.inputs.stacks_order = [2,1] >>> recon_stage.run() # doctest: +SKIP """ recon_stage = pe.Workflow(name=name) # Set up a node to define all inputs required for the # preprocessing workflow input_fields = ["input_images", "input_masks", "stacks_order"] if p_do_nlm_denoising: input_fields += ["input_images_nlm"] if p_do_reconstruct_labels: input_fields += ["input_labels"] inputnode = pe.Node( interface=util.IdentityInterface(fields=input_fields), name="inputnode" ) output_fields = [ "output_sr", "output_sdi", "output_hr_mask", "output_transforms", "output_json_path", "output_sr_png", "output_TV_parameters", ] if p_do_reconstruct_labels: output_fields += ["output_labelmap"] if p_do_multi_parameters: output_fields += ["output_TV_params"] outputnode = pe.Node( interface=util.IdentityInterface(fields=output_fields), name="outputnode", ) # Setting default TV parameters if not defined deltatTV = ( p_paramTV["deltatTV"] if "deltatTV" in p_paramTV.keys() else 0.01 ) lambdaTV = ( p_paramTV["lambdaTV"] if "lambdaTV" in p_paramTV.keys() else 0.75 ) num_iterations = ( p_paramTV["num_iterations"] if "num_iterations" in p_paramTV.keys() else 50 ) num_primal_dual_loops = ( p_paramTV["num_primal_dual_loops"] if "num_primal_dual_loops" in p_paramTV.keys() else 10 ) num_bregman_loops = ( p_paramTV["num_bregman_loops"] if "num_bregman_loops" in p_paramTV.keys() else 3 ) step_scale = ( p_paramTV["step_scale"] if "step_scale" in p_paramTV.keys() else 1 ) gamma = p_paramTV["gamma"] if "gamma" in p_paramTV.keys() else 1 srtkImageReconstruction = pe.Node( interface=reconstruction.MialsrtkImageReconstruction( sub_ses=p_sub_ses, skip_svr=p_skip_svr, verbose=p_verbose ), name="srtkImageReconstruction", ) if p_do_nlm_denoising: sdiComputation = pe.Node( interface=reconstruction.MialsrtkSDIComputation( sub_ses=p_sub_ses, verbose=p_verbose ), name="sdiComputation", ) srtkTVSuperResolution = pe.Node( interface=reconstruction.MialsrtkTVSuperResolution( sub_ses=p_sub_ses, use_manual_masks=p_use_manual_masks, verbose=p_verbose, ), name="srtkTVSuperResolution", ) if p_do_multi_parameters: deltatTV = [deltatTV] if not isinstance(deltatTV, list) else deltatTV lambdaTV = [lambdaTV] if not isinstance(lambdaTV, list) else lambdaTV num_iterations = ( [num_iterations] if not isinstance(num_iterations, list) else num_iterations ) num_primal_dual_loops = ( [num_primal_dual_loops] if not isinstance(num_primal_dual_loops, list) else num_primal_dual_loops ) num_bregman_loops = ( [num_bregman_loops] if not isinstance(num_bregman_loops, list) else num_bregman_loops ) step_scale = ( [step_scale] if not isinstance(step_scale, list) else step_scale ) gamma = [gamma] if not isinstance(gamma, list) else gamma iterables_TV_parameters = [ ("in_lambda", lambdaTV), ("in_deltat", deltatTV), ("in_iter", num_iterations), ("in_loop", num_primal_dual_loops), ("in_bregman_loop", num_bregman_loops), ("in_step_scale", step_scale), ("in_gamma", gamma), ] srtkTVSuperResolution.iterables = iterables_TV_parameters else: srtkTVSuperResolution.inputs.in_lambda = lambdaTV srtkTVSuperResolution.inputs.in_deltat = deltatTV srtkTVSuperResolution.inputs.in_iter = num_iterations srtkTVSuperResolution.inputs.in_loop = num_primal_dual_loops srtkTVSuperResolution.inputs.in_bregman_loop = num_bregman_loops srtkTVSuperResolution.inputs.in_step_scale = step_scale srtkTVSuperResolution.inputs.in_gamma = gamma if p_do_refine_hr_mask: srtkHRMask = pe.Node( interface=postprocess.MialsrtkRefineHRMaskByIntersection( verbose=p_verbose ), name="srtkHRMask", ) else: srtkHRMask = pe.Node( interface=postprocess.BinarizeImage(), name="srtkHRMask" ) if p_do_reconstruct_labels: recon_labels_stage = recon_labelmap_stage.create_recon_labelmap_stage( p_sub_ses=p_sub_ses, p_verbose=p_verbose ) recon_stage.connect( inputnode, ("input_masks", utils.sort_ascending), srtkImageReconstruction, "input_masks", ) recon_stage.connect( inputnode, "stacks_order", srtkImageReconstruction, "stacks_order" ) if p_do_nlm_denoising: recon_stage.connect( inputnode, ("input_images_nlm", utils.sort_ascending), srtkImageReconstruction, "input_images", ) recon_stage.connect( inputnode, "stacks_order", sdiComputation, "stacks_order" ) recon_stage.connect( inputnode, ("input_images", utils.sort_ascending), sdiComputation, "input_images", ) recon_stage.connect( inputnode, ("input_masks", utils.sort_ascending), sdiComputation, "input_masks", ) recon_stage.connect( srtkImageReconstruction, ("output_transforms", utils.sort_ascending), sdiComputation, "input_transforms", ) recon_stage.connect( srtkImageReconstruction, "output_sdi", sdiComputation, "input_reference", ) recon_stage.connect( sdiComputation, "output_sdi", srtkTVSuperResolution, "input_sdi" ) else: recon_stage.connect( inputnode, ("input_images", utils.sort_ascending), srtkImageReconstruction, "input_images", ) recon_stage.connect( srtkImageReconstruction, "output_sdi", srtkTVSuperResolution, "input_sdi", ) recon_stage.connect( inputnode, ("input_images", utils.sort_ascending), srtkTVSuperResolution, "input_images", ) recon_stage.connect( srtkImageReconstruction, ("output_transforms", utils.sort_ascending), srtkTVSuperResolution, "input_transforms", ) recon_stage.connect( inputnode, ("input_masks", utils.sort_ascending), srtkTVSuperResolution, "input_masks", ) recon_stage.connect( inputnode, "stacks_order", srtkTVSuperResolution, "stacks_order" ) if p_do_reconstruct_labels: recon_stage.connect( inputnode, ("input_labels", utils.sort_ascending), recon_labels_stage, "inputnode.input_labels", ) recon_stage.connect( inputnode, ("input_masks", utils.sort_ascending), recon_labels_stage, "inputnode.input_masks", ) recon_stage.connect( srtkImageReconstruction, ("output_transforms", utils.sort_ascending), recon_labels_stage, "inputnode.input_transforms", ) recon_stage.connect( srtkImageReconstruction, "output_sdi", recon_labels_stage, "inputnode.input_reference", ) recon_stage.connect( inputnode, "stacks_order", recon_labels_stage, "inputnode.stacks_order", ) recon_stage.connect( recon_labels_stage, "outputnode.output_labelmap", outputnode, "output_labelmap", ) if p_do_refine_hr_mask: recon_stage.connect( inputnode, ("input_images", utils.sort_ascending), srtkHRMask, "input_images", ) recon_stage.connect( inputnode, ("input_masks", utils.sort_ascending), srtkHRMask, "input_masks", ) recon_stage.connect( srtkImageReconstruction, ("output_transforms", utils.sort_ascending), srtkHRMask, "input_transforms", ) recon_stage.connect( srtkImageReconstruction, "output_sdi", srtkHRMask, "input_sr" ) else: recon_stage.connect( srtkImageReconstruction, "output_sdi", srtkHRMask, "input_image" ) if p_do_nlm_denoising: recon_stage.connect( sdiComputation, "output_sdi", outputnode, "output_sdi" ) else: recon_stage.connect( srtkImageReconstruction, "output_sdi", outputnode, "output_sdi" ) recon_stage.connect( srtkImageReconstruction, "output_transforms", outputnode, "output_transforms", ) recon_stage.connect( srtkTVSuperResolution, "output_sr", outputnode, "output_sr" ) recon_stage.connect( srtkHRMask, "output_srmask", outputnode, "output_hr_mask" ) recon_stage.connect( srtkTVSuperResolution, "output_json_path", outputnode, "output_json_path", ) recon_stage.connect( srtkTVSuperResolution, "output_sr_png", outputnode, "output_sr_png" ) recon_stage.connect( srtkTVSuperResolution, "output_TV_parameters", outputnode, "output_TV_parameters", ) if p_do_multi_parameters: outputnode.inputs.output_TV_params = iterables_TV_parameters return recon_stage, srtkTVSuperResolution.name