1+ # -*- coding: utf-8 -*-
2+ """Advanced_Stable_Baselines3_Trading_Agent_Marktechpost.ipynb
3+
4+ Automatically generated by Colab.
5+
6+ Original file is located at
7+ https://colab.research.google.com/drive/1hl3dX_Q-Eki2pxRAhil1Mt4RFwH9m8Ur
8+ """
9+
10+ !pip install stable - baselines3 [extra ] gymnasium pygame
11+ import numpy as np
12+ import gymnasium as gym
13+ from gymnasium import spaces
14+ import matplotlib .pyplot as plt
15+ from stable_baselines3 import PPO , A2C , DQN , SAC
16+ from stable_baselines3 .common .env_checker import check_env
17+ from stable_baselines3 .common .callbacks import BaseCallback
18+ from stable_baselines3 .common .vec_env import DummyVecEnv , VecNormalize
19+ from stable_baselines3 .common .evaluation import evaluate_policy
20+ from stable_baselines3 .common .monitor import Monitor
21+ import torch
22+
23+ class TradingEnv (gym .Env ):
24+ def __init__ (self , max_steps = 200 ):
25+ super ().__init__ ()
26+ self .max_steps = max_steps
27+ self .action_space = spaces .Discrete (3 )
28+ self .observation_space = spaces .Box (low = - np .inf , high = np .inf , shape = (5 ,), dtype = np .float32 )
29+ self .reset ()
30+ def reset (self , seed = None , options = None ):
31+ super ().reset (seed = seed )
32+ self .current_step = 0
33+ self .balance = 1000.0
34+ self .shares = 0
35+ self .price = 100.0
36+ self .price_history = [self .price ]
37+ return self ._get_obs (), {}
38+ def _get_obs (self ):
39+ price_trend = np .mean (self .price_history [- 5 :]) if len (self .price_history ) >= 5 else self .price
40+ return np .array ([
41+ self .balance / 1000.0 ,
42+ self .shares / 10.0 ,
43+ self .price / 100.0 ,
44+ price_trend / 100.0 ,
45+ self .current_step / self .max_steps
46+ ], dtype = np .float32 )
47+ def step (self , action ):
48+ self .current_step += 1
49+ trend = 0.001 * np .sin (self .current_step / 20 )
50+ self .price *= (1 + trend + np .random .normal (0 , 0.02 ))
51+ self .price = np .clip (self .price , 50 , 200 )
52+ self .price_history .append (self .price )
53+ reward = 0
54+ if action == 1 and self .balance >= self .price :
55+ shares_to_buy = int (self .balance / self .price )
56+ cost = shares_to_buy * self .price
57+ self .balance -= cost
58+ self .shares += shares_to_buy
59+ reward = - 0.01
60+ elif action == 2 and self .shares > 0 :
61+ revenue = self .shares * self .price
62+ self .balance += revenue
63+ self .shares = 0
64+ reward = 0.01
65+ portfolio_value = self .balance + self .shares * self .price
66+ reward += (portfolio_value - 1000 ) / 1000
67+ terminated = self .current_step >= self .max_steps
68+ truncated = False
69+ return self ._get_obs (), reward , terminated , truncated , {"portfolio" : portfolio_value }
70+ def render (self ):
71+ print (f"Step: { self .current_step } , Balance: ${ self .balance :.2f} , Shares: { self .shares } , Price: ${ self .price :.2f} " )
72+
73+ class ProgressCallback (BaseCallback ):
74+ def __init__ (self , check_freq = 1000 , verbose = 1 ):
75+ super ().__init__ (verbose )
76+ self .check_freq = check_freq
77+ self .rewards = []
78+ def _on_step (self ):
79+ if self .n_calls % self .check_freq == 0 :
80+ mean_reward = np .mean ([ep_info ["r" ] for ep_info in self .model .ep_info_buffer ])
81+ self .rewards .append (mean_reward )
82+ if self .verbose :
83+ print (f"Steps: { self .n_calls } , Mean Reward: { mean_reward :.2f} " )
84+ return True
85+
86+ print ("=" * 60 )
87+ print ("Setting up custom trading environment..." )
88+ env = TradingEnv ()
89+ check_env (env , warn = True )
90+ print ("✓ Environment validation passed!" )
91+ env = Monitor (env )
92+ vec_env = DummyVecEnv ([lambda : env ])
93+ vec_env = VecNormalize (vec_env , norm_obs = True , norm_reward = True )
94+
95+ print ("\n " + "=" * 60 )
96+ print ("Training multiple RL algorithms..." )
97+ algorithms = {
98+ "PPO" : PPO ("MlpPolicy" , vec_env , verbose = 0 , learning_rate = 3e-4 , n_steps = 2048 ),
99+ "A2C" : A2C ("MlpPolicy" , vec_env , verbose = 0 , learning_rate = 7e-4 ),
100+ }
101+ results = {}
102+ for name , model in algorithms .items ():
103+ print (f"\n Training { name } ..." )
104+ callback = ProgressCallback (check_freq = 2000 , verbose = 0 )
105+ model .learn (total_timesteps = 50000 , callback = callback , progress_bar = True )
106+ results [name ] = {"model" : model , "rewards" : callback .rewards }
107+ print (f"✓ { name } training complete!" )
108+
109+ print ("\n " + "=" * 60 )
110+ print ("Evaluating trained models..." )
111+ eval_env = Monitor (TradingEnv ())
112+ for name , data in results .items ():
113+ mean_reward , std_reward = evaluate_policy (data ["model" ], eval_env , n_eval_episodes = 20 , deterministic = True )
114+ results [name ]["eval_mean" ] = mean_reward
115+ results [name ]["eval_std" ] = std_reward
116+ print (f"{ name } : Mean Reward = { mean_reward :.2f} +/- { std_reward :.2f} " )
117+
118+ print ("\n " + "=" * 60 )
119+ print ("Generating visualizations..." )
120+ fig , axes = plt .subplots (2 , 2 , figsize = (14 , 10 ))
121+ ax = axes [0 , 0 ]
122+ for name , data in results .items ():
123+ ax .plot (data ["rewards" ], label = name , linewidth = 2 )
124+ ax .set_xlabel ("Training Checkpoints (x1000 steps)" )
125+ ax .set_ylabel ("Mean Episode Reward" )
126+ ax .set_title ("Training Progress Comparison" )
127+ ax .legend ()
128+ ax .grid (True , alpha = 0.3 )
129+
130+ ax = axes [0 , 1 ]
131+ names = list (results .keys ())
132+ means = [results [n ]["eval_mean" ] for n in names ]
133+ stds = [results [n ]["eval_std" ] for n in names ]
134+ ax .bar (names , means , yerr = stds , capsize = 10 , alpha = 0.7 , color = ['#1f77b4' , '#ff7f0e' ])
135+ ax .set_ylabel ("Mean Reward" )
136+ ax .set_title ("Evaluation Performance (20 episodes)" )
137+ ax .grid (True , alpha = 0.3 , axis = 'y' )
138+
139+ ax = axes [1 , 0 ]
140+ best_model = max (results .items (), key = lambda x : x [1 ]["eval_mean" ])[1 ]["model" ]
141+ obs = eval_env .reset ()[0 ]
142+ portfolio_values = [1000 ]
143+ for _ in range (200 ):
144+ action , _ = best_model .predict (obs , deterministic = True )
145+ obs , reward , done , truncated , info = eval_env .step (action )
146+ portfolio_values .append (info .get ("portfolio" , portfolio_values [- 1 ]))
147+ if done :
148+ break
149+ ax .plot (portfolio_values , linewidth = 2 , color = 'green' )
150+ ax .axhline (y = 1000 , color = 'red' , linestyle = '--' , label = 'Initial Value' )
151+ ax .set_xlabel ("Steps" )
152+ ax .set_ylabel ("Portfolio Value ($)" )
153+ ax .set_title (f"Best Model ({ max (results .items (), key = lambda x : x [1 ]['eval_mean' ])[0 ]} ) Episode" )
154+ ax .legend ()
155+ ax .grid (True , alpha = 0.3 )
156+
157+ ax = axes [1 , 1 ]
158+ obs = eval_env .reset ()[0 ]
159+ actions = []
160+ for _ in range (200 ):
161+ action , _ = best_model .predict (obs , deterministic = True )
162+ actions .append (action )
163+ obs , _ , done , truncated , _ = eval_env .step (action )
164+ if done :
165+ break
166+ action_names = ['Hold' , 'Buy' , 'Sell' ]
167+ action_counts = [actions .count (i ) for i in range (3 )]
168+ ax .pie (action_counts , labels = action_names , autopct = '%1.1f%%' , startangle = 90 , colors = ['#ff9999' , '#66b3ff' , '#99ff99' ])
169+ ax .set_title ("Action Distribution (Best Model)" )
170+ plt .tight_layout ()
171+ plt .savefig ('sb3_advanced_results.png' , dpi = 150 , bbox_inches = 'tight' )
172+ print ("✓ Visualizations saved as 'sb3_advanced_results.png'" )
173+ plt .show ()
174+
175+ print ("\n " + "=" * 60 )
176+ print ("Saving and loading models..." )
177+ best_name = max (results .items (), key = lambda x : x [1 ]["eval_mean" ])[0 ]
178+ best_model = results [best_name ]["model" ]
179+ best_model .save (f"best_trading_model_{ best_name } " )
180+ vec_env .save ("vec_normalize.pkl" )
181+ loaded_model = PPO .load (f"best_trading_model_{ best_name } " )
182+ print (f"✓ Best model ({ best_name } ) saved and loaded successfully!" )
183+ print ("\n " + "=" * 60 )
184+ print ("TUTORIAL COMPLETE!" )
185+ print (f"Best performing algorithm: { best_name } " )
186+ print (f"Final evaluation score: { results [best_name ]['eval_mean' ]:.2f} " )
187+ print ("=" * 60 )
0 commit comments