Skip to content

Commit 19c0a66

Browse files
authored
Add files via upload
1 parent 29fd44b commit 19c0a66

1 file changed

Lines changed: 187 additions & 0 deletions

File tree

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# -*- coding: utf-8 -*-
2+
"""Advanced_Stable_Baselines3_Trading_Agent_Marktechpost.ipynb
3+
4+
Automatically generated by Colab.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/1hl3dX_Q-Eki2pxRAhil1Mt4RFwH9m8Ur
8+
"""
9+
10+
!pip install stable-baselines3[extra] gymnasium pygame
11+
import numpy as np
12+
import gymnasium as gym
13+
from gymnasium import spaces
14+
import matplotlib.pyplot as plt
15+
from stable_baselines3 import PPO, A2C, DQN, SAC
16+
from stable_baselines3.common.env_checker import check_env
17+
from stable_baselines3.common.callbacks import BaseCallback
18+
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
19+
from stable_baselines3.common.evaluation import evaluate_policy
20+
from stable_baselines3.common.monitor import Monitor
21+
import torch
22+
23+
class TradingEnv(gym.Env):
24+
def __init__(self, max_steps=200):
25+
super().__init__()
26+
self.max_steps = max_steps
27+
self.action_space = spaces.Discrete(3)
28+
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(5,), dtype=np.float32)
29+
self.reset()
30+
def reset(self, seed=None, options=None):
31+
super().reset(seed=seed)
32+
self.current_step = 0
33+
self.balance = 1000.0
34+
self.shares = 0
35+
self.price = 100.0
36+
self.price_history = [self.price]
37+
return self._get_obs(), {}
38+
def _get_obs(self):
39+
price_trend = np.mean(self.price_history[-5:]) if len(self.price_history) >= 5 else self.price
40+
return np.array([
41+
self.balance / 1000.0,
42+
self.shares / 10.0,
43+
self.price / 100.0,
44+
price_trend / 100.0,
45+
self.current_step / self.max_steps
46+
], dtype=np.float32)
47+
def step(self, action):
48+
self.current_step += 1
49+
trend = 0.001 * np.sin(self.current_step / 20)
50+
self.price *= (1 + trend + np.random.normal(0, 0.02))
51+
self.price = np.clip(self.price, 50, 200)
52+
self.price_history.append(self.price)
53+
reward = 0
54+
if action == 1 and self.balance >= self.price:
55+
shares_to_buy = int(self.balance / self.price)
56+
cost = shares_to_buy * self.price
57+
self.balance -= cost
58+
self.shares += shares_to_buy
59+
reward = -0.01
60+
elif action == 2 and self.shares > 0:
61+
revenue = self.shares * self.price
62+
self.balance += revenue
63+
self.shares = 0
64+
reward = 0.01
65+
portfolio_value = self.balance + self.shares * self.price
66+
reward += (portfolio_value - 1000) / 1000
67+
terminated = self.current_step >= self.max_steps
68+
truncated = False
69+
return self._get_obs(), reward, terminated, truncated, {"portfolio": portfolio_value}
70+
def render(self):
71+
print(f"Step: {self.current_step}, Balance: ${self.balance:.2f}, Shares: {self.shares}, Price: ${self.price:.2f}")
72+
73+
class ProgressCallback(BaseCallback):
74+
def __init__(self, check_freq=1000, verbose=1):
75+
super().__init__(verbose)
76+
self.check_freq = check_freq
77+
self.rewards = []
78+
def _on_step(self):
79+
if self.n_calls % self.check_freq == 0:
80+
mean_reward = np.mean([ep_info["r"] for ep_info in self.model.ep_info_buffer])
81+
self.rewards.append(mean_reward)
82+
if self.verbose:
83+
print(f"Steps: {self.n_calls}, Mean Reward: {mean_reward:.2f}")
84+
return True
85+
86+
print("=" * 60)
87+
print("Setting up custom trading environment...")
88+
env = TradingEnv()
89+
check_env(env, warn=True)
90+
print("✓ Environment validation passed!")
91+
env = Monitor(env)
92+
vec_env = DummyVecEnv([lambda: env])
93+
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True)
94+
95+
print("\n" + "=" * 60)
96+
print("Training multiple RL algorithms...")
97+
algorithms = {
98+
"PPO": PPO("MlpPolicy", vec_env, verbose=0, learning_rate=3e-4, n_steps=2048),
99+
"A2C": A2C("MlpPolicy", vec_env, verbose=0, learning_rate=7e-4),
100+
}
101+
results = {}
102+
for name, model in algorithms.items():
103+
print(f"\nTraining {name}...")
104+
callback = ProgressCallback(check_freq=2000, verbose=0)
105+
model.learn(total_timesteps=50000, callback=callback, progress_bar=True)
106+
results[name] = {"model": model, "rewards": callback.rewards}
107+
print(f"✓ {name} training complete!")
108+
109+
print("\n" + "=" * 60)
110+
print("Evaluating trained models...")
111+
eval_env = Monitor(TradingEnv())
112+
for name, data in results.items():
113+
mean_reward, std_reward = evaluate_policy(data["model"], eval_env, n_eval_episodes=20, deterministic=True)
114+
results[name]["eval_mean"] = mean_reward
115+
results[name]["eval_std"] = std_reward
116+
print(f"{name}: Mean Reward = {mean_reward:.2f} +/- {std_reward:.2f}")
117+
118+
print("\n" + "=" * 60)
119+
print("Generating visualizations...")
120+
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
121+
ax = axes[0, 0]
122+
for name, data in results.items():
123+
ax.plot(data["rewards"], label=name, linewidth=2)
124+
ax.set_xlabel("Training Checkpoints (x1000 steps)")
125+
ax.set_ylabel("Mean Episode Reward")
126+
ax.set_title("Training Progress Comparison")
127+
ax.legend()
128+
ax.grid(True, alpha=0.3)
129+
130+
ax = axes[0, 1]
131+
names = list(results.keys())
132+
means = [results[n]["eval_mean"] for n in names]
133+
stds = [results[n]["eval_std"] for n in names]
134+
ax.bar(names, means, yerr=stds, capsize=10, alpha=0.7, color=['#1f77b4', '#ff7f0e'])
135+
ax.set_ylabel("Mean Reward")
136+
ax.set_title("Evaluation Performance (20 episodes)")
137+
ax.grid(True, alpha=0.3, axis='y')
138+
139+
ax = axes[1, 0]
140+
best_model = max(results.items(), key=lambda x: x[1]["eval_mean"])[1]["model"]
141+
obs = eval_env.reset()[0]
142+
portfolio_values = [1000]
143+
for _ in range(200):
144+
action, _ = best_model.predict(obs, deterministic=True)
145+
obs, reward, done, truncated, info = eval_env.step(action)
146+
portfolio_values.append(info.get("portfolio", portfolio_values[-1]))
147+
if done:
148+
break
149+
ax.plot(portfolio_values, linewidth=2, color='green')
150+
ax.axhline(y=1000, color='red', linestyle='--', label='Initial Value')
151+
ax.set_xlabel("Steps")
152+
ax.set_ylabel("Portfolio Value ($)")
153+
ax.set_title(f"Best Model ({max(results.items(), key=lambda x: x[1]['eval_mean'])[0]}) Episode")
154+
ax.legend()
155+
ax.grid(True, alpha=0.3)
156+
157+
ax = axes[1, 1]
158+
obs = eval_env.reset()[0]
159+
actions = []
160+
for _ in range(200):
161+
action, _ = best_model.predict(obs, deterministic=True)
162+
actions.append(action)
163+
obs, _, done, truncated, _ = eval_env.step(action)
164+
if done:
165+
break
166+
action_names = ['Hold', 'Buy', 'Sell']
167+
action_counts = [actions.count(i) for i in range(3)]
168+
ax.pie(action_counts, labels=action_names, autopct='%1.1f%%', startangle=90, colors=['#ff9999', '#66b3ff', '#99ff99'])
169+
ax.set_title("Action Distribution (Best Model)")
170+
plt.tight_layout()
171+
plt.savefig('sb3_advanced_results.png', dpi=150, bbox_inches='tight')
172+
print("✓ Visualizations saved as 'sb3_advanced_results.png'")
173+
plt.show()
174+
175+
print("\n" + "=" * 60)
176+
print("Saving and loading models...")
177+
best_name = max(results.items(), key=lambda x: x[1]["eval_mean"])[0]
178+
best_model = results[best_name]["model"]
179+
best_model.save(f"best_trading_model_{best_name}")
180+
vec_env.save("vec_normalize.pkl")
181+
loaded_model = PPO.load(f"best_trading_model_{best_name}")
182+
print(f"✓ Best model ({best_name}) saved and loaded successfully!")
183+
print("\n" + "=" * 60)
184+
print("TUTORIAL COMPLETE!")
185+
print(f"Best performing algorithm: {best_name}")
186+
print(f"Final evaluation score: {results[best_name]['eval_mean']:.2f}")
187+
print("=" * 60)

0 commit comments

Comments
 (0)