-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.jl
More file actions
128 lines (108 loc) · 3.8 KB
/
utils.jl
File metadata and controls
128 lines (108 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
@testset "Grid graphs" begin
using DecisionFocusedLearningBenchmarks.Utils
using DecisionFocusedLearningBenchmarks.Utils: count_edges, get_path, index_to_coord
using Graphs
h = 4
w = 7
costs = rand(h, w)
for acyclic in (true, false)
g = grid_graph(costs; acyclic=acyclic)
@test nv(g) == h * w
@test ne(g) == count_edges(h, w; acyclic)
@test all(edges(g)) do e
v1, v2 = src(e), dst(e)
i1, j1 = index_to_coord(v1, h, w)
i2, j2 = index_to_coord(v2, h, w)
a = max(abs(i1 - i2), abs(j1 - j2)) == 1
b = g.weights[v2, v1] == costs[v2]
return a && b
end
path = get_path(dijkstra_shortest_paths(g, 1).parents, 1, nv(g))
@test max(h, w) <= length(path) <= h + w
end
end
@testset "DataSample" begin
using DecisionFocusedLearningBenchmarks
using StableRNGs
using StatsBase:
ZScoreTransform,
UnitRangeTransform,
fit,
transform,
transform!,
reconstruct,
reconstruct!
rng = StableRNG(1234)
function random_sample()
return DataSample(;
x=randn(rng, 10, 5),
θ=rand(rng, 5),
y=rand(rng, 10),
instance="this is an instance",
)
end
sample = random_sample()
@test sample isa DataSample
io = IOBuffer()
show(io, sample)
s = String(take!(io))
@test occursin("DataSample(", s)
@test occursin("θ_true", s)
@test occursin("y_true", s)
@test occursin("instance=\"this is an instance\"", s)
@test propertynames(sample) == (:x, :θ, :y, :context, :extra, :instance)
# Create a dataset for testing
N = 5
dataset = [random_sample() for _ in 1:N]
# Test fit with ZScoreTransform
zt = fit(ZScoreTransform, dataset; dims=2)
@test zt isa ZScoreTransform
# Test fit with UnitRangeTransform
ut = fit(UnitRangeTransform, dataset; dims=2)
@test ut isa UnitRangeTransform
# Test transform (non-mutating)
dataset_zt = transform(zt, dataset)
@test length(dataset_zt) == length(dataset)
@test all(d -> d isa DataSample, dataset_zt)
# Check that other fields are preserved
for i in 1:N
@test dataset_zt[i].θ == dataset[i].θ
@test dataset_zt[i].y == dataset[i].y
@test dataset_zt[i].context == dataset[i].context
end
# Check that features are actually transformed
@test dataset_zt[1].x != dataset[1].x
# Test transform! (mutating)
dataset_copy = deepcopy(dataset)
original_x = copy(dataset_copy[1].x)
transform!(ut, dataset_copy)
@test dataset_copy[1].x != original_x
# Check that other fields remain unchanged after transform!
for i in 1:N
@test dataset_copy[i].θ == dataset[i].θ
@test dataset_copy[i].y == dataset[i].y
@test dataset_copy[i].context == dataset[i].context
end
# Test reconstruct (non-mutating)
dataset_reconstructed = reconstruct(zt, dataset_zt)
@test length(dataset_reconstructed) == length(dataset)
# Test round-trip consistency (should be close to original)
for i in 1:N
@test dataset_reconstructed[i].x ≈ dataset[i].x atol = 1e-10
@test dataset_reconstructed[i].θ == dataset[i].θ
@test dataset_reconstructed[i].y == dataset[i].y
@test dataset_reconstructed[i].context == dataset[i].context
end
# Test reconstruct! (mutating)
reconstruct!(zt, dataset_zt)
for i in 1:N
@test dataset_zt[i].x ≈ dataset[i].x atol = 1e-10
end
end
@testset "Maximizers" begin
using DecisionFocusedLearningBenchmarks.Utils: TopKMaximizer
top_k = TopKMaximizer(3)
@test top_k([1, 2, 3, 4, 5]) == [0, 0, 1, 1, 1]
@test top_k([5, 4, 3, 2, 1]) == [1, 1, 1, 0, 0]
@test_throws(AssertionError, top_k([1, 2]))
end