Skip to content
master
  1using JLD2
  2using LazyStack
  3using PaddedViews
  4using Printf
  5
  6# Gaussian and its derivatives.
  7
  8function gaussian(α, L, center; gaussian_factor=1.0, δx)
  9    # Heuristic to have width that make sense.
 10    # 0.1: sharp, close to χ; 1.0 smooth.
 11    σ = 0.5 / δx * L * gaussian_factor
 12    x -> α / ((2π) * σ) * exp(-((x - center) / σ)^2 / 2)
 13end
 14function ∂gaussian(α, L, center; gaussian_factor=1.0, δx)
 15    σ = 0.5 / δx * L * gaussian_factor
 16    x -> -α / ((2π) * σ^3) * (x - center) * exp(-((x - center) / σ)^2 / 2)
 17end
 18function ∂²gaussian(α, L, center; gaussian_factor=1.0, δx)
 19    σ = 0.5 / δx * L * gaussian_factor
 20    x -> α / ((2π) * σ^5) * exp(-((x - center) / σ)^2 / 2) * ((x - center)^2 - σ^2)
 21end
 22function ∂³gaussian(α, L, center; gaussian_factor=1.0, δx)
 23    σ = 0.5 / δx * L * gaussian_factor
 24    x -> -α / ((2π) * σ^7) * exp(-((x - center) / σ)^2 / 2) * (x - center) * ((x - center)^2 - 3σ^2)
 25end
 26function ∂⁴gaussian(α, L, center; gaussian_factor=1.0, δx)
 27    σ = 0.5 / δx * L * gaussian_factor
 28    x -> α / ((2π) * σ^9) * exp(-((x - center) / σ)^2 / 2) * ((x - center)^4 - 6 * σ^2 * (x - center)^2 + 3σ^4)
 29end
 30
 31# In-place fn_m functions.
 32
 33function f2_2!(result, λi, λj, δ, i, j, gaussians, ∂gaussians, ∂²gaussians)
 34    λij = λj - λi
 35    if norm(λij) > δ
 36        @. result = 2 * (gaussians[j] - gaussians[i] - ∂gaussians[i] * λij) / λij^2
 37    else
 38        # Use derivative formula.
 39        copy!(result, ∂²gaussians[i])
 40    end
 41end
 42
 43function f3_3!(result, r1, r2, λi, λj, λk, δ, i, j, k,
 44               gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
 45    λjk = λk - λj
 46    λij = λj - λi
 47    if norm(λjk) > δ
 48        f2_2!(r1, λi, λk, δ, i, k, gaussians, ∂gaussians, ∂²gaussians)
 49        f2_2!(r2, λi, λj, δ, i, j, gaussians, ∂gaussians, ∂²gaussians)
 50        @. result = 3 * (r1 - r2) / λjk
 51    elseif norm(λij) > δ
 52        @. result = -12 * (gaussians[j] - gaussians[i] - 0.5 * (∂gaussians[j] + ∂gaussians[i]) * λij) / λij^3
 53    else
 54        copy!(result, ∂³gaussians[i])
 55    end
 56end
 57
 58function f4_4!(result, r1, r11, r12, r2, r21, r22, λi, λj, λk, λl, δ, i, j, k, l,
 59               gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians, ∂⁴gaussians)
 60    λkl = λl - λk
 61    if norm(λkl) > δ
 62        f3_3!(r1, r11, r12, λi, λj, λl, δ, i, j, l, gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
 63        f3_3!(r2, r21, r22, λi, λj, λk, δ, i, j, k, gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
 64        @. result = 4 * (r1 - r2) / λkl
 65        return nothing
 66    end
 67    λik = λk - λi
 68    λjk = λk - λj
 69    if norm(λjk) > δ && norm(λkl) < δ && norm(λik) > δ
 70        f2_2!(r1, λi, λk, δ, i, k, gaussians, ∂gaussians, ∂²gaussians)
 71        f2_2!(r2, λi, λj, δ, i, j, gaussians, ∂gaussians, ∂²gaussians)
 72        @. result = -48 * ((gaussians[k] - gaussians[i] - 0.5 * (∂gaussians[k] + ∂gaussians[i]) * (λk - λi)) / (λjk * (λk - λi)^3)) - 12 * ((r1 - r2) / (λk - λj)^2)
 73        return nothing
 74    end
 75    λjl = λl - λj
 76    if norm(λjl) > δ && norm(λik) < δ && norm(λkl) < δ
 77        f2_2!(r1, λi, λj, δ, i, j, gaussians, ∂gaussians, ∂²gaussians)
 78        @. result = 4 * ∂³gaussians[i] / (λi - λj) - 12 * (∂²gaussians[i] - r1) / (λi - λj)^2
 79        return nothing
 80    end
 81    λij = λj - λi
 82    if norm(λij) > δ && norm(λjk) < δ && norm(λkl) < δ
 83        @. result = 72 * (gaussians[j] - gaussians[i] - 0.5 * (∂gaussians[j] + ∂gaussians[i]) * λij) / λij^4 - 12 * (∂gaussians[j] - ∂gaussians[i] - ∂²gaussians[j] * λij) / λij^3
 84        return nothing
 85    end
 86    if norm(λij) < δ && norm(λjk) < δ && norm(λkl) < δ
 87        copy!(result, ∂⁴gaussians[i])
 88        return nothing
 89    end
 90end
 91
 92function _precompute_operator_matrices(basis, kpt, ψkX, ∂v2, ∂²v2)
 93    n_bands = size(ψkX, 2)
 94    T = eltype(ψkX)
 95    K = zeros(T, n_bands, n_bands)
 96    X = zeros(T, n_bands, n_bands)
 97    X_sq = zeros(T, n_bands, n_bands)
 98
 99    # Pre-allocate buffers for FFTs to reduce memory allocations inside the loop.
100    ui = zeros(T, basis.fft_size)
101    uj = zeros(T, basis.fft_size)
102
103    for i in 1:n_bands
104        for j in 1:n_bands
105            K[i, j] = kinetic1d(basis, ψkX, kpt, i, j)
106            X[i, j] = potential1d(basis, ψkX, kpt, ∂v2, i, j; ui, uj)
107            X_sq[i, j] = potential1d(basis, ψkX, kpt, ∂²v2, i, j; ui, uj)
108        end
109    end
110    (; K, X, X_sq)
111end
112
113function calculate_order2(λs, K, X, X_sq, gaussians_info, δ)
114    n_bands = length(λs)
115    dos_order2 = zeros(eltype(first(gaussians_info.gaussians)), length(first(gaussians_info.gaussians)))
116
117    (; gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians, ∂⁴gaussians) = gaussians_info
118
119    # Temporary arrays for in-place fm_n function.
120    tmp_f2_r = similar(dos_order2)
121    tmp_f3_r1 = similar(dos_order2)
122    tmp_f3_r2 = similar(dos_order2)
123    tmp_f3_r11 = similar(dos_order2)
124    tmp_f3_r12 = similar(dos_order2)
125    tmp_f3_r21 = similar(dos_order2)
126    tmp_f3_r22 = similar(dos_order2)
127    tmp_f4_r = similar(dos_order2)
128    tmp_f4_r1 = similar(dos_order2)
129    tmp_f4_r11 = similar(dos_order2)
130    tmp_f4_r12 = similar(dos_order2)
131    tmp_f4_r2 = similar(dos_order2)
132    tmp_f4_r21 = similar(dos_order2)
133    tmp_f4_r22 = similar(dos_order2)
134
135    for m in 1:n_bands
136        λm = λs[m]
137        dos_order2 .-= real(∂²gaussians[m] .* X_sq[m, m] ./ 8)
138
139        for n in 1:n_bands
140            λn = λs[n]
141
142            f3_3!(tmp_f3_r1, tmp_f3_r11, tmp_f3_r12, λm, λm, λn, δ, m, m, n,
143                  gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
144            f3_3!(tmp_f3_r2, tmp_f3_r21, tmp_f3_r22, λm, λn, λn, δ, m, n, n
145                  , gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
146
147            dos_order2 .-= (1/(4*6)) .* (2 .* tmp_f3_r1 .- tmp_f3_r2) .* abs2(X[m, n])
148
149            for p in 1:n_bands
150                λp = λs[p]
151
152                f3_3!(tmp_f3_r1, tmp_f3_r11, tmp_f3_r12, λm, λn, λp, δ, m, n, p,
153                      gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
154                term3_expr = real(2 * K[n, p] * X_sq[m, n] * K[p, m] - K[m, n] * X_sq[n, p] * K[p, m])
155                dos_order2 .-= (1/(4*6)) .* tmp_f3_r1 .* term3_expr
156
157                for q in 1:n_bands
158                    λq = λs[q]
159
160                    f4_4!(tmp_f4_r, tmp_f4_r1, tmp_f4_r11, tmp_f4_r12, tmp_f4_r2, tmp_f4_r21,
161                          tmp_f4_r22, λm, λn, λp, λq, δ, m, n, p, q,
162                          gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians, ∂⁴gaussians)
163
164                    K_mn = K[m, n]; K_np = K[n, p]; K_pq = K[p, q]; K_qm = K[q, m]
165                    X_mn = X[m, n]; X_np = X[n, p]; X_pq = X[p, q]; X_qm = X[q, m]
166
167                    term4_expr = real(
168                                (X_mn * K_np) * (K_pq * X_qm) +
169                                (X_mn * K_pq) * (K_np * X_qm) +
170                                (K_mn * X_pq) * (X_np * K_qm) +
171                                (K_mn * X_np) * (X_pq * K_qm) -
172                                2 * real( (K_mn * X_np) * (K_pq * X_qm) ) -
173                                2 * real( (X_mn * K_pq) * (X_np * K_qm) ) +
174                                2 * real( (X_mn * K_qm) * (K_np * X_pq - X_np * K_pq) )
175                            )
176                    dos_order2 .+= (1/(4*4*6)) .* tmp_f4_r .* term4_expr
177                end
178            end
179        end
180    end
181    dos_order2
182end
183
184function _dos_contribution_single_disregistry(disregistry, lattice, atoms, positions_fn,
185                                               Ecut, kgrid, params, xs, orders, δ, n_bands_diag)
186    doses = [zeros(length(xs)) for _ in 0:maximum(orders)]
187
188    gaussian_fns = (
189        λ -> [gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
190        λ -> [∂gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
191        λ -> [∂²gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
192        λ -> [∂³gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
193        λ -> [∂⁴gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
194    )
195
196    model = Model(lattice, atoms, positions_fn(disregistry);
197                  terms=[Kinetic(), AtomicLocal()], n_electrons=length(atoms))
198    basis = PlaneWaveBasis(model; Ecut, kgrid)
199    data = diagonalize([basis]; n_bands=n_bands_diag)
200
201    ∂v2 = compute_δv2(basis)
202    ∂²v2 = compute_δ²v2(basis)
203
204    for ik in eachindex(basis.kpoints)
205        kpt = basis.kpoints[ik]
206        ψkX = data.X[1][ik]
207        λs = data.λ[1][ik]
208        n_bands = length(λs)
209
210        gaussians_info = (;
211            gaussians   = [gaussian_fns[1](λ) for λ in λs],
212            ∂gaussians  = [gaussian_fns[2](λ) for λ in λs],
213            ∂²gaussians = [gaussian_fns[3](λ) for λ in λs],
214            ∂³gaussians = [gaussian_fns[4](λ) for λ in λs],
215            ∂⁴gaussians = [gaussian_fns[5](λ) for λ in λs]
216        )
217
218        if 0  orders
219            doses[1] .+= sum(gaussians_info.gaussians)
220        end
221
222        if any(o -> o > 0, orders)
223            mats = _precompute_operator_matrices(basis, kpt, ψkX, ∂v2, ∂²v2)
224
225            tmp_f2_order1 = similar(first(gaussians_info.gaussians))
226
227            if 1  orders
228                dos_order1_k = zeros(eltype(doses[2]), length(xs))
229                for i in 1:n_bands
230                    for j in 1:n_bands
231                        term = imag(mats.K[i, j] * mats.X[j, i])
232                        f2_2!(tmp_f2_order1, λs[i], λs[j], δ, i, j,
233                              gaussians_info.gaussians, gaussians_info.∂gaussians, gaussians_info.∂²gaussians)
234                        dos_order1_k .-= 0.5 .* term .* tmp_f2_order1
235                    end
236                end
237                doses[2] .+= dos_order1_k
238            end
239
240            if 2  orders
241                dos_order2_k = calculate_order2(λs, mats.K, mats.X, mats.X_sq, gaussians_info, δ)
242                doses[3] .+= dos_order2_k
243            end
244        end
245    end
246    doses
247end
248
249
250# One disregistry at the time.
251function dos_gaussian(lattice, atoms, positions_fn, Ecut, kgrid, params, xs;
252                      disregistries, orders=[0], δ=zero(eltype(xs)), n_bands_diag=nothing,
253                      checkpoint_path=nothing, cleanup_checkpoint=false)
254
255    start_index = 1
256    total_doses = [zeros(length(xs)) for _ in 0:maximum(orders)]
257
258    if checkpoint_path !== nothing && isfile(checkpoint_path)
259        @printf "Checkpoint file found at %s. Resuming calculation.\n" checkpoint_path
260        jldopen(checkpoint_path, "r") do file
261            start_index = file["last_completed_disregistry"] + 1
262            total_doses = file["total_doses"]
263        end
264    end
265
266    if start_index <= length(disregistries)
267        for i in start_index:length(disregistries)
268            disregistry = disregistries[i]
269
270            @printf "Calculating for disregistry %d/%d (value: %.4f)\n" i length(disregistries) disregistry
271
272            doses_contrib = _dos_contribution_single_disregistry(disregistry, lattice, atoms, positions_fn,
273                                                                 Ecut, kgrid, params, xs, orders, δ, n_bands_diag)
274            total_doses .+= doses_contrib
275
276            # Save checkpoint after each step
277            if checkpoint_path !== nothing
278                jldopen(checkpoint_path, "w") do file
279                    file["last_completed_disregistry"] = i
280                    file["total_doses"] = total_doses
281                    file["disregistries"] = disregistries  # for consistency checks
282                end
283            end
284        end
285    else
286        @printf "Calculation already completed according to checkpoint.\n"
287    end
288
289    # Final normalization.
290    n_disregistries = length(disregistries)
291    n_kpoints_per_disregistry = prod(kgrid)
292    normalization = 1 / (n_kpoints_per_disregistry * n_disregistries)
293
294    final_doses = total_doses .* normalization
295
296    total_dos = sum(final_doses)
297
298    if cleanup_checkpoint && checkpoint_path !== nothing && isfile(checkpoint_path)
299        rm(checkpoint_path)
300        @printf "Calculation complete. Checkpoint file cleaned up.\n"
301    end
302
303    (; dos=total_dos, doses=final_doses)
304end