Files
MidasEngine/src/MidasHL/src/trading/old_env.py
2025-05-15 01:45:28 +00:00

164 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
src/trading/env.py
This module defines a custom OpenAI Gym environment for /MES futures trading on a 5minute timeframe.
The environment includes logic for:
- Tracking positions (long/short)
- Evaluating positions every 5 minutes
- Implementing ATRbased trailing stops for risk management
- Allowing a single contract initially with hooks for multicontract trading in the future
"""
import gym
from gym import spaces
import numpy as np
import pandas as pd
import logging
import threading
class FuturesTradingEnv(gym.Env):
"""
Custom Gym environment for futures trading.
"""
metadata = {'render.modes': ['human']}
def __init__(self, data, lstm_model=None, config=None):
"""
Initialize the trading environment.
Parameters:
- data (pandas.DataFrame): Preprocessed market data.
- lstm_model: Pretrained LSTM model for price forecasting (optional).
- config (dict): Configuration parameters for the environment.
"""
super(FuturesTradingEnv, self).__init__()
self.logger = logging.getLogger(__name__)
self.data = data.reset_index(drop=True)
self.lstm_model = lstm_model
self.config = config if config is not None else {}
# Define action space: discrete orders ranging from -max_contracts to +max_contracts.
self.max_contracts = self.config.get('max_contracts', 1)
self.action_space = spaces.Discrete(2 * self.max_contracts + 1)
# Define observation space: technical indicators plus current position and forecast.
self.tech_indicator_cols = [col for col in self.data.columns if col not in ['Date', 'Close']]
obs_dim = len(self.tech_indicator_cols) + 2 # +1 for normalized position, +1 for forecast.
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_dim,), dtype=np.float32)
# Trading variables
self.current_step = 0
self.contracts_held = 0
self.entry_price = None
# ATRbased trailing stop parameters
self.atr_period = self.config.get('atr_period', 14)
self.trailing_stop_multiplier = self.config.get('trailing_stop_multiplier', 1.5)
self.trailing_stop = None
# Lock for threadsafe LSTM predictions
self.lstm_lock = threading.Lock()
def _get_observation(self):
"""
Construct the observation vector from the current market data.
Returns:
- observation (numpy.array): Observation vector for the current step.
"""
row = self.data.iloc[self.current_step]
tech_indicators = row[self.tech_indicator_cols].values.astype(np.float32)
norm_position = np.array([self.contracts_held / self.max_contracts], dtype=np.float32)
if self.lstm_model and self.current_step >= self.config.get('window_size', 15):
# For proper forecasting, construct an input sequence; here we use a placeholder.
forecast = np.array([0.001], dtype=np.float32)
else:
forecast = np.array([0.0], dtype=np.float32)
observation = np.concatenate([tech_indicators, norm_position, forecast])
return observation
def reset(self):
"""
Reset the environment to an initial state.
Returns:
- observation (numpy.array): The initial observation.
"""
self.current_step = 0
self.contracts_held = 0
self.entry_price = None
self.trailing_stop = None
return self._get_observation()
def step(self, action):
"""
Execute one time step in the environment.
Parameters:
- action (int): Action taken by the agent (discrete order adjustment).
Returns:
- observation (numpy.array): Next state observation.
- reward (float): Reward obtained from the action.
- done (bool): Whether the episode has ended.
- info (dict): Additional environment info.
"""
trade_adjustment = action - self.max_contracts
current_price = self.data.iloc[self.current_step]['Close']
transaction_cost = self.config.get('transaction_cost', 0.001) * abs(trade_adjustment) * current_price
reward = 0.0
if trade_adjustment != 0:
if self.contracts_held == 0:
self.contracts_held = trade_adjustment
self.entry_price = current_price
atr = self.data.iloc[self.current_step].get('ATR', 0.01)
self.trailing_stop = current_price - self.trailing_stop_multiplier * atr if trade_adjustment > 0 else current_price + self.trailing_stop_multiplier * atr
else:
prev_position = self.contracts_held
self.contracts_held += trade_adjustment
if self.contracts_held != 0:
self.entry_price = (self.entry_price * prev_position + current_price * trade_adjustment) / self.contracts_held
else:
self.entry_price = None
reward -= transaction_cost
else:
if self.contracts_held != 0 and self.entry_price is not None:
reward = (current_price - self.entry_price) * self.contracts_held
# ATRbased trailing stop logic:
if self.trailing_stop is not None:
if self.contracts_held > 0 and current_price < self.trailing_stop:
self.logger.info("ATR trailing stop triggered for long position.")
reward += (current_price - self.entry_price) * self.contracts_held
self.contracts_held = 0
self.entry_price = None
self.trailing_stop = None
elif self.contracts_held < 0 and current_price > self.trailing_stop:
self.logger.info("ATR trailing stop triggered for short position.")
reward += (self.entry_price - current_price) * abs(self.contracts_held)
self.contracts_held = 0
self.entry_price = None
self.trailing_stop = None
self.current_step += 1
done = self.current_step >= len(self.data) - 1
observation = self._get_observation()
info = {
'contracts_held': self.contracts_held,
'entry_price': self.entry_price
}
return observation, reward, done, info
def render(self, mode='human'):
"""
Render the current state.
"""
current_price = self.data.iloc[self.current_step]['Close']
print(f"Step: {self.current_step}, Price: {current_price:.4f}, Contracts Held: {self.contracts_held}, Entry Price: {self.entry_price}")