@@ -48,10 +48,15 @@ def pull(self, arm_index: int) -> int:
4848 Pull an arm of the bandit.
4949
5050 Args:
51- arm : The arm to pull.
51+ arm_index : The arm to pull.
5252
5353 Returns:
5454 The reward for the arm.
55+
56+ Example:
57+ >>> bandit = Bandit([0.1, 0.5, 0.9])
58+ >>> isinstance(bandit.pull(0), int)
59+ True
5560 """
5661 rng = np .random .default_rng ()
5762 return 1 if rng .random () < self .probabilities [arm_index ] else 0
@@ -86,6 +91,11 @@ def select_arm(self):
8691
8792 Returns:
8893 The index of the arm to pull.
94+
95+ Example:
96+ >>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
97+ >>> 0 <= strategy.select_arm() < 3
98+ True
8999 """
90100 rng = np .random .default_rng ()
91101
@@ -101,6 +111,12 @@ def update(self, arm_index: int, reward: int):
101111 Args:
102112 arm_index: The index of the arm to pull.
103113 reward: The reward for the arm.
114+
115+ Example:
116+ >>> strategy = EpsilonGreedy(epsilon=0.1, k=3)
117+ >>> strategy.update(0, 1)
118+ >>> strategy.counts[0] == 1
119+ True
104120 """
105121 self .counts [arm_index ] += 1
106122 n = self .counts [arm_index ]
@@ -135,6 +151,11 @@ def select_arm(self):
135151
136152 Returns:
137153 The index of the arm to pull.
154+
155+ Example:
156+ >>> strategy = UCB(k=3)
157+ >>> 0 <= strategy.select_arm() < 3
158+ True
138159 """
139160 if self .total_counts < self .k :
140161 return self .total_counts
@@ -149,6 +170,12 @@ def update(self, arm_index: int, reward: int):
149170 Args:
150171 arm_index: The index of the arm to pull.
151172 reward: The reward for the arm.
173+
174+ Example:
175+ >>> strategy = UCB(k=3)
176+ >>> strategy.update(0, 1)
177+ >>> strategy.counts[0] == 1
178+ True
152179 """
153180 self .counts [arm_index ] += 1
154181 self .total_counts += 1
@@ -184,6 +211,11 @@ def select_arm(self):
184211 Returns:
185212 The index of the arm to pull based on the Thompson Sampling strategy
186213 which relies on the Beta distribution.
214+
215+ Example:
216+ >>> strategy = ThompsonSampling(k=3)
217+ >>> 0 <= strategy.select_arm() < 3
218+ True
187219 """
188220 rng = np .random .default_rng ()
189221
@@ -199,6 +231,12 @@ def update(self, arm_index: int, reward: int):
199231 Args:
200232 arm_index: The index of the arm to pull.
201233 reward: The reward for the arm.
234+
235+ Example:
236+ >>> strategy = ThompsonSampling(k=3)
237+ >>> strategy.update(0, 1)
238+ >>> strategy.successes[0] == 1
239+ True
202240 """
203241 if reward == 1 :
204242 self .successes [arm_index ] += 1
@@ -210,7 +248,7 @@ def update(self, arm_index: int, reward: int):
210248class RandomStrategy :
211249 """
212250 A class for choosing totally random at each round to give
213- a better comparison with the other optimisedstrategies .
251+ a better comparison with the other optimised strategies .
214252 """
215253
216254 def __init__ (self , k : int ):
@@ -228,6 +266,11 @@ def select_arm(self):
228266
229267 Returns:
230268 The index of the arm to pull.
269+
270+ Example:
271+ >>> strategy = RandomStrategy(k=3)
272+ >>> 0 <= strategy.select_arm() < 3
273+ True
231274 """
232275 rng = np .random .default_rng ()
233276 return rng .integers (self .k )
@@ -239,6 +282,10 @@ def update(self, arm_index: int, reward: int):
239282 Args:
240283 arm_index: The index of the arm to pull.
241284 reward: The reward for the arm.
285+
286+ Example:
287+ >>> strategy = RandomStrategy(k=3)
288+ >>> strategy.update(0, 1)
242289 """
243290
244291
@@ -268,6 +315,11 @@ def select_arm(self):
268315
269316 Returns:
270317 The index of the arm to pull.
318+
319+ Example:
320+ >>> strategy = GreedyStrategy(k=3)
321+ >>> 0 <= strategy.select_arm() < 3
322+ True
271323 """
272324 return np .argmax (self .values )
273325
@@ -278,6 +330,12 @@ def update(self, arm_index: int, reward: int):
278330 Args:
279331 arm_index: The index of the arm to pull.
280332 reward: The reward for the arm.
333+
334+ Example:
335+ >>> strategy = GreedyStrategy(k=3)
336+ >>> strategy.update(0, 1)
337+ >>> strategy.counts[0] == 1
338+ True
281339 """
282340 self .counts [arm_index ] += 1
283341 n = self .counts [arm_index ]
@@ -329,4 +387,6 @@ def test_mab_strategies():
329387
330388
331389if __name__ == "__main__" :
390+ import doctest
391+ doctest .testmod ()
332392 test_mab_strategies ()
0 commit comments