164 lines
6.7 KiB
Python
164 lines
6.7 KiB
Python
"""
|
||
src/trading/env.py
|
||
|
||
This module defines a custom OpenAI Gym environment for /MES futures trading on a 5‑minute timeframe.
|
||
The environment includes logic for:
|
||
- Tracking positions (long/short)
|
||
- Evaluating positions every 5 minutes
|
||
- Implementing ATR‑based trailing stops for risk management
|
||
- Allowing a single contract initially with hooks for multi‑contract trading in the future
|
||
"""
|
||
|
||
import gym
|
||
from gym import spaces
|
||
import numpy as np
|
||
import pandas as pd
|
||
import logging
|
||
import threading
|
||
|
||
class FuturesTradingEnv(gym.Env):
|
||
"""
|
||
Custom Gym environment for futures trading.
|
||
"""
|
||
metadata = {'render.modes': ['human']}
|
||
|
||
def __init__(self, data, lstm_model=None, config=None):
|
||
"""
|
||
Initialize the trading environment.
|
||
|
||
Parameters:
|
||
- data (pandas.DataFrame): Preprocessed market data.
|
||
- lstm_model: Pre‑trained LSTM model for price forecasting (optional).
|
||
- config (dict): Configuration parameters for the environment.
|
||
"""
|
||
super(FuturesTradingEnv, self).__init__()
|
||
self.logger = logging.getLogger(__name__)
|
||
self.data = data.reset_index(drop=True)
|
||
self.lstm_model = lstm_model
|
||
self.config = config if config is not None else {}
|
||
|
||
# Define action space: discrete orders ranging from -max_contracts to +max_contracts.
|
||
self.max_contracts = self.config.get('max_contracts', 1)
|
||
self.action_space = spaces.Discrete(2 * self.max_contracts + 1)
|
||
|
||
# Define observation space: technical indicators plus current position and forecast.
|
||
self.tech_indicator_cols = [col for col in self.data.columns if col not in ['Date', 'Close']]
|
||
obs_dim = len(self.tech_indicator_cols) + 2 # +1 for normalized position, +1 for forecast.
|
||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)
|
||
|
||
# Trading variables
|
||
self.current_step = 0
|
||
self.contracts_held = 0
|
||
self.entry_price = None
|
||
|
||
# ATR‑based trailing stop parameters
|
||
self.atr_period = self.config.get('atr_period', 14)
|
||
self.trailing_stop_multiplier = self.config.get('trailing_stop_multiplier', 1.5)
|
||
self.trailing_stop = None
|
||
|
||
# Lock for thread‑safe LSTM predictions
|
||
self.lstm_lock = threading.Lock()
|
||
|
||
def _get_observation(self):
|
||
"""
|
||
Construct the observation vector from the current market data.
|
||
|
||
Returns:
|
||
- observation (numpy.array): Observation vector for the current step.
|
||
"""
|
||
row = self.data.iloc[self.current_step]
|
||
tech_indicators = row[self.tech_indicator_cols].values.astype(np.float32)
|
||
norm_position = np.array([self.contracts_held / self.max_contracts], dtype=np.float32)
|
||
|
||
if self.lstm_model and self.current_step >= self.config.get('window_size', 15):
|
||
# For proper forecasting, construct an input sequence; here we use a placeholder.
|
||
forecast = np.array([0.001], dtype=np.float32)
|
||
else:
|
||
forecast = np.array([0.0], dtype=np.float32)
|
||
|
||
observation = np.concatenate([tech_indicators, norm_position, forecast])
|
||
return observation
|
||
|
||
def reset(self):
|
||
"""
|
||
Reset the environment to an initial state.
|
||
|
||
Returns:
|
||
- observation (numpy.array): The initial observation.
|
||
"""
|
||
self.current_step = 0
|
||
self.contracts_held = 0
|
||
self.entry_price = None
|
||
self.trailing_stop = None
|
||
return self._get_observation()
|
||
|
||
def step(self, action):
|
||
"""
|
||
Execute one time step in the environment.
|
||
|
||
Parameters:
|
||
- action (int): Action taken by the agent (discrete order adjustment).
|
||
|
||
Returns:
|
||
- observation (numpy.array): Next state observation.
|
||
- reward (float): Reward obtained from the action.
|
||
- done (bool): Whether the episode has ended.
|
||
- info (dict): Additional environment info.
|
||
"""
|
||
trade_adjustment = action - self.max_contracts
|
||
current_price = self.data.iloc[self.current_step]['Close']
|
||
|
||
transaction_cost = self.config.get('transaction_cost', 0.001) * abs(trade_adjustment) * current_price
|
||
|
||
reward = 0.0
|
||
if trade_adjustment != 0:
|
||
if self.contracts_held == 0:
|
||
self.contracts_held = trade_adjustment
|
||
self.entry_price = current_price
|
||
atr = self.data.iloc[self.current_step].get('ATR', 0.01)
|
||
self.trailing_stop = current_price - self.trailing_stop_multiplier * atr if trade_adjustment > 0 else current_price + self.trailing_stop_multiplier * atr
|
||
else:
|
||
prev_position = self.contracts_held
|
||
self.contracts_held += trade_adjustment
|
||
if self.contracts_held != 0:
|
||
self.entry_price = (self.entry_price * prev_position + current_price * trade_adjustment) / self.contracts_held
|
||
else:
|
||
self.entry_price = None
|
||
reward -= transaction_cost
|
||
else:
|
||
if self.contracts_held != 0 and self.entry_price is not None:
|
||
reward = (current_price - self.entry_price) * self.contracts_held
|
||
|
||
# ATR‑based trailing stop logic:
|
||
if self.trailing_stop is not None:
|
||
if self.contracts_held > 0 and current_price < self.trailing_stop:
|
||
self.logger.info("ATR trailing stop triggered for long position.")
|
||
reward += (current_price - self.entry_price) * self.contracts_held
|
||
self.contracts_held = 0
|
||
self.entry_price = None
|
||
self.trailing_stop = None
|
||
elif self.contracts_held < 0 and current_price > self.trailing_stop:
|
||
self.logger.info("ATR trailing stop triggered for short position.")
|
||
reward += (self.entry_price - current_price) * abs(self.contracts_held)
|
||
self.contracts_held = 0
|
||
self.entry_price = None
|
||
self.trailing_stop = None
|
||
|
||
self.current_step += 1
|
||
done = self.current_step >= len(self.data) - 1
|
||
observation = self._get_observation()
|
||
|
||
info = {
|
||
'contracts_held': self.contracts_held,
|
||
'entry_price': self.entry_price
|
||
}
|
||
return observation, reward, done, info
|
||
|
||
def render(self, mode='human'):
|
||
"""
|
||
Render the current state.
|
||
"""
|
||
current_price = self.data.iloc[self.current_step]['Close']
|
||
print(f"Step: {self.current_step}, Price: {current_price:.4f}, Contracts Held: {self.contracts_held}, Entry Price: {self.entry_price}")
|
||
|