master
1using JLD2
2using LazyStack
3using PaddedViews
4using Printf
5
6# Gaussian and its derivatives.
7
8function gaussian(α, L, center; gaussian_factor=1.0, δx)
9 # Heuristic to have width that make sense.
10 # 0.1: sharp, close to χ; 1.0 smooth.
11 σ = 0.5 / δx * L * gaussian_factor
12 x -> α / (√(2π) * σ) * exp(-((x - center) / σ)^2 / 2)
13end
14function ∂gaussian(α, L, center; gaussian_factor=1.0, δx)
15 σ = 0.5 / δx * L * gaussian_factor
16 x -> -α / (√(2π) * σ^3) * (x - center) * exp(-((x - center) / σ)^2 / 2)
17end
18function ∂²gaussian(α, L, center; gaussian_factor=1.0, δx)
19 σ = 0.5 / δx * L * gaussian_factor
20 x -> α / (√(2π) * σ^5) * exp(-((x - center) / σ)^2 / 2) * ((x - center)^2 - σ^2)
21end
22function ∂³gaussian(α, L, center; gaussian_factor=1.0, δx)
23 σ = 0.5 / δx * L * gaussian_factor
24 x -> -α / (√(2π) * σ^7) * exp(-((x - center) / σ)^2 / 2) * (x - center) * ((x - center)^2 - 3σ^2)
25end
26function ∂⁴gaussian(α, L, center; gaussian_factor=1.0, δx)
27 σ = 0.5 / δx * L * gaussian_factor
28 x -> α / (√(2π) * σ^9) * exp(-((x - center) / σ)^2 / 2) * ((x - center)^4 - 6 * σ^2 * (x - center)^2 + 3σ^4)
29end
30
31# In-place fn_m functions.
32
33function f2_2!(result, λi, λj, δ, i, j, gaussians, ∂gaussians, ∂²gaussians)
34 λij = λj - λi
35 if norm(λij) > δ
36 @. result = 2 * (gaussians[j] - gaussians[i] - ∂gaussians[i] * λij) / λij^2
37 else
38 # Use derivative formula.
39 copy!(result, ∂²gaussians[i])
40 end
41end
42
43function f3_3!(result, r1, r2, λi, λj, λk, δ, i, j, k,
44 gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
45 λjk = λk - λj
46 λij = λj - λi
47 if norm(λjk) > δ
48 f2_2!(r1, λi, λk, δ, i, k, gaussians, ∂gaussians, ∂²gaussians)
49 f2_2!(r2, λi, λj, δ, i, j, gaussians, ∂gaussians, ∂²gaussians)
50 @. result = 3 * (r1 - r2) / λjk
51 elseif norm(λij) > δ
52 @. result = -12 * (gaussians[j] - gaussians[i] - 0.5 * (∂gaussians[j] + ∂gaussians[i]) * λij) / λij^3
53 else
54 copy!(result, ∂³gaussians[i])
55 end
56end
57
58function f4_4!(result, r1, r11, r12, r2, r21, r22, λi, λj, λk, λl, δ, i, j, k, l,
59 gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians, ∂⁴gaussians)
60 λkl = λl - λk
61 if norm(λkl) > δ
62 f3_3!(r1, r11, r12, λi, λj, λl, δ, i, j, l, gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
63 f3_3!(r2, r21, r22, λi, λj, λk, δ, i, j, k, gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
64 @. result = 4 * (r1 - r2) / λkl
65 return nothing
66 end
67 λik = λk - λi
68 λjk = λk - λj
69 if norm(λjk) > δ && norm(λkl) < δ && norm(λik) > δ
70 f2_2!(r1, λi, λk, δ, i, k, gaussians, ∂gaussians, ∂²gaussians)
71 f2_2!(r2, λi, λj, δ, i, j, gaussians, ∂gaussians, ∂²gaussians)
72 @. result = -48 * ((gaussians[k] - gaussians[i] - 0.5 * (∂gaussians[k] + ∂gaussians[i]) * (λk - λi)) / (λjk * (λk - λi)^3)) - 12 * ((r1 - r2) / (λk - λj)^2)
73 return nothing
74 end
75 λjl = λl - λj
76 if norm(λjl) > δ && norm(λik) < δ && norm(λkl) < δ
77 f2_2!(r1, λi, λj, δ, i, j, gaussians, ∂gaussians, ∂²gaussians)
78 @. result = 4 * ∂³gaussians[i] / (λi - λj) - 12 * (∂²gaussians[i] - r1) / (λi - λj)^2
79 return nothing
80 end
81 λij = λj - λi
82 if norm(λij) > δ && norm(λjk) < δ && norm(λkl) < δ
83 @. result = 72 * (gaussians[j] - gaussians[i] - 0.5 * (∂gaussians[j] + ∂gaussians[i]) * λij) / λij^4 - 12 * (∂gaussians[j] - ∂gaussians[i] - ∂²gaussians[j] * λij) / λij^3
84 return nothing
85 end
86 if norm(λij) < δ && norm(λjk) < δ && norm(λkl) < δ
87 copy!(result, ∂⁴gaussians[i])
88 return nothing
89 end
90end
91
92function _precompute_operator_matrices(basis, kpt, ψkX, ∂v2, ∂²v2)
93 n_bands = size(ψkX, 2)
94 T = eltype(ψkX)
95 K = zeros(T, n_bands, n_bands)
96 X = zeros(T, n_bands, n_bands)
97 X_sq = zeros(T, n_bands, n_bands)
98
99 # Pre-allocate buffers for FFTs to reduce memory allocations inside the loop.
100 ui = zeros(T, basis.fft_size)
101 uj = zeros(T, basis.fft_size)
102
103 for i in 1:n_bands
104 for j in 1:n_bands
105 K[i, j] = kinetic1d(basis, ψkX, kpt, i, j)
106 X[i, j] = potential1d(basis, ψkX, kpt, ∂v2, i, j; ui, uj)
107 X_sq[i, j] = potential1d(basis, ψkX, kpt, ∂²v2, i, j; ui, uj)
108 end
109 end
110 (; K, X, X_sq)
111end
112
113function calculate_order2(λs, K, X, X_sq, gaussians_info, δ)
114 n_bands = length(λs)
115 dos_order2 = zeros(eltype(first(gaussians_info.gaussians)), length(first(gaussians_info.gaussians)))
116
117 (; gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians, ∂⁴gaussians) = gaussians_info
118
119 # Temporary arrays for in-place fm_n function.
120 tmp_f2_r = similar(dos_order2)
121 tmp_f3_r1 = similar(dos_order2)
122 tmp_f3_r2 = similar(dos_order2)
123 tmp_f3_r11 = similar(dos_order2)
124 tmp_f3_r12 = similar(dos_order2)
125 tmp_f3_r21 = similar(dos_order2)
126 tmp_f3_r22 = similar(dos_order2)
127 tmp_f4_r = similar(dos_order2)
128 tmp_f4_r1 = similar(dos_order2)
129 tmp_f4_r11 = similar(dos_order2)
130 tmp_f4_r12 = similar(dos_order2)
131 tmp_f4_r2 = similar(dos_order2)
132 tmp_f4_r21 = similar(dos_order2)
133 tmp_f4_r22 = similar(dos_order2)
134
135 for m in 1:n_bands
136 λm = λs[m]
137 dos_order2 .-= real(∂²gaussians[m] .* X_sq[m, m] ./ 8)
138
139 for n in 1:n_bands
140 λn = λs[n]
141
142 f3_3!(tmp_f3_r1, tmp_f3_r11, tmp_f3_r12, λm, λm, λn, δ, m, m, n,
143 gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
144 f3_3!(tmp_f3_r2, tmp_f3_r21, tmp_f3_r22, λm, λn, λn, δ, m, n, n
145 , gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
146
147 dos_order2 .-= (1/(4*6)) .* (2 .* tmp_f3_r1 .- tmp_f3_r2) .* abs2(X[m, n])
148
149 for p in 1:n_bands
150 λp = λs[p]
151
152 f3_3!(tmp_f3_r1, tmp_f3_r11, tmp_f3_r12, λm, λn, λp, δ, m, n, p,
153 gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians)
154 term3_expr = real(2 * K[n, p] * X_sq[m, n] * K[p, m] - K[m, n] * X_sq[n, p] * K[p, m])
155 dos_order2 .-= (1/(4*6)) .* tmp_f3_r1 .* term3_expr
156
157 for q in 1:n_bands
158 λq = λs[q]
159
160 f4_4!(tmp_f4_r, tmp_f4_r1, tmp_f4_r11, tmp_f4_r12, tmp_f4_r2, tmp_f4_r21,
161 tmp_f4_r22, λm, λn, λp, λq, δ, m, n, p, q,
162 gaussians, ∂gaussians, ∂²gaussians, ∂³gaussians, ∂⁴gaussians)
163
164 K_mn = K[m, n]; K_np = K[n, p]; K_pq = K[p, q]; K_qm = K[q, m]
165 X_mn = X[m, n]; X_np = X[n, p]; X_pq = X[p, q]; X_qm = X[q, m]
166
167 term4_expr = real(
168 (X_mn * K_np) * (K_pq * X_qm) +
169 (X_mn * K_pq) * (K_np * X_qm) +
170 (K_mn * X_pq) * (X_np * K_qm) +
171 (K_mn * X_np) * (X_pq * K_qm) -
172 2 * real( (K_mn * X_np) * (K_pq * X_qm) ) -
173 2 * real( (X_mn * K_pq) * (X_np * K_qm) ) +
174 2 * real( (X_mn * K_qm) * (K_np * X_pq - X_np * K_pq) )
175 )
176 dos_order2 .+= (1/(4*4*6)) .* tmp_f4_r .* term4_expr
177 end
178 end
179 end
180 end
181 dos_order2
182end
183
184function _dos_contribution_single_disregistry(disregistry, lattice, atoms, positions_fn,
185 Ecut, kgrid, params, xs, orders, δ, n_bands_diag)
186 doses = [zeros(length(xs)) for _ in 0:maximum(orders)]
187
188 gaussian_fns = (
189 λ -> [gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
190 λ -> [∂gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
191 λ -> [∂²gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
192 λ -> [∂³gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
193 λ -> [∂⁴gaussian(params.α, params.δ, ix; params.gaussian_factor, params.δx)(λ) for ix in xs],
194 )
195
196 model = Model(lattice, atoms, positions_fn(disregistry);
197 terms=[Kinetic(), AtomicLocal()], n_electrons=length(atoms))
198 basis = PlaneWaveBasis(model; Ecut, kgrid)
199 data = diagonalize([basis]; n_bands=n_bands_diag)
200
201 ∂v2 = compute_δv2(basis)
202 ∂²v2 = compute_δ²v2(basis)
203
204 for ik in eachindex(basis.kpoints)
205 kpt = basis.kpoints[ik]
206 ψkX = data.X[1][ik]
207 λs = data.λ[1][ik]
208 n_bands = length(λs)
209
210 gaussians_info = (;
211 gaussians = [gaussian_fns[1](λ) for λ in λs],
212 ∂gaussians = [gaussian_fns[2](λ) for λ in λs],
213 ∂²gaussians = [gaussian_fns[3](λ) for λ in λs],
214 ∂³gaussians = [gaussian_fns[4](λ) for λ in λs],
215 ∂⁴gaussians = [gaussian_fns[5](λ) for λ in λs]
216 )
217
218 if 0 ∈ orders
219 doses[1] .+= sum(gaussians_info.gaussians)
220 end
221
222 if any(o -> o > 0, orders)
223 mats = _precompute_operator_matrices(basis, kpt, ψkX, ∂v2, ∂²v2)
224
225 tmp_f2_order1 = similar(first(gaussians_info.gaussians))
226
227 if 1 ∈ orders
228 dos_order1_k = zeros(eltype(doses[2]), length(xs))
229 for i in 1:n_bands
230 for j in 1:n_bands
231 term = imag(mats.K[i, j] * mats.X[j, i])
232 f2_2!(tmp_f2_order1, λs[i], λs[j], δ, i, j,
233 gaussians_info.gaussians, gaussians_info.∂gaussians, gaussians_info.∂²gaussians)
234 dos_order1_k .-= 0.5 .* term .* tmp_f2_order1
235 end
236 end
237 doses[2] .+= dos_order1_k
238 end
239
240 if 2 ∈ orders
241 dos_order2_k = calculate_order2(λs, mats.K, mats.X, mats.X_sq, gaussians_info, δ)
242 doses[3] .+= dos_order2_k
243 end
244 end
245 end
246 doses
247end
248
249
250# One disregistry at the time.
251function dos_gaussian(lattice, atoms, positions_fn, Ecut, kgrid, params, xs;
252 disregistries, orders=[0], δ=zero(eltype(xs)), n_bands_diag=nothing,
253 checkpoint_path=nothing, cleanup_checkpoint=false)
254
255 start_index = 1
256 total_doses = [zeros(length(xs)) for _ in 0:maximum(orders)]
257
258 if checkpoint_path !== nothing && isfile(checkpoint_path)
259 @printf "Checkpoint file found at %s. Resuming calculation.\n" checkpoint_path
260 jldopen(checkpoint_path, "r") do file
261 start_index = file["last_completed_disregistry"] + 1
262 total_doses = file["total_doses"]
263 end
264 end
265
266 if start_index <= length(disregistries)
267 for i in start_index:length(disregistries)
268 disregistry = disregistries[i]
269
270 @printf "Calculating for disregistry %d/%d (value: %.4f)\n" i length(disregistries) disregistry
271
272 doses_contrib = _dos_contribution_single_disregistry(disregistry, lattice, atoms, positions_fn,
273 Ecut, kgrid, params, xs, orders, δ, n_bands_diag)
274 total_doses .+= doses_contrib
275
276 # Save checkpoint after each step
277 if checkpoint_path !== nothing
278 jldopen(checkpoint_path, "w") do file
279 file["last_completed_disregistry"] = i
280 file["total_doses"] = total_doses
281 file["disregistries"] = disregistries # for consistency checks
282 end
283 end
284 end
285 else
286 @printf "Calculation already completed according to checkpoint.\n"
287 end
288
289 # Final normalization.
290 n_disregistries = length(disregistries)
291 n_kpoints_per_disregistry = prod(kgrid)
292 normalization = 1 / (n_kpoints_per_disregistry * n_disregistries)
293
294 final_doses = total_doses .* normalization
295
296 total_dos = sum(final_doses)
297
298 if cleanup_checkpoint && checkpoint_path !== nothing && isfile(checkpoint_path)
299 rm(checkpoint_path)
300 @printf "Calculation complete. Checkpoint file cleaned up.\n"
301 end
302
303 (; dos=total_dos, doses=final_doses)
304end