Skip to content

Commit fb6c2fb

Browse files
authored
Merge pull request #90 from jusack/2drfft-for-nd
2D rfft for N-d arrays
2 parents 76f44a1 + 80eddc9 commit fb6c2fb

File tree

5 files changed

+37
-37
lines changed

5 files changed

+37
-37
lines changed

src/plan.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ function AbstractFFTs.plan_rfft(x::AbstractArray{T,N}, region; kwargs...)::FFTAP
6565
pinv = FFTAInvPlan{Complex{T},FFTN}()
6666
return FFTAPlan_re{Complex{T},FFTN}(tuple(g), region, FFT_FORWARD, pinv, size(x,region[]))
6767
elseif FFTN == 2
68-
if N !== 2
69-
throw(ArgumentError("2D real FFT only supported for 2D arrays"))
70-
end
7168
sort!(region)
7269
g1 = CallGraph{Complex{T}}(size(x,region[1]))
7370
g2 = CallGraph{Complex{T}}(size(x,region[2]))
@@ -85,9 +82,6 @@ function AbstractFFTs.plan_brfft(x::AbstractArray{T,N}, len, region; kwargs...):
8582
pinv = FFTAInvPlan{T,FFTN}()
8683
return FFTAPlan_re{T,FFTN}((g,), region, FFT_BACKWARD, pinv, len)
8784
elseif FFTN == 2
88-
if N !== 2
89-
throw(ArgumentError("2D real FFT only supported for 2D arrays"))
90-
end
9185
sort!(region)
9286
g1 = CallGraph{T}(len)
9387
g2 = CallGraph{T}(size(x,region[2]))
@@ -120,8 +114,8 @@ function LinearAlgebra.mul!(y::AbstractArray{U,N}, p::FFTAPlan_cx{T,1}, x::Abstr
120114
if size(p, 1) != size(x, p.region[])
121115
throw(DimensionMismatch("plan has size $(size(p, 1)), but input array has size $(size(x, p.region[])) along region $(p.region[])"))
122116
end
123-
Rpre = CartesianIndices(size(x)[1:p.region-1])
124-
Rpost = CartesianIndices(size(x)[p.region+1:end])
117+
Rpre = CartesianIndices(size(x)[1:p.region[]-1])
118+
Rpost = CartesianIndices(size(x)[p.region[]+1:end])
125119
for Ipre in Rpre
126120
for Ipost in Rpost
127121
@views fft!(y[Ipre,:,Ipost], x[Ipre,:,Ipost], 1, 1, p.dir, p.callgraph[1][1].type, p.callgraph[1], 1)
@@ -262,37 +256,43 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractArray{T,N}) where {T<:Complex,
262256
throw(ArgumentError("only FFT_BACKWARD supported for complex arrays"))
263257
end
264258

265-
#### 2D plan 2D array
259+
#### 2D plan ND array
266260
##### Forward
267-
function Base.:*(p::FFTAPlan_re{Complex{T},2}, x::AbstractArray{T,2}) where {T<:Real}
261+
function Base.:*(p::FFTAPlan_re{Complex{T},2}, x::AbstractArray{T,N}) where {T<:Real, N}
268262
Base.require_one_based_indexing(x)
269263
if p.dir === FFT_FORWARD
270264
half_1 = 1:(p.flen ÷ 2 + 1)
271265
x_c = similar(x, Complex{T})
272266
copy!(x_c, x)
273267
y = similar(x_c)
274268
LinearAlgebra.mul!(y, complex(p), x_c)
275-
return y[half_1, :]
269+
return copy(selectdim(y, p.region[1], half_1))
276270
end
277271
throw(ArgumentError("only FFT_FORWARD supported for real arrays"))
278272
end
279273

280274
##### Backward
281-
function Base.:*(p::FFTAPlan_re{T,2}, x::AbstractArray{T,2}) where {T<:Complex}
275+
function Base.:*(p::FFTAPlan_re{T,2}, x::AbstractArray{T,N}) where {T<:Complex, N}
282276
Base.require_one_based_indexing(x)
283-
if size(p, 1) ÷ 2 + 1 != size(x, 1)
284-
throw(DimensionMismatch("real 2D plan has size $(size(p)). First dimension of input array should have size ($(size(p, 1) ÷ 2 + 1)), but has size $(size(x, 1))"))
277+
if size(p, 1) ÷ 2 + 1 != size(x, p.region[1])
278+
throw(DimensionMismatch("real 2D plan has size $(size(p)). First transform dimension of input array should have size ($(size(p, 1) ÷ 2 + 1)), but has size $(size(x, p.region[1]))"))
285279
end
286280
if p.dir === FFT_BACKWARD
281+
res_size = ntuple(i->ifelse(i==p.region[1], p.flen, size(x,i)), ndims(x))
287282
# for the inverse transformation we have to reconstruct the full array
288-
m, n = size(x)
289283
half_1 = 1:(p.flen ÷ 2 + 1)
290284
half_2 = half_1[end]+1:p.flen
291-
x_full = similar(x, p.flen, n)
292-
x_full[1:m, :] = x
293-
start_reverse = m - iseven(p.flen)
294-
map!(conj, view(x_full, (m + 1):p.flen, 1), view(x, start_reverse:-1:2, 1))
295-
map!(conj, view(x_full, half_2, 2:n), view(x, start_reverse:-1:2, n:-1:2))
285+
x_full = similar(x, res_size)
286+
# use first half as is
287+
copy!(selectdim(x_full, p.region[1], half_1), x)
288+
289+
# the second half in the first transform dimension is reversed and conjugated
290+
x_half_2 = selectdim(x_full, p.region[1], half_2) # view to the second half of x
291+
start_reverse = size(x, p.region[1]) - iseven(p.flen)
292+
293+
map!(conj, x_half_2, selectdim(x, p.region[1], start_reverse:-1:2))
294+
# for the 2D transform we have to reverse index 2:end of the same block in the second transform dimension as well
295+
reverse!(selectdim(x_half_2, p.region[2], 2:size(x, p.region[2])), dims=p.region[2])
296296

297297
y = similar(x_full)
298298
LinearAlgebra.mul!(y, complex(p), x_full)

test/argument_checking.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,3 @@ end
8282
end
8383
end
8484
end
85-
86-
@testset "2D real FFT only supported for 2D arrays" begin
87-
xr = zeros(2, 2, 2)
88-
xc = complex(xr)
89-
@test_throws ArgumentError("2D real FFT only supported for 2D arrays") plan_rfft(xr, 2:3)
90-
@test_throws ArgumentError("2D real FFT only supported for 2D arrays") plan_brfft(xc, 2, 2:3)
91-
end

test/onedim/real_backward.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ end
2929
@testset "1D plan, ND array. Size: $n" for n in 1:64
3030
x = randn(n, n + 1, n + 2)
3131

32+
@testset "round tripping with irfft, r=$r" for r in 1:3
33+
@test irfft(rfft(x, r), size(x,r), r) x
34+
end
35+
3236
@testset "against 1D array with mapslices, r=$r" for r in 1:3
3337
y = rfft(x, r)
3438
@test brfft(y, size(x, r), r) == mapslices(t -> brfft(t, size(x, r)), y; dims = r)

test/twodim/real_backward.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,19 @@ end
2727
@testset "2D plan, ND array. Size: $n" for n in 1:64
2828
x = randn(n, n + 1, n + 2)
2929

30-
@testset "against 1D array with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]]
31-
# y = rfft(x, r)
32-
y = fft(x, r) # to produce y while tests are broken
33-
@test_broken brfft(y, size(x, r), r) == mapslices(t -> brfft(t, size(x, r)), y; dims = r)
30+
@testset "round trip with irfft, r=$r" for r in [[1,2], [1,3], [2,3]]
31+
@test x irfft(rfft(x,r), size(x,r[1]), r)
32+
end
33+
34+
@testset "against 2D array with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]]
35+
y = rfft(x, r)
36+
@test brfft(y, size(x, r[1]), r) == mapslices(t -> brfft(t, size(x, r[1])), y; dims = r)
3437
end
3538
end
3639

3740
@testset "allocations" begin
3841
X = randn(256, 256)
3942
Y = rfft(X)
4043
brfft(Y, 256) # compile
41-
@test (@test_allocations brfft(Y, 256)) <= 68
44+
@test (@test_allocations brfft(Y, 256)) <= 82
4245
end

test/twodim/real_forward.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ using FFTA, Test
77
y_ref[1] = length(x)
88
@test y y_ref
99
x = randn(N,N)
10-
@test_broken rfft(x) rfft(reshape(x,1,N,N), [2,3])[1,:,:]
11-
@test_broken rfft(x) rfft(reshape(x,1,N,N,1), [2,3])[1,:,:,1]
12-
@test_broken rfft(x) rfft(reshape(x,1,1,N,N,1), [3,4])[1,1,:,:,1]
10+
@test rfft(x) rfft(reshape(x,1,N,N), [2,3])[1,:,:]
11+
@test rfft(x) rfft(reshape(x,1,N,N,1), [2,3])[1,:,:,1]
12+
@test rfft(x) rfft(reshape(x,1,1,N,N,1), [3,4])[1,1,:,:,1]
1313
@test size(rfft(x)) == (N÷2+1, N)
1414
end
1515

@@ -31,12 +31,12 @@ end
3131
x = randn(n, n + 1, n + 2)
3232

3333
@testset "against 1D array with mapslices, r=$r" for r in [[1,2], [1,3], [2,3]]
34-
@test_broken rfft(x, r) == mapslices(rfft, x; dims = r)
34+
@test rfft(x, r) == mapslices(rfft, x; dims = r)
3535
end
3636
end
3737

3838
@testset "allocations" begin
3939
X = randn(256, 256)
4040
rfft(X) # compile
41-
@test (@test_allocations rfft(X)) <= 62
41+
@test (@test_allocations rfft(X)) <= 63
4242
end

0 commit comments

Comments
 (0)