Skip to content
master
 1using Test
 2using DFTK
 3using Distributed
 4using LinearAlgebra
 5
 6if nworkers() == 1
 7    addprocs(5)
 8end
 9@show nworkers()
10
11@everywhere using SacerDOS
12
13@testset "Distributed Checkpointing" begin
14    using Distributed
15    using LinearAlgebra
16    using JLD2
17    using SacerDOS
18
19    if nworkers() == 1
20        addprocs(5)
21    end
22    @everywhere using SacerDOS
23
24    case = :mfast
25    p = get_params(; case)
26    g_params = gaussian_parameters(p)
27    orders = [0, 1, 2]
28    checkpoint_file = "dist_checkpoint_test.jld2"
29
30    # Full computation.
31    terms_reference = SacerDOS.dos_gaussian_distributed(p.lattice, p.atoms, p.positions,
32                                                         p.Ecut, p.kgrid, g_params, p.xs;
33                                                         disregistries=p.disregistries, orders,
34                                                         δ=p.δ, checkpoint_path=nothing)
35
36    # Simulate interrupted computation.
37    disregistries_part1 = p.disregistries[1:3]
38    all_positions_part1 = [p.positions(d) for d in disregistries_part1]
39
40    map_fn = dp_tuple -> begin
41        disregistry, positions = dp_tuple
42        proxy_positions_fn = _ -> positions
43        SacerDOS._dos_contribution_single_disregistry(
44            disregistry, p.lattice, p.atoms, proxy_positions_fn, p.Ecut, p.kgrid, g_params, p.xs,
45            orders, p.δ, nothing
46        )
47    end
48
49    list_of_doses_part1 = pmap(map_fn, zip(disregistries_part1, all_positions_part1))
50    accumulated_doses_partial = reduce(+, list_of_doses_part1)
51
52    jldopen(checkpoint_file, "w") do file
53        file["completed_indices"] = [1, 2, 3]
54        file["accumulated_doses"] = accumulated_doses_partial
55        file["disregistries"] = p.disregistries
56    end
57    @test isfile(checkpoint_file)
58
59    # Resumed computation.
60    terms_resumed = SacerDOS.dos_gaussian_distributed(p.lattice, p.atoms, p.positions,
61                                                     p.Ecut, p.kgrid, g_params, p.xs;
62                                                     disregistries=p.disregistries, orders,
63                                                     δ=p.δ, checkpoint_path=checkpoint_file,
64                                                     cleanup_checkpoint=true)
65
66    @test terms_reference.dos  terms_resumed.dos
67    @test all(terms_reference.doses .≈ terms_resumed.doses)
68    @test !isfile(checkpoint_file)
69
70    if isfile(checkpoint_file) rm(checkpoint_file) end
71end