fixed griffin's retrieval code. Doesn't retrieve accurate data for some reason
This commit is contained in:
2
IB_Gateway/.env
Normal file
2
IB_Gateway/.env
Normal file
@@ -0,0 +1,2 @@
|
||||
jacobmardian
|
||||
Griffinisgay123
|
||||
329
IB_Gateway/data.py
Normal file
329
IB_Gateway/data.py
Normal file
@@ -0,0 +1,329 @@
|
||||
import json
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from zoneinfo import ZoneInfo
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import asyncio # For new event loops in threads
|
||||
from ib_insync import IB, Future, util
|
||||
|
||||
# Try to import tabulate for a nice summary table
|
||||
try:
|
||||
from tabulate import tabulate
|
||||
HAS_TABULATE = True
|
||||
except ImportError:
|
||||
HAS_TABULATE = False
|
||||
|
||||
# --- CONFIGURATION ---
|
||||
CONFIG = {
|
||||
"MAX_WORKERS": 4, # Number of threads (concurrent workers)
|
||||
"REQUEST_DELAY": 0.2, # Delay in seconds between historical data requests
|
||||
"BASE_CLIENT_ID": 100, # Starting clientId; each thread gets a unique clientId
|
||||
}
|
||||
|
||||
# --- Custom Colored Logger Setup ---
|
||||
class ColoredFormatter(logging.Formatter):
|
||||
COLORS = {
|
||||
logging.DEBUG: "\033[36m", # Cyan
|
||||
logging.INFO: "\033[32m", # Green
|
||||
logging.WARNING: "\033[33m", # Yellow
|
||||
logging.ERROR: "\033[31m", # Red
|
||||
logging.CRITICAL: "\033[41m", # Red background
|
||||
}
|
||||
RESET = "\033[0m"
|
||||
|
||||
def format(self, record):
|
||||
color = self.COLORS.get(record.levelno, self.RESET)
|
||||
record.msg = f"{color}{record.msg}{self.RESET}"
|
||||
return super().format(record)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
handler = logging.StreamHandler()
|
||||
formatter = ColoredFormatter("%(asctime)s [%(levelname)s] %(message)s", "%Y-%m-%d %H:%M:%S")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# --- Global Time Zone ---
|
||||
TZ = ZoneInfo("US/Eastern")
|
||||
|
||||
# --- Functions ---
|
||||
def get_third_friday(year, month):
|
||||
"""Return the third Friday of the given year/month as a datetime with TZ."""
|
||||
fridays = []
|
||||
for day in range(1, 32):
|
||||
try:
|
||||
d = datetime.date(year, month, day)
|
||||
except ValueError:
|
||||
break
|
||||
if d.weekday() == 4:
|
||||
fridays.append(d)
|
||||
if len(fridays) >= 3:
|
||||
dt = datetime.datetime.combine(fridays[2], datetime.time(16, 0))
|
||||
elif fridays:
|
||||
dt = datetime.datetime.combine(fridays[-1], datetime.time(16, 0))
|
||||
else:
|
||||
dt = datetime.datetime(year, month, 1, 16, 0)
|
||||
return dt.replace(tzinfo=TZ)
|
||||
|
||||
def generate_contract_months(start_date, end_date):
|
||||
"""Generate a sorted list of contract month strings ('YYYYMM') between start_date and end_date."""
|
||||
months = [3, 6, 9, 12]
|
||||
result = []
|
||||
for year in range(start_date.year, end_date.year + 2):
|
||||
for m in months:
|
||||
dt = datetime.datetime(year, m, 1, tzinfo=TZ)
|
||||
if dt <= end_date:
|
||||
result.append(f"{year}{m:02d}")
|
||||
return sorted(set(result))
|
||||
|
||||
def get_data_chunk(ib_conn, contract, end_dt, duration_str="1 W"):
|
||||
"""Request a chunk of historical data for the given contract ending at end_dt using ib_conn."""
|
||||
try:
|
||||
bars = ib_conn.reqHistoricalData(
|
||||
contract,
|
||||
endDateTime=end_dt.strftime("%Y%m%d %H:%M:%S"),
|
||||
durationStr=duration_str,
|
||||
barSizeSetting="5 mins",
|
||||
whatToShow="TRADES",
|
||||
useRTH=False,
|
||||
formatDate=1
|
||||
)
|
||||
return bars
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving data for {contract.localSymbol} ending at {end_dt}: {e}")
|
||||
return None
|
||||
|
||||
def getMESContract(ib_conn, cm, contract_expiration):
|
||||
"""
|
||||
Try several MES contract definitions for the given contract month (cm) and expiration date.
|
||||
Return a tuple (contract, variant_used) or (None, None) if none match.
|
||||
"""
|
||||
expiration_str = contract_expiration.strftime("%Y%m%d")
|
||||
variants = []
|
||||
|
||||
# Variant 1: Full expiration date, exchange GLOBEX.
|
||||
contract1 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=expiration_str,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract1.includeExpired = True
|
||||
variants.append(("Variant 1: full expiration, GLOBEX", contract1))
|
||||
|
||||
# Variant 2: Full expiration date, exchange CME.
|
||||
contract2 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=expiration_str,
|
||||
exchange='CME',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract2.includeExpired = True
|
||||
variants.append(("Variant 2: full expiration, CME", contract2))
|
||||
|
||||
# Variant 3: Contract month, exchange GLOBEX, add tradingClass.
|
||||
contract3 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5,
|
||||
tradingClass='MES'
|
||||
)
|
||||
contract3.includeExpired = True
|
||||
variants.append(("Variant 3: contract month, GLOBEX, tradingClass", contract3))
|
||||
|
||||
# Variant 4: Contract month, exchange CME, add tradingClass.
|
||||
contract4 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
exchange='CME',
|
||||
currency='USD',
|
||||
multiplier=5,
|
||||
tradingClass='MES'
|
||||
)
|
||||
contract4.includeExpired = True
|
||||
variants.append(("Variant 4: contract month, CME, tradingClass", contract4))
|
||||
|
||||
# Variant 5: Contract month, exchange GLOBEX, with computed localSymbol.
|
||||
month_codes = {1:'F', 2:'G', 3:'H', 4:'J', 5:'K', 6:'M', 7:'N', 8:'Q', 9:'U', 10:'V', 11:'X', 12:'Z'}
|
||||
year = int(cm[:4])
|
||||
month = int(cm[4:])
|
||||
local_symbol = f"MES{month_codes.get(month, '')}{str(year)[-1]}"
|
||||
contract5 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
localSymbol=local_symbol,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract5.includeExpired = True
|
||||
variants.append(("Variant 5: contract month, GLOBEX, localSymbol", contract5))
|
||||
|
||||
for variant_desc, contract in variants:
|
||||
logger.info(f"Trying {variant_desc} for {cm} (expiration: {expiration_str})...")
|
||||
details = ib_conn.reqContractDetails(contract)
|
||||
if details:
|
||||
logger.info(f"Success with {variant_desc}: {details[0].contract}")
|
||||
return details[0].contract, variant_desc
|
||||
else:
|
||||
logger.info(f"{variant_desc} did not return any details.")
|
||||
return None, None
|
||||
|
||||
def process_month(cm, start_date, end_date, request_delay, client_id):
|
||||
"""
|
||||
Process a single contract month:
|
||||
- Create a new event loop and IB connection (with unique client_id).
|
||||
- Retrieve a valid MES contract.
|
||||
- Request historical data in chunks.
|
||||
- Return a summary dict and list of retrieved bars.
|
||||
"""
|
||||
# Create and set a new event loop in this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
summary = {
|
||||
"Month": cm,
|
||||
"Status": "",
|
||||
"Variant Used": "",
|
||||
"LocalSymbol": "",
|
||||
"Bars Retrieved": 0,
|
||||
"Reason": ""
|
||||
}
|
||||
bars_list = []
|
||||
try:
|
||||
ib_conn = IB()
|
||||
ib_conn.connect('127.0.0.1', 4002, clientId=client_id)
|
||||
except Exception as e:
|
||||
summary["Status"] = "Skipped"
|
||||
summary["Reason"] = f"Connection error: {e}"
|
||||
logger.error(f"Client {client_id}: Connection error for month {cm}: {e}")
|
||||
return summary, bars_list
|
||||
|
||||
year = int(cm[:4])
|
||||
month = int(cm[4:])
|
||||
contract_expiration = get_third_friday(year, month)
|
||||
contract, variant_used = getMESContract(ib_conn, cm, contract_expiration)
|
||||
if not contract:
|
||||
summary["Status"] = "Skipped"
|
||||
summary["Reason"] = "No valid MES contract found"
|
||||
logger.warning(f"Client {client_id}: No valid MES contract found for {cm}. Skipping.")
|
||||
ib_conn.disconnect()
|
||||
return summary, bars_list
|
||||
|
||||
summary["Status"] = "Processed"
|
||||
summary["Variant Used"] = variant_used
|
||||
summary["LocalSymbol"] = contract.localSymbol
|
||||
logger.info(f"Client {client_id}: Processing contract {contract.localSymbol} for month {cm} (expiration approx: {contract_expiration.strftime('%Y-%m-%d %H:%M:%S %Z')})")
|
||||
|
||||
contract_end = min(end_date, contract_expiration)
|
||||
current_end = contract_end
|
||||
chunk_duration = datetime.timedelta(weeks=1)
|
||||
|
||||
while current_end > start_date:
|
||||
logger.info(f"Client {client_id}: Requesting {contract.localSymbol} data ending at {current_end.strftime('%Y-%m-%d %H:%M:%S %Z')}")
|
||||
bars = get_data_chunk(ib_conn, contract, current_end, duration_str="1 W")
|
||||
time.sleep(request_delay) # Rate limiting
|
||||
if bars is None:
|
||||
logger.error(f"Client {client_id}: Error retrieving chunk; moving to next week.")
|
||||
current_end -= chunk_duration
|
||||
continue
|
||||
if not bars:
|
||||
logger.info(f"Client {client_id}: No data returned for period ending at {current_end.strftime('%Y-%m-%d %H:%M:%S %Z')}; ending requests.")
|
||||
break
|
||||
for bar in bars:
|
||||
bar_time = bar.date.astimezone(TZ)
|
||||
if start_date <= bar_time <= end_date:
|
||||
bars_list.append({
|
||||
'date': bar_time.strftime("%Y-%m-%d %H:%M:%S %Z"),
|
||||
'open': bar.open,
|
||||
'high': bar.high,
|
||||
'low': bar.low,
|
||||
'close': bar.close,
|
||||
'volume': bar.volume,
|
||||
'contract': contract.localSymbol
|
||||
})
|
||||
earliest_time = min(bar.date.astimezone(TZ) for bar in bars)
|
||||
new_end = earliest_time - datetime.timedelta(seconds=1)
|
||||
if new_end >= current_end:
|
||||
break
|
||||
current_end = new_end
|
||||
|
||||
summary["Bars Retrieved"] = len(bars_list)
|
||||
if len(bars_list) == 0:
|
||||
summary["Reason"] = "No data returned for this contract period"
|
||||
ib_conn.disconnect()
|
||||
return summary, bars_list
|
||||
|
||||
# --- Main Execution ---
|
||||
if __name__ == "__main__":
|
||||
overall_start_time = time.time()
|
||||
logger.info("Starting retrieval process...")
|
||||
|
||||
end_date = datetime.datetime.now(TZ)
|
||||
start_date = end_date - datetime.timedelta(days=3*365)
|
||||
logger.info(f"Retrieving MES data from {start_date.strftime('%Y-%m-%d %H:%M:%S %Z')} to {end_date.strftime('%Y-%m-%d %H:%M:%S %Z')}")
|
||||
|
||||
contract_months = generate_contract_months(start_date, end_date)
|
||||
logger.info(f"Contract months to process: {contract_months}")
|
||||
|
||||
all_month_summaries = []
|
||||
all_bars = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=CONFIG["MAX_WORKERS"]) as executor:
|
||||
futures = []
|
||||
for i, cm in enumerate(contract_months):
|
||||
client_id = CONFIG["BASE_CLIENT_ID"] + i
|
||||
futures.append(executor.submit(process_month, cm, start_date, end_date, CONFIG["REQUEST_DELAY"], client_id))
|
||||
|
||||
for future in as_completed(futures):
|
||||
summary, bars = future.result()
|
||||
all_month_summaries.append(summary)
|
||||
all_bars.extend(bars)
|
||||
|
||||
final_data = sorted(all_bars, key=lambda x: x['date'])
|
||||
expected_bars = int((end_date - start_date).total_seconds() / (5 * 60))
|
||||
|
||||
output_filename = "mes_5min_data.json"
|
||||
if final_data:
|
||||
try:
|
||||
with open(output_filename, "w") as f:
|
||||
json.dump(final_data, f, indent=4)
|
||||
logger.info(f"Data successfully saved to {output_filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing to JSON file: {e}")
|
||||
else:
|
||||
logger.info("No data retrieved. File not saved.")
|
||||
|
||||
# Build final summary table.
|
||||
summary_rows = []
|
||||
for s in all_month_summaries:
|
||||
summary_rows.append([s["Month"], s["Status"], s["Variant Used"], s["LocalSymbol"], s["Bars Retrieved"], s["Reason"]])
|
||||
|
||||
headers = ["Contract Month", "Status", "Variant Used", "LocalSymbol", "Bars Retrieved", "Reason"]
|
||||
if HAS_TABULATE:
|
||||
table = tabulate(summary_rows, headers=headers, tablefmt="grid")
|
||||
else:
|
||||
header_line = " | ".join(headers)
|
||||
separator = "-" * len(header_line)
|
||||
table = header_line + "\n" + separator + "\n"
|
||||
for row in summary_rows:
|
||||
table += " | ".join(str(item) for item in row) + "\n"
|
||||
|
||||
logger.info("\n" + table)
|
||||
processed = len([s for s in all_month_summaries if s["Status"] == "Processed"])
|
||||
skipped = len([s for s in all_month_summaries if s["Status"] == "Skipped"])
|
||||
logger.info(f"Total contract months processed: {len(contract_months)}")
|
||||
logger.info(f" Processed: {processed} Skipped: {skipped}")
|
||||
logger.info(f"Total bars retrieved: {len(final_data)} (expected ~{expected_bars})")
|
||||
|
||||
overall_end_time = time.time()
|
||||
runtime_sec = overall_end_time - overall_start_time
|
||||
runtime_str = str(datetime.timedelta(seconds=int(runtime_sec)))
|
||||
logger.info(f"Total runtime: {runtime_str}")
|
||||
|
||||
237
IB_Gateway/data.py.bak
Normal file
237
IB_Gateway/data.py.bak
Normal file
@@ -0,0 +1,237 @@
|
||||
import json
|
||||
import datetime
|
||||
from ib_insync import IB, Future, util
|
||||
|
||||
def get_third_friday(year, month):
|
||||
"""
|
||||
Returns the third Friday of the given year and month as a datetime,
|
||||
which is a common expiration for futures.
|
||||
"""
|
||||
fridays = []
|
||||
for day in range(1, 32):
|
||||
try:
|
||||
d = datetime.date(year, month, day)
|
||||
except ValueError:
|
||||
break
|
||||
if d.weekday() == 4: # Friday
|
||||
fridays.append(d)
|
||||
if len(fridays) >= 3:
|
||||
return datetime.datetime.combine(fridays[2], datetime.time(16, 0)) # 4:00 PM
|
||||
elif fridays:
|
||||
return datetime.datetime.combine(fridays[-1], datetime.time(16, 0))
|
||||
else:
|
||||
return datetime.datetime(year, month, 1, 16, 0)
|
||||
|
||||
def generate_contract_months(start_date, end_date):
|
||||
"""
|
||||
Generate a sorted list of contract month strings (format "YYYYMM")
|
||||
that might be active between start_date and end_date.
|
||||
MES futures are listed for quarters (Mar, Jun, Sep, Dec).
|
||||
"""
|
||||
months = [3, 6, 9, 12]
|
||||
result = []
|
||||
for year in range(start_date.year, end_date.year + 2):
|
||||
for m in months:
|
||||
dt = datetime.datetime(year, m, 1)
|
||||
if dt <= end_date:
|
||||
result.append(f"{year}{m:02d}")
|
||||
return sorted(set(result))
|
||||
|
||||
def get_data_chunk(contract, end_dt, duration_str="1 W"):
|
||||
"""
|
||||
Request a chunk of historical data for the given contract ending at end_dt.
|
||||
Returns a list of bars (or None on error).
|
||||
"""
|
||||
try:
|
||||
bars = ib.reqHistoricalData(
|
||||
contract,
|
||||
endDateTime=end_dt.strftime("%Y%m%d %H:%M:%S"),
|
||||
durationStr=duration_str,
|
||||
barSizeSetting="5 mins",
|
||||
whatToShow="TRADES",
|
||||
useRTH=True, # Regular trading hours only
|
||||
formatDate=1
|
||||
)
|
||||
return bars
|
||||
except Exception as e:
|
||||
print(f"Error retrieving data chunk for {contract.localSymbol} ending at {end_dt}: {e}")
|
||||
return None
|
||||
|
||||
def getMESContract(cm, contract_expiration):
|
||||
"""
|
||||
Attempt multiple MES contract definitions for the given contract month (cm) and expiration date.
|
||||
Returns the first valid contract (as returned by reqContractDetails) or None.
|
||||
"""
|
||||
expiration_str = contract_expiration.strftime("%Y%m%d")
|
||||
variants = []
|
||||
|
||||
# Variant 1: Use full expiration date (YYYYMMDD), exchange GLOBEX, minimal fields.
|
||||
contract1 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=expiration_str,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract1.includeExpired = True
|
||||
variants.append(("Variant 1: full expiration date, GLOBEX", contract1))
|
||||
|
||||
# Variant 2: Full expiration date, exchange CME.
|
||||
contract2 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=expiration_str,
|
||||
exchange='CME',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract2.includeExpired = True
|
||||
variants.append(("Variant 2: full expiration date, CME", contract2))
|
||||
|
||||
# Variant 3: Use contract month (YYYYMM), exchange GLOBEX, add tradingClass.
|
||||
contract3 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5,
|
||||
tradingClass='MES'
|
||||
)
|
||||
contract3.includeExpired = True
|
||||
variants.append(("Variant 3: contract month, GLOBEX, tradingClass MES", contract3))
|
||||
|
||||
# Variant 4: Use contract month, exchange CME, add tradingClass.
|
||||
contract4 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
exchange='CME',
|
||||
currency='USD',
|
||||
multiplier=5,
|
||||
tradingClass='MES'
|
||||
)
|
||||
contract4.includeExpired = True
|
||||
variants.append(("Variant 4: contract month, CME, tradingClass MES", contract4))
|
||||
|
||||
# Variant 5: Use contract month, exchange GLOBEX, with a computed localSymbol.
|
||||
month_codes = {1:'F', 2:'G', 3:'H', 4:'J', 5:'K', 6:'M', 7:'N', 8:'Q', 9:'U', 10:'V', 11:'X', 12:'Z'}
|
||||
year = int(cm[:4])
|
||||
month = int(cm[4:])
|
||||
local_symbol = f"MES{month_codes.get(month, '')}{str(year)[-1]}"
|
||||
contract5 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
localSymbol=local_symbol,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract5.includeExpired = True
|
||||
variants.append(("Variant 5: contract month, GLOBEX, localSymbol", contract5))
|
||||
|
||||
# Try each variant.
|
||||
for variant_desc, contract in variants:
|
||||
print(f"Trying {variant_desc} for contract month {cm} (expiration: {expiration_str})...")
|
||||
details = ib.reqContractDetails(contract)
|
||||
if details:
|
||||
print(f"Success with {variant_desc}: found contract details: {details[0].contract}")
|
||||
return details[0].contract
|
||||
else:
|
||||
print(f"{variant_desc} did not return contract details.")
|
||||
return None
|
||||
|
||||
# --- Main Script ---
|
||||
|
||||
# Connect to IB Gateway (ensure your account is active and market data is subscribed)
|
||||
ib = IB()
|
||||
try:
|
||||
ib.connect('127.0.0.1', 4002, clientId=1)
|
||||
except Exception as e:
|
||||
print(f"Connection error: {e}")
|
||||
exit(1)
|
||||
|
||||
# Define overall desired time range: last 3 years up until today.
|
||||
# We'll use naive datetime objects (local time)
|
||||
end_date = datetime.datetime.now()
|
||||
start_date = end_date - datetime.timedelta(days=3*365)
|
||||
|
||||
# Generate contract month strings (e.g., "202303", "202306", etc.)
|
||||
contract_months = generate_contract_months(start_date, end_date)
|
||||
print("Contract months to process:", contract_months)
|
||||
|
||||
all_data = []
|
||||
|
||||
# Process each contract month
|
||||
for cm in contract_months:
|
||||
year = int(cm[:4])
|
||||
month = int(cm[4:])
|
||||
# Compute the contract expiration date (using third Friday)
|
||||
contract_expiration = get_third_friday(year, month)
|
||||
|
||||
# Try to obtain a valid MES contract using our diagnostic function.
|
||||
mes_contract = getMESContract(cm, contract_expiration)
|
||||
if not mes_contract:
|
||||
print(f"*** No valid MES contract found for {cm}. Skipping this month. ***")
|
||||
continue
|
||||
|
||||
print(f"Processing contract {mes_contract.localSymbol} (approx expiration: {contract_expiration})")
|
||||
|
||||
# Determine the effective data period for this contract.
|
||||
contract_end = min(end_date, contract_expiration)
|
||||
current_end = contract_end
|
||||
|
||||
contract_data = []
|
||||
chunk_duration = datetime.timedelta(weeks=1)
|
||||
|
||||
# Request data in weekly chunks until we reach start_date.
|
||||
while current_end > start_date:
|
||||
print(f" Requesting {mes_contract.localSymbol} data ending at {current_end}")
|
||||
bars = get_data_chunk(mes_contract, current_end, duration_str="1 W")
|
||||
if bars is None:
|
||||
print(" Error retrieving chunk; moving to next week.")
|
||||
current_end -= chunk_duration
|
||||
continue
|
||||
if len(bars) == 0:
|
||||
print(" No data returned for this period; ending requests for this contract.")
|
||||
break
|
||||
for bar in bars:
|
||||
# Remove timezone info from bar.date so we can compare with naive datetimes.
|
||||
bar_time = bar.date.replace(tzinfo=None)
|
||||
if start_date <= bar_time <= end_date:
|
||||
contract_data.append({
|
||||
'date': bar_time.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'open': bar.open,
|
||||
'high': bar.high,
|
||||
'low': bar.low,
|
||||
'close': bar.close,
|
||||
'volume': bar.volume,
|
||||
'contract': mes_contract.localSymbol
|
||||
})
|
||||
earliest_time = min(bar.date.replace(tzinfo=None) for bar in bars)
|
||||
new_end = earliest_time - datetime.timedelta(seconds=1)
|
||||
if new_end >= current_end:
|
||||
break
|
||||
current_end = new_end
|
||||
|
||||
print(f" Retrieved {len(contract_data)} bars for contract {mes_contract.localSymbol}")
|
||||
all_data.extend(contract_data)
|
||||
|
||||
# Remove duplicate bars (if any) based on timestamp.
|
||||
unique_data = {d['date']: d for d in all_data}
|
||||
final_data = sorted(unique_data.values(), key=lambda x: x['date'])
|
||||
|
||||
expected_bars = ((end_date - start_date).total_seconds() / (5 * 60))
|
||||
if len(final_data) < expected_bars * 0.9:
|
||||
print("Warning: Retrieved data appears significantly less than expected.")
|
||||
|
||||
output_filename = "mes_5min_data.json"
|
||||
if final_data:
|
||||
try:
|
||||
with open(output_filename, "w") as f:
|
||||
json.dump(final_data, f, indent=4)
|
||||
print(f"Data successfully saved to {output_filename}")
|
||||
except Exception as e:
|
||||
print(f"Error writing to JSON file: {e}")
|
||||
else:
|
||||
print("No data retrieved. File not saved.")
|
||||
|
||||
ib.disconnect()
|
||||
|
||||
1436915
IB_Gateway/mes_5min_data.json
Normal file
1436915
IB_Gateway/mes_5min_data.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,11 @@
|
||||
ibapi
|
||||
numpy>=1.19.2
|
||||
pandas>=1.1.5
|
||||
|
||||
eventkit==1.0.3
|
||||
ib-insync==0.9.86
|
||||
ibapi @ file:///home/midas/codeWS/dependencies/IBJts/source/pythonclient
|
||||
nest-asyncio==1.6.0
|
||||
numpy==2.2.4
|
||||
pandas==2.2.3
|
||||
python-dateutil==2.9.0.post0
|
||||
pytz==2025.2
|
||||
six==1.17.0
|
||||
tabulate==0.9.0
|
||||
tzdata==2025.2
|
||||
|
||||
@@ -1,253 +0,0 @@
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
from ibapi.client import EClient
|
||||
from ibapi.wrapper import EWrapper
|
||||
from ibapi.contract import Contract
|
||||
from ibapi.order import Order
|
||||
import pickle
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging Configuration: Set up logging for debugging and operational insights.
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s %(levelname)s: %(message)s')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
def USStockContract(symbol):
|
||||
"""
|
||||
Create and return an IBKR stock contract.
|
||||
|
||||
:param symbol: Underlying stock symbol (e.g., "AAPL")
|
||||
:return: Configured IBKR Contract object for a stock
|
||||
"""
|
||||
contract = Contract()
|
||||
contract.symbol = symbol
|
||||
contract.secType = "STK"
|
||||
contract.currency = "USD"
|
||||
contract.exchange = "SMART"
|
||||
return contract
|
||||
|
||||
def makeLimitOrder(action, quantity, limitPrice):
|
||||
"""
|
||||
Create and return a limit order.
|
||||
|
||||
:param action: "BUY" or "SELL"
|
||||
:param quantity: Number of shares to trade
|
||||
:param limitPrice: The limit price for the order
|
||||
:return: Configured Order object
|
||||
"""
|
||||
order = Order()
|
||||
order.action = action
|
||||
order.orderType = "LMT"
|
||||
order.totalQuantity = quantity
|
||||
order.lmtPrice = limitPrice
|
||||
return order
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# PPO Model Integration (Placeholder)
|
||||
# -----------------------------------------------------------------------------
|
||||
class PPOModel:
|
||||
def __init__(self, model_path):
|
||||
"""
|
||||
Load a pre-trained PPO model from disk.
|
||||
|
||||
:param model_path: Path to the saved PPO model file.
|
||||
"""
|
||||
try:
|
||||
with open(model_path, 'rb') as f:
|
||||
self.model = pickle.load(f)
|
||||
logging.info("PPO model loaded successfully.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading PPO model: {e}")
|
||||
self.model = None
|
||||
|
||||
def get_action(self, observation):
|
||||
"""
|
||||
Get an action from the PPO model given the current observation.
|
||||
|
||||
:param observation: A numeric observation (e.g., current stock price)
|
||||
:return: Action decision as a string (e.g., "BUY", "SELL", "HOLD")
|
||||
"""
|
||||
try:
|
||||
# Replace the dummy logic below with your actual model inference.
|
||||
price = observation
|
||||
# For demonstration: Buy if price is below a threshold, otherwise hold.
|
||||
if price < 150:
|
||||
return "BUY"
|
||||
elif price > 170:
|
||||
return "SELL"
|
||||
else:
|
||||
return "HOLD"
|
||||
except Exception as e:
|
||||
logging.error(f"Error in model inference: {e}")
|
||||
return "HOLD"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Trading Bot Class
|
||||
# -----------------------------------------------------------------------------
|
||||
class AIDrivenIBBot(EWrapper, EClient):
|
||||
def __init__(self, ppo_model, stock_symbol="AAPL"):
|
||||
"""
|
||||
Initialize the trading bot with the PPO model and set up required variables.
|
||||
|
||||
:param ppo_model: An instance of PPOModel containing the trained model.
|
||||
:param stock_symbol: The stock symbol to trade (default is "AAPL")
|
||||
"""
|
||||
EClient.__init__(self, self)
|
||||
self.ppo_model = ppo_model
|
||||
self.stock_symbol = stock_symbol
|
||||
self.nextValidOrderId = None
|
||||
self.latest_price = None
|
||||
# Flag to avoid repeated trades within the same signal window.
|
||||
self.trade_executed = False
|
||||
|
||||
# ------------------------
|
||||
# IB API Callbacks
|
||||
# ------------------------
|
||||
def error(self, reqId, errorCode, errorString):
|
||||
"""
|
||||
Handle error messages received from IB Gateway.
|
||||
"""
|
||||
logging.error(f"Error. ReqId: {reqId}, Code: {errorCode}, Msg: {errorString}")
|
||||
|
||||
def nextValidId(self, orderId):
|
||||
"""
|
||||
Callback for receiving the next valid order ID. Starts the market data request.
|
||||
"""
|
||||
logging.info(f"Next valid order ID: {orderId}")
|
||||
self.nextValidOrderId = orderId
|
||||
self.start_data_stream()
|
||||
|
||||
def tickPrice(self, reqId, tickType, price, attrib):
|
||||
"""
|
||||
Callback for receiving live price updates.
|
||||
"""
|
||||
logging.info(f"Tick Price. ReqId: {reqId}, TickType: {tickType}, Price: {price}")
|
||||
# For simplicity, we assume tickType corresponds to the current trade price.
|
||||
self.latest_price = price
|
||||
self.evaluate_ai_decision()
|
||||
|
||||
def tickSize(self, reqId, tickType, size):
|
||||
"""
|
||||
Optional: Handle tick size data if needed.
|
||||
"""
|
||||
logging.info(f"Tick Size. ReqId: {reqId}, TickType: {tickType}, Size: {size}")
|
||||
|
||||
# ------------------------
|
||||
# Custom Methods
|
||||
# ------------------------
|
||||
def start_data_stream(self):
|
||||
"""
|
||||
Request live market data for the specified stock.
|
||||
"""
|
||||
try:
|
||||
# Set market data type: 1 = live, 2 = frozen, 3 = delayed.
|
||||
self.reqMarketDataType(1)
|
||||
contract = USStockContract(self.stock_symbol)
|
||||
# The reqId here is arbitrary; ensure it's unique.
|
||||
self.reqMktData(1001, contract, "", False, False, [])
|
||||
logging.info(f"Started market data stream for {self.stock_symbol}.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error starting data stream: {e}")
|
||||
|
||||
def evaluate_ai_decision(self):
|
||||
"""
|
||||
Evaluate the current market data with the PPO model and execute a trade if needed.
|
||||
"""
|
||||
if self.latest_price is not None:
|
||||
try:
|
||||
# Get action from PPO model using the latest price as observation.
|
||||
action = self.ppo_model.get_action(self.latest_price)
|
||||
logging.info(f"Model action: {action} for price: {self.latest_price}")
|
||||
|
||||
# Risk management: Only execute a trade if no trade has been executed recently.
|
||||
if not self.trade_executed:
|
||||
if action == "BUY":
|
||||
self.execute_trade("BUY")
|
||||
elif action == "SELL":
|
||||
self.execute_trade("SELL")
|
||||
else:
|
||||
logging.info("Action HOLD: No trade executed.")
|
||||
else:
|
||||
logging.info("Trade already executed for this signal window. Waiting for new data.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error evaluating AI decision: {e}")
|
||||
|
||||
def execute_trade(self, action):
|
||||
"""
|
||||
Execute a trade after performing thorough checks.
|
||||
|
||||
:param action: "BUY" or "SELL"
|
||||
"""
|
||||
if self.nextValidOrderId is None:
|
||||
logging.error("Cannot execute trade: Order ID not available.")
|
||||
return
|
||||
|
||||
# --- Risk Management & Validation ---
|
||||
# Here you would include additional checks such as:
|
||||
# - Account balance verification
|
||||
# - Margin checks
|
||||
# - Trade frequency limitations
|
||||
# - Any other custom risk metrics
|
||||
|
||||
try:
|
||||
contract = USStockContract(self.stock_symbol)
|
||||
# For shares, the quantity could be defined as per your risk management rules.
|
||||
quantity = 100 # Example: trade 100 shares
|
||||
# Set a limit price. In a real scenario, this might be derived from the current price and desired slippage.
|
||||
limitPrice = self.latest_price
|
||||
order = makeLimitOrder(action, quantity, limitPrice)
|
||||
# Place the order using the next valid order ID.
|
||||
self.placeOrder(self.nextValidOrderId, contract, order)
|
||||
logging.info(f"Placed {action} order for {quantity} shares of {self.stock_symbol} at {limitPrice}.")
|
||||
# Update the order ID and mark that a trade has been executed to avoid duplicate trades.
|
||||
self.nextValidOrderId += 1
|
||||
self.trade_executed = True
|
||||
except Exception as e:
|
||||
logging.error(f"Error executing trade: {e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Helper Function to Run the Bot in a Separate Thread
|
||||
# -----------------------------------------------------------------------------
|
||||
def run_loop(app):
|
||||
"""
|
||||
Run the IB API message loop.
|
||||
"""
|
||||
try:
|
||||
app.run()
|
||||
except Exception as e:
|
||||
logging.error(f"Error in run loop: {e}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Main Execution Block
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
# Path to your pre-trained PPO model (update with your actual model file)
|
||||
model_path = "ppo_model.pkl"
|
||||
model = PPOModel(model_path)
|
||||
|
||||
# Initialize the trading bot with the PPO model and desired stock symbol.
|
||||
bot = AIDrivenIBBot(model, stock_symbol="AAPL")
|
||||
|
||||
# Connect to IB Gateway. Update IP, port, and clientId as required.
|
||||
bot.connect("127.0.0.1", 4002, clientId=202)
|
||||
|
||||
# Start the IB API message loop in a separate thread.
|
||||
api_thread = threading.Thread(target=run_loop, args=(bot,), daemon=True)
|
||||
api_thread.start()
|
||||
|
||||
# Keep the program running for a desired duration (e.g., 120 seconds for testing).
|
||||
try:
|
||||
time.sleep(120)
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Interrupted by user.")
|
||||
finally:
|
||||
bot.disconnect()
|
||||
logging.info("Disconnected from IB Gateway.")
|
||||
|
||||
Reference in New Issue
Block a user