Image deconvolution (noiseless)
Let \(\mathbf{u}^*\) represent a noiseless and distortionless image; likewise let
represent the observed image, where
\(\mbox{Op}(\mathbf{u})\) represents the forward model; in particular, let \( \mbox{Op}(\mathbf{u}) = h *\mathbf{x} \), i.e. 2D convolution.
\(\boldsymbol{\eta}\) is additive Gaussian noise (for this particular example, \(\sigma^2_{\boldsymbol{\eta}} = 0)\).
For the above case, the MAP (maximum a posteriori) estimator is given by
Taking \( \mbox{Op}(\mathbf{u}) = h *\mathbf{x} \), then
the gradient of \(F(\mathbf{u})\) is given by
\[ \large \nabla F(\mathbf{u}) = g*h*\mathbf{u} - g*\mathbf{b} \qquad (2)\]where
\( g = {\cal{A}}\{h\}\) is the adjoint of filter \(h\) (i.e. \(g=\)
np.flip(np.flip(h,1),0)),
in the frecuency domain
\[ \large {\cal{F}}\{ \nabla F(\mathbf{u}) \} = H_F^* \odot H_F\odot \mathbf{u} - H_F^* \odot B_F \qquad (3)\]where
\({\cal{F}}\{\cdot\}\) represents the direct 2D Fourier transform,
\(H_F = {\cal{F}}\{h\}\) is the (zero-pad) 2D fourier transfomr of \(h\),
\(H_F^* = {\operatorname{conj}}(H_F)\), the complex conjugate of \(H_F\),
\(B_F = {\cal{F}}\{\mathbf{b}\}\), and
\(\odot\) represents hadamard (element-wise) multiplication.
Convolution in the frequency domain:
Recall that
boundary condition
See this explanation
Simulation setup
The collapsible items below include the code to generate a synthetic problem to test diferent algorithms that target the solution of (1).
1. Load F2O
# Generic F2O imports
import F2O.F2O_utils as F2O
import F2O.fwOp.fwOperator as fwOp
# F2O 'utils' for reading / ploting images
from F2O.imgUtils.image_utils import ImgPlot, ImgRead
# F2O 'utils' for (i) adding noise, (ii) collecting image metrics and (iii) applying the forward model
from F2O.imgUtils.image_utils import ImgMetrics, ImgApplyFwNoise
from F2O.noise.noiseModels import noiseModels
# Other packages
import matplotlib.pylab as PLT
import numpy as np
# If you get an error while loading the F2O, then
#
# * Exit Jupyter
# * Go to the F2O root dir, and execute
# export PYTHONPATH=$PYTHONPATH:`pwd`
# * Relaunch Jupyter
---------------------------------------------------------------------------
ModuleNotFoundError Traceback (most recent call last)
/tmp/ipykernel_377/1396103323.py in <module>
1 # Generic F2O imports
----> 2 import F2O.F2O_utils as F2O
3 import F2O.fwOp.fwOperator as fwOp
4
5 # F2O 'utils' for reading / ploting images
ModuleNotFoundError: No module named 'F2O'
2. Load test images (SIPI Image database)
# Imports needed to read images from an URL address
import requests
from io import BytesIO
# Define an 'F2O-image' object
testImgs = ImgRead()
testImgs.f2oJax.have_jax = False # If present, disable JAX, since it has
# limited support for convolve2d
# and sometines can't determine best cudnn convolution algorithm
# Test images from the SIPI image database
fname = {0: requests.get('http://sipi.usc.edu/database/misc/5.2.10.tiff'), # bridge (grayscale)
1: requests.get('http://sipi.usc.edu/database/misc/boat.512.tiff'), # boat (grayscale)
2: requests.get('http://sipi.usc.edu/database/misc/4.2.03.tiff'), # mandrill (color)
}
testImgs.list.append([BytesIO(fname.get(1).content),'g'])
testImgs.list.append([BytesIO(fname.get(2).content),'c'])
u = testImgs.readListImgs() # read list of images, normalize them between 0 and 1
pltImg = ImgPlot()
pltImg.plotNImgs(u, len(u), None, 5)
3. Configure the forward operator
3.1 Generate filters
kernel = fwOp.conv2DOp()
# Blur filters
H = []
H.append( kernel.gauss2D( (9,9),10.0) ) # Gaussian
H.append( kernel.average( (5,5)) ) # Average
# Labels
Hlabel = {0: 'Gaussian',
1: 'Average'}
3.2 Set Op(.)
Op = []
for l in range(len(H)):
Op.append( fwOp.fwOp(linOp=F2O.f2oDef.fAx_conv2D, # Op is 2D convolution
A=H[l]) # Set the kernel
)
Op[l].label = Hlabel.get(l)
Op[l].vecFlag = True # input / output data are asumed / force to be vectorized
Op[l].f2oJax.have_jax = False # JAX convolve2d has limited support
# only support boundary='fill', fillvalue=0,
# and sometines can't determine best cudnn convolution algorithm
Op[l].boundary = 'symm' # Zero-pad. Three options: 'fill', 'wrap' and 'symm' (default).
# When solving a deconvolution problem, this
# flag should be kept, otherwise the restoration quality
# decreases.
4. Apply forward model (include noise model)
4.1 Set parameters for a noiseless (Guassian) example
sigma = 0.0
noise = noiseModels()
noise.model = F2O.f2oDef.noise_Gaussian
noise.mean = 0.
noise.sigma = sigma
addNoise = noise.sel_NoiseModel()
4.2 Apply forward model
applyModel = ImgApplyFwNoise(u, Op, noise) # Applies the model to all images
# in u. NOTE: u can be
# * an data array
# * a list of data arrays
applyModel.computeMetrics = True # This option is handy when performing comparisons
applyModel.displayImgs = True
b, metrics, pltImg = applyModel.obsImg()
4.3 Show resulting images
for k in range(len(u)):
pltImg.plotNImgs(pltImg.imgShow[k], len(Op)+1, pltImg.txtN[k])
GD – spatial-based solution
1. Load F2O's GD (spatial) routine
from F2O.F2O_sptl import gd as GD
2. Set the arguments that define the optimization problem
args = F2O.argsF2O()
args.f2oJax.have_jax = False # Force disabling JAX support
args.verbose = False
args.fCostClass = args.f2oDef.cost_L2_lin # F(x) = 0.5|| Op(x) - b ||_2^2, where Op(.) is linear
args.freqSol = False
3. Call the routine to solve the problem
ssPolicy = []
ssPolicy.append(F2O.f2oDef.ss_BBv1)
ssPolicy.append(F2O.f2oDef.ss_CauchyLagged)
nIter = 20
x = []
gdStats = []
recMetrics = ImgMetrics()
for k in range(len(u)):
x.append([])
gdStats.append([])
recMetrics.appendEmpty()
for l in range(len(Op)):
Op[l].boundary = 'symm' # This option should match the one used
# when generating the (sythetic) observed image,
# otherwise, reconstruction quality decreases.
#print('Solving problem {:d} out of {:d}'.format(k*len(u)+l+1, len(u)*len(Op)))
args.ssPoliciy = ssPolicy[k]
sol = GD(Op[l], b[k][l], nIter, args)
x[k].append(sol[0])
gdStats[k].append(sol[1])
recMetrics.computeAll(u[k],x[k][l],k)
4. Show results and statistics
for k in range(len(u)):
for l in range(len(Op)):
txtRec = []
txtRec.append('Original')
txtRec.append('Observed (H {}) \n PSNR: {:1.2f} \n SNR: {:1.2f} \n MSE: {:.2e} \n SSIM: {:1.2f} '.format(Op[l].label,
metrics.valPsnr[k][l],metrics.valSnr[k][l],metrics.valMse[k][l],metrics.valSsim[k][l]))
txtRec.append('Restore (GD) \n PSNR: {:1.2f} \n SNR: {:1.2f} \n MSE: {:.2e} \n SSIM: {:1.2f}'.format
(recMetrics.valPsnr[k][l],recMetrics.valSnr[k][l],recMetrics.valMse[k][l],recMetrics.valSsim[k][l]))
imgShow = []
imgShow.append(u[k])
imgShow.append(b[k][l])
imgShow.append(x[k][l])
pltImg.plotNImgs(imgShow, 3, txtRec, winSize=7)
fig = PLT.figure(figsize=(24, 16))
ax1 = fig.add_subplot(2, 1, 1)
# NOTE: use PLT.plot(gdStats[k][l][:,0]... to plot iterations instead of time
PLT.plot(gdStats[k][l][:,2], gdStats[k][l][:,0], label=r'$\alpha_k$ : {0}'.format(args.f2oDef.ss_list[ssPolicy[k]]) )
PLT.legend(loc='upper right',fontsize=20)
PLT.ylabel(r'$f(x) = \frac{1}{2} \|\|$Op$(\mathbf{u}) - \mathbf{b} \|\|_2^2$',fontsize=20)
PLT.xlabel('Time',fontsize=20);
5. Additional comments
blah
blah 2
GD – frequency-based solution
1. Load F2O's GD (frequency) routine
from F2O.F2O_freq import gd as fGD
2. Set parameters and forward operator in frequency
argsF = F2O.argsF2O()
argsF.f2oJax.have_jax = True # Force using JAX support
argsF.verbose = False
argsF.fCostClass = argsF.f2oDef.cost_L2_lin # F(x) = 0.5|| Op(x) - b ||_2^2, where Op(.) is linear
argsF.padFlag = True
argsF.padMode = 'symmetric'
# Select frequency domain routines
freqOp = []
for l in range(len(H)):
freqOp.append( fwOp.fwOp_f(linOp=F2O.f2oDef.fAx_conv2D, # Op is 2D convolution
A=H[l])) # Set the kernel
freqOp[l].label = Hlabel.get(l)
freqOp[l].vecFlag = True # input / output data are asumed / force to be vectorized
freqOp[l].f2oJax.have_jax = True # JAX
3. Call the routine to solve the problem
ssPolicy = []
ssPolicy.append(F2O.f2oDef.ss_Cte)
ssPolicy.append(F2O.f2oDef.ss_CauchyLagged)
nIter = 20
x = []
gdStats = []
recMetrics = ImgMetrics()
for k in range(len(u)):
x.append([])
gdStats.append([])
recMetrics.appendEmpty()
for l in range(len(Op)):
argsF.cleanShpVars() # needed here since the 'argsF' variable is being reused.
#print('Solving problem {:d} out of {:d}'.format(k*len(u)+l+1, len(u)*len(Op)))
argsF.ssPoliciy = ssPolicy[k]
argsF.ssCte = 3.5e-1
argsF.mulCte = 0.3
sol = fGD(freqOp[l], b[k][l], nIter, argsF) # NOTE: using "Op" (spatial) would flag an error
x[k].append( np.clip( sol[0], 0.0, 1.0) )
gdStats[k].append(sol[1])
recMetrics.computeAll(u[k],x[k][l],k)
4. Show results and statistics
for k in range(len(u)):
for l in range(len(Op)):
txtRec = []
txtRec.append('Original')
txtRec.append('Observed (H {}) \n PSNR: {:1.2f} \n SNR: {:1.2f} \n MSE: {:.2e} \n SSIM: {:1.2f} '.format(Op[l].label,
metrics.valPsnr[k][l],metrics.valSnr[k][l],metrics.valMse[k][l],metrics.valSsim[k][l]))
txtRec.append('Restore (GD) \n PSNR: {:1.2f} \n SNR: {:1.2f} \n MSE: {:.2e} \n SSIM: {:1.2f}'.format
(recMetrics.valPsnr[k][l],recMetrics.valSnr[k][l],recMetrics.valMse[k][l],recMetrics.valSsim[k][l]))
imgShow = []
imgShow.append(u[k])
imgShow.append(b[k][l])
imgShow.append(x[k][l])
pltImg.plotNImgs(imgShow, 3, txtRec, winSize=7 )
fig = PLT.figure(figsize=(24, 16))
ax1 = fig.add_subplot(2, 1, 1)
# NOTE: use PLT.plot(gdStats[k][l][:,0]... to plot iterations instead of time
PLT.plot(gdStats[k][l][:,2], gdStats[k][l][:,0], label=r'$\alpha_k$ : {0}'.format(argsF.f2oDef.ss_list[ssPolicy[k]]) )
PLT.legend(loc='upper right',fontsize=20)
PLT.ylabel(r'$f(x) = \frac{1}{2} \|\|$Op$(\mathbf{u}) - \mathbf{b} \|\|_2^2$',fontsize=20)
PLT.xlabel('Time',fontsize=20);
5. Additional comments
blah
blah 2