88See: [https://en.wikipedia.org/wiki/Q-learning](https://en.wikipedia.org/wiki/Q-learning)
99"""
1010
11- from collections import defaultdict
1211import random
12+ from collections import defaultdict
13+
14+ # Type alias for state
15+ type State = tuple [int , int ]
1316
1417# Hyperparameters for Q-Learning
1518LEARNING_RATE = 0.1
1922EPSILON_MIN = 0.01
2023
2124# Global Q-table to store state-action values
22- q_table = defaultdict (lambda : defaultdict (float ))
25+ q_table : dict [ State , dict [ int , float ]] = defaultdict (lambda : defaultdict (float ))
2326
2427# Environment variables for simple grid world
2528SIZE = 4
2629GOAL = (SIZE - 1 , SIZE - 1 )
2730current_state = (0 , 0 )
2831
2932
30- def get_q_value (state , action ) :
33+ def get_q_value (state : State , action : int ) -> float :
3134 """
3235 Get Q-value for a given state-action pair.
3336
37+ >>> q_table.clear()
3438 >>> get_q_value((0, 0), 2)
3539 0.0
3640 """
3741 return q_table [state ][action ]
3842
3943
40- def get_best_action (state , available_actions ) :
44+ def get_best_action (state : State , available_actions : list [ int ]) -> int :
4145 """
4246 Get the action with maximum Q-value in the given state.
4347
48+ >>> q_table.clear()
4449 >>> q_table[(0, 0)][1] = 0.7
4550 >>> q_table[(0, 0)][2] = 0.7
4651 >>> q_table[(0, 0)][3] = 0.5
@@ -54,14 +59,18 @@ def get_best_action(state, available_actions):
5459 return random .choice (best )
5560
5661
57- def choose_action (state , available_actions ) :
62+ def choose_action (state : State , available_actions : list [ int ]) -> int :
5863 """
5964 Choose action using epsilon-greedy policy.
6065
66+ >>> q_table.clear()
67+ >>> old_epsilon = EPSILON
6168 >>> EPSILON = 0.0
6269 >>> q_table[(0, 0)][1] = 1.0
6370 >>> q_table[(0, 0)][2] = 0.5
64- >>> choose_action((0, 0), [1, 2])
71+ >>> result = choose_action((0, 0), [1, 2])
72+ >>> EPSILON = old_epsilon # Restore
73+ >>> result
6574 1
6675 """
6776 global EPSILON
@@ -72,64 +81,84 @@ def choose_action(state, available_actions):
7281 return get_best_action (state , available_actions )
7382
7483
75- def update (state , action , reward , next_state , next_available_actions , done = False ):
84+ def update (
85+ state : State ,
86+ action : int ,
87+ reward : float ,
88+ next_state : State ,
89+ next_available_actions : list [int ],
90+ done : bool = False ,
91+ alpha : float | None = None ,
92+ gamma : float | None = None ,
93+ ) -> None :
7694 """
7795 Perform Q-value update for a transition using the Q-learning rule.
7896
7997 Q(s,a) <- Q(s,a) + alpha * (r + gamma * max_a' Q(s',a') - Q(s,a))
8098
81- >>> LEARNING_RATE = 0.5
82- >>> DISCOUNT_FACTOR = 0.9
83- >>> update((0,0), 1, 1.0, (0,1), [1,2], done=True)
84- >>> get_q_value((0,0), 1)
99+ >>> q_table.clear()
100+ >>> update((0, 0), 1, 1.0, (0, 1), [1, 2], done=True, alpha=0.5, gamma=0.9)
101+ >>> get_q_value((0, 0), 1)
85102 0.5
86103 """
87104 global LEARNING_RATE , DISCOUNT_FACTOR
105+ alpha = alpha if alpha is not None else LEARNING_RATE
106+ gamma = gamma if gamma is not None else DISCOUNT_FACTOR
107+ max_q_next = 0.0 if done or not next_available_actions else max (
108+ get_q_value (next_state , a ) for a in next_available_actions
88109 max_q_next = (
89110 0.0
90111 if done or not next_available_actions
91112 else max (get_q_value (next_state , a ) for a in next_available_actions )
92113 )
93114 old_q = get_q_value (state , action )
94- new_q = (1 - LEARNING_RATE ) * old_q + LEARNING_RATE * (
95- reward + DISCOUNT_FACTOR * max_q_next
115+ new_q = (1 - alpha ) * old_q + alpha * (
116+ reward + gamma * max_q_next
96117 )
97118 q_table [state ][action ] = new_q
98119
99120
100- def get_policy ():
121+ def get_policy () -> dict [ State , int ] :
101122 """
102123 Extract a deterministic policy from the Q-table.
103124
104- >>> q_table[(1,2)][1] = 2.0
105- >>> q_table[(1,2)][2] = 1.0
106- >>> get_policy()[(1,2)]
125+ >>> q_table.clear()
126+ >>> q_table[(1, 2)][1] = 2.0
127+ >>> q_table[(1, 2)][2] = 1.0
128+ >>> get_policy()[(1, 2)]
107129 1
108130 """
109- policy = {}
131+ policy : dict [ State , int ] = {}
110132 for s , a_dict in q_table .items ():
111133 if a_dict :
112134 policy [s ] = max (a_dict , key = a_dict .get )
113135 return policy
114136
115137
116- def reset_env ():
138+ def reset_env () -> State :
117139 """
118140 Reset the environment to initial state.
141+
142+ >>> old_state = current_state
143+ >>> current_state = (1, 1) # Simulate non-initial state
144+ >>> result = reset_env()
145+ >>> current_state = old_state # Restore for other tests
146+ >>> result
147+ (0, 0)
119148 """
120149 global current_state
121150 current_state = (0 , 0 )
122151 return current_state
123152
124153
125- def get_available_actions_env ():
154+ def get_available_actions_env () -> list [ int ] :
126155 """
127156 Get available actions in the current environment state.
128157 """
129- return [0 , 1 , 2 , 3 ]
158+ return [0 , 1 , 2 , 3 ] # 0: up, 1: right, 2: down, 3: left
130159
131160
132- def step_env (action ) :
161+ def step_env (action : int ) -> tuple [ State , float , bool ] :
133162 """
134163 Take a step in the environment with the given action.
135164 """
@@ -150,13 +179,13 @@ def step_env(action):
150179 return next_state , reward , done
151180
152181
153- def run_q_learning ():
182+ def run_q_learning () -> None :
154183 """
155184 Run Q-Learning on the simple grid world environment.
156185 """
157186 global EPSILON
158187 episodes = 200
159- for episode in range (episodes ):
188+ for _ in range (episodes ):
160189 state = reset_env ()
161190 done = False
162191 while not done :
@@ -178,3 +207,4 @@ def run_q_learning():
178207
179208 doctest .testmod ()
180209 run_q_learning ()
210+
0 commit comments