master
1using Test
2using DFTK
3using Distributed
4using LinearAlgebra
5
6if nworkers() == 1
7 addprocs(5)
8end
9@show nworkers()
10
11@everywhere using SacerDOS
12
13@testset "Distributed Checkpointing" begin
14 using Distributed
15 using LinearAlgebra
16 using JLD2
17 using SacerDOS
18
19 if nworkers() == 1
20 addprocs(5)
21 end
22 @everywhere using SacerDOS
23
24 case = :mfast
25 p = get_params(; case)
26 g_params = gaussian_parameters(p)
27 orders = [0, 1, 2]
28 checkpoint_file = "dist_checkpoint_test.jld2"
29
30 # Full computation.
31 terms_reference = SacerDOS.dos_gaussian_distributed(p.lattice, p.atoms, p.positions,
32 p.Ecut, p.kgrid, g_params, p.xs;
33 disregistries=p.disregistries, orders,
34 δ=p.δ, checkpoint_path=nothing)
35
36 # Simulate interrupted computation.
37 disregistries_part1 = p.disregistries[1:3]
38 all_positions_part1 = [p.positions(d) for d in disregistries_part1]
39
40 map_fn = dp_tuple -> begin
41 disregistry, positions = dp_tuple
42 proxy_positions_fn = _ -> positions
43 SacerDOS._dos_contribution_single_disregistry(
44 disregistry, p.lattice, p.atoms, proxy_positions_fn, p.Ecut, p.kgrid, g_params, p.xs,
45 orders, p.δ, nothing
46 )
47 end
48
49 list_of_doses_part1 = pmap(map_fn, zip(disregistries_part1, all_positions_part1))
50 accumulated_doses_partial = reduce(+, list_of_doses_part1)
51
52 jldopen(checkpoint_file, "w") do file
53 file["completed_indices"] = [1, 2, 3]
54 file["accumulated_doses"] = accumulated_doses_partial
55 file["disregistries"] = p.disregistries
56 end
57 @test isfile(checkpoint_file)
58
59 # Resumed computation.
60 terms_resumed = SacerDOS.dos_gaussian_distributed(p.lattice, p.atoms, p.positions,
61 p.Ecut, p.kgrid, g_params, p.xs;
62 disregistries=p.disregistries, orders,
63 δ=p.δ, checkpoint_path=checkpoint_file,
64 cleanup_checkpoint=true)
65
66 @test terms_reference.dos ≈ terms_resumed.dos
67 @test all(terms_reference.doses .≈ terms_resumed.doses)
68 @test !isfile(checkpoint_file)
69
70 if isfile(checkpoint_file) rm(checkpoint_file) end
71end