Skip to content
master
 1using Distributed
 2using JLD2
 3using Printf
 4using ProgressMeter
 5
 6# Distributed equivalent
 7function dos_gaussian_distributed(lattice, atoms, positions_fn, Ecut, kgrid, params, xs;
 8                                  disregistries, orders=[0], δ=zero(eltype(xs)),
 9                                  n_bands_diag=nothing, checkpoint_path=nothing,
10                                  cleanup_checkpoint=false)
11
12    n_dis = length(disregistries)
13    completed_indices = Set{Int}()
14    accumulated_doses = [zeros(length(xs)) for _ in 0:maximum(orders)]
15
16    if checkpoint_path !== nothing && isfile(checkpoint_path)
17        @printf "Checkpoint file found. Resuming...\n"
18        jldopen(checkpoint_path, "r") do file
19            completed_indices = file["completed_indices"]
20            accumulated_doses = file["accumulated_doses"]
21        end
22        @printf "%d / %d jobs already completed.\n" length(completed_indices) n_dis
23    end
24
25    indices_to_run = [i for i in 1:n_dis if i  completed_indices]
26    n_jobs_to_run = length(indices_to_run)
27
28    if n_jobs_to_run == 0
29        @printf "All jobs already completed according to checkpoint.\n"
30    else
31        jobs_channel = RemoteChannel(() -> Channel{Tuple{Int,Any}}(n_jobs_to_run))
32        results_channel = RemoteChannel(() -> Channel{Tuple{Int,Vector{Vector{Float64}}}}(n_jobs_to_run))
33
34        function worker_task(jobs, results)
35            while true
36                job_index, disregistry = take!(jobs)
37                if job_index == -1  # end of work
38                    break
39                end
40
41                # Single-disregistry computation
42                positions = positions_fn(disregistry)
43                proxy_pos_fn = _ -> positions
44                doses_contrib = _dos_contribution_single_disregistry(
45                    disregistry, lattice, atoms, proxy_pos_fn, Ecut, kgrid, params, xs,
46                    orders, δ, n_bands_diag
47                )
48
49                put!(results, (job_index, doses_contrib))
50            end
51        end
52
53        for p in workers()
54            remote_do(worker_task, p, jobs_channel, results_channel)
55        end
56
57        @printf "Dispatching %d jobs to %d workers...\n" n_jobs_to_run nworkers()
58        for i in indices_to_run
59            put!(jobs_channel, (i, disregistries[i]))
60        end
61
62        @showprogress "Jobs Completed" for _ in 1:n_jobs_to_run
63            job_index, doses_contrib = take!(results_channel)
64
65            accumulated_doses .+= doses_contrib
66            push!(completed_indices, job_index)
67
68            if checkpoint_path !== nothing
69                jldopen(checkpoint_path, "w") do file
70                    file["completed_indices"] = completed_indices
71                    file["accumulated_doses"] = accumulated_doses
72                    file["disregistries"] = disregistries # For consistency
73                end
74            end
75        end
76
77        for _ in workers()
78            put!(jobs_channel, (-1, nothing))
79        end
80    end
81
82    # Final normalization.
83    n_disregistries = length(disregistries)
84    n_kpoints_total = prod(kgrid) * n_disregistries
85    normalization = 1 / n_kpoints_total
86    final_doses = accumulated_doses .* normalization
87    total_dos = sum(final_doses)
88
89    if cleanup_checkpoint && checkpoint_path !== nothing && isfile(checkpoint_path)
90        rm(checkpoint_path)
91        @printf "Calculation complete. Checkpoint file cleaned up.\n"
92    end
93
94    (; dos=total_dos, doses=final_doses)
95end