254 lines
9.6 KiB
Plaintext
254 lines
9.6 KiB
Plaintext
import threading
|
|
import time
|
|
import logging
|
|
from ibapi.client import EClient
|
|
from ibapi.wrapper import EWrapper
|
|
from ibapi.contract import Contract
|
|
from ibapi.order import Order
|
|
import pickle
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Logging Configuration: Set up logging for debugging and operational insights.
|
|
# -----------------------------------------------------------------------------
|
|
logging.basicConfig(level=logging.INFO,
|
|
format='%(asctime)s %(levelname)s: %(message)s')
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Helper Functions
|
|
# -----------------------------------------------------------------------------
|
|
def USStockContract(symbol):
|
|
"""
|
|
Create and return an IBKR stock contract.
|
|
|
|
:param symbol: Underlying stock symbol (e.g., "AAPL")
|
|
:return: Configured IBKR Contract object for a stock
|
|
"""
|
|
contract = Contract()
|
|
contract.symbol = symbol
|
|
contract.secType = "STK"
|
|
contract.currency = "USD"
|
|
contract.exchange = "SMART"
|
|
return contract
|
|
|
|
def makeLimitOrder(action, quantity, limitPrice):
|
|
"""
|
|
Create and return a limit order.
|
|
|
|
:param action: "BUY" or "SELL"
|
|
:param quantity: Number of shares to trade
|
|
:param limitPrice: The limit price for the order
|
|
:return: Configured Order object
|
|
"""
|
|
order = Order()
|
|
order.action = action
|
|
order.orderType = "LMT"
|
|
order.totalQuantity = quantity
|
|
order.lmtPrice = limitPrice
|
|
return order
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# PPO Model Integration (Placeholder)
|
|
# -----------------------------------------------------------------------------
|
|
class PPOModel:
|
|
def __init__(self, model_path):
|
|
"""
|
|
Load a pre-trained PPO model from disk.
|
|
|
|
:param model_path: Path to the saved PPO model file.
|
|
"""
|
|
try:
|
|
with open(model_path, 'rb') as f:
|
|
self.model = pickle.load(f)
|
|
logging.info("PPO model loaded successfully.")
|
|
except Exception as e:
|
|
logging.error(f"Error loading PPO model: {e}")
|
|
self.model = None
|
|
|
|
def get_action(self, observation):
|
|
"""
|
|
Get an action from the PPO model given the current observation.
|
|
|
|
:param observation: A numeric observation (e.g., current stock price)
|
|
:return: Action decision as a string (e.g., "BUY", "SELL", "HOLD")
|
|
"""
|
|
try:
|
|
# Replace the dummy logic below with your actual model inference.
|
|
price = observation
|
|
# For demonstration: Buy if price is below a threshold, otherwise hold.
|
|
if price < 150:
|
|
return "BUY"
|
|
elif price > 170:
|
|
return "SELL"
|
|
else:
|
|
return "HOLD"
|
|
except Exception as e:
|
|
logging.error(f"Error in model inference: {e}")
|
|
return "HOLD"
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Trading Bot Class
|
|
# -----------------------------------------------------------------------------
|
|
class AIDrivenIBBot(EWrapper, EClient):
|
|
def __init__(self, ppo_model, stock_symbol="AAPL"):
|
|
"""
|
|
Initialize the trading bot with the PPO model and set up required variables.
|
|
|
|
:param ppo_model: An instance of PPOModel containing the trained model.
|
|
:param stock_symbol: The stock symbol to trade (default is "AAPL")
|
|
"""
|
|
EClient.__init__(self, self)
|
|
self.ppo_model = ppo_model
|
|
self.stock_symbol = stock_symbol
|
|
self.nextValidOrderId = None
|
|
self.latest_price = None
|
|
# Flag to avoid repeated trades within the same signal window.
|
|
self.trade_executed = False
|
|
|
|
# ------------------------
|
|
# IB API Callbacks
|
|
# ------------------------
|
|
def error(self, reqId, errorCode, errorString):
|
|
"""
|
|
Handle error messages received from IB Gateway.
|
|
"""
|
|
logging.error(f"Error. ReqId: {reqId}, Code: {errorCode}, Msg: {errorString}")
|
|
|
|
def nextValidId(self, orderId):
|
|
"""
|
|
Callback for receiving the next valid order ID. Starts the market data request.
|
|
"""
|
|
logging.info(f"Next valid order ID: {orderId}")
|
|
self.nextValidOrderId = orderId
|
|
self.start_data_stream()
|
|
|
|
def tickPrice(self, reqId, tickType, price, attrib):
|
|
"""
|
|
Callback for receiving live price updates.
|
|
"""
|
|
logging.info(f"Tick Price. ReqId: {reqId}, TickType: {tickType}, Price: {price}")
|
|
# For simplicity, we assume tickType corresponds to the current trade price.
|
|
self.latest_price = price
|
|
self.evaluate_ai_decision()
|
|
|
|
def tickSize(self, reqId, tickType, size):
|
|
"""
|
|
Optional: Handle tick size data if needed.
|
|
"""
|
|
logging.info(f"Tick Size. ReqId: {reqId}, TickType: {tickType}, Size: {size}")
|
|
|
|
# ------------------------
|
|
# Custom Methods
|
|
# ------------------------
|
|
def start_data_stream(self):
|
|
"""
|
|
Request live market data for the specified stock.
|
|
"""
|
|
try:
|
|
# Set market data type: 1 = live, 2 = frozen, 3 = delayed.
|
|
self.reqMarketDataType(1)
|
|
contract = USStockContract(self.stock_symbol)
|
|
# The reqId here is arbitrary; ensure it's unique.
|
|
self.reqMktData(1001, contract, "", False, False, [])
|
|
logging.info(f"Started market data stream for {self.stock_symbol}.")
|
|
except Exception as e:
|
|
logging.error(f"Error starting data stream: {e}")
|
|
|
|
def evaluate_ai_decision(self):
|
|
"""
|
|
Evaluate the current market data with the PPO model and execute a trade if needed.
|
|
"""
|
|
if self.latest_price is not None:
|
|
try:
|
|
# Get action from PPO model using the latest price as observation.
|
|
action = self.ppo_model.get_action(self.latest_price)
|
|
logging.info(f"Model action: {action} for price: {self.latest_price}")
|
|
|
|
# Risk management: Only execute a trade if no trade has been executed recently.
|
|
if not self.trade_executed:
|
|
if action == "BUY":
|
|
self.execute_trade("BUY")
|
|
elif action == "SELL":
|
|
self.execute_trade("SELL")
|
|
else:
|
|
logging.info("Action HOLD: No trade executed.")
|
|
else:
|
|
logging.info("Trade already executed for this signal window. Waiting for new data.")
|
|
except Exception as e:
|
|
logging.error(f"Error evaluating AI decision: {e}")
|
|
|
|
def execute_trade(self, action):
|
|
"""
|
|
Execute a trade after performing thorough checks.
|
|
|
|
:param action: "BUY" or "SELL"
|
|
"""
|
|
if self.nextValidOrderId is None:
|
|
logging.error("Cannot execute trade: Order ID not available.")
|
|
return
|
|
|
|
# --- Risk Management & Validation ---
|
|
# Here you would include additional checks such as:
|
|
# - Account balance verification
|
|
# - Margin checks
|
|
# - Trade frequency limitations
|
|
# - Any other custom risk metrics
|
|
|
|
try:
|
|
contract = USStockContract(self.stock_symbol)
|
|
# For shares, the quantity could be defined as per your risk management rules.
|
|
quantity = 100 # Example: trade 100 shares
|
|
# Set a limit price. In a real scenario, this might be derived from the current price and desired slippage.
|
|
limitPrice = self.latest_price
|
|
order = makeLimitOrder(action, quantity, limitPrice)
|
|
# Place the order using the next valid order ID.
|
|
self.placeOrder(self.nextValidOrderId, contract, order)
|
|
logging.info(f"Placed {action} order for {quantity} shares of {self.stock_symbol} at {limitPrice}.")
|
|
# Update the order ID and mark that a trade has been executed to avoid duplicate trades.
|
|
self.nextValidOrderId += 1
|
|
self.trade_executed = True
|
|
except Exception as e:
|
|
logging.error(f"Error executing trade: {e}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Helper Function to Run the Bot in a Separate Thread
|
|
# -----------------------------------------------------------------------------
|
|
def run_loop(app):
|
|
"""
|
|
Run the IB API message loop.
|
|
"""
|
|
try:
|
|
app.run()
|
|
except Exception as e:
|
|
logging.error(f"Error in run loop: {e}")
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Main Execution Block
|
|
# -----------------------------------------------------------------------------
|
|
if __name__ == "__main__":
|
|
# Path to your pre-trained PPO model (update with your actual model file)
|
|
model_path = "ppo_model.pkl"
|
|
model = PPOModel(model_path)
|
|
|
|
# Initialize the trading bot with the PPO model and desired stock symbol.
|
|
bot = AIDrivenIBBot(model, stock_symbol="AAPL")
|
|
|
|
# Connect to IB Gateway. Update IP, port, and clientId as required.
|
|
bot.connect("127.0.0.1", 4002, clientId=202)
|
|
|
|
# Start the IB API message loop in a separate thread.
|
|
api_thread = threading.Thread(target=run_loop, args=(bot,), daemon=True)
|
|
api_thread.start()
|
|
|
|
# Keep the program running for a desired duration (e.g., 120 seconds for testing).
|
|
try:
|
|
time.sleep(120)
|
|
except KeyboardInterrupt:
|
|
logging.info("Interrupted by user.")
|
|
finally:
|
|
bot.disconnect()
|
|
logging.info("Disconnected from IB Gateway.")
|
|
|