master
1using Distributed
2using JLD2
3using Printf
4using ProgressMeter
5
6# Distributed equivalent
7function dos_gaussian_distributed(lattice, atoms, positions_fn, Ecut, kgrid, params, xs;
8 disregistries, orders=[0], δ=zero(eltype(xs)),
9 n_bands_diag=nothing, checkpoint_path=nothing,
10 cleanup_checkpoint=false)
11
12 n_dis = length(disregistries)
13 completed_indices = Set{Int}()
14 accumulated_doses = [zeros(length(xs)) for _ in 0:maximum(orders)]
15
16 if checkpoint_path !== nothing && isfile(checkpoint_path)
17 @printf "Checkpoint file found. Resuming...\n"
18 jldopen(checkpoint_path, "r") do file
19 completed_indices = file["completed_indices"]
20 accumulated_doses = file["accumulated_doses"]
21 end
22 @printf "%d / %d jobs already completed.\n" length(completed_indices) n_dis
23 end
24
25 indices_to_run = [i for i in 1:n_dis if i ∉ completed_indices]
26 n_jobs_to_run = length(indices_to_run)
27
28 if n_jobs_to_run == 0
29 @printf "All jobs already completed according to checkpoint.\n"
30 else
31 jobs_channel = RemoteChannel(() -> Channel{Tuple{Int,Any}}(n_jobs_to_run))
32 results_channel = RemoteChannel(() -> Channel{Tuple{Int,Vector{Vector{Float64}}}}(n_jobs_to_run))
33
34 function worker_task(jobs, results)
35 while true
36 job_index, disregistry = take!(jobs)
37 if job_index == -1 # end of work
38 break
39 end
40
41 # Single-disregistry computation
42 positions = positions_fn(disregistry)
43 proxy_pos_fn = _ -> positions
44 doses_contrib = _dos_contribution_single_disregistry(
45 disregistry, lattice, atoms, proxy_pos_fn, Ecut, kgrid, params, xs,
46 orders, δ, n_bands_diag
47 )
48
49 put!(results, (job_index, doses_contrib))
50 end
51 end
52
53 for p in workers()
54 remote_do(worker_task, p, jobs_channel, results_channel)
55 end
56
57 @printf "Dispatching %d jobs to %d workers...\n" n_jobs_to_run nworkers()
58 for i in indices_to_run
59 put!(jobs_channel, (i, disregistries[i]))
60 end
61
62 @showprogress "Jobs Completed" for _ in 1:n_jobs_to_run
63 job_index, doses_contrib = take!(results_channel)
64
65 accumulated_doses .+= doses_contrib
66 push!(completed_indices, job_index)
67
68 if checkpoint_path !== nothing
69 jldopen(checkpoint_path, "w") do file
70 file["completed_indices"] = completed_indices
71 file["accumulated_doses"] = accumulated_doses
72 file["disregistries"] = disregistries # For consistency
73 end
74 end
75 end
76
77 for _ in workers()
78 put!(jobs_channel, (-1, nothing))
79 end
80 end
81
82 # Final normalization.
83 n_disregistries = length(disregistries)
84 n_kpoints_total = prod(kgrid) * n_disregistries
85 normalization = 1 / n_kpoints_total
86 final_doses = accumulated_doses .* normalization
87 total_dos = sum(final_doses)
88
89 if cleanup_checkpoint && checkpoint_path !== nothing && isfile(checkpoint_path)
90 rm(checkpoint_path)
91 @printf "Calculation complete. Checkpoint file cleaned up.\n"
92 end
93
94 (; dos=total_dos, doses=final_doses)
95end