updating for download

This commit is contained in:
2025-04-12 15:37:37 +00:00
parent 4ea426689a
commit 9f63656268
3148 changed files with 4592393 additions and 48 deletions

View File

@@ -1,48 +0,0 @@
2025-03-26 03:01:52,506 - INFO - ===== Resource Statistics =====
2025-03-26 03:01:52,506 - INFO - Physical CPU Cores: 28
2025-03-26 03:01:52,506 - INFO - Logical CPU Cores: 56
2025-03-26 03:01:52,507 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
2025-03-26 03:01:52,507 - INFO - No GPUs detected.
2025-03-26 03:01:52,507 - INFO - =================================
2025-03-26 03:01:52,507 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-03-26 03:01:52,508 - INFO - Loading data from: data/MES2023Z.csv
2025-03-26 03:01:52,513 - ERROR - Unexpected error: Missing column provided to 'parse_dates': 'time'
2025-03-26 03:04:50,616 - INFO - ===== Resource Statistics =====
2025-03-26 03:04:50,616 - INFO - Physical CPU Cores: 28
2025-03-26 03:04:50,616 - INFO - Logical CPU Cores: 56
2025-03-26 03:04:50,616 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]%
2025-03-26 03:04:50,617 - INFO - No GPUs detected.
2025-03-26 03:04:50,617 - INFO - =================================
2025-03-26 03:04:50,617 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-03-26 03:04:50,618 - INFO - Loading data from: data/MES2023Z.csv
2025-03-26 03:04:50,621 - ERROR - Unexpected error: Missing column provided to 'parse_dates': 'time'
2025-03-26 03:08:02,316 - INFO - ===== Resource Statistics =====
2025-03-26 03:08:02,316 - INFO - Physical CPU Cores: 28
2025-03-26 03:08:02,316 - INFO - Logical CPU Cores: 56
2025-03-26 03:08:02,317 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
2025-03-26 03:08:02,317 - INFO - No GPUs detected.
2025-03-26 03:08:02,317 - INFO - =================================
2025-03-26 03:08:02,317 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-03-26 03:08:02,318 - INFO - Loading data from: data/MES2023Z.csv
2025-03-26 03:08:02,355 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
2025-03-26 03:08:02,383 - INFO - Data loaded and sorted successfully.
2025-03-26 03:08:02,383 - INFO - Calculating technical indicators...
2025-03-26 03:08:02,448 - INFO - Technical indicators calculated successfully.
2025-03-26 03:08:02,464 - INFO - Starting parallel feature engineering with 54 workers...
2025-03-26 03:08:03,331 - INFO - Parallel feature engineering completed.
2025-03-26 03:08:03,341 - INFO - Training sequences shape: (676, 15, 17)
2025-03-26 03:08:03,342 - INFO - Validation sequences shape: (144, 15, 17)
2025-03-26 03:08:03,342 - INFO - Testing sequences shape: (146, 15, 17)
2025-03-26 03:08:03,342 - INFO - Starting LSTM hyperparameter optimization with Optuna using 54 parallel trials...
2025-03-26 03:22:04,033 - INFO - Best LSTM Hyperparameters: {'num_lstm_layers': 2, 'lstm_units': 64, 'dropout_rate': 0.13619292923712067, 'learning_rate': 0.0030545284525912166, 'optimizer': 'Nadam', 'decay': 9.615099767236892e-05}
2025-03-26 03:22:04,553 - INFO - Training best LSTM model with optimized hyperparameters...
2025-03-26 03:24:28,296 - INFO - Evaluating final LSTM model...
2025-03-26 03:24:29,722 - INFO - Test MSE: 0.3437
2025-03-26 03:24:29,722 - INFO - Test RMSE: 0.5862
2025-03-26 03:24:29,722 - INFO - Test MAE: 0.4561
2025-03-26 03:24:29,722 - INFO - Test R2 Score: 0.8620
2025-03-26 03:24:29,722 - INFO - Directional Accuracy: 0.2759
2025-03-26 03:24:30,013 - WARNING - You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`.
2025-03-26 03:24:30,121 - INFO - Saved best LSTM model and scaler objects.
2025-03-26 03:24:30,150 - INFO - Starting PPO training...
2025-03-26 05:47:15,571 - INFO - PPO training completed and model saved.

View File

@@ -33,6 +33,8 @@ from multiprocessing import Pool, cpu_count
import threading
import time
VERSION = 1
# Suppress TensorFlow logs beyond errors
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

View File

@@ -0,0 +1,748 @@
import os
import sys
import argparse
import numpy as np
import pandas as pd
import logging
from tabulate import tabulate
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import psutil
import GPUtil
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense, Dropout, Bidirectional
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.losses import Huber
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam, Nadam
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import joblib
import optuna
from optuna.integration import KerasPruningCallback
import gym
from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from multiprocessing import Pool, cpu_count
import threading
import time
# Suppress TensorFlow logs beyond errors
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# =============================================================================
# Resource Detection Functions
# =============================================================================
def get_cpu_info():
cpu_count_physical = psutil.cpu_count(logical=False) # Physical cores
cpu_count_logical = psutil.cpu_count(logical=True) # Logical cores
cpu_percent = psutil.cpu_percent(interval=1, percpu=True)
return {
'physical_cores': cpu_count_physical,
'logical_cores': cpu_count_logical,
'cpu_percent': cpu_percent
}
def get_gpu_info():
gpus = GPUtil.getGPUs()
gpu_info = []
for gpu in gpus:
gpu_info.append({
'id': gpu.id,
'name': gpu.name,
'load': gpu.load * 100, # Convert to percentage
'memory_total': gpu.memoryTotal,
'memory_used': gpu.memoryUsed,
'memory_free': gpu.memoryFree,
'temperature': gpu.temperature
})
return gpu_info
def configure_tensorflow(cpu_stats, gpu_stats):
logical_cores = cpu_stats['logical_cores']
os.environ["OMP_NUM_THREADS"] = str(logical_cores)
os.environ["TF_NUM_INTRAOP_THREADS"] = str(logical_cores)
os.environ["TF_NUM_INTEROP_THREADS"] = str(logical_cores)
if gpu_stats:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logging.info(f"Enabled memory growth for {len(gpus)} GPU(s).")
except RuntimeError as e:
logging.error(f"TensorFlow GPU configuration error: {e}")
else:
tf.config.threading.set_intra_op_parallelism_threads(logical_cores)
tf.config.threading.set_inter_op_parallelism_threads(logical_cores)
logging.info("Configured TensorFlow to use CPU with optimized thread settings.")
def monitor_resources(interval=60):
while True:
cpu = psutil.cpu_percent(interval=1, percpu=True)
gpu = get_gpu_info()
logging.info(f"CPU Usage per Core: {cpu}%")
if gpu:
for gpu_stat in gpu:
logging.info(f"GPU {gpu_stat['id']} - {gpu_stat['name']}: Load: {gpu_stat['load']}%, "
f"Memory Used: {gpu_stat['memory_used']}MB / {gpu_stat['memory_total']}MB, "
f"Temperature: {gpu_stat['temperature']}°C")
else:
logging.info("No GPUs detected.")
logging.info("-" * 50)
time.sleep(interval)
# =============================================================================
# Data Loading & Technical Indicators
# =============================================================================
def load_data(file_path):
logging.info(f"Loading data from: {file_path}")
try:
df = pd.read_csv(file_path, parse_dates=['time'])
except FileNotFoundError:
logging.error(f"File not found: {file_path}")
sys.exit(1)
except pd.errors.ParserError as e:
logging.error(f"Error parsing CSV file: {e}")
sys.exit(1)
except Exception as e:
logging.error(f"Unexpected error: {e}")
sys.exit(1)
rename_mapping = {
'time': 'Date',
'open': 'Open',
'high': 'High',
'low': 'Low',
'close': 'Close'
}
df.rename(columns=rename_mapping, inplace=True)
logging.info(f"Data columns after renaming: {df.columns.tolist()}")
df.sort_values('Date', inplace=True)
df.reset_index(drop=True, inplace=True)
logging.info("Data loaded and sorted successfully.")
return df
def compute_rsi(series, window=14):
delta = series.diff()
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
RS = gain / (loss + 1e-9)
return 100 - (100 / (1 + RS))
def compute_macd(series, span_short=12, span_long=26, span_signal=9):
ema_short = series.ewm(span=span_short, adjust=False).mean()
ema_long = series.ewm(span=span_long, adjust=False).mean()
macd_line = ema_short - ema_long
signal_line = macd_line.ewm(span=span_signal, adjust=False).mean()
return macd_line - signal_line # histogram
def compute_obv(df):
signed_volume = (np.sign(df['Close'].diff()) * df['Volume']).fillna(0)
return signed_volume.cumsum()
def compute_adx(df, window=14):
df['H-L'] = df['High'] - df['Low']
df['H-Cp'] = (df['High'] - df['Close'].shift(1)).abs()
df['L-Cp'] = (df['Low'] - df['Close'].shift(1)).abs()
tr = df[['H-L','H-Cp','L-Cp']].max(axis=1)
tr_rolling = tr.rolling(window=window).mean()
adx_placeholder = tr_rolling / (df['Close'] + 1e-9)
df.drop(['H-L','H-Cp','L-Cp'], axis=1, inplace=True)
return adx_placeholder
def compute_bollinger_bands(series, window=20, num_std=2):
sma = series.rolling(window=window).mean()
std = series.rolling(window=window).std()
upper = sma + num_std * std
lower = sma - num_std * std
bandwidth = (upper - lower) / (sma + 1e-9)
return upper, lower, bandwidth
def compute_mfi(df, window=14):
typical_price = (df['High'] + df['Low'] + df['Close']) / 3
money_flow = typical_price * df['Volume']
prev_tp = typical_price.shift(1)
flow_pos = money_flow.where(typical_price > prev_tp, 0)
flow_neg = money_flow.where(typical_price < prev_tp, 0)
pos_sum = flow_pos.rolling(window=window).sum()
neg_sum = flow_neg.rolling(window=window).sum()
mfi = 100 - (100 / (1 + pos_sum / (neg_sum + 1e-9)))
return mfi
def calculate_technical_indicators(df):
logging.info("Calculating technical indicators...")
df['RSI'] = compute_rsi(df['Close'], 14)
df['MACD'] = compute_macd(df['Close'])
df['OBV'] = compute_obv(df)
df['ADX'] = compute_adx(df)
up, lo, bw = compute_bollinger_bands(df['Close'], 20, 2)
df['BB_Upper'] = up
df['BB_Lower'] = lo
df['BB_Width'] = bw
df['MFI'] = compute_mfi(df, 14)
df['SMA_5'] = df['Close'].rolling(5).mean()
df['SMA_10'] = df['Close'].rolling(10).mean()
df['EMA_5'] = df['Close'].ewm(span=5, adjust=False).mean()
df['EMA_10'] = df['Close'].ewm(span=10, adjust=False).mean()
df['STDDEV_5'] = df['Close'].rolling(5).std()
df.dropna(inplace=True)
logging.info("Technical indicators calculated successfully.")
return df
# =============================================================================
# Argument Parsing
# =============================================================================
def parse_arguments():
parser = argparse.ArgumentParser(description='Futures Trading with LSTM Forecasting and PPO.')
parser.add_argument('csv_path', type=str,
help='Path to CSV data with columns [time, open, high, low, close, volume].')
parser.add_argument('--lstm_window_size', type=int, default=15,
help='Sequence window size for LSTM forecasting. Default=15.')
parser.add_argument('--ppo_total_timesteps', type=int, default=100000,
help='Total timesteps to train the PPO model. Default=100000.')
parser.add_argument('--n_trials_lstm', type=int, default=30,
help='Number of Optuna trials for LSTM hyperparameter tuning. Default=30.')
parser.add_argument('--preprocess_workers', type=int, default=None,
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
parser.add_argument('--monitor_resources', action='store_true',
help='Enable real-time resource monitoring.')
parser.add_argument('--output_dir', type=str, default='output',
help='Directory where all output files will be saved.')
parser.add_argument('--action_mode', type=str, choices=['discrete', 'continuous'], default='discrete',
help='Select action space type: discrete (e.g., -5 to +5) or continuous (Box). Default=discrete.')
parser.add_argument('--max_contracts', type=int, default=5,
help='Maximum number of contracts to trade per action. Default=5.')
return parser.parse_args()
# =============================================================================
# LSTM Price Predictor (renamed from LSTM part)
# =============================================================================
def build_lstm(input_shape, hyperparams):
model = Sequential()
num_layers = hyperparams['num_lstm_layers']
units = hyperparams['lstm_units']
drop = hyperparams['dropout_rate']
for i in range(num_layers):
return_seqs = (i < num_layers - 1)
if i == 0:
model.add(Bidirectional(LSTM(units, return_sequences=return_seqs, kernel_regularizer=l2(1e-4)),
input_shape=input_shape))
else:
model.add(Bidirectional(LSTM(units, return_sequences=return_seqs, kernel_regularizer=l2(1e-4))))
model.add(Dropout(drop))
model.add(Dense(1, activation='linear'))
opt_name = hyperparams['optimizer']
lr = hyperparams['learning_rate']
decay = hyperparams['decay']
if opt_name == 'Adam':
opt = Adam(learning_rate=lr, decay=decay)
elif opt_name == 'Nadam':
opt = Nadam(learning_rate=lr)
else:
opt = Adam(learning_rate=lr)
model.compile(loss=Huber(), optimizer=opt, metrics=['mae'])
return model
# =============================================================================
# Custom Gym Environment for Futures Trading with LSTM Forecasting
# =============================================================================
class FuturesTradingEnv(gym.Env):
"""
A custom OpenAI Gym environment for futures trading.
It integrates an LSTM price predictor for forecasting.
The environment tracks positions as contracts_held (can be negative for shorts).
Reward is defined as the change in mark-to-market profit (unrealized PnL)
minus transaction costs and includes a bonus based on the LSTM forecast.
The action space can be either discrete (e.g., -max_contracts ... +max_contracts)
or continuous (Box space) which is then rounded to an integer.
"""
metadata = {'render.modes': ['human']}
def __init__(self, df, feature_columns, lstm_model, scaler_features, scaler_target,
window_size=15, transaction_cost=0.001, action_mode='discrete', max_contracts=5):
super(FuturesTradingEnv, self).__init__()
self.df = df.reset_index(drop=True)
self.feature_columns = feature_columns
self.lstm_model = lstm_model # Frozen LSTM model for forecasting
self.scaler_features = scaler_features
self.scaler_target = scaler_target
self.window_size = window_size
self.transaction_cost = transaction_cost
self.action_mode = action_mode
self.max_contracts = max_contracts
self.max_steps = len(df)
self.current_step = 0
# Futures position variables
self.contracts_held = 0 # positive for long, negative for short
self.entry_price = None # weighted average entry price
# Pre-calculate normalized features for observations
self.raw_features = df[feature_columns].values
# Define action space
if self.action_mode == 'discrete':
# Actions: integer orders from -max_contracts to +max_contracts
self.action_space = spaces.Discrete(2 * self.max_contracts + 1)
else:
# Continuous action: a real number in [-max_contracts, max_contracts]
self.action_space = spaces.Box(low=-self.max_contracts, high=self.max_contracts, shape=(1,), dtype=np.float32)
# Observation space: technical indicators + [normalized contracts_held, normalized unrealized PnL] + LSTM forecast
obs_len = len(feature_columns) + 2 + 1
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_len,), dtype=np.float32)
# Lock for LSTM prediction (if used in multi-threaded settings)
self.lstm_lock = threading.Lock()
# Cache for the LSTM forecast for reward shaping
self.last_forecast = 0.0
def reset(self):
self.current_step = 0
self.contracts_held = 0
self.entry_price = None
self.last_forecast = 0.0
return self._get_obs()
def _get_obs(self):
# Normalize the raw features row by row
row = self.raw_features[self.current_step]
row_max = np.max(np.abs(row)) if np.max(np.abs(row)) != 0 else 1.0
row_norm = row / row_max
# Additional account info:
# Normalize contracts_held by max_contracts.
# Unrealized PnL: if entry_price exists, (current_price - entry_price)*contracts_held; else 0.
current_price = self.df.loc[self.current_step, 'Close']
pnl = (current_price - self.entry_price) * self.contracts_held if self.entry_price is not None else 0.0
additional = np.array([
self.contracts_held / self.max_contracts,
pnl # Consider normalizing pnl as needed
], dtype=np.float32)
# LSTM price forecast: predict next price if possible
if self.current_step < self.window_size:
forecast = 0.0
else:
seq = self.raw_features[self.current_step - self.window_size: self.current_step]
seq_scaled = self.scaler_features.transform(seq)
seq_scaled = np.expand_dims(seq_scaled, axis=0) # shape: (1, window_size, num_features)
with self.lstm_lock:
pred_scaled = self.lstm_model.predict(seq_scaled, verbose=0).flatten()[0]
pred_scaled = np.clip(pred_scaled, 0, 1)
unscaled = self.scaler_target.inverse_transform([[pred_scaled]])[0, 0]
# Forecast as relative difference from the current price
forecast = (unscaled - current_price) / (current_price + 1e-9)
# Cache the forecast for use in the reward function
self.last_forecast = forecast
obs = np.concatenate([row_norm, additional, [forecast]]).astype(np.float32)
return obs
def step(self, action):
prev_price = self.df.loc[self.current_step, 'Close']
prev_position = self.contracts_held
# Convert action to an integer number of contracts
if self.action_mode == 'discrete':
# Discrete action space: 0 corresponds to -max_contracts, last corresponds to +max_contracts.
action_int = action - self.max_contracts
else:
# For continuous, round to nearest integer
action_int = int(np.round(action[0]))
action_int = np.clip(action_int, -self.max_contracts, self.max_contracts)
current_price = self.df.loc[self.current_step, 'Close']
fee = self.transaction_cost * abs(action_int) * current_price
# Update position logic
if action_int != 0:
# If no current position, set new position and record entry price.
if self.contracts_held == 0:
self.contracts_held = action_int
self.entry_price = current_price
# If same sign, update weighted average entry price.
elif np.sign(self.contracts_held) == np.sign(action_int):
total_contracts = self.contracts_held + action_int
self.entry_price = (self.entry_price * self.contracts_held + current_price * action_int) / total_contracts
self.contracts_held = total_contracts
# If opposite sign, reduce/flip position.
else:
if abs(action_int) >= abs(self.contracts_held):
self.contracts_held = self.contracts_held + action_int # may flip sign
self.entry_price = current_price if self.contracts_held != 0 else None
else:
self.contracts_held = self.contracts_held + action_int
# Mark-to-market PnL: change from previous price * previous position
pnl_change = (current_price - prev_price) * prev_position
reward = pnl_change - fee
# Bonus reward based on forecast and the chosen action.
bonus_factor = 10.0 # Scale factor; tune as needed.
bonus = bonus_factor * (action_int * self.last_forecast)
reward += bonus
self.current_step += 1
done = (self.current_step >= self.max_steps - 1)
obs = self._get_obs()
return obs, reward, done, {}
def render(self, mode='human'):
current_price = self.df.loc[self.current_step, 'Close']
pnl = (current_price - self.entry_price) * self.contracts_held if self.entry_price is not None else 0.0
print(f"Step: {self.current_step}, Contracts Held: {self.contracts_held}, "
f"Entry Price: {self.entry_price}, Current Price: {current_price:.2f}, PnL: {pnl:.2f}")
# =============================================================================
# Placeholders for Live Deployment Functions
# =============================================================================
def get_live_data():
"""
Placeholder: Connect to a live data feed and return the latest market data.
"""
return None
def execute_order(action):
"""
Placeholder: Execute the trading order in a live environment.
"""
logging.info(f"Executing order: {action}")
def live_trading_loop(model, env, polling_interval=5):
obs = env.reset()
done = False
while not done:
live_data = get_live_data()
if live_data is not None:
pass # Update environment with live data as needed
action, _ = model.predict(obs, deterministic=True)
execute_order(action)
obs, reward, done, _ = env.step(action)
env.render()
time.sleep(polling_interval)
# =============================================================================
# Data Preprocessing with Parallelization
# =============================================================================
def parallel_feature_engineering(row):
"""
Placeholder function for additional feature engineering.
"""
return row
def feature_engineering_parallel(df, num_workers):
logging.info(f"Starting parallel feature engineering with {num_workers} workers...")
with Pool(processes=num_workers) as pool:
processed_rows = pool.map(parallel_feature_engineering, [row for _, row in df.iterrows()])
df_processed = pd.DataFrame(processed_rows)
logging.info("Parallel feature engineering completed.")
return df_processed
# =============================================================================
# MAIN FUNCTION: LSTM Training + PPO for Futures Trading
# =============================================================================
def main():
args = parse_arguments()
csv_path = args.csv_path
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
lstm_window_size = args.lstm_window_size
ppo_total_timesteps = args.ppo_total_timesteps
n_trials_lstm = args.n_trials_lstm
preprocess_workers = args.preprocess_workers
enable_resource_monitor = args.monitor_resources
action_mode = args.action_mode
max_contracts = args.max_contracts
# -----------------------------
# Setup Logging
# -----------------------------
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.join(output_dir, "FuturesPPO.log")),
logging.StreamHandler(sys.stdout)
])
# -----------------------------
# Resource Detection & Logging
# -----------------------------
cpu_stats = get_cpu_info()
gpu_stats = get_gpu_info()
logging.info("===== Resource Statistics =====")
logging.info(f"Physical CPU Cores: {cpu_stats['physical_cores']}")
logging.info(f"Logical CPU Cores: {cpu_stats['logical_cores']}")
logging.info(f"CPU Usage per Core: {cpu_stats['cpu_percent']}%")
if gpu_stats:
for gpu in gpu_stats:
logging.info(f"GPU {gpu['id']} - {gpu['name']}: Load: {gpu['load']}%, Memory Used: {gpu['memory_used']}MB/{gpu['memory_total']}MB, Temperature: {gpu['temperature']}°C")
else:
logging.info("No GPUs detected.")
logging.info("=================================")
# -----------------------------
# Configure TensorFlow
# -----------------------------
configure_tensorflow(cpu_stats, gpu_stats)
# -----------------------------
# Start Resource Monitoring (Optional)
# -----------------------------
if enable_resource_monitor:
logging.info("Starting real-time resource monitoring...")
resource_monitor_thread = threading.Thread(target=monitor_resources, args=(60,), daemon=True)
resource_monitor_thread.start()
##########################################
# A) LSTM PART: LOAD, PREPROCESS & TUNE
##########################################
df = load_data(csv_path)
df = calculate_technical_indicators(df)
feature_columns = [
'SMA_5','SMA_10','EMA_5','EMA_10','STDDEV_5',
'RSI','MACD','ADX','OBV','Volume','Open','High','Low',
'BB_Upper','BB_Lower','BB_Width','MFI'
]
target_column = 'Close'
df = df[['Date'] + feature_columns + [target_column]].dropna()
# 2) Controlled Parallel Data Preprocessing
if preprocess_workers is None:
preprocess_workers = max(1, cpu_stats['logical_cores'] - 2)
else:
preprocess_workers = min(preprocess_workers, cpu_stats['logical_cores'])
df = feature_engineering_parallel(df, num_workers=preprocess_workers)
scaler_features = MinMaxScaler()
scaler_target = MinMaxScaler()
X_all = df[feature_columns].values
y_all = df[[target_column]].values
X_scaled = scaler_features.fit_transform(X_all)
y_scaled = scaler_target.fit_transform(y_all).flatten()
# 3) Create sequences for LSTM forecasting
def create_sequences(features, target, window_size):
X_seq, y_seq = [], []
for i in range(len(features) - window_size):
X_seq.append(features[i:i+window_size])
y_seq.append(target[i+window_size])
return np.array(X_seq), np.array(y_seq)
X, y = create_sequences(X_scaled, y_scaled, lstm_window_size)
# 4) Split into train/val/test
train_size = int(len(X) * 0.7)
val_size = int(len(X) * 0.15)
test_size = len(X) - train_size - val_size
X_train, y_train = X[:train_size], y[:train_size]
X_val, y_val = X[train_size: train_size + val_size], y[train_size: train_size + val_size]
X_test, y_test = X[train_size + val_size:], y[train_size + val_size:]
logging.info(f"Training sequences shape: {X_train.shape}")
logging.info(f"Validation sequences shape: {X_val.shape}")
logging.info(f"Testing sequences shape: {X_test.shape}")
# 5) Define LSTM objective for hyperparameter tuning using Optuna
def lstm_objective(trial):
num_lstm_layers = trial.suggest_int('num_lstm_layers', 1, 3)
lstm_units = trial.suggest_categorical('lstm_units', [32, 64, 96, 128])
dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'Nadam'])
decay = trial.suggest_float('decay', 0.0, 1e-4)
hyperparams = {
'num_lstm_layers': num_lstm_layers,
'lstm_units': lstm_units,
'dropout_rate': dropout_rate,
'learning_rate': learning_rate,
'optimizer': optimizer_name,
'decay': decay
}
model_ = build_lstm((X_train.shape[1], X_train.shape[2]), hyperparams)
early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
lr_reduce = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
cb_prune = KerasPruningCallback(trial, 'val_loss')
history = model_.fit(
X_train, y_train,
epochs=100,
batch_size=16,
validation_data=(X_val, y_val),
callbacks=[early_stop, lr_reduce, cb_prune],
verbose=0
)
val_mae = min(history.history['val_mae'])
return val_mae
logging.info(f"Starting LSTM hyperparameter optimization with {cpu_stats['logical_cores']-2} parallel trials...")
study_lstm = optuna.create_study(direction='minimize')
study_lstm.optimize(lstm_objective, n_trials=n_trials_lstm, n_jobs=cpu_stats['logical_cores']-2)
best_lstm_params = study_lstm.best_params
logging.info(f"Best LSTM Hyperparameters: {best_lstm_params}")
# 6) Train final LSTM (PricePredictorLSTM) with best hyperparameters
final_lstm = build_lstm((X_train.shape[1], X_train.shape[2]), best_lstm_params)
early_stop_final = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
lr_reduce_final = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
logging.info("Training best LSTM model with optimized hyperparameters...")
final_lstm.fit(
X_train, y_train,
epochs=300,
batch_size=16,
validation_data=(X_val, y_val),
callbacks=[early_stop_final, lr_reduce_final],
verbose=1
)
# 7) Evaluate final LSTM
def evaluate_final_lstm(model, X_test, y_test):
logging.info("Evaluating final LSTM model...")
y_pred_scaled = model.predict(X_test).flatten()
y_pred_scaled = np.clip(y_pred_scaled, 0, 1)
y_pred = scaler_target.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
y_test_actual = scaler_target.inverse_transform(y_test.reshape(-1, 1)).flatten()
mse_ = mean_squared_error(y_test_actual, y_pred)
rmse_ = np.sqrt(mse_)
mae_ = mean_absolute_error(y_test_actual, y_pred)
r2_ = r2_score(y_test_actual, y_pred)
direction_actual = np.sign(np.diff(y_test_actual))
direction_pred = np.sign(np.diff(y_pred))
directional_accuracy = np.mean(direction_actual == direction_pred)
logging.info(f"Test MSE: {mse_:.4f}")
logging.info(f"Test RMSE: {rmse_:.4f}")
logging.info(f"Test MAE: {mae_:.4f}")
logging.info(f"Test R2 Score: {r2_:.4f}")
logging.info(f"Directional Accuracy: {directional_accuracy:.4f}")
plt.figure(figsize=(14, 7))
plt.plot(y_test_actual, label='Actual Price')
plt.plot(y_pred, label='Predicted Price')
plt.title('LSTM: Actual vs Predicted Closing Prices')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, 'lstm_actual_vs_pred.png'))
plt.close()
table = []
limit = min(40, len(y_test_actual))
for i in range(limit):
table.append([i, round(y_test_actual[i], 2), round(y_pred[i], 2)])
headers = ["Index", "Actual Price", "Predicted Price"]
print("\nFirst 40 Actual vs. Predicted Prices:")
print(tabulate(table, headers=headers, tablefmt="pretty"))
return r2_, directional_accuracy
_r2, _diracc = evaluate_final_lstm(final_lstm, X_test, y_test)
# 8) Save final LSTM model and scalers
final_lstm.save(os.path.join(output_dir, 'best_lstm_model.h5'))
joblib.dump(scaler_features, os.path.join(output_dir, 'scaler_features.pkl'))
joblib.dump(scaler_target, os.path.join(output_dir, 'scaler_target.pkl'))
logging.info("Saved best LSTM model and scaler objects.")
##########################################
# B) PPO PART: SET UP FUTURES TRADING ENVIRONMENT
##########################################
env_params = {
'df': df,
'feature_columns': feature_columns,
'lstm_model': final_lstm, # Frozen LSTM for forecasting
'scaler_features': scaler_features,
'scaler_target': scaler_target,
'window_size': lstm_window_size,
'transaction_cost': 0.001,
'action_mode': action_mode,
'max_contracts': max_contracts
}
# Create the FuturesTradingEnv and wrap it for PPO training
env = FuturesTradingEnv(**env_params)
vec_env = DummyVecEnv([lambda: env])
# PPO hyperparameters (customize as needed)
ppo_hyperparams = {
'n_steps': 2048,
'batch_size': 64,
'gae_lambda': 0.95,
'gamma': 0.99,
'learning_rate': 3e-4,
'ent_coef': 0.0,
'verbose': 1
}
# -----------------------------
# Train PPO Model
# -----------------------------
logging.info("Starting PPO training...")
ppo_model = PPO('MlpPolicy', vec_env, **ppo_hyperparams)
ppo_model.learn(total_timesteps=ppo_total_timesteps)
ppo_model.save(os.path.join(output_dir, "best_ppo_model.zip"))
logging.info("PPO training completed and model saved.")
##########################################
# C) FINAL INFERENCE & (Optional) LIVE TRADING EXAMPLE
##########################################
# Evaluate the trained PPO model in the environment
obs = env.reset()
done = False
total_reward = 0.0
step_data = []
step_count = 0
while not done:
step_count += 1
action, _ = ppo_model.predict(obs, deterministic=True)
obs, reward, done, _ = env.step(action)
total_reward += reward
step_data.append({
"Step": step_count,
"Action": int(action) if action_mode=='discrete' else int(np.round(action[0])),
"Reward": reward,
"Contracts": env.contracts_held
})
final_pnl = (env.df.loc[env.current_step, 'Close'] - (env.entry_price if env.entry_price is not None else 0)) * env.contracts_held
print("\n=== Final PPO Inference ===")
print(f"Total Steps: {step_count}")
print(f"Final Contracts Held: {env.contracts_held}")
print(f"Final Estimated PnL: {final_pnl:.2f}")
print(f"Total Reward Sum: {total_reward:.2f}")
print("\nLast 15 Steps:")
last_n = step_data[-15:] if len(step_data) > 15 else step_data
print(tabulate(last_n, headers="keys", tablefmt="pretty"))
# OPTIONAL: Uncomment to run a live trading loop (requires implementation of live data feed and order execution)
# live_trading_loop(ppo_model, env)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,721 @@
import os
import sys
import argparse
import numpy as np
import pandas as pd
import logging
from tabulate import tabulate
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import psutil
import GPUtil
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense, Dropout, Bidirectional
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.losses import Huber
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam, Nadam
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import joblib
import optuna
from optuna.integration import KerasPruningCallback
import gym
from gym import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from multiprocessing import Pool, cpu_count
import threading
import time
from dateutil import parser # For custom date parsing
# Suppress TensorFlow logs beyond errors
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# =============================================================================
# Custom Date Parser: handles EDT/EST then drops timezone info
# =============================================================================
def custom_date_parser(date_str):
tzinfos = {"EDT": -14400, "EST": -18000}
dt = parser.parse(date_str, tzinfos=tzinfos)
return dt.replace(tzinfo=None)
# =============================================================================
# Resource Detection Functions
# =============================================================================
def get_cpu_info():
cpu_count_physical = psutil.cpu_count(logical=False)
cpu_count_logical = psutil.cpu_count(logical=True)
cpu_percent = psutil.cpu_percent(interval=1, percpu=True)
return {
'physical_cores': cpu_count_physical,
'logical_cores': cpu_count_logical,
'cpu_percent': cpu_percent
}
def get_gpu_info():
gpus = GPUtil.getGPUs()
gpu_info = []
for gpu in gpus:
gpu_info.append({
'id': gpu.id,
'name': gpu.name,
'load': gpu.load * 100,
'memory_total': gpu.memoryTotal,
'memory_used': gpu.memoryUsed,
'memory_free': gpu.memoryFree,
'temperature': gpu.temperature
})
return gpu_info
def configure_tensorflow(cpu_stats, gpu_stats):
logical_cores = cpu_stats['logical_cores']
os.environ["OMP_NUM_THREADS"] = str(logical_cores)
os.environ["TF_NUM_INTRAOP_THREADS"] = str(logical_cores)
os.environ["TF_NUM_INTEROP_THREADS"] = str(logical_cores)
if gpu_stats:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logging.info(f"Enabled memory growth for {len(gpus)} GPU(s).")
except RuntimeError as e:
logging.error(f"TensorFlow GPU configuration error: {e}")
else:
tf.config.threading.set_intra_op_parallelism_threads(logical_cores)
tf.config.threading.set_inter_op_parallelism_threads(logical_cores)
logging.info("Configured TensorFlow to use CPU with optimized thread settings.")
def monitor_resources(interval=60):
while True:
cpu = psutil.cpu_percent(interval=1, percpu=True)
gpu = get_gpu_info()
logging.info(f"CPU Usage per Core: {cpu}%")
if gpu:
for gpu_stat in gpu:
logging.info(f"GPU {gpu_stat['id']} - {gpu_stat['name']}: Load: {gpu_stat['load']}%, "
f"Memory Used: {gpu_stat['memory_used']}MB/{gpu_stat['memory_total']}MB, "
f"Temperature: {gpu_stat['temperature']}°C")
else:
logging.info("No GPUs detected.")
logging.info("-" * 50)
time.sleep(interval)
# =============================================================================
# Data Loading & Technical Indicators
# =============================================================================
def load_data(file_path):
"""
Loads data from a CSV or JSON file.
Expects a time column plus: open, high, low, close, volume.
Uses a custom date parser and flexible column mapping.
"""
logging.info(f"Loading data from: {file_path}")
try:
if file_path.lower().endswith('.csv'):
df = pd.read_csv(file_path, parse_dates=['time'], date_parser=custom_date_parser)
elif file_path.lower().endswith('.json'):
df = pd.read_json(file_path, convert_dates=False)
# Try to find a time/date column
time_col = None
for col in df.columns:
if col.strip().lower() in ['time', 'date']:
time_col = col
break
if time_col:
df[time_col] = df[time_col].apply(custom_date_parser)
else:
logging.error("No time column found in JSON data.")
sys.exit(1)
else:
logging.error("Unsupported file format. Please provide CSV or JSON.")
sys.exit(1)
except Exception as e:
logging.error(f"Error loading file: {e}")
sys.exit(1)
# Clean and rename columns
df.columns = [col.strip() for col in df.columns]
lower_cols = [col.lower() for col in df.columns]
required_fields = {
'time': ['time', 'date'],
'open': ['open'],
'high': ['high'],
'low': ['low'],
'close': ['close'],
'volume': ['volume', 'vol']
}
rename_mapping = {}
for canonical, alternatives in required_fields.items():
found = False
for alt in alternatives:
if alt in lower_cols:
orig_col = df.columns[lower_cols.index(alt)]
if canonical == 'time':
rename_mapping[orig_col] = 'Date'
else:
rename_mapping[orig_col] = canonical.capitalize()
found = True
break
if not found:
logging.error(f"Required column for '{canonical}' not found. Alternatives: {alternatives}")
sys.exit(1)
df.rename(columns=rename_mapping, inplace=True)
logging.info(f"Columns after renaming: {df.columns.tolist()}")
try:
df.sort_values('Date', inplace=True)
except KeyError:
logging.error("Column 'Date' not found after renaming. Check data format.")
sys.exit(1)
df.reset_index(drop=True, inplace=True)
logging.info("Data loaded and sorted successfully.")
return df
def compute_rsi(series, window=14):
delta = series.diff()
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
RS = gain / (loss + 1e-9)
return 100 - (100 / (1 + RS))
def compute_macd(series, span_short=12, span_long=26, span_signal=9):
ema_short = series.ewm(span=span_short, adjust=False).mean()
ema_long = series.ewm(span=span_long, adjust=False).mean()
macd_line = ema_short - ema_long
signal_line = macd_line.ewm(span=span_signal, adjust=False).mean()
return macd_line - signal_line
def compute_obv(df):
signed_volume = (np.sign(df['Close'].diff()) * df['Volume']).fillna(0)
return signed_volume.cumsum()
def compute_adx(df, window=14):
df['H-L'] = df['High'] - df['Low']
df['H-Cp'] = (df['High'] - df['Close'].shift(1)).abs()
df['L-Cp'] = (df['Low'] - df['Close'].shift(1)).abs()
tr = df[['H-L','H-Cp','L-Cp']].max(axis=1)
tr_rolling = tr.rolling(window=window).mean()
adx_placeholder = tr_rolling / (df['Close'] + 1e-9)
df.drop(['H-L','H-Cp','L-Cp'], axis=1, inplace=True)
return adx_placeholder
def compute_bollinger_bands(series, window=20, num_std=2):
sma = series.rolling(window=window).mean()
std = series.rolling(window=window).std()
upper = sma + num_std * std
lower = sma - num_std * std
bandwidth = (upper - lower) / (sma + 1e-9)
return upper, lower, bandwidth
def compute_mfi(df, window=14):
typical_price = (df['High'] + df['Low'] + df['Close']) / 3
money_flow = typical_price * df['Volume']
prev_tp = typical_price.shift(1)
flow_pos = money_flow.where(typical_price > prev_tp, 0)
flow_neg = money_flow.where(typical_price < prev_tp, 0)
pos_sum = flow_pos.rolling(window=window).sum()
neg_sum = flow_neg.rolling(window=window).sum()
mfi = 100 - (100 / (1 + pos_sum / (neg_sum + 1e-9)))
return mfi
def calculate_technical_indicators(df):
logging.info("Calculating technical indicators...")
df['RSI'] = compute_rsi(df['Close'], 14)
df['MACD'] = compute_macd(df['Close'])
df['OBV'] = compute_obv(df)
df['ADX'] = compute_adx(df)
up, lo, bw = compute_bollinger_bands(df['Close'], 20, 2)
df['BB_Upper'] = up
df['BB_Lower'] = lo
df['BB_Width'] = bw
df['MFI'] = compute_mfi(df, 14)
df['SMA_5'] = df['Close'].rolling(5).mean()
df['SMA_10'] = df['Close'].rolling(10).mean()
df['EMA_5'] = df['Close'].ewm(span=5, adjust=False).mean()
df['EMA_10'] = df['Close'].ewm(span=10, adjust=False).mean()
df['STDDEV_5'] = df['Close'].rolling(5).std()
df.dropna(inplace=True)
logging.info("Technical indicators calculated successfully.")
return df
# =============================================================================
# Argument Parsing
# =============================================================================
def parse_arguments():
parser = argparse.ArgumentParser(description='Futures Trading with LSTM and PPO')
parser.add_argument('data_path', type=str,
help='Path to CSV or JSON file with columns [time, open, high, low, close, volume].')
parser.add_argument('--lstm_window_size', type=int, default=15,
help='Sequence window size for LSTM forecasting. Default=15.')
parser.add_argument('--ppo_total_timesteps', type=int, default=100000,
help='Total timesteps to train the PPO model. Default=100000.')
parser.add_argument('--n_trials_lstm', type=int, default=30,
help='Number of Optuna trials for LSTM hyperparameter tuning. Default=30.')
parser.add_argument('--preprocess_workers', type=int, default=None,
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
parser.add_argument('--monitor_resources', action='store_true',
help='Enable real-time resource monitoring.')
parser.add_argument('--output_dir', type=str, default='output',
help='Directory where output files will be saved.')
parser.add_argument('--action_mode', type=str, choices=['discrete', 'continuous'], default='discrete',
help='Action space type. Default=discrete.')
parser.add_argument('--max_contracts', type=int, default=5,
help='Maximum number of contracts per action. Default=5.')
parser.add_argument('--initial_balance', type=float, default=10000,
help='Initial account balance. Default=10000.')
parser.add_argument('--stop_loss_points', type=float, default=20,
help='Stop loss in points. Default=20.')
parser.add_argument('--trailing_stop_points', type=float, default=5,
help='Trailing stop in points. Default=5.')
parser.add_argument('--min_risk_reward', type=float, default=2.0,
help='Minimum risk-reward ratio for trade entry. Default=2.0.')
return parser.parse_args()
# =============================================================================
# LSTM Price Predictor
# =============================================================================
def build_lstm(input_shape, hyperparams):
model = Sequential()
num_layers = hyperparams['num_lstm_layers']
units = hyperparams['lstm_units']
dropout = hyperparams['dropout_rate']
for i in range(num_layers):
return_sequences = (i < num_layers - 1)
if i == 0:
model.add(Bidirectional(LSTM(units, return_sequences=return_sequences, kernel_regularizer=l2(1e-4)),
input_shape=input_shape))
else:
model.add(Bidirectional(LSTM(units, return_sequences=return_sequences, kernel_regularizer=l2(1e-4))))
model.add(Dropout(dropout))
model.add(Dense(1, activation='linear'))
opt_name = hyperparams['optimizer']
lr = hyperparams['learning_rate']
decay = hyperparams['decay']
if opt_name == 'Adam':
opt = Adam(learning_rate=lr, decay=decay)
elif opt_name == 'Nadam':
opt = Nadam(learning_rate=lr)
else:
opt = Adam(learning_rate=lr)
model.compile(loss=Huber(), optimizer=opt, metrics=['mae'])
return model
# =============================================================================
# Custom Gym Environment for Futures Trading with LSTM Forecasting and Risk Management
# =============================================================================
class FuturesTradingEnv(gym.Env):
"""
A Gym environment that incorporates:
- LSTM forecasting (with forecast cached for reward shaping)
- Futures trading with risk management (stop loss, trailing stop, position sizing)
- Bonus reward when the action aligns with the LSTM forecast
It supports both discrete and continuous action spaces.
"""
metadata = {'render.modes': ['human']}
def __init__(self, df, feature_columns, lstm_model, scaler_features, scaler_target,
window_size=15, transaction_cost=0.60, action_mode='discrete', max_contracts=5,
initial_balance=10000, max_risk_per_trade=0.02, stop_loss_points=20,
trailing_stop_points=5, min_risk_reward=2.0, max_daily_loss=0.05):
super(FuturesTradingEnv, self).__init__()
self.df = df.reset_index(drop=True)
self.feature_columns = feature_columns
self.lstm_model = lstm_model
self.scaler_features = scaler_features
self.scaler_target = scaler_target
self.window_size = window_size
self.transaction_cost = transaction_cost
self.action_mode = action_mode
self.max_contracts = max_contracts
# Futures-specific parameters
self.multiplier = 5 # e.g. contract multiplier
self.tick_size = 0.25
# Risk management and account parameters
self.initial_balance = initial_balance
self.balance = initial_balance
self.max_risk_per_trade = max_risk_per_trade
self.stop_loss_points = stop_loss_points
self.trailing_stop_points = trailing_stop_points
self.min_risk_reward = min_risk_reward
self.max_daily_loss = max_daily_loss
self.daily_loss = 0.0
self.max_steps = len(df)
self.current_step = 0
self.contracts_held = 0
self.entry_price = None
self.peak_price = None
self.trade_history = []
self.daily_balances = [initial_balance]
self.current_day = 0
self.raw_features = df[feature_columns].values
# Define action space
if self.action_mode == 'discrete':
self.action_space = spaces.Discrete(2 * self.max_contracts + 1)
else:
self.action_space = spaces.Box(low=-self.max_contracts, high=self.max_contracts, shape=(1,), dtype=np.float32)
# Observation: normalized features, account info, and cached LSTM forecast
obs_len = len(feature_columns) + 4 # account info: balance ratio, normalized position, pnl, peak diff
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_len + 1,), dtype=np.float32)
self.lstm_lock = threading.Lock()
# Cache for the forecast (used for reward shaping)
self.last_forecast = 0.0
def reset(self):
self.current_step = 0
self.contracts_held = 0
self.entry_price = None
self.peak_price = None
self.balance = self.initial_balance
self.daily_loss = 0.0
self.trade_history = []
self.daily_balances = [self.initial_balance]
self.current_day = 0
self.last_forecast = 0.0
return self._get_obs()
def _get_obs(self):
# Normalize current feature row
row = self.raw_features[self.current_step]
row_max = np.max(np.abs(row)) if np.max(np.abs(row)) != 0 else 1.0
row_norm = row / row_max
# Account info: balance ratio, normalized position, unrealized pnl, peak difference
current_price = self.df.loc[self.current_step, 'Close']
pnl = (current_price - self.entry_price) * self.contracts_held * self.multiplier if self.entry_price is not None else 0.0
peak_diff = (current_price - self.peak_price) if self.peak_price is not None else 0.0
additional = np.array([
self.balance / self.initial_balance,
self.contracts_held / self.max_contracts,
pnl,
peak_diff
], dtype=np.float32)
# LSTM forecast: predict next price difference (cached for reward shaping)
if self.current_step < self.window_size:
forecast = 0.0
else:
seq = self.raw_features[self.current_step - self.window_size:self.current_step]
seq_scaled = self.scaler_features.transform(seq)
seq_scaled = np.expand_dims(seq_scaled, axis=0)
with self.lstm_lock:
pred_scaled = self.lstm_model.predict(seq_scaled, verbose=0).flatten()[0]
# Inverse scale to obtain predicted price and subtract current price
predicted_price = self.scaler_target.inverse_transform([[pred_scaled]])[0, 0]
forecast = predicted_price - current_price
self.last_forecast = forecast
return np.concatenate([row_norm, additional, [forecast]]).astype(np.float32)
def _calculate_position_size(self, action_int):
# Calculate the maximum number of contracts you can trade based on risk
risk_amount = self.initial_balance * self.max_risk_per_trade
risk_per_contract = self.stop_loss_points * self.multiplier
max_contracts_risk = int(risk_amount / risk_per_contract) if risk_per_contract > 0 else 1
return min(abs(action_int), max(1, max_contracts_risk), self.max_contracts) * np.sign(action_int)
def step(self, action):
start_balance = self.balance
current_price = self.df.loc[self.current_step, 'Close']
prev_position = self.contracts_held
# Convert action to integer (discrete mapping)
if self.action_mode == 'discrete':
action_int = action - self.max_contracts
else:
action_int = int(np.round(action[0]))
action_int = np.clip(action_int, -self.max_contracts, self.max_contracts)
# Use the cached forecast for risk/reward decision
forecast = self.last_forecast
risk = self.stop_loss_points
min_required_reward = self.min_risk_reward * risk
# If forecast magnitude is insufficient, do not trade
if abs(forecast) < min_required_reward:
action_int = 0
logging.debug("No trade: forecast below risk-reward threshold.")
else:
# Only allow trades in the direction of the forecast
if np.sign(action_int) != np.sign(forecast):
action_int = 0
logging.debug("No trade: action direction does not match forecast.")
else:
# Scale action based on forecast strength and then calculate position size
max_forecast = risk * 2
scaling = min(1, abs(forecast) / max_forecast)
action_int = int(np.sign(action_int) * max(1, int(abs(action_int) * scaling)))
action_int = self._calculate_position_size(action_int)
logging.debug(f"Trade approved: adjusted action to {action_int} based on forecast {forecast:.2f}.")
# Check for stop loss or trailing stop if already in a position
if self.contracts_held != 0 and self.entry_price is not None:
price_change = current_price - self.entry_price
if self.contracts_held > 0:
self.peak_price = max(self.peak_price or current_price, current_price)
trail_diff = self.peak_price - current_price
if price_change <= -risk or trail_diff >= self.trailing_stop_points:
logging.debug("Stop loss / trailing stop triggered for long position.")
action_int = -self.contracts_held
else:
self.peak_price = min(self.peak_price or current_price, current_price)
trail_diff = current_price - self.peak_price
if price_change >= risk or trail_diff >= self.trailing_stop_points:
logging.debug("Stop loss / trailing stop triggered for short position.")
action_int = -self.contracts_held
fee = self.transaction_cost * abs(action_int)
# Update position logic
if action_int != 0:
if self.contracts_held == 0:
self.contracts_held = action_int
self.entry_price = current_price
self.peak_price = current_price
elif np.sign(self.contracts_held) == np.sign(action_int):
total_contracts = self.contracts_held + action_int
self.entry_price = (self.entry_price * self.contracts_held + current_price * action_int) / total_contracts
self.contracts_held = total_contracts
else:
if abs(action_int) >= abs(self.contracts_held):
self.contracts_held += action_int
self.entry_price = current_price if self.contracts_held != 0 else None
self.peak_price = current_price if self.contracts_held != 0 else None
else:
self.contracts_held += action_int
# Calculate mark-to-market PnL change from previous step (using previous step close)
prev_close = self.df.loc[self.current_step-1, 'Close'] if self.current_step > 0 else current_price
pnl_change = (current_price - prev_close) * prev_position * self.multiplier
# Base reward: pnl change minus fee
reward = pnl_change - fee
# Add bonus reward if action aligns with forecast (scaled by bonus factor)
bonus_factor = 10.0
reward += bonus_factor * (action_int * forecast)
self.balance += reward
self.daily_loss -= min(0, reward)
# Record trade if a position was closed
if self.contracts_held == 0 and prev_position != 0:
profit = self.balance - start_balance
day = self.current_step // (60 * 24)
self.trade_history.append((start_balance, profit, day))
# Update daily balance if new day begins
current_day = self.current_step // (60 * 24)
if current_day > self.current_day:
self.daily_balances.append(self.balance)
self.daily_loss = 0.0
self.current_day = current_day
self.current_step += 1
done = (self.current_step >= self.max_steps - 1) or (self.daily_loss >= self.initial_balance * self.max_daily_loss)
return self._get_obs(), reward, done, {}
def render(self, mode='human'):
current_price = self.df.loc[self.current_step, 'Close']
pnl = (current_price - self.entry_price) * self.contracts_held * self.multiplier if self.entry_price is not None else 0.0
print(f"Step: {self.current_step}, Balance: {self.balance:.2f}, Contracts: {self.contracts_held}, PnL: {pnl:.2f}, Daily Loss: {self.daily_loss:.2f}")
def calculate_metrics(self):
if not self.trade_history or len(self.daily_balances) < 2:
return {
'avg_daily_return': 0.0,
'avg_return_per_trade': 0.0,
'win_rate': 0.0,
'avg_dollar_return_daily': 0.0,
'avg_dollar_return_per_trade': 0.0,
'sharpe_ratio': 0.0
}
daily_returns = [(self.daily_balances[i] - self.daily_balances[i-1]) / self.daily_balances[i-1]
for i in range(1, len(self.daily_balances))]
avg_daily_return = np.mean(daily_returns) * 100
avg_dollar_return_daily = np.mean([self.daily_balances[i] - self.daily_balances[i-1]
for i in range(1, len(self.daily_balances))])
trade_returns = [profit / start for start, profit, _ in self.trade_history]
avg_return_per_trade = np.mean(trade_returns) * 100
avg_dollar_return_per_trade = np.mean([profit for _, profit, _ in self.trade_history])
win_rate = len([r for r in trade_returns if r > 0]) / len(trade_returns) * 100
daily_returns_np = np.array(daily_returns)
sharpe_ratio = (np.mean(daily_returns_np) / np.std(daily_returns_np)) * np.sqrt(252) if np.std(daily_returns_np) != 0 else 0.0
return {
'avg_daily_return': avg_daily_return,
'avg_return_per_trade': avg_return_per_trade,
'win_rate': win_rate,
'avg_dollar_return_daily': avg_dollar_return_daily,
'avg_dollar_return_per_trade': avg_dollar_return_per_trade,
'sharpe_ratio': sharpe_ratio
}
# =============================================================================
# Placeholders for Live Deployment Functions
# =============================================================================
def get_live_data():
return None
def execute_order(action):
logging.info(f"Executing order: {action}")
def live_trading_loop(model, env, polling_interval=5):
obs = env.reset()
done = False
while not done:
live_data = get_live_data()
if live_data is not None:
pass
action, _ = model.predict(obs, deterministic=True)
execute_order(action)
obs, reward, done, _ = env.step(action)
env.render()
time.sleep(polling_interval)
# =============================================================================
# Data Preprocessing with Parallelization
# =============================================================================
def parallel_feature_engineering(row):
return row
def feature_engineering_parallel(df, num_workers):
logging.info(f"Starting parallel feature engineering with {num_workers} workers...")
with Pool(processes=num_workers) as pool:
processed_rows = pool.map(parallel_feature_engineering, [row for _, row in df.iterrows()])
df_processed = pd.DataFrame(processed_rows)
logging.info("Parallel feature engineering completed.")
return df_processed
# =============================================================================
# MAIN FUNCTION: LSTM Training + PPO for Futures Trading
# =============================================================================
def main():
args = parse_arguments()
data_path = args.data_path
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
lstm_window_size = args.lstm_window_size
ppo_total_timesteps = args.ppo_total_timesteps
n_trials_lstm = args.n_trials_lstm
preprocess_workers = args.preprocess_workers
enable_resource_monitor = args.monitor_resources
action_mode = args.action_mode
max_contracts = args.max_contracts
# Load and process data (CSV or JSON)
df = load_data(data_path)
df = calculate_technical_indicators(df)
# Use numeric columns (exclude Date)
feature_columns = [col for col in df.columns if col != 'Date' and np.issubdtype(df[col].dtype, np.number)]
# Fit scalers on the numeric data to avoid issues with column names
scaler_features = MinMaxScaler().fit(df[feature_columns].values)
scaler_target = MinMaxScaler().fit(df[['Close']].values)
# Build a basic LSTM model (further tuning could be applied using Optuna)
final_lstm = build_lstm((lstm_window_size, len(feature_columns)), {
'num_lstm_layers': 2,
'lstm_units': 50,
'dropout_rate': 0.2,
'optimizer': 'Adam',
'learning_rate': 0.001,
'decay': 1e-6
})
# (Optional) You can add LSTM hyperparameter tuning here using Optuna
# Set up the trading environment
env_params = {
'df': df,
'feature_columns': feature_columns,
'lstm_model': final_lstm,
'scaler_features': scaler_features,
'scaler_target': scaler_target,
'window_size': lstm_window_size,
'transaction_cost': 0.60,
'action_mode': action_mode,
'max_contracts': max_contracts,
'initial_balance': args.initial_balance,
'stop_loss_points': args.stop_loss_points,
'trailing_stop_points': args.trailing_stop_points,
'min_risk_reward': args.min_risk_reward,
'max_daily_loss': 0.05
}
env = FuturesTradingEnv(**env_params)
vec_env = DummyVecEnv([lambda: env])
# PPO hyperparameters
ppo_hyperparams = {
'n_steps': 2048,
'batch_size': 64,
'gae_lambda': 0.95,
'gamma': 0.99,
'learning_rate': 3e-4,
'ent_coef': 0.0,
'verbose': 1
}
if enable_resource_monitor:
resource_monitor_thread = threading.Thread(target=monitor_resources, args=(60,), daemon=True)
resource_monitor_thread.start()
logging.info("Starting PPO training...")
ppo_model = PPO('MlpPolicy', vec_env, **ppo_hyperparams)
ppo_model.learn(total_timesteps=ppo_total_timesteps)
ppo_model.save(os.path.join(output_dir, "best_ppo_model.zip"))
logging.info("PPO training completed and model saved.")
# Inference on the environment
obs = env.reset()
done = False
total_reward = 0.0
step_data = []
while not done:
step_count = env.current_step + 1
action, _ = ppo_model.predict(obs, deterministic=True)
obs, reward, done, _ = env.step(action)
total_reward += reward
step_data.append({
"Step": step_count,
"Action": int(action) if action_mode == 'discrete' else int(np.round(action[0])),
"Reward": reward,
"Contracts": env.contracts_held
})
final_pnl = (env.df.loc[env.current_step, 'Close'] - (env.entry_price if env.entry_price else 0)) * env.contracts_held * env.multiplier
print("\n=== Final PPO Inference ===")
print(f"Total Steps: {step_count}")
print(f"Final Contracts Held: {env.contracts_held}")
print(f"Final Estimated PnL: {final_pnl:.2f}")
print(f"Total Reward Sum: {total_reward:.2f}")
print("\nLast 15 Steps:")
last_n = step_data[-15:] if len(step_data) > 15 else step_data
print(tabulate(last_n, headers="keys", tablefmt="pretty"))
metrics = env.calculate_metrics()
print("\n=== Trading Metrics ===")
print(f"Average Daily Return: {metrics['avg_daily_return']:.2f}%")
print(f"Average Return per Trade: {metrics['avg_return_per_trade']:.2f}%")
print(f"Win Rate: {metrics['win_rate']:.2f}%")
print(f"Average Dollar Return Daily: ${metrics['avg_dollar_return_daily']:.2f}")
print(f"Average Dollar Return per Trade: ${metrics['avg_dollar_return_per_trade']:.2f}")
print(f"Sharpe Ratio: {metrics['sharpe_ratio']:.2f}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,257 @@
#!/usr/bin/env python3
"""
MidasWrapper.py
This script connects to IBKR via IB_insync, attempts to select a MES futures contract
using several variants, and then, based on the AI signal, places a market order and attaches
a stop-loss order.
Usage:
- Ensure that IBKR TWS or Gateway is running in paper trading mode (port 4002 used here).
- Make sure your IBKR account is subscribed to MES futures and the requested contract month is active.
- Example: py MidasWrapper.py
"""
import datetime
import time
import logging
from zoneinfo import ZoneInfo
from ib_insync import IB, Future, MarketOrder, Order, util
# --- Logger Setup ---
logging.basicConfig(level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger(__name__)
# --- Global Time Zone ---
TZ = ZoneInfo("US/Eastern")
# --- Utility Functions ---
def get_third_friday(year, month):
"""Return the third Friday of the given year/month as a datetime with TZ."""
fridays = []
for day in range(1, 32):
try:
d = datetime.date(year, month, day)
except ValueError:
break
if d.weekday() == 4:
fridays.append(d)
if len(fridays) >= 3:
dt = datetime.datetime.combine(fridays[2], datetime.time(16, 0))
elif fridays:
dt = datetime.datetime.combine(fridays[-1], datetime.time(16, 0))
else:
dt = datetime.datetime(year, month, 1, 16, 0)
return dt.replace(tzinfo=TZ)
def getMESContract(ib_conn, cm, contract_expiration):
"""
Try several MES contract definitions for the given contract month (cm)
and expiration date. Returns a tuple (contract, variant_desc) if found,
or (None, None) if no valid contract is returned.
"""
expiration_str = contract_expiration.strftime("%Y%m%d")
variants = []
# Variant 1: Full expiration date, exchange GLOBEX.
contract1 = Future(
symbol='MES',
lastTradeDateOrContractMonth=expiration_str,
exchange='GLOBEX',
currency='USD',
multiplier=5
)
contract1.includeExpired = True
variants.append(("Variant 1: full expiration, GLOBEX", contract1))
# Variant 2: Full expiration date, exchange CME.
contract2 = Future(
symbol='MES',
lastTradeDateOrContractMonth=expiration_str,
exchange='CME',
currency='USD',
multiplier=5
)
contract2.includeExpired = True
variants.append(("Variant 2: full expiration, CME", contract2))
# Variant 3: Contract month with tradingClass on GLOBEX.
contract3 = Future(
symbol='MES',
lastTradeDateOrContractMonth=cm,
exchange='GLOBEX',
currency='USD',
multiplier=5,
tradingClass='MES'
)
contract3.includeExpired = True
variants.append(("Variant 3: contract month, GLOBEX, tradingClass", contract3))
# Variant 4: Contract month with tradingClass on CME.
contract4 = Future(
symbol='MES',
lastTradeDateOrContractMonth=cm,
exchange='CME',
currency='USD',
multiplier=5,
tradingClass='MES'
)
contract4.includeExpired = True
variants.append(("Variant 4: contract month, CME, tradingClass", contract4))
# Variant 5: Contract month using computed localSymbol on GLOBEX.
month_codes = {1: 'F', 2: 'G', 3: 'H', 4: 'J', 5: 'K', 6: 'M', 7: 'N', 8: 'Q', 9: 'U', 10: 'V', 11: 'X', 12: 'Z'}
year_num = int(cm[:4])
month_num = int(cm[4:])
local_symbol = f"MES{month_codes.get(month_num, '')}{str(year_num)[-1]}"
contract5 = Future(
symbol='MES',
lastTradeDateOrContractMonth=cm,
localSymbol=local_symbol,
exchange='GLOBEX',
currency='USD',
multiplier=5
)
contract5.includeExpired = True
variants.append(("Variant 5: contract month, GLOBEX, localSymbol", contract5))
# Variant 6: Basic contract definition on GLOBEX using only contract month.
contract6 = Future(
symbol='MES',
lastTradeDateOrContractMonth=cm,
exchange='GLOBEX',
currency='USD',
multiplier=5
)
contract6.includeExpired = True
variants.append(("Variant 6: basic contract, GLOBEX", contract6))
# Variant 7: Basic contract definition on CME using only contract month.
contract7 = Future(
symbol='MES',
lastTradeDateOrContractMonth=cm,
exchange='CME',
currency='USD',
multiplier=5
)
contract7.includeExpired = True
variants.append(("Variant 7: basic contract, CME", contract7))
for variant_desc, contract in variants:
logger.info(f"Trying {variant_desc} for {cm} (expiration: {expiration_str})...")
details = ib_conn.reqContractDetails(contract)
if details:
logger.info(f"Success with {variant_desc}: {details[0].contract.localSymbol}")
return details[0].contract, variant_desc
else:
logger.info(f"{variant_desc} did not return any details.")
return None, None
# --- Main Order Execution Function ---
def execute_future_with_stop(action_signal, quantity=1, stop_loss_points=20):
"""
Based on an AI signal (action_signal), this function:
- Determines order side (BUY if signal > 0, SELL if signal < 0).
- Selects an appropriate MES futures contract.
- Submits a market order and waits for its fill.
- Then submits a stop-loss order at a price offset by stop_loss_points.
"""
if action_signal == 0:
logger.info("Action signal is neutral. No trade will be executed.")
return
side = "BUY" if action_signal > 0 else "SELL"
stop_side = "SELL" if side == "BUY" else "BUY"
logger.info(f"Action signal received: {action_signal} -> {side} {quantity} contract(s)")
# Connect to IBKR (using a dedicated clientId and port for orders)
ib = IB()
try:
logger.info("Connecting to IBKR on 127.0.0.1:4002 with clientId 1...")
ib.connect('127.0.0.1', 4002, clientId=1)
except Exception as e:
logger.error(f"Error connecting to IBKR: {e}")
return
# Determine contract month here we use the current month.
now = datetime.datetime.now(TZ)
cm = now.strftime("%Y%m")
contract_expiration = get_third_friday(now.year, now.month)
contract, variant_used = getMESContract(ib, cm, contract_expiration)
if not contract:
logger.error("No valid MES contract found. Aborting order execution.")
logger.error("Please verify that your IBKR account is subscribed to MES futures and that the contract month is active.")
ib.disconnect()
return
logger.info(f"Selected contract: {contract.localSymbol} using {variant_used}")
# Subscribe to market data to get the current price.
ticker = ib.reqMktData(contract)
t0 = time.time()
while ticker.last is None and time.time() - t0 < 10:
ib.sleep(0.5)
if ticker.last is None:
logger.error("Market data not available. Aborting order execution.")
ib.disconnect()
return
current_price = float(ticker.last)
logger.info(f"Current market price for {contract.localSymbol}: {current_price}")
# Place the market order (parent order)
parent_order = MarketOrder(side, quantity)
parent_order.transmit = True # Transmit immediately.
parent_trade = ib.placeOrder(contract, parent_order)
logger.info(f"Placed market order: {side} {quantity} contract(s) for {contract.localSymbol}")
# Wait for the market order to fill.
while parent_trade.orderStatus.status not in ("Filled", "Cancelled"):
ib.sleep(0.5)
if parent_trade.orderStatus.status != "Filled":
logger.error(f"Market order not filled. Status: {parent_trade.orderStatus.status}")
ib.disconnect()
return
fill_price = parent_trade.orderStatus.avgFillPrice
logger.info(f"Market order filled at {fill_price}")
# Calculate stop-loss price.
# For BUY orders, the stop price is below fill price; for SELL orders, above.
if side == "BUY":
stop_price = fill_price - stop_loss_points
else:
stop_price = fill_price + stop_loss_points
logger.info(f"Placing stop loss order at {stop_price} for {quantity} contract(s)")
# Create and place the stop-loss order.
stop_order = Order()
stop_order.action = stop_side
stop_order.orderType = "STP"
stop_order.totalQuantity = quantity
stop_order.auxPrice = stop_price
stop_order.transmit = True
stop_trade = ib.placeOrder(contract, stop_order)
logger.info(f"Stop loss order placed: {stop_side} {quantity} contract(s) at {stop_price}")
# Optionally, wait a moment to see if the stop order is accepted.
ib.sleep(1)
ib.disconnect()
return parent_trade, stop_trade
# --- Main Execution ---
if __name__ == "__main__":
# Example usage:
# In practice, the action_signal would come from your AI/PPO system.
# A positive value means buy, negative means sell, zero means no action.
simulated_ai_signal = 1 # Change to -1 for a sell signal; 0 for no action.
quantity = 1 # Number of contracts to trade.
stop_loss_points = 20 # Stop-loss offset (price units).
execute_future_with_stop(simulated_ai_signal, quantity, stop_loss_points)

View File

@@ -0,0 +1,106 @@
2025-03-26 03:01:52,506 - INFO - ===== Resource Statistics =====
2025-03-26 03:01:52,506 - INFO - Physical CPU Cores: 28
2025-03-26 03:01:52,506 - INFO - Logical CPU Cores: 56
2025-03-26 03:01:52,507 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
2025-03-26 03:01:52,507 - INFO - No GPUs detected.
2025-03-26 03:01:52,507 - INFO - =================================
2025-03-26 03:01:52,507 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-03-26 03:01:52,508 - INFO - Loading data from: data/MES2023Z.csv
2025-03-26 03:01:52,513 - ERROR - Unexpected error: Missing column provided to 'parse_dates': 'time'
2025-03-26 03:04:50,616 - INFO - ===== Resource Statistics =====
2025-03-26 03:04:50,616 - INFO - Physical CPU Cores: 28
2025-03-26 03:04:50,616 - INFO - Logical CPU Cores: 56
2025-03-26 03:04:50,616 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]%
2025-03-26 03:04:50,617 - INFO - No GPUs detected.
2025-03-26 03:04:50,617 - INFO - =================================
2025-03-26 03:04:50,617 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-03-26 03:04:50,618 - INFO - Loading data from: data/MES2023Z.csv
2025-03-26 03:04:50,621 - ERROR - Unexpected error: Missing column provided to 'parse_dates': 'time'
2025-03-26 03:08:02,316 - INFO - ===== Resource Statistics =====
2025-03-26 03:08:02,316 - INFO - Physical CPU Cores: 28
2025-03-26 03:08:02,316 - INFO - Logical CPU Cores: 56
2025-03-26 03:08:02,317 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
2025-03-26 03:08:02,317 - INFO - No GPUs detected.
2025-03-26 03:08:02,317 - INFO - =================================
2025-03-26 03:08:02,317 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-03-26 03:08:02,318 - INFO - Loading data from: data/MES2023Z.csv
2025-03-26 03:08:02,355 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
2025-03-26 03:08:02,383 - INFO - Data loaded and sorted successfully.
2025-03-26 03:08:02,383 - INFO - Calculating technical indicators...
2025-03-26 03:08:02,448 - INFO - Technical indicators calculated successfully.
2025-03-26 03:08:02,464 - INFO - Starting parallel feature engineering with 54 workers...
2025-03-26 03:08:03,331 - INFO - Parallel feature engineering completed.
2025-03-26 03:08:03,341 - INFO - Training sequences shape: (676, 15, 17)
2025-03-26 03:08:03,342 - INFO - Validation sequences shape: (144, 15, 17)
2025-03-26 03:08:03,342 - INFO - Testing sequences shape: (146, 15, 17)
2025-03-26 03:08:03,342 - INFO - Starting LSTM hyperparameter optimization with Optuna using 54 parallel trials...
2025-03-26 03:22:04,033 - INFO - Best LSTM Hyperparameters: {'num_lstm_layers': 2, 'lstm_units': 64, 'dropout_rate': 0.13619292923712067, 'learning_rate': 0.0030545284525912166, 'optimizer': 'Nadam', 'decay': 9.615099767236892e-05}
2025-03-26 03:22:04,553 - INFO - Training best LSTM model with optimized hyperparameters...
2025-03-26 03:24:28,296 - INFO - Evaluating final LSTM model...
2025-03-26 03:24:29,722 - INFO - Test MSE: 0.3437
2025-03-26 03:24:29,722 - INFO - Test RMSE: 0.5862
2025-03-26 03:24:29,722 - INFO - Test MAE: 0.4561
2025-03-26 03:24:29,722 - INFO - Test R2 Score: 0.8620
2025-03-26 03:24:29,722 - INFO - Directional Accuracy: 0.2759
2025-03-26 03:24:30,013 - WARNING - You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`.
2025-03-26 03:24:30,121 - INFO - Saved best LSTM model and scaler objects.
2025-03-26 03:24:30,150 - INFO - Starting PPO training...
2025-03-26 05:47:15,571 - INFO - PPO training completed and model saved.
2025-04-11 22:15:50,927 - INFO - ===== Resource Statistics =====
2025-04-11 22:15:50,929 - INFO - Physical CPU Cores: 28
2025-04-11 22:15:50,929 - INFO - Logical CPU Cores: 56
2025-04-11 22:15:50,930 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
2025-04-11 22:15:50,930 - INFO - No GPUs detected.
2025-04-11 22:15:50,930 - INFO - =================================
2025-04-11 22:15:50,932 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-04-11 22:15:50,932 - INFO - Loading data from: data/cleaned_MES_data.csv
2025-04-11 22:15:50,933 - ERROR - File not found: data/cleaned_MES_data.csv
2025-04-11 22:17:40,253 - INFO - ===== Resource Statistics =====
2025-04-11 22:17:40,253 - INFO - Physical CPU Cores: 28
2025-04-11 22:17:40,253 - INFO - Logical CPU Cores: 56
2025-04-11 22:17:40,253 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
2025-04-11 22:17:40,253 - INFO - No GPUs detected.
2025-04-11 22:17:40,254 - INFO - =================================
2025-04-11 22:17:40,254 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-04-11 22:17:40,254 - INFO - Loading data from: ../data/cleaned_MES_data.csv
2025-04-11 22:17:40,373 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
2025-04-11 22:17:40,392 - INFO - Data loaded and sorted successfully.
2025-04-11 22:17:40,392 - INFO - Calculating technical indicators...
2025-04-11 22:17:40,517 - INFO - Technical indicators calculated successfully.
2025-04-11 22:17:40,549 - INFO - Starting parallel feature engineering with 54 workers...
2025-04-11 22:18:25,291 - INFO - ===== Resource Statistics =====
2025-04-11 22:18:25,291 - INFO - Physical CPU Cores: 28
2025-04-11 22:18:25,291 - INFO - Logical CPU Cores: 56
2025-04-11 22:18:25,291 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
2025-04-11 22:18:25,291 - INFO - No GPUs detected.
2025-04-11 22:18:25,291 - INFO - =================================
2025-04-11 22:18:25,292 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-04-11 22:18:25,292 - INFO - Loading data from: ../data/cleaned_MES_data.csv
2025-04-11 22:18:25,373 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
2025-04-11 22:18:25,378 - INFO - Data loaded and sorted successfully.
2025-04-11 22:18:25,378 - INFO - Calculating technical indicators...
2025-04-11 22:18:25,440 - INFO - Technical indicators calculated successfully.
2025-04-11 22:18:25,456 - INFO - Starting parallel feature engineering with 54 workers...
2025-04-11 22:18:53,636 - INFO - Parallel feature engineering completed.
2025-04-11 22:18:53,925 - INFO - Training sequences shape: (45647, 15, 17)
2025-04-11 22:18:53,925 - INFO - Validation sequences shape: (9781, 15, 17)
2025-04-11 22:18:53,925 - INFO - Testing sequences shape: (9782, 15, 17)
2025-04-11 22:18:53,925 - INFO - Starting LSTM hyperparameter optimization with Optuna using 54 parallel trials...
2025-04-12 15:24:28,932 - INFO - ===== Resource Statistics =====
2025-04-12 15:24:28,932 - INFO - Physical CPU Cores: 28
2025-04-12 15:24:28,932 - INFO - Logical CPU Cores: 56
2025-04-12 15:24:28,932 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
2025-04-12 15:24:28,932 - INFO - No GPUs detected.
2025-04-12 15:24:28,932 - INFO - =================================
2025-04-12 15:24:28,933 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
2025-04-12 15:24:28,933 - INFO - Loading data from: ../data/cleaned_MES_data.csv
2025-04-12 15:24:29,013 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
2025-04-12 15:24:29,018 - INFO - Data loaded and sorted successfully.
2025-04-12 15:24:29,018 - INFO - Calculating technical indicators...
2025-04-12 15:24:29,080 - INFO - Technical indicators calculated successfully.
2025-04-12 15:24:29,096 - INFO - Starting parallel feature engineering with 54 workers...
2025-04-12 15:24:56,873 - INFO - Parallel feature engineering completed.
2025-04-12 15:24:57,163 - INFO - Training sequences shape: (45647, 15, 17)
2025-04-12 15:24:57,163 - INFO - Validation sequences shape: (9781, 15, 17)
2025-04-12 15:24:57,163 - INFO - Testing sequences shape: (9782, 15, 17)
2025-04-12 15:24:57,163 - INFO - Starting LSTM hyperparameter optimization with Optuna using 54 parallel trials...

Binary file not shown.

View File

Before

Width:  |  Height:  |  Size: 80 KiB

After

Width:  |  Height:  |  Size: 80 KiB

View File

@@ -0,0 +1,247 @@
<#
.Synopsis
Activate a Python virtual environment for the current PowerShell session.
.Description
Pushes the python executable for a virtual environment to the front of the
$Env:PATH environment variable and sets the prompt to signify that you are
in a Python virtual environment. Makes use of the command line switches as
well as the `pyvenv.cfg` file values present in the virtual environment.
.Parameter VenvDir
Path to the directory that contains the virtual environment to activate. The
default value for this is the parent of the directory that the Activate.ps1
script is located within.
.Parameter Prompt
The prompt prefix to display when this virtual environment is activated. By
default, this prompt is the name of the virtual environment folder (VenvDir)
surrounded by parentheses and followed by a single space (ie. '(.venv) ').
.Example
Activate.ps1
Activates the Python virtual environment that contains the Activate.ps1 script.
.Example
Activate.ps1 -Verbose
Activates the Python virtual environment that contains the Activate.ps1 script,
and shows extra information about the activation as it executes.
.Example
Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
Activates the Python virtual environment located in the specified location.
.Example
Activate.ps1 -Prompt "MyPython"
Activates the Python virtual environment that contains the Activate.ps1 script,
and prefixes the current prompt with the specified string (surrounded in
parentheses) while the virtual environment is active.
.Notes
On Windows, it may be required to enable this Activate.ps1 script by setting the
execution policy for the user. You can do this by issuing the following PowerShell
command:
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
For more information on Execution Policies:
https://go.microsoft.com/fwlink/?LinkID=135170
#>
Param(
[Parameter(Mandatory = $false)]
[String]
$VenvDir,
[Parameter(Mandatory = $false)]
[String]
$Prompt
)
<# Function declarations --------------------------------------------------- #>
<#
.Synopsis
Remove all shell session elements added by the Activate script, including the
addition of the virtual environment's Python executable from the beginning of
the PATH variable.
.Parameter NonDestructive
If present, do not remove this function from the global namespace for the
session.
#>
function global:deactivate ([switch]$NonDestructive) {
# Revert to original values
# The prior prompt:
if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
}
# The prior PYTHONHOME:
if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
}
# The prior PATH:
if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
Remove-Item -Path Env:_OLD_VIRTUAL_PATH
}
# Just remove the VIRTUAL_ENV altogether:
if (Test-Path -Path Env:VIRTUAL_ENV) {
Remove-Item -Path env:VIRTUAL_ENV
}
# Just remove VIRTUAL_ENV_PROMPT altogether.
if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
Remove-Item -Path env:VIRTUAL_ENV_PROMPT
}
# Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
}
# Leave deactivate function in the global namespace if requested:
if (-not $NonDestructive) {
Remove-Item -Path function:deactivate
}
}
<#
.Description
Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
given folder, and returns them in a map.
For each line in the pyvenv.cfg file, if that line can be parsed into exactly
two strings separated by `=` (with any amount of whitespace surrounding the =)
then it is considered a `key = value` line. The left hand string is the key,
the right hand is the value.
If the value starts with a `'` or a `"` then the first and last character is
stripped from the value before being captured.
.Parameter ConfigDir
Path to the directory that contains the `pyvenv.cfg` file.
#>
function Get-PyVenvConfig(
[String]
$ConfigDir
) {
Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
# Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
$pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
# An empty map will be returned if no config file is found.
$pyvenvConfig = @{ }
if ($pyvenvConfigPath) {
Write-Verbose "File exists, parse `key = value` lines"
$pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
$pyvenvConfigContent | ForEach-Object {
$keyval = $PSItem -split "\s*=\s*", 2
if ($keyval[0] -and $keyval[1]) {
$val = $keyval[1]
# Remove extraneous quotations around a string value.
if ("'""".Contains($val.Substring(0, 1))) {
$val = $val.Substring(1, $val.Length - 2)
}
$pyvenvConfig[$keyval[0]] = $val
Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
}
}
}
return $pyvenvConfig
}
<# Begin Activate script --------------------------------------------------- #>
# Determine the containing directory of this script
$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
$VenvExecDir = Get-Item -Path $VenvExecPath
Write-Verbose "Activation script is located in path: '$VenvExecPath'"
Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
# Set values required in priority: CmdLine, ConfigFile, Default
# First, get the location of the virtual environment, it might not be
# VenvExecDir if specified on the command line.
if ($VenvDir) {
Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
}
else {
Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
$VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
Write-Verbose "VenvDir=$VenvDir"
}
# Next, read the `pyvenv.cfg` file to determine any required value such
# as `prompt`.
$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
# Next, set the prompt from the command line, or the config file, or
# just use the name of the virtual environment folder.
if ($Prompt) {
Write-Verbose "Prompt specified as argument, using '$Prompt'"
}
else {
Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
$Prompt = $pyvenvCfg['prompt'];
}
else {
Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
$Prompt = Split-Path -Path $venvDir -Leaf
}
}
Write-Verbose "Prompt = '$Prompt'"
Write-Verbose "VenvDir='$VenvDir'"
# Deactivate any currently active virtual environment, but leave the
# deactivate function in place.
deactivate -nondestructive
# Now set the environment variable VIRTUAL_ENV, used by many tools to determine
# that there is an activated venv.
$env:VIRTUAL_ENV = $VenvDir
if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
Write-Verbose "Setting prompt to '$Prompt'"
# Set the prompt to include the env name
# Make sure _OLD_VIRTUAL_PROMPT is global
function global:_OLD_VIRTUAL_PROMPT { "" }
Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
function global:prompt {
Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
_OLD_VIRTUAL_PROMPT
}
$env:VIRTUAL_ENV_PROMPT = $Prompt
}
# Clear PYTHONHOME
if (Test-Path -Path Env:PYTHONHOME) {
Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
Remove-Item -Path Env:PYTHONHOME
}
# Add the venv to the PATH
Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"

View File

@@ -0,0 +1,69 @@
# This file must be used with "source bin/activate" *from bash*
# you cannot run it directly
deactivate () {
# reset old environment variables
if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
PATH="${_OLD_VIRTUAL_PATH:-}"
export PATH
unset _OLD_VIRTUAL_PATH
fi
if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
export PYTHONHOME
unset _OLD_VIRTUAL_PYTHONHOME
fi
# This should detect bash and zsh, which have a hash command that must
# be called to get it to forget past commands. Without forgetting
# past commands the $PATH changes we made may not be respected
if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
hash -r 2> /dev/null
fi
if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
PS1="${_OLD_VIRTUAL_PS1:-}"
export PS1
unset _OLD_VIRTUAL_PS1
fi
unset VIRTUAL_ENV
unset VIRTUAL_ENV_PROMPT
if [ ! "${1:-}" = "nondestructive" ] ; then
# Self destruct!
unset -f deactivate
fi
}
# unset irrelevant variables
deactivate nondestructive
VIRTUAL_ENV="/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv"
export VIRTUAL_ENV
_OLD_VIRTUAL_PATH="$PATH"
PATH="$VIRTUAL_ENV/bin:$PATH"
export PATH
# unset PYTHONHOME if set
# this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
# could use `if (set -u; : $PYTHONHOME) ;` in bash
if [ -n "${PYTHONHOME:-}" ] ; then
_OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
unset PYTHONHOME
fi
if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
_OLD_VIRTUAL_PS1="${PS1:-}"
PS1="(wrappervenv) ${PS1:-}"
export PS1
VIRTUAL_ENV_PROMPT="(wrappervenv) "
export VIRTUAL_ENV_PROMPT
fi
# This should detect bash and zsh, which have a hash command that must
# be called to get it to forget past commands. Without forgetting
# past commands the $PATH changes we made may not be respected
if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
hash -r 2> /dev/null
fi

View File

@@ -0,0 +1,26 @@
# This file must be used with "source bin/activate.csh" *from csh*.
# You cannot run it directly.
# Created by Davide Di Blasi <davidedb@gmail.com>.
# Ported to Python 3.3 venv by Andrew Svetlov <andrew.svetlov@gmail.com>
alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate'
# Unset irrelevant variables.
deactivate nondestructive
setenv VIRTUAL_ENV "/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv"
set _OLD_VIRTUAL_PATH="$PATH"
setenv PATH "$VIRTUAL_ENV/bin:$PATH"
set _OLD_VIRTUAL_PROMPT="$prompt"
if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
set prompt = "(wrappervenv) $prompt"
setenv VIRTUAL_ENV_PROMPT "(wrappervenv) "
endif
alias pydoc python -m pydoc
rehash

View File

@@ -0,0 +1,69 @@
# This file must be used with "source <venv>/bin/activate.fish" *from fish*
# (https://fishshell.com/); you cannot run it directly.
function deactivate -d "Exit virtual environment and return to normal shell environment"
# reset old environment variables
if test -n "$_OLD_VIRTUAL_PATH"
set -gx PATH $_OLD_VIRTUAL_PATH
set -e _OLD_VIRTUAL_PATH
end
if test -n "$_OLD_VIRTUAL_PYTHONHOME"
set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
set -e _OLD_VIRTUAL_PYTHONHOME
end
if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
set -e _OLD_FISH_PROMPT_OVERRIDE
# prevents error when using nested fish instances (Issue #93858)
if functions -q _old_fish_prompt
functions -e fish_prompt
functions -c _old_fish_prompt fish_prompt
functions -e _old_fish_prompt
end
end
set -e VIRTUAL_ENV
set -e VIRTUAL_ENV_PROMPT
if test "$argv[1]" != "nondestructive"
# Self-destruct!
functions -e deactivate
end
end
# Unset irrelevant variables.
deactivate nondestructive
set -gx VIRTUAL_ENV "/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv"
set -gx _OLD_VIRTUAL_PATH $PATH
set -gx PATH "$VIRTUAL_ENV/bin" $PATH
# Unset PYTHONHOME if set.
if set -q PYTHONHOME
set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
set -e PYTHONHOME
end
if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
# fish uses a function instead of an env var to generate the prompt.
# Save the current fish_prompt function as the function _old_fish_prompt.
functions -c fish_prompt _old_fish_prompt
# With the original prompt function renamed, we can override with our own.
function fish_prompt
# Save the return status of the last command.
set -l old_status $status
# Output the venv prompt; color taken from the blue of the Python logo.
printf "%s%s%s" (set_color 4B8BBE) "(wrappervenv) " (set_color normal)
# Restore the return status of the previous command.
echo "exit $old_status" | .
# Output the original/"old" prompt.
_old_fish_prompt
end
set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
set -gx VIRTUAL_ENV_PROMPT "(wrappervenv) "
end

View File

@@ -0,0 +1,8 @@
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from numpy.f2py.f2py2e import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

View File

@@ -0,0 +1,8 @@
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from numpy._configtool import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

View File

@@ -0,0 +1,8 @@
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

View File

@@ -0,0 +1,8 @@
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

View File

@@ -0,0 +1,8 @@
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

View File

@@ -0,0 +1 @@
python3

View File

@@ -0,0 +1 @@
/home/midas/.pyenv/versions/3.11.4/bin/python3

View File

@@ -0,0 +1 @@
python3

View File

@@ -0,0 +1,222 @@
# don't import any costly modules
import sys
import os
is_pypy = '__pypy__' in sys.builtin_module_names
def warn_distutils_present():
if 'distutils' not in sys.modules:
return
if is_pypy and sys.version_info < (3, 7):
# PyPy for 3.6 unconditionally imports distutils, so bypass the warning
# https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
return
import warnings
warnings.warn(
"Distutils was imported before Setuptools, but importing Setuptools "
"also replaces the `distutils` module in `sys.modules`. This may lead "
"to undesirable behaviors or errors. To avoid these issues, avoid "
"using distutils directly, ensure that setuptools is installed in the "
"traditional way (e.g. not an editable install), and/or make sure "
"that setuptools is always imported before distutils."
)
def clear_distutils():
if 'distutils' not in sys.modules:
return
import warnings
warnings.warn("Setuptools is replacing distutils.")
mods = [
name
for name in sys.modules
if name == "distutils" or name.startswith("distutils.")
]
for name in mods:
del sys.modules[name]
def enabled():
"""
Allow selection of distutils by environment variable.
"""
which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'local')
return which == 'local'
def ensure_local_distutils():
import importlib
clear_distutils()
# With the DistutilsMetaFinder in place,
# perform an import to cause distutils to be
# loaded from setuptools._distutils. Ref #2906.
with shim():
importlib.import_module('distutils')
# check that submodules load as expected
core = importlib.import_module('distutils.core')
assert '_distutils' in core.__file__, core.__file__
assert 'setuptools._distutils.log' not in sys.modules
def do_override():
"""
Ensure that the local copy of distutils is preferred over stdlib.
See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401
for more motivation.
"""
if enabled():
warn_distutils_present()
ensure_local_distutils()
class _TrivialRe:
def __init__(self, *patterns):
self._patterns = patterns
def match(self, string):
return all(pat in string for pat in self._patterns)
class DistutilsMetaFinder:
def find_spec(self, fullname, path, target=None):
# optimization: only consider top level modules and those
# found in the CPython test suite.
if path is not None and not fullname.startswith('test.'):
return
method_name = 'spec_for_{fullname}'.format(**locals())
method = getattr(self, method_name, lambda: None)
return method()
def spec_for_distutils(self):
if self.is_cpython():
return
import importlib
import importlib.abc
import importlib.util
try:
mod = importlib.import_module('setuptools._distutils')
except Exception:
# There are a couple of cases where setuptools._distutils
# may not be present:
# - An older Setuptools without a local distutils is
# taking precedence. Ref #2957.
# - Path manipulation during sitecustomize removes
# setuptools from the path but only after the hook
# has been loaded. Ref #2980.
# In either case, fall back to stdlib behavior.
return
class DistutilsLoader(importlib.abc.Loader):
def create_module(self, spec):
mod.__name__ = 'distutils'
return mod
def exec_module(self, module):
pass
return importlib.util.spec_from_loader(
'distutils', DistutilsLoader(), origin=mod.__file__
)
@staticmethod
def is_cpython():
"""
Suppress supplying distutils for CPython (build and tests).
Ref #2965 and #3007.
"""
return os.path.isfile('pybuilddir.txt')
def spec_for_pip(self):
"""
Ensure stdlib distutils when running under pip.
See pypa/pip#8761 for rationale.
"""
if self.pip_imported_during_build():
return
clear_distutils()
self.spec_for_distutils = lambda: None
@classmethod
def pip_imported_during_build(cls):
"""
Detect if pip is being imported in a build script. Ref #2355.
"""
import traceback
return any(
cls.frame_file_is_setup(frame) for frame, line in traceback.walk_stack(None)
)
@staticmethod
def frame_file_is_setup(frame):
"""
Return True if the indicated frame suggests a setup.py file.
"""
# some frames may not have __file__ (#2940)
return frame.f_globals.get('__file__', '').endswith('setup.py')
def spec_for_sensitive_tests(self):
"""
Ensure stdlib distutils when running select tests under CPython.
python/cpython#91169
"""
clear_distutils()
self.spec_for_distutils = lambda: None
sensitive_tests = (
[
'test.test_distutils',
'test.test_peg_generator',
'test.test_importlib',
]
if sys.version_info < (3, 10)
else [
'test.test_distutils',
]
)
for name in DistutilsMetaFinder.sensitive_tests:
setattr(
DistutilsMetaFinder,
f'spec_for_{name}',
DistutilsMetaFinder.spec_for_sensitive_tests,
)
DISTUTILS_FINDER = DistutilsMetaFinder()
def add_shim():
DISTUTILS_FINDER in sys.meta_path or insert_shim()
class shim:
def __enter__(self):
insert_shim()
def __exit__(self, exc, value, tb):
remove_shim()
def insert_shim():
sys.meta_path.insert(0, DISTUTILS_FINDER)
def remove_shim():
try:
sys.meta_path.remove(DISTUTILS_FINDER)
except ValueError:
pass

View File

@@ -0,0 +1 @@
__import__('_distutils_hack').do_override()

View File

@@ -0,0 +1 @@
import os; var = 'SETUPTOOLS_USE_DISTUTILS'; enabled = os.environ.get(var, 'local') == 'local'; enabled and __import__('_distutils_hack').add_shim();

View File

@@ -0,0 +1,25 @@
BSD 2-Clause License
Copyright (c) 2023, Ewald de Wit
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,210 @@
Metadata-Version: 2.1
Name: eventkit
Version: 1.0.3
Summary: Event-driven data pipelines
Home-page: https://github.com/erdewit/eventkit
Author: Ewald R. de Wit
Author-email: ewald.de.wit@gmail.com
License: BSD
Keywords: python asyncio event driven data pipelines
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: BSD License
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3 :: Only
License-File: LICENSE
Requires-Dist: numpy
|Build| |PyVersion| |Status| |PyPiVersion| |License| |Docs|
Introduction
------------
The primary use cases of eventkit are
* to send events between loosely coupled components;
* to compose all kinds of event-driven data pipelines.
The interface is kept as Pythonic as possible,
with familiar names from Python and its libraries where possible.
For scheduling asyncio is used and there is seamless integration with it.
See the examples and the
`introduction notebook <https://github.com/erdewit/eventkit/tree/master/notebooks/eventkit_introduction.ipynb>`_
to get a true feel for the possibilities.
Installation
------------
::
pip3 install eventkit
Python_ version 3.6 or higher is required.
Examples
--------
**Create an event and connect two listeners**
.. code-block:: python
import eventkit as ev
def f(a, b):
print(a * b)
def g(a, b):
print(a / b)
event = ev.Event()
event += f
event += g
event.emit(10, 5)
**Create a simple pipeline**
.. code-block:: python
import eventkit as ev
event = (
ev.Sequence('abcde')
.map(str.upper)
.enumerate()
)
print(event.run()) # in Jupyter: await event.list()
Output::
[(0, 'A'), (1, 'B'), (2, 'C'), (3, 'D'), (4, 'E')]
**Create a pipeline to get a running average and standard deviation**
.. code-block:: python
import random
import eventkit as ev
source = ev.Range(1000).map(lambda i: random.gauss(0, 1))
event = source.array(500)[ev.ArrayMean, ev.ArrayStd].zip()
print(event.last().run()) # in Jupyter: await event.last()
Output::
[(0.00790957852672618, 1.0345673260655333)]
**Combine async iterators together**
.. code-block:: python
import asyncio
import eventkit as ev
async def ait(r):
for i in r:
await asyncio.sleep(0.1)
yield i
async def main():
async for t in ev.Zip(ait('XYZ'), ait('123')):
print(t)
asyncio.get_event_loop().run_until_complete(main()) # in Jupyter: await main()
Output::
('X', '1')
('Y', '2')
('Z', '3')
**Real-time video analysis pipeline**
.. code-block:: python
self.video = VideoStream(conf.CAM_ID)
scene = self.video | FaceTracker | SceneAnalyzer
lastScene = scene.aiter(skip_to_last=True)
async for frame, persons in lastScene:
...
`Full source code <https://github.com/erdewit/heartwave/blob/100e1a89d18756e141f9dcfbb73c55a1009debf4/heartwave/app.py#L88>`_
Distributed computing
---------------------
The `distex <https://github.com/erdewit/distex>`_ library provides a
``poolmap`` extension method to put multiple cores or machines to use:
.. code-block:: python
from distex import Pool
import eventkit as ev
import bz2
pool = Pool()
# await pool # un-comment in Jupyter
data = [b'A' * 1000000] * 1000
pipe = ev.Sequence(data).poolmap(pool, bz2.compress).map(len).mean().last()
print(pipe.run()) # in Jupyter: print(await pipe)
pool.shutdown()
Inspired by:
------------
* `Qt Signals & Slots <https://doc.qt.io/qt-5/signalsandslots.html>`_
* `itertools <https://docs.python.org/3/library/itertools.html>`_
* `aiostream <https://github.com/vxgmichel/aiostream>`_
* `Bacon <https://baconjs.github.io/index.html>`_
* `aioreactive <https://github.com/dbrattli/aioreactive>`_
* `Reactive extensions <http://reactivex.io/documentation/operators.html>`_
* `underscore.js <https://underscorejs.org>`_
* `.NET Events <https://docs.microsoft.com/en-us/dotnet/standard/events>`_
Documentation
-------------
The complete `API documentation <https://eventkit.readthedocs.io/en/latest/api.html>`_.
.. _Python: http://www.python.org
.. _`Interactive Brokers Python API`: http://interactivebrokers.github.io
.. |Build| image:: https://github.com/erdewit/eventkit/actions/workflows/test.yml/badge.svg?branch=master
:alt: Build
:target: https://github.com/erdewit/eventkit/actions
.. |PyPiVersion| image:: https://img.shields.io/pypi/v/eventkit.svg
:alt: PyPi
:target: https://pypi.python.org/pypi/eventkit
.. |PyVersion| image:: https://img.shields.io/badge/python-3.6+-blue.svg
:alt:
.. |Status| image:: https://img.shields.io/badge/status-stable-green.svg
:alt:
.. |License| image:: https://img.shields.io/badge/license-BSD-blue.svg
:alt:
.. |Docs| image:: https://readthedocs.org/projects/eventkit/badge/?version=latest
:alt: Documentation
:target: https://eventkit.readthedocs.io

View File

@@ -0,0 +1,50 @@
eventkit-1.0.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
eventkit-1.0.3.dist-info/LICENSE,sha256=AQjFGoH_Hjo3QMlS4NcbQdSPkhQtqRRJAX5exgWZZsc,1317
eventkit-1.0.3.dist-info/METADATA,sha256=5sTsyrTbHL4M_iikr626sWVt36trzZwbh1dV4WLWm1g,5441
eventkit-1.0.3.dist-info/RECORD,,
eventkit-1.0.3.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
eventkit-1.0.3.dist-info/top_level.txt,sha256=F6uIbfpq0pwfdya0QaNm3KncSabwqY3zCfX_Sy5K8NQ,15
eventkit/__init__.py,sha256=huK8A1TNx_ofW4m7mbQayY7vPclhO9NIZgWKSAqaCZw,998
eventkit/__pycache__/__init__.cpython-311.pyc,,
eventkit/__pycache__/event.cpython-311.pyc,,
eventkit/__pycache__/util.cpython-311.pyc,,
eventkit/__pycache__/version.cpython-311.pyc,,
eventkit/event.py,sha256=cFkqeOqe9Bzu7bUfW-NVVsSmIi46c4xC1MH_mEPkS2Q,41870
eventkit/ops/__init__.py,sha256=IepLb1S6t8pGGzDt9bpYG_CFfE4uBxPcFtuXlhdvGLE,23
eventkit/ops/__pycache__/__init__.cpython-311.pyc,,
eventkit/ops/__pycache__/aggregate.cpython-311.pyc,,
eventkit/ops/__pycache__/array.cpython-311.pyc,,
eventkit/ops/__pycache__/combine.cpython-311.pyc,,
eventkit/ops/__pycache__/create.cpython-311.pyc,,
eventkit/ops/__pycache__/misc.cpython-311.pyc,,
eventkit/ops/__pycache__/op.cpython-311.pyc,,
eventkit/ops/__pycache__/select.cpython-311.pyc,,
eventkit/ops/__pycache__/timing.cpython-311.pyc,,
eventkit/ops/__pycache__/transform.cpython-311.pyc,,
eventkit/ops/aggregate.py,sha256=6pWaPeG9xxaYvLfa3QVrTyRBCwpjPRCTTFdSUhUdJnw,3875
eventkit/ops/array.py,sha256=kiHny-DpR_68_fGur7UtLT2zs2UXj2zRQkWVgKa84RY,2424
eventkit/ops/combine.py,sha256=NOGLZqsoePxDvt90L26JpbUYkoWIgwQm_2eePh6I8KE,9242
eventkit/ops/create.py,sha256=VRVQrA_GM8HD8mIukUQE7IHLPyfGar_ghAbfRozBTwY,3302
eventkit/ops/misc.py,sha256=yQ_FmyVXYUpFQsyvHdgmI11T75UoHQzou-L9Vr-fkOk,587
eventkit/ops/op.py,sha256=uwjmTZY-cb0FpoVK3Ey31eSh2neTbbALGQAy8Vbo2k0,1711
eventkit/ops/select.py,sha256=i5AEzx_xDoGgSmhHeJbdUQWom28mGnCVIYCpuTyVslk,3461
eventkit/ops/timing.py,sha256=Cpc-jlzqLKub3AUkdKNi8s6Jry7OBzENVTGmlNMjuVY,6159
eventkit/ops/transform.py,sha256=GCJm0UkFjE0TRNFl4IlFUjOQ1DLS5veGdzDMibERRhk,8994
eventkit/util.py,sha256=COHtEwRtXm_H51UAbB4CxtRPaTylaDzWWLqAbKfF0sI,2134
eventkit/version.py,sha256=S8gNTliBVeYLnV8E5UQ8c_2kz3gycGdX9tqY_Ec8H2s,86
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
tests/__pycache__/__init__.cpython-311.pyc,,
tests/__pycache__/aggregate_test.cpython-311.pyc,,
tests/__pycache__/combine_test.cpython-311.pyc,,
tests/__pycache__/create_test.cpython-311.pyc,,
tests/__pycache__/event_test.cpython-311.pyc,,
tests/__pycache__/select_test.cpython-311.pyc,,
tests/__pycache__/timing_test.cpython-311.pyc,,
tests/__pycache__/transform_test.cpython-311.pyc,,
tests/aggregate_test.py,sha256=31sQbR0EKbYoYVwFLWS5Aj6Cb0FUEvdsMfcXkq2XTGM,1654
tests/combine_test.py,sha256=c3_tB8c10AyuoSJE2MnkUt-VwwiEzGQrEVKM28ah0Tg,1904
tests/create_test.py,sha256=eduKfDW3b5hLnpqHRofs2MZkHMdTtVi3R0Fxr_MirpI,800
tests/event_test.py,sha256=DbuLAOXsDgeykb2A5l0__a_yLEYzAWgwr0Mxeo6uQf8,4059
tests/select_test.py,sha256=ZmfAoJdU3hccKzks3KjeIxrIMfzpU9mQs2tSZZGP-RU,1276
tests/timing_test.py,sha256=YPGkdLP85FvtH6kdib_mynrqDBbGrQbd3uwCRp0_24M,1609
tests/transform_test.py,sha256=89_-WWJTORiiGr90mN5L4m9eOwb7m29RhkwQKJFwLtM,5375

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.40.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,23 @@
"""Event-driven data pipelines."""
from .event import Event
from .ops.aggregate import (
All, Any, Count, Deque, Ema, List, Max, Mean, Min, Pairwise, Product,
Reduce, Sum)
from .ops.array import (
Array, ArrayAll, ArrayAny, ArrayMax, ArrayMean, ArrayMin, ArrayStd,
ArraySum)
from .ops.combine import (
AddableJoinOp, Chain, Concat, Fork, Merge, Switch, Zip, Ziplatest)
from .ops.create import (
Aiterate, Marble, Range, Repeat, Sequence, Timer, Timerange, Wait)
from .ops.misc import EndOnError, Errors
from .ops.op import Op
from .ops.select import (
Changes, DropWhile, Filter, Last, Skip, Take, TakeUntil, TakeWhile, Unique)
from .ops.timing import (Debounce, Delay, Sample, Throttle, Timeout)
from .ops.transform import (
Chainmap, Chunk, ChunkWith, Concatmap, Constant, Copy, Deepcopy, Emap,
Enumerate, Iterate, Map, Mergemap, Pack, Partial, PartialRight, Pluck,
Previous, Star, Switchmap, Timestamp)
from .version import __version__, __version_info__

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
"""Event operators."""

View File

@@ -0,0 +1,159 @@
import itertools
import operator
from collections import deque
from .op import Op
from .transform import Iterate
from ..util import NO_VALUE
class Count(Iterate):
__slots__ = ()
def __init__(self, start=0, step=1, source=None):
it = itertools.count(start, step)
Iterate.__init__(self, it, source)
class Reduce(Op):
__slots__ = ('_func', '_initializer', '_prev')
def __init__(self, func, initializer=NO_VALUE, source=None):
Op.__init__(self, source)
self._func = func
self._initializer = initializer
self._prev = NO_VALUE
def on_source(self, arg):
if self._prev is NO_VALUE:
if self._initializer is NO_VALUE:
self._prev = arg
else:
self._prev = self._func(self._initializer, arg)
self.emit(self._prev)
else:
self._prev = self._func(self._prev, arg)
self.emit(self._prev)
class Min(Reduce):
__slots__ = ()
def __init__(self, source=None):
Reduce.__init__(self, min, float('inf'), source)
class Max(Reduce):
__slots__ = ()
def __init__(self, source=None):
Reduce.__init__(self, max, -float('inf'), source)
class Sum(Reduce):
__slots__ = ()
def __init__(self, start=0, source=None):
Reduce.__init__(self, operator.add, start, source)
class Product(Reduce):
__slots__ = ()
def __init__(self, start=1, source=None):
Reduce.__init__(self, operator.mul, start, source)
class Mean(Op):
__slots__ = ('_count', '_sum')
def __init__(self, source=None):
Op.__init__(self, source)
self._count = 0
self._sum = 0
def on_source(self, arg):
self._count += 1
self._sum += arg
self.emit(self._sum / self._count)
class Any(Reduce):
__slots__ = ()
def __init__(self, source=None):
Reduce.__init__(self, lambda prev, v: prev or bool(v), False, source)
class All(Reduce):
__slots__ = ()
def __init__(self, source=None):
Reduce.__init__(self, lambda prev, v: prev and bool(v), True, source)
class Ema(Op):
__slots__ = ('_f1', '_f2', '_prev')
def __init__(self, n=None, weight=None, source=None):
Op.__init__(self, source)
self._f1 = weight or 2.0 / (n + 1)
self._f2 = 1 - self._f1
self._prev = NO_VALUE
def on_source(self, *args):
if self._prev is NO_VALUE:
value = args
else:
value = [
self._f2 * p + self._f1 * a for p, a in zip(self._prev, args)]
self._prev = value
self.emit(*value)
class Pairwise(Op):
__slots__ = ('_prev', '_has_prev')
def __init__(self, source=None):
Op.__init__(self, source)
self._has_prev = False
def on_source(self, *args):
value = args[0] if len(args) == 1 else args if args else NO_VALUE
if self._has_prev:
self.emit(self._prev, value)
else:
self._has_prev = True
self._prev = value
class List(Op):
__slots__ = ('_values')
def __init__(self, source=None):
Op.__init__(self, source)
self._values = []
def on_source(self, *args):
self._values.append(
args[0] if len(args) == 1 else args if args else NO_VALUE)
def on_source_done(self, source):
self.emit(self._values)
Op.on_source_done(self, source)
class Deque(Op):
__slots__ = ('_count', '_q')
def __init__(self, count, source=None):
Op.__init__(self, source)
self._count = count
self._q = deque()
def on_source(self, *args):
self._q.append(
args[0] if len(args) == 1 else args if args else NO_VALUE)
if self._count and len(self._q) > self._count:
self._q.popleft()
self.emit(self._q)

View File

@@ -0,0 +1,126 @@
from collections import deque
import numpy as np
from .op import Op
from ..util import NO_VALUE
class Array(Op):
__slots__ = ('_count', '_q')
def __init__(self, count, source=None):
Op.__init__(self, source)
self._count = count
self._q = deque()
def on_source(self, *args):
self._q.append(
args[0] if len(args) == 1 else args if args else NO_VALUE)
if self._count and len(self._q) > self._count:
self._q.popleft()
self.emit(np.asarray(self._q))
def min(self) -> "ArrayMin": # type: ignore
"""
Minimum value.
"""
return ArrayMin(self)
def max(self) -> "ArrayMax": # type: ignore
"""
Maximum value.
"""
return ArrayMax(self)
def sum(self) -> "ArraySum": # type: ignore
"""
Summation.
"""
return ArraySum(self)
def prod(self) -> "ArrayProd":
"""
Product.
"""
return ArrayProd(self)
def mean(self) -> "ArrayMean": # type: ignore
"""
Mean value.
"""
return ArrayMean(self)
def std(self) -> "ArrayStd": # type: ignore
"""
Sample standard deviation.
"""
return ArrayStd(self)
def any(self) -> "ArrayAny": # type: ignore
"""
Test if any array value is true.
"""
return ArrayAny(self)
def all(self) -> "ArrayAll": # type: ignore
"""
Test if all array values are true.
"""
return ArrayAll(self)
class ArrayMin(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(arg.min())
class ArrayMax(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(arg.max())
class ArraySum(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(arg.sum())
class ArrayProd(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(arg.prod())
class ArrayMean(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(arg.mean())
class ArrayStd(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(arg.std(ddof=1) if len(arg) > 1 else np.nan)
class ArrayAny(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(arg.any())
class ArrayAll(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(arg.all())

View File

@@ -0,0 +1,302 @@
import functools
from collections import defaultdict, deque
from typing import Deque, Optional
from .op import Op
from ..event import Event
from ..util import NO_VALUE
class Fork(list):
__slots__ = ()
def __init__(self):
list.__init__(self)
def join(self, joiner: "JoinOp"):
joiner._set_sources(*self)
self.clear()
return joiner
def concat(self) -> "Concat":
return self.join(Concat())
def merge(self) -> "Merge":
return self.join(Merge())
def switch(self) -> "Switch":
return self.join(Switch())
def zip(self) -> "Zip":
return self.join(Zip())
def ziplatest(self) -> "Ziplatest":
return self.join(Ziplatest())
def chain(self) -> "Chain":
return self.join(Chain())
class JoinOp(Op):
"""
Base class for join operators that combine the emits
from multiple source events.
"""
__slots__ = ('_sources',)
_sources: Deque[Event]
def _set_sources(self, sources):
raise NotImplementedError
class AddableJoinOp(JoinOp):
"""
Base class for join operators where new sources, produced by a
parent higher-order event, can be added dynamically.
"""
__slots__ = ('_parent',)
_parent: Optional[Event]
def __init__(self, *sources: Event):
JoinOp.__init__(self)
self._sources = deque()
self._parent = None
self._set_sources(*sources)
def _set_sources(self, *sources):
for source in sources:
source = Event.create(source)
self.add_source(source)
def add_source(self, source):
# note: the same source can be added multiple times
raise NotImplementedError
def set_parent(self, parent: Event):
self._parent = parent
if parent.done_event:
parent.done_event += self._on_parent_done
def on_source_done(self, source):
self._disconnect_from(source)
self._sources.remove(source)
if not self._sources and self._parent is None:
self.set_done()
def _on_parent_done(self, parent):
parent -= self._on_parent_done
self._parent = None
if not self._sources:
self.set_done()
class Merge(AddableJoinOp):
__slots__ = ()
def add_source(self, source):
self._sources.append(source)
self._connect_from(source)
class Switch(AddableJoinOp):
__slots__ = ('_source2cb', '_active_source')
def __init__(self, *sources):
AddableJoinOp.__init__(self)
self._source2cb = {} # map from source to callback
self._active_source = None
self._set_sources(*sources)
def add_source(self, source):
self._sources.append(source)
cb = self._source2cb.get(source)
if not cb:
cb = functools.partial(self.on_source_s, source)
self._source2cb[source] = cb
source.connect(cb, done=self.on_source_done)
def _remove_source(self, source):
if source in self._sources:
self._sources.remove(source)
cb = self._source2cb.pop(source, None)
if cb:
source -= cb
def on_source_s(self, source, *args):
if source is not self._active_source:
self._remove_source(self._active_source)
self._active_source = source
self.emit(*args)
def on_source_done(self, source):
self._remove_source(source)
if not self._sources and self._parent is None:
self._active_source = None
self.set_done()
class Concat(AddableJoinOp):
__slots__ = ('_source2cb',)
def __init__(self, *sources):
AddableJoinOp.__init__(self)
self._source2cb = {} # map from source to callback
self._set_sources(*sources)
def add_source(self, source):
if source in self._sources:
return
self._sources.append(source)
cb = self._source2cb.get(source)
if not cb:
cb = functools.partial(self._on_source_s, source)
self._source2cb[source] = cb
source.connect(cb, done=self.on_source_done)
def _on_source_s(self, source, *args):
while self._sources and self._sources[0] is not source:
s = self._sources.popleft()
cb = self._source2cb.pop(s, None)
if cb:
s.disconnect(cb, done=self.on_source_done)
self.emit(*args)
def on_source_done(self, source):
cb = self._source2cb.pop(source)
source.disconnect(cb, done=self.on_source_done)
while source in self._sources:
self._sources.remove(source)
if not self._sources and self._parent is None:
self.set_done()
class Chain(AddableJoinOp):
__slots__ = ('_qq', '_source2cbs')
def __init__(self, *sources):
AddableJoinOp.__init__(self)
self._qq = deque()
self._source2cbs = defaultdict(list) # map from source to callbacks
self._set_sources(*sources)
def add_source(self, source):
if not self._sources:
self._connect_from(source)
else:
def cb(*args):
q.append(args)
q = deque()
self._qq.append(q)
source += cb
self._source2cbs[source].append(cb)
self._sources.append(source)
def on_source_done(self, source):
if source is not self._sources[0]:
return
self._disconnect_from(source)
self._sources.popleft()
while self._sources:
source = self._sources[0]
q = self._qq.popleft()
for args in q:
self.emit(*args)
for cb in self._source2cbs.pop(source, []):
source -= cb
if source.done():
self._sources.popleft()
continue
self._connect_from(source)
return
if not self._sources and self._parent is None:
self.set_done()
class Zip(JoinOp):
__slots__ = ('_results', '_source2cbs', '_num_ready')
def __init__(self, *sources):
JoinOp.__init__(self)
self._num_ready = 0 # number of sources with a pending result
self._source2cbs = defaultdict(list) # map from source to callbacks
if sources:
self._set_sources(*sources)
def _set_sources(self, *sources):
self._sources = deque(Event.create(s) for s in sources)
if any(s.done() for s in self._sources):
self.set_done()
return
self._results = [deque() for _ in self._sources]
for i, source in enumerate(self._sources):
cb = functools.partial(self._on_source_i, i)
source.connect(cb, self.on_source_error, self.on_source_done)
self._source2cbs[source].append(cb)
def _on_source_i(self, i, *args):
q = self._results[i]
if not q:
self._num_ready += 1
ready = self._num_ready == len(self._results)
else:
ready = False
q.append(args[0] if len(args) == 1 else args if args else NO_VALUE)
if ready:
tup = tuple(q.popleft() for q in self._results)
self._num_ready = sum(bool(q) for q in self._results)
self.emit(*tup)
def on_source_done(self, source):
self._sources.remove(source)
if not self._sources:
for source, cbs in self._source2cbs.items():
for cb in cbs:
source.disconnect(
cb, self.on_source_error, self.on_source_done)
self._source2cbs = None
self.set_done()
class Ziplatest(JoinOp):
__slots__ = ('_values', '_is_primed', '_source2cbs')
def __init__(self, *sources, partial=True):
JoinOp.__init__(self)
self._is_primed = partial
self._source2cbs = defaultdict(list) # map from source to callbacks
if sources:
self._set_sources(*sources)
def _set_sources(self, *sources):
sources = [Event.create(s) for s in sources]
self._sources = deque(s for s in sources if not s.done())
if not self._sources:
self.set_done()
return
self._values = [s.value() for s in sources]
for i, source in enumerate(self._sources):
cb = functools.partial(self._on_source_i, i)
source.connect(cb, self.on_source_error, self.on_source_done)
self._source2cbs[source].append(cb)
def _on_source_i(self, i, *args):
self._values[i] = \
args[0] if len(args) == 1 else args if args else NO_VALUE
if not self._is_primed:
self._is_primed = not any(r is NO_VALUE for r in self._values)
if self._is_primed:
self.emit(*self._values)
def on_source_done(self, source):
self._sources.remove(source)
if not self._sources:
for source, cbs in self._source2cbs.items():
for cb in cbs:
source.disconnect(
cb, self.on_source_error, self.on_source_done)
self._source2cbs = None
self.set_done()

View File

@@ -0,0 +1,123 @@
import asyncio
import itertools
import time
from .op import Op
from ..event import Event
from ..util import NO_VALUE, get_event_loop, timerange
class Wait(Event):
__slots__ = ('_task',)
def __init__(self, future, name='wait'):
Event.__init__(self, name)
if future.done():
self._task = None
self.set_done()
else:
loop = get_event_loop()
self._task = asyncio.ensure_future(future, loop=loop)
future.add_done_callback(self._on_task_done)
def _on_task_done(self, task):
try:
result = task.result()
except Exception as error:
result = NO_VALUE
self.error_event.emit(self, error)
self.emit(result)
self._task = None
self.set_done()
def __del__(self):
if self._task:
self._task.cancel()
class Aiterate(Event):
__slots__ = ('_task',)
def __init__(self, ait):
Event.__init__(self, ait.__qualname__)
loop = get_event_loop()
self._task = asyncio.ensure_future(self._looper(ait), loop=loop)
async def _looper(self, ait):
try:
async for args in ait:
self.emit(args)
except Exception as error:
self.error_event.emit(self, error)
self._task = None
self.set_done()
def __del__(self):
if self._task:
self._task.cancel()
class Sequence(Aiterate):
__slots__ = ()
def __init__(self, values, interval=0, times=None):
async def sequence():
t0 = time.time()
if times is not None:
for t, value in zip(times, values):
delay = max(0, time.time() + t - t0)
await asyncio.sleep(delay)
yield value
else:
for i, value in enumerate(values):
delay = max(0, i * interval + t0 - time.time())
await asyncio.sleep(delay)
yield value
Aiterate.__init__(self, sequence())
class Repeat(Sequence):
__slots__ = ()
def __init__(self, value, count, interval=0, times=None):
Sequence.__init__(self, itertools.repeat(count), interval, times)
class Range(Sequence):
__slots__ = ()
def __init__(self, *args, interval=0, times=None):
Sequence.__init__(self, range(*args), interval, times)
class Timerange(Aiterate):
__slots__ = ()
def __init__(self, start=0, end=None, step=1):
Aiterate.__init__(self, timerange(start, end, step))
class Timer(Aiterate):
__slots__ = ()
def __init__(self, interval, count=None):
async def timer():
t0 = time.time()
i = 0
while count is None or i < count:
i += 1
delay = i * interval + t0 - time.time()
await asyncio.sleep(delay)
yield i * interval
Aiterate.__init__(self, timer())
class Marble(Op):
__slots__ = ()
def __init__(self, s, interval=0, times=None):
s = s.replace('_', '')
source = Event.sequence(s, interval, times) \
.filter(lambda c: c not in '- ') \
.takewhile(lambda c: c != '|')
Op.__init__(self, source)

View File

@@ -0,0 +1,26 @@
from .op import Op
from ..event import Event
class Errors(Event):
__slots__ = ('_source',)
def __init__(self, source=None):
Event.__init__(self)
self._source = source
if source is not None and source.done():
self.set_done()
else:
source.error_event += self.emit
class EndOnError(Op):
__slots__ = ()
def __init__(self, source=None):
Op.__init__(self, source)
def on_source_error(self, error):
self.disconnect_from(self._source)
self.error_event.emit(error)
self.set_done()

View File

@@ -0,0 +1,63 @@
from typing import Union
from ..event import Event
class Op(Event):
"""
Base functionality for operators.
The Observer pattern is implemented by the following three methods::
on_source(self, *args)
on_source_error(self, source, error)
on_source_done(self, source)
The default handlers will pass along source emits, errors and done events.
This makes ``Op`` also suitable as an identity operator.
"""
__slots__ = ()
def __init__(self, source: Union[Event, None] = None):
Event.__init__(self)
if source is not None:
self.set_source(source)
on_source = Event.emit
def on_source_error(self, source, error):
if len(self.error_event):
self.error_event.emit(source, error)
else:
Event.logger.exception(error)
def on_source_done(self, _source):
if self._source is not None:
self._disconnect_from(self._source)
self._source = None
self.set_done()
def set_source(self, source):
source = Event.create(source)
if self._source is None:
self._source = source
self._connect_from(source)
else:
self._source.set_source(source)
def _connect_from(self, source: Event):
if source.done():
self.set_done()
else:
source.connect(
self.on_source,
self.on_source_error,
self.on_source_done,
keep_ref=True)
def _disconnect_from(self, source: Event):
source.disconnect(
self.on_source,
self.on_source_error,
self.on_source_done)

View File

@@ -0,0 +1,145 @@
from .op import Op
from ..util import NO_VALUE
class Filter(Op):
__slots__ = ('_predicate',)
def __init__(self, predicate=bool, source=None):
Op.__init__(self, source)
self._predicate = predicate
def on_source(self, *args):
if self._predicate(*args):
self.emit(*args)
class Skip(Op):
__slots__ = ('_count', '_n')
def __init__(self, count=1, source=None):
Op.__init__(self, source)
self._count = count
self._n = 0
def on_source(self, *args):
self._n += 1
if self._n == self._count:
self._source -= self.on_source
self._source += self.emit
class Take(Op):
__slots__ = ('_count', '_n')
def __init__(self, count=1, source=None):
Op.__init__(self, source)
self._count = count
self._n = 0
def on_source(self, *args):
self._n += 1
if self._n <= self._count:
self.emit(*args)
if self._n == self._count:
self._disconnect_from(self._source)
self.set_done()
class TakeWhile(Op):
__slots__ = ('_predicate',)
def __init__(self, predicate=bool, source=None):
Op.__init__(self, source)
self._predicate = predicate
def on_source(self, *args):
if self._predicate(*args):
self.emit(*args)
else:
self.set_done()
self._disconnect_from(self._source)
class DropWhile(Op):
__slots__ = ('_predicate', '_drop')
def __init__(self, predicate=lambda x: not x, source=None):
Op.__init__(self, source)
self._predicate = predicate
self._drop = True
def on_source(self, *args):
if self._drop:
self._drop = self._predicate(*args)
if not self._drop:
self.emit(*args)
class TakeUntil(Op):
__slots__ = ('_notifier',)
def __init__(self, notifier, source=None):
Op.__init__(self, source)
self._notifier = notifier
notifier.connect(
self._on_notifier,
self.on_source_error,
self.on_source_done)
def _on_notifier(self, *args):
self.on_source_done(self._source)
def on_source_done(self, source):
Op.on_source_done(self, self._source)
self._notifier.disconnect(
self._on_notifier,
self.on_source_error,
self.on_source_done)
self._notifier = None
class Changes(Op):
__slots__ = ('_prev',)
def __init__(self, source=None):
Op.__init__(self, source)
self._prev = NO_VALUE
def on_source(self, *args):
if args != self._prev:
self.emit(*args)
self._prev = args
class Unique(Op):
__slots__ = ('_key', '_seen')
def __init__(self, key, source=None):
Op.__init__(self, source)
self._key = key
self._seen = set()
def on_source(self, *args):
if self._key is None:
new = args not in self._seen
else:
new = self._key(*args) not in self._seen
self._seen.add(args)
if new:
self.emit(*args)
class Last(Op):
__slots__ = ('_last',)
def __init__(self, source=None):
Op.__init__(self, source)
self._last = NO_VALUE
def on_source(self, *args):
self._last = args
def on_source_done(self, source):
self.emit(*self._last)
Op.on_source_done(self, source)

View File

@@ -0,0 +1,211 @@
from collections import deque
from .op import Op
from ..event import Event
from ..util import NO_VALUE, get_event_loop
class Delay(Op):
__slots__ = ('_delay',)
def __init__(self, delay, source=None):
Op.__init__(self, source)
self._delay = delay
def on_source(self, *args):
loop = get_event_loop()
loop.call_later(self._delay, self.emit, *args)
def on_source_error(self, error):
loop = get_event_loop()
loop.call_later(self._delay, self.error_event.emit, error)
def on_source_done(self, source):
if self._source is not None:
self._disconnect_from(self._source)
self._source = None
loop = get_event_loop()
loop.call_later(self._delay, self.set_done)
class Timeout(Op):
__slots__ = ('_timeout', '_handle', '_last_time')
def __init__(self, timeout, source=None):
Op.__init__(self, source)
if source is not None and source.done():
return
self._timeout = timeout
loop = get_event_loop()
self._last_time = loop.time()
self._handle = None
self._schedule()
def on_source(self, *args):
loop = get_event_loop()
self._last_time = loop.time()
def on_source_done(self, source):
self._handle.cancel()
del self._handle
Op.on_source_done(self, source)
def _schedule(self):
loop = get_event_loop()
self._handle = loop.call_at(
self._last_time + self._timeout, self._on_timer)
def _on_timer(self):
loop = get_event_loop()
if loop.time() - self._last_time > self._timeout:
self.emit()
self.set_done()
else:
self._schedule()
class Debounce(Op):
__slots__ = ('_interval', '_on_first', '_handle', '_last_time')
def __init__(self, interval, on_first=False, source=None):
Op.__init__(self, source)
self._interval = interval
self._on_first = on_first
self._last_time = -float('inf')
self._handle = None
def on_source(self, *args):
loop = get_event_loop()
time = loop.time()
delta = time - self._last_time
self._last_time = time
if self._on_first:
if delta >= self._interval:
self.emit(*args)
else:
if self._handle:
self._handle.cancel()
self._handle = loop.call_at(
time + self._interval, self._delayed_emit, *args)
def _delayed_emit(self, *args):
self._handle = None
self.emit(*args)
if self._source is None:
self.set_done()
def on_source_done(self, source):
self._disconnect_from(source)
self._source = None
if not self._handle:
self.set_done()
class Throttle(Op):
__slots__ = (
'status_event', '_maximum', '_interval', '_cost_func',
'_q', '_time_q', '_cost_q', '_is_throttling')
def __init__(self, maximum, interval, cost_func=None, source=None):
Op.__init__(self, source)
self.status_event = Event('throttle_status')
"""
Sub event that emits ``True`` when throttling starts and ``False``
when throttling ends.
"""
self._maximum = maximum
self._interval = interval
self._cost_func = cost_func
self._q = deque() # deque of (args, cost) tuples
self._time_q = deque() # deque of previous emit times
self._cost_q = deque() # deque of costs of previous emits
self._is_throttling = False
def set_limit(self, maximum, interval):
"""
Dynamically update the ``maximum`` per ``interval`` limit.
"""
self._maximum = maximum
self._interval = interval
def on_source(self, *args):
cost = self._cost_func
if cost is not None:
cost = cost(*args)
self._q.append((args, cost))
self._try_emit()
def on_source_done(self, source):
self._disconnect_from(source)
self._source = None
if not self._q:
self.set_done()
self.status_event.set_done()
def _try_emit(self):
loop = get_event_loop()
t = loop.time()
q = self._q
times = self._time_q
costs = self._cost_q
# forget old emit times
while times and t - times[0] > self._interval:
times.popleft()
costs.popleft()
# emit values while not exceeding the limit
while q:
args, cost = q[0]
if self._cost_func:
cost = self._cost_func(*args)
total_cost = cost + sum(costs)
else:
cost = None
total_cost = 1 + len(costs)
if self._maximum and total_cost >= self._maximum:
break
args, cost = q.popleft()
times.append(t)
costs.append(cost)
self.emit(*args)
# update status and schedule new emits
if q:
if not self._is_throttling:
self.status_event.emit(True)
loop.call_at(times[0] + self._interval, self._try_emit)
elif self._is_throttling:
self.status_event.emit(False)
self._is_throttling = bool(q)
if not q and self._source is None:
self.set_done()
self.status_event.set_done()
class Sample(Op):
__slots__ = ('_timer',)
def __init__(self, timer, source=None):
Op.__init__(self, source)
self._timer = timer
timer.connect(
self._on_timer,
self.on_source_error,
self.on_source_done)
def on_source(self, *args):
self._value = args
def _on_timer(self, *args):
if self._value is not NO_VALUE:
self.emit(*self._value)
def on_source_done(self, source):
Op.on_source_done(self, self._source)
self._timer.disconnect(
self._on_timer,
self.on_source_error,
self.on_source_done)
self._timer = None

View File

@@ -0,0 +1,346 @@
import asyncio
import copy
import time
from collections import deque
from .combine import Chain, Concat, Merge, Switch
from .op import Op
from ..util import NO_VALUE, get_event_loop
class Constant(Op):
__slots__ = ('_constant',)
def __init__(self, constant, source=None):
Op.__init__(self, source)
self._constant = constant
def on_source(self, *args):
self.emit(self._constant)
class Iterate(Op):
__slots__ = ('_it',)
def __init__(self, it, source=None):
Op.__init__(self, source)
self._it = iter(it)
def on_source(self, *args):
try:
value = next(self._it)
self.emit(value)
except StopIteration:
self._disconnect_from(self._source)
self.set_done()
class Enumerate(Op):
__slots__ = ('_step', '_i')
def __init__(self, start=0, step=1, source=None):
Op.__init__(self, source)
self._i = start
self._step = step
def on_source(self, *args):
self.emit(
self._i,
args[0] if len(args) == 1 else args if args else NO_VALUE)
self._i += self._step
class Timestamp(Op):
__slots__ = ()
def on_source(self, *args):
self.emit(
time.time(),
args[0] if len(args) == 1 else args if args else NO_VALUE)
class Partial(Op):
__slots__ = ('_left_args',)
def __init__(self, *left_args, source=None):
Op.__init__(self, source)
self._left_args = left_args
def on_source(self, *args):
self.emit(*(self._left_args + args))
class PartialRight(Op):
__slots__ = ('_right_args',)
def __init__(self, *right_args, source=None):
Op.__init__(self, source)
self._right_args = right_args
def on_source(self, *args):
self.emit(*(args + self._right_args))
class Star(Op):
__slots__ = ()
def on_source(self, arg):
self.emit(*arg)
class Pack(Op):
__slots__ = ()
def on_source(self, *args):
self.emit(args)
class Pluck(Op):
__slots__ = ('_selections',)
def __init__(self, *selections, source=None):
Op.__init__(self, source)
self._selections = [] # list of [arg-index, *sub-attributes]
for sel in selections:
if type(sel) is int:
s = [sel]
else:
s = sel.split('.')
if s[0].isdigit():
s[0] = int(s[0])
elif s[0] == '':
s[0] = 0
else:
s.insert(0, 0)
self._selections.append(s)
def on_source(self, *args):
values = []
for s in self._selections:
try:
value = args[s[0]]
for attr in s[1:]:
value = getattr(value, attr)
except Exception:
value = NO_VALUE
values.append(value)
self.emit(*values)
class Previous(Op):
__slots__ = ('_count', '_q')
def __init__(self, count=1, source=None):
Op.__init__(self, source)
self._count = count
self._q = deque()
def on_source(self, *args):
self._q.append(args)
if len(self._q) > self._count:
self.emit(*self._q.popleft())
class Copy(Op):
__slots__ = ()
def on_source(self, *args):
self.emit(*(copy.copy(a) for a in args))
class Deepcopy(Op):
__slots__ = ()
def on_source(self, *args):
self.emit(*copy.deepcopy(args))
class Chunk(Op):
__slots__ = ('_size', '_list')
def __init__(self, size, source=None):
Op.__init__(self, source)
self._size = size
self._list = []
def on_source(self, *args):
self._list.append(
args[0] if len(args) == 1 else args if args else NO_VALUE)
if len(self._list) == self._size:
self.emit(self._list)
self._list = []
def on_source_done(self, source):
if self._list:
self.emit(self._list)
Op.on_source_done(self, self._source)
class ChunkWith(Op):
__slots__ = ('_timer', '_list', '_emit_empty')
def __init__(self, timer, emit_empty, source=None):
Op.__init__(self, source)
self._timer = timer
self._emit_empty = emit_empty
self._list = []
timer.connect(
self._on_timer,
self.on_source_error,
self.on_source_done)
def on_source(self, *args):
self._list.append(
args[0] if len(args) == 1 else args if args else NO_VALUE)
def _on_timer(self, *args):
if self._list or self._emit_empty:
self.emit(self._list)
self._list = []
def on_source_done(self, source):
if self._list:
self.emit(self._list)
self._list = None
if self._timer is not None:
self._timer.disconnect(
self._on_timer,
self.on_source_error,
self.on_source_done)
self._timer = None
Op.on_source_done(self, self._source)
class Map(Op):
__slots__ = (
'_func', '_timeout', '_ordered', '_task_limit', '_coro_q', '_tasks')
def __init__(
self, func, timeout=0, ordered=True, task_limit=None, source=None):
Op.__init__(self, source)
if source is not None and source.done():
return
self._func = func
self._timeout = timeout
self._ordered = ordered
self._task_limit = task_limit
self._coro_q = deque()
self._tasks = deque()
def on_source(self, *args):
obj = self._func(*args)
if hasattr(obj, '__await__'):
# function returns an awaitable
if not self._task_limit or len(self._tasks) < self._task_limit:
# schedule right away
self._create_task(obj)
else:
# queue for later
self._coro_q.append(obj)
else:
# regular function returns the result directly
self.emit(obj)
def on_source_done(self, source):
if not self._tasks:
# only end when no tasks are pending
Op.on_source_done(self, self._source)
self._source = None
def _create_task(self, coro):
# schedule a task to be run
if self._timeout:
coro = asyncio.wait_for(coro, self._timeout)
loop = get_event_loop()
task = asyncio.ensure_future(coro, loop=loop)
task.add_done_callback(self._on_task_done)
self._tasks.append(task)
def _on_task_done(self, task):
# handle task result
tasks = self._tasks
if self._ordered:
while tasks and tasks[0].done():
# remove task after emitting result
task = tasks[0]
self._emit_task(task)
task = tasks.popleft()
else:
# remove task after emitting result
self._emit_task(task)
tasks.remove(task)
# schedule pending awaitables from the queue
while self._coro_q and (
not self._task_limit or len(tasks) < self._task_limit):
self._create_task(self._coro_q.popleft())
# end when source has ended with no pending tasks
if not tasks and self._source is None:
Op.on_source_done(self, self._source)
def _emit_task(self, task):
try:
result = task.result()
except Exception as error:
result = NO_VALUE
self.error_event.emit(error)
self.emit(result)
class Emap(Op):
__slots__ = ('_constr', '_joiner',)
def __init__(self, constr, joiner, source=None):
Op.__init__(self, source)
self._constr = constr
self._joiner = joiner
joiner.set_parent(source)
joiner.connect(
self.emit,
self.error_event.emit,
self._on_joiner_done)
def on_source(self, *args):
obj = self._constr(*args)
event = self.create(obj)
self._joiner.add_source(event)
def on_source_done(self, source):
pass
def _on_joiner_done(self, joiner):
joiner.disconnect(
self.emit,
self.error_event.emit,
self._on_joiner_done)
self._joiner = None
self.set_done()
class Mergemap(Emap):
__slots__ = ()
def __init__(self, constr, source=None):
Emap.__init__(self, constr, Merge(), source)
class Chainmap(Emap):
__slots__ = ()
def __init__(self, constr, source=None):
Emap.__init__(self, constr, Chain(), source)
class Concatmap(Emap):
__slots__ = ()
def __init__(self, constr, source=None):
Emap.__init__(self, constr, Concat(), source)
class Switchmap(Emap):
__slots__ = ()
def __init__(self, constr, source=None):
Emap.__init__(self, constr, Switch(), source)

View File

@@ -0,0 +1,78 @@
import asyncio
import datetime as dt
from typing import AsyncIterator
class _NoValue:
def __bool__(self):
return False
def __repr__(self):
return '<NoValue>'
__str__ = __repr__
NO_VALUE = _NoValue()
def get_event_loop():
"""Get asyncio event loop, running or not."""
return asyncio.get_event_loop_policy().get_event_loop()
main_event_loop = get_event_loop()
async def timerange(start=0, end=None, step: float = 1) \
-> AsyncIterator[dt.datetime]:
"""
Iterator that waits periodically until certain time points are
reached while yielding those time points.
Args:
start: Start time, can be specified as:
* ``datetime.datetime``.
* ``datetime.time``: Today is used as date.
* ``int`` or ``float``: Number of seconds relative to now.
Values will be quantized to the given step.
end: End time, can be specified as:
* ``datetime.datetime``.
* ``datetime.time``: Today is used as date.
* ``None``: No end limit.
step: Number of seconds, or ``datetime.timedelta``,
to space between values.
"""
tz = getattr(start, 'tzinfo', None)
now = dt.datetime.now(tz)
if isinstance(step, dt.timedelta):
delta = step
step = delta.total_seconds()
else:
delta = dt.timedelta(seconds=step)
t = start
if t == 0 or isinstance(t, (int, float)):
t = now + dt.timedelta(seconds=t)
# quantize to step
t = dt.datetime.fromtimestamp(
step * int(t.timestamp() / step))
elif isinstance(t, dt.time):
t = dt.datetime.combine(now.today(), t)
if t < now:
# t += delta
t -= ((t - now) // delta) * delta
if isinstance(end, dt.time):
end = dt.datetime.combine(now.today(), end)
elif isinstance(end, (int, float)):
end = now + dt.timedelta(seconds=end)
while end is None or t <= end:
now = dt.datetime.now(tz)
secs = (t - now).total_seconds()
await asyncio.sleep(secs)
yield t
t += delta

View File

@@ -0,0 +1,2 @@
__version_info__ = (1, 0, 3)
__version__ = '.'.join(str(v) for v in __version_info__)

View File

@@ -0,0 +1,25 @@
BSD 2-Clause License
Copyright (c) 2019, Ewald de Wit
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -0,0 +1,162 @@
Metadata-Version: 2.1
Name: ib-insync
Version: 0.9.86
Summary: Python sync/async framework for Interactive Brokers API
Home-page: https://github.com/erdewit/ib_insync
Author: Ewald R. de Wit
Author-email: ewald.de.wit@gmail.com
License: BSD
Keywords: ibapi tws asyncio jupyter interactive brokers async
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Topic :: Office/Business :: Financial :: Investment
Classifier: License :: OSI Approved :: BSD License
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3 :: Only
Requires-Python: >=3.6
License-File: LICENSE
Requires-Dist: eventkit
Requires-Dist: nest-asyncio
Requires-Dist: dataclasses ; python_version < "3.7"
Requires-Dist: backports.zoneinfo ; python_version < "3.9"
|Build| |Group| |PyVersion| |Status| |PyPiVersion| |CondaVersion| |License| |Downloads| |Docs|
Introduction
============
The goal of the IB-insync library is to make working with the
`Trader Workstation API <http://interactivebrokers.github.io/tws-api/>`_
from Interactive Brokers as easy as possible.
The main features are:
* An easy to use linear style of programming;
* An `IB component <https://ib-insync.readthedocs.io/api.html#module-ib_insync.ib>`_
that automatically keeps in sync with the TWS or IB Gateway application;
* A fully asynchonous framework based on
`asyncio <https://docs.python.org/3/library/asyncio.html>`_
and
`eventkit <https://github.com/erdewit/eventkit>`_
for advanced users;
* Interactive operation with live data in Jupyter notebooks.
Be sure to take a look at the
`notebooks <https://ib-insync.readthedocs.io/notebooks.html>`_,
the `recipes <https://ib-insync.readthedocs.io/recipes.html>`_
and the `API docs <https://ib-insync.readthedocs.io/api.html>`_.
Installation
------------
::
pip install ib_insync
Requirements:
* Python 3.6 or higher;
* A running TWS or IB Gateway application (version 1023 or higher).
Make sure the
`API port is enabled <https://interactivebrokers.github.io/tws-api/initial_setup.html>`_
and 'Download open orders on connection' is checked.
The ibapi package from IB is not needed.
Example
-------
This is a complete script to download historical data:
.. code-block:: python
from ib_insync import *
# util.startLoop() # uncomment this line when in a notebook
ib = IB()
ib.connect('127.0.0.1', 7497, clientId=1)
contract = Forex('EURUSD')
bars = ib.reqHistoricalData(
contract, endDateTime='', durationStr='30 D',
barSizeSetting='1 hour', whatToShow='MIDPOINT', useRTH=True)
# convert to pandas dataframe (pandas needs to be installed):
df = util.df(bars)
print(df)
Output::
date open high low close volume \
0 2019-11-19 23:15:00 1.107875 1.108050 1.107725 1.107825 -1
1 2019-11-20 00:00:00 1.107825 1.107925 1.107675 1.107825 -1
2 2019-11-20 01:00:00 1.107825 1.107975 1.107675 1.107875 -1
3 2019-11-20 02:00:00 1.107875 1.107975 1.107025 1.107225 -1
4 2019-11-20 03:00:00 1.107225 1.107725 1.107025 1.107525 -1
.. ... ... ... ... ... ...
705 2020-01-02 14:00:00 1.119325 1.119675 1.119075 1.119225 -1
Documentation
-------------
The complete `API documentation <https://ib-insync.readthedocs.io/api.html>`_.
`Changelog <https://ib-insync.readthedocs.io/changelog.html>`_.
Discussion
----------
The `insync user group <https://groups.io/g/insync>`_ is the place to discuss
IB-insync and anything related to it.
Disclaimer
----------
The software is provided on the conditions of the simplified BSD license.
This project is not affiliated with Interactive Brokers Group, Inc.'s.
Good luck and enjoy,
:author: Ewald de Wit <ewald.de.wit@gmail.com>
.. _`Interactive Brokers Python API`: http://interactivebrokers.github.io
.. |Group| image:: https://img.shields.io/badge/groups.io-insync-green.svg
:alt: Join the user group
:target: https://groups.io/g/insync
.. |PyPiVersion| image:: https://img.shields.io/pypi/v/ib_insync.svg
:alt: PyPi
:target: https://pypi.python.org/pypi/ib_insync
.. |CondaVersion| image:: https://img.shields.io/conda/vn/conda-forge/ib-insync.svg
:alt: Conda
:target: https://anaconda.org/conda-forge/ib-insync
.. |PyVersion| image:: https://img.shields.io/badge/python-3.6+-blue.svg
:alt:
.. |Status| image:: https://img.shields.io/badge/status-beta-green.svg
:alt:
.. |License| image:: https://img.shields.io/badge/license-BSD-blue.svg
:alt:
.. |Docs| image:: https://img.shields.io/badge/Documentation-green.svg
:alt: Documentation
:target: https://ib-insync.readthedocs.io/api.html
.. |Downloads| image:: https://pepy.tech/badge/ib-insync
:alt: Number of downloads
:target: https://pepy.tech/project/ib-insync
.. |Build| image:: https://github.com/erdewit/ib_insync/actions/workflows/test.yml/badge.svg?branch=master
:target: https://github.com/erdewit/ib_insync/actions

View File

@@ -0,0 +1,36 @@
ib_insync-0.9.86.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
ib_insync-0.9.86.dist-info/LICENSE,sha256=wVXmemzM4v_VlNCzedi7qFodW6NAhMiDQaSAy10WxtU,1317
ib_insync-0.9.86.dist-info/METADATA,sha256=zgqOVqK0KlHbwOxnc_0XlLCC6V1pwlxkzknbpK5_Akg,5369
ib_insync-0.9.86.dist-info/RECORD,,
ib_insync-0.9.86.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
ib_insync-0.9.86.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
ib_insync-0.9.86.dist-info/top_level.txt,sha256=b9ruIFE0O0uehMGcAzo9k5WBt1ZeHmukMQrF_Qaq0J4,10
ib_insync/__init__.py,sha256=bkuPG1U_EIGJ8WjiY2vxdDVpDisj7yr6aI3IsAgrcQ0,3611
ib_insync/__pycache__/__init__.cpython-311.pyc,,
ib_insync/__pycache__/client.cpython-311.pyc,,
ib_insync/__pycache__/connection.cpython-311.pyc,,
ib_insync/__pycache__/contract.cpython-311.pyc,,
ib_insync/__pycache__/decoder.cpython-311.pyc,,
ib_insync/__pycache__/flexreport.cpython-311.pyc,,
ib_insync/__pycache__/ib.cpython-311.pyc,,
ib_insync/__pycache__/ibcontroller.cpython-311.pyc,,
ib_insync/__pycache__/objects.cpython-311.pyc,,
ib_insync/__pycache__/order.cpython-311.pyc,,
ib_insync/__pycache__/ticker.cpython-311.pyc,,
ib_insync/__pycache__/util.cpython-311.pyc,,
ib_insync/__pycache__/version.cpython-311.pyc,,
ib_insync/__pycache__/wrapper.cpython-311.pyc,,
ib_insync/client.py,sha256=HYaYnWgViB6ohXbo473PNP5A0dp3jf7mhhnugOU6KCI,33172
ib_insync/connection.py,sha256=gXxcI_1xPw4XHtmAKpGN6grxHsEeoKiTSz6H6JteCAk,1654
ib_insync/contract.py,sha256=o9jSUcTwmiM8nP1asKq8mRXesUzbh42kOlbFk6KU87A,17102
ib_insync/decoder.py,sha256=sGy-JOyjfjOG0KshPetD5JMXHoy7fCYGDqSTb5f3m0w,40909
ib_insync/flexreport.py,sha256=91rYpgKMLMshpyW7_S4M-sBgA1c7S1hvXw3_i0tTwQQ,4009
ib_insync/ib.py,sha256=tfYCIMlBozXEfLsCV6NC4xIrHwfMn4CaNqDnybS6hPw,85631
ib_insync/ibcontroller.py,sha256=Mod0rBATMFeZEjad89LGqfJPkinIMJDXQpyrxMOots4,12510
ib_insync/objects.py,sha256=ktzQWpRSH6tUTcneEtHD5kunfoqEkllC6cWj2hK6J0Y,9817
ib_insync/order.py,sha256=8BSwY7iZ_Z7BzUvUz5zINN4O-4nAib3rZzwJs5ITHKc,12368
ib_insync/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
ib_insync/ticker.py,sha256=7pEKbvPGy9axCyZwOLm_6Syh1IPkGytpCJoQSQyVSd0,10642
ib_insync/util.py,sha256=hsclDyRUVyGk_SMOdHsNAeZssfikrRLhUsenDUs32WQ,16122
ib_insync/version.py,sha256=Ny-bTPSSnRRmaGV_vw1sp_cJarnJFQJJJELra6-cO98,108
ib_insync/wrapper.py,sha256=TtcpkQaXnAslGGWNTT7zcf5aObJUWGJUZQ7cNSF8sME,46713

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: bdist_wheel (0.38.4)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,75 @@
"""Python sync/async framework for Interactive Brokers API"""
import dataclasses
import sys
from eventkit import Event
from . import util
from .client import Client
from .contract import (
Bag, Bond, CFD, ComboLeg, Commodity, ContFuture, Contract,
ContractDescription, ContractDetails, Crypto, DeltaNeutralContract,
Forex, Future, FuturesOption, Index, MutualFund, Option, ScanData, Stock,
TagValue, Warrant)
from .flexreport import FlexError, FlexReport
from .ib import IB
from .ibcontroller import IBC, Watchdog
from .objects import (
AccountValue, BarData, BarDataList, CommissionReport, ConnectionStats,
DOMLevel, DepthMktDataDescription, Dividends, Execution, ExecutionFilter,
FamilyCode, Fill, FundamentalRatios, HistogramData, HistoricalNews,
HistoricalSchedule, HistoricalSession, HistoricalTick,
HistoricalTickBidAsk, HistoricalTickLast, MktDepthData, NewsArticle,
NewsBulletin, NewsProvider, NewsTick, OptionChain, OptionComputation,
PnL, PnLSingle, PortfolioItem,
Position, PriceIncrement, RealTimeBar, RealTimeBarList, ScanDataList,
ScannerSubscription, SmartComponent, SoftDollarTier, TickAttrib,
TickAttribBidAsk, TickAttribLast, TickByTickAllLast, TickByTickBidAsk,
TickByTickMidPoint, TickData, TradeLogEntry, WshEventData)
from .order import (
BracketOrder, ExecutionCondition, LimitOrder, MarginCondition, MarketOrder,
Order, OrderComboLeg, OrderCondition, OrderState, OrderStatus,
PercentChangeCondition, PriceCondition, StopLimitOrder, StopOrder,
TimeCondition, Trade, VolumeCondition)
from .ticker import Ticker
from .version import __version__, __version_info__
from .wrapper import RequestError, Wrapper
__all__ = [
'Event', 'util', 'Client',
'Bag', 'Bond', 'CFD', 'ComboLeg', 'Commodity', 'ContFuture', 'Contract',
'ContractDescription', 'ContractDetails', 'Crypto', 'DeltaNeutralContract',
'Forex', 'Future', 'FuturesOption', 'Index', 'MutualFund', 'Option',
'ScanData', 'Stock', 'TagValue', 'Warrant', 'FlexError', 'FlexReport',
'IB', 'IBC', 'Watchdog',
'AccountValue', 'BarData', 'BarDataList', 'CommissionReport',
'ConnectionStats', 'DOMLevel', 'DepthMktDataDescription', 'Dividends',
'Execution', 'ExecutionFilter', 'FamilyCode', 'Fill', 'FundamentalRatios',
'HistogramData', 'HistoricalNews', 'HistoricalTick',
'HistoricalTickBidAsk', 'HistoricalTickLast',
'HistoricalSchedule', 'HistoricalSession', 'MktDepthData',
'NewsArticle', 'NewsBulletin', 'NewsProvider', 'NewsTick', 'OptionChain',
'OptionComputation', 'PnL', 'PnLSingle', 'PortfolioItem', 'Position',
'PriceIncrement', 'RealTimeBar', 'RealTimeBarList', 'ScanDataList',
'ScannerSubscription', 'SmartComponent', 'SoftDollarTier', 'TickAttrib',
'TickAttribBidAsk', 'TickAttribLast', 'TickByTickAllLast', 'WshEventData',
'TickByTickBidAsk', 'TickByTickMidPoint', 'TickData', 'TradeLogEntry',
'BracketOrder', 'ExecutionCondition', 'LimitOrder', 'MarginCondition',
'MarketOrder', 'Order', 'OrderComboLeg', 'OrderCondition', 'OrderState',
'OrderStatus', 'PercentChangeCondition', 'PriceCondition',
'StopLimitOrder', 'StopOrder', 'TimeCondition', 'Trade', 'VolumeCondition',
'Ticker', '__version__', '__version_info__', 'RequestError', 'Wrapper'
]
# compatibility with old Object
for obj in locals().copy().values():
if dataclasses.is_dataclass(obj):
obj.dict = util.dataclassAsDict
obj.tuple = util.dataclassAsTuple
obj.update = util.dataclassUpdate
obj.nonDefaults = util.dataclassNonDefaults
del sys
del dataclasses

View File

@@ -0,0 +1,996 @@
"""Socket client for communicating with Interactive Brokers."""
import asyncio
import io
import logging
import math
import struct
import time
from collections import deque
from typing import Deque, List, Optional
from eventkit import Event
from .connection import Connection
from .contract import Contract
from .decoder import Decoder
from .objects import ConnectionStats, WshEventData
from .util import UNSET_DOUBLE, UNSET_INTEGER, dataclassAsTuple, getLoop, run
class Client:
"""
Replacement for ``ibapi.client.EClient`` that uses asyncio.
The client is fully asynchronous and has its own
event-driven networking code that replaces the
networking code of the standard EClient.
It also replaces the infinite loop of ``EClient.run()``
with the asyncio event loop. It can be used as a drop-in
replacement for the standard EClient as provided by IBAPI.
Compared to the standard EClient this client has the following
additional features:
* ``client.connect()`` will block until the client is ready to
serve requests; It is not necessary to wait for ``nextValidId``
to start requests as the client has already done that.
The reqId is directly available with :py:meth:`.getReqId()`.
* ``client.connectAsync()`` is a coroutine for connecting asynchronously.
* When blocking, ``client.connect()`` can be made to time out with
the timeout parameter (default 2 seconds).
* Optional ``wrapper.priceSizeTick(reqId, tickType, price, size)`` that
combines price and size instead of the two wrapper methods
priceTick and sizeTick.
* Automatic request throttling.
* Optional ``wrapper.tcpDataArrived()`` method;
If the wrapper has this method it is invoked directly after
a network packet has arrived.
A possible use is to timestamp all data in the packet with
the exact same time.
* Optional ``wrapper.tcpDataProcessed()`` method;
If the wrapper has this method it is invoked after the
network packet's data has been handled.
A possible use is to write or evaluate the newly arrived data in
one batch instead of item by item.
Parameters:
MaxRequests (int):
Throttle the number of requests to ``MaxRequests`` per
``RequestsInterval`` seconds. Set to 0 to disable throttling.
RequestsInterval (float):
Time interval (in seconds) for request throttling.
MinClientVersion (int):
Client protocol version.
MaxClientVersion (int):
Client protocol version.
Events:
* ``apiStart`` ()
* ``apiEnd`` ()
* ``apiError`` (errorMsg: str)
* ``throttleStart`` ()
* ``throttleEnd`` ()
"""
events = ('apiStart', 'apiEnd', 'apiError', 'throttleStart', 'throttleEnd')
MaxRequests = 45
RequestsInterval = 1
MinClientVersion = 157
MaxClientVersion = 176
(DISCONNECTED, CONNECTING, CONNECTED) = range(3)
def __init__(self, wrapper):
self.wrapper = wrapper
self.decoder = Decoder(wrapper, 0)
self.apiStart = Event('apiStart')
self.apiEnd = Event('apiEnd')
self.apiError = Event('apiError')
self.throttleStart = Event('throttleStart')
self.throttleEnd = Event('throttleEnd')
self._logger = logging.getLogger('ib_insync.client')
self.conn = Connection()
self.conn.hasData += self._onSocketHasData
self.conn.disconnected += self._onSocketDisconnected
# extra optional wrapper methods
self._priceSizeTick = getattr(wrapper, 'priceSizeTick', None)
self._tcpDataArrived = getattr(wrapper, 'tcpDataArrived', None)
self._tcpDataProcessed = getattr(wrapper, 'tcpDataProcessed', None)
self.host = ''
self.port = -1
self.clientId = -1
self.optCapab = ''
self.connectOptions = b''
self.reset()
def reset(self):
self.connState = Client.DISCONNECTED
self._apiReady = False
self._serverVersion = 0
self._data = b''
self._hasReqId = False
self._reqIdSeq = 0
self._accounts = []
self._startTime = time.time()
self._numBytesRecv = 0
self._numMsgRecv = 0
self._isThrottling = False
self._msgQ: Deque[str] = deque()
self._timeQ: Deque[float] = deque()
def serverVersion(self) -> int:
return self._serverVersion
def run(self):
loop = getLoop()
loop.run_forever()
def isConnected(self):
return self.connState == Client.CONNECTED
def isReady(self) -> bool:
"""Is the API connection up and running?"""
return self._apiReady
def connectionStats(self) -> ConnectionStats:
"""Get statistics about the connection."""
if not self.isReady():
raise ConnectionError('Not connected')
return ConnectionStats(
self._startTime,
time.time() - self._startTime,
self._numBytesRecv, self.conn.numBytesSent,
self._numMsgRecv, self.conn.numMsgSent)
def getReqId(self) -> int:
"""Get new request ID."""
if not self.isReady():
raise ConnectionError('Not connected')
newId = self._reqIdSeq
self._reqIdSeq += 1
return newId
def updateReqId(self, minReqId):
"""Update the next reqId to be at least ``minReqId``."""
self._reqIdSeq = max(self._reqIdSeq, minReqId)
def getAccounts(self) -> List[str]:
"""Get the list of account names that are under management."""
if not self.isReady():
raise ConnectionError('Not connected')
return self._accounts
def setConnectOptions(self, connectOptions: str):
"""
Set additional connect options.
Args:
connectOptions: Use "+PACEAPI" to use request-pacing built
into TWS/gateway 974+ (obsolete).
"""
self.connectOptions = connectOptions.encode()
def connect(
self, host: str, port: int, clientId: int,
timeout: Optional[float] = 2.0):
"""
Connect to a running TWS or IB gateway application.
Args:
host: Host name or IP address.
port: Port number.
clientId: ID number to use for this client; must be unique per
connection.
timeout: If establishing the connection takes longer than
``timeout`` seconds then the ``asyncio.TimeoutError`` exception
is raised. Set to 0 to disable timeout.
"""
run(self.connectAsync(host, port, clientId, timeout))
async def connectAsync(self, host, port, clientId, timeout=2.0):
try:
self._logger.info(
f'Connecting to {host}:{port} with clientId {clientId}...')
self.host = host
self.port = int(port)
self.clientId = int(clientId)
self.connState = Client.CONNECTING
timeout = timeout or None
await asyncio.wait_for(self.conn.connectAsync(host, port), timeout)
self._logger.info('Connected')
msg = b'API\0' + self._prefix(b'v%d..%d%s' % (
self.MinClientVersion, self.MaxClientVersion,
b' ' + self.connectOptions if self.connectOptions else b''))
self.conn.sendMsg(msg)
await asyncio.wait_for(self.apiStart, timeout)
self._logger.info('API connection ready')
except BaseException as e:
self.disconnect()
msg = f'API connection failed: {e!r}'
self._logger.error(msg)
self.apiError.emit(msg)
if isinstance(e, ConnectionRefusedError):
self._logger.error('Make sure API port on TWS/IBG is open')
raise
def disconnect(self):
"""Disconnect from IB connection."""
self._logger.info('Disconnecting')
self.connState = Client.DISCONNECTED
self.conn.disconnect()
self.reset()
def send(self, *fields, makeEmpty=True):
"""Serialize and send the given fields using the IB socket protocol."""
if not self.isConnected():
raise ConnectionError('Not connected')
msg = io.StringIO()
empty = (None, UNSET_INTEGER, UNSET_DOUBLE) if makeEmpty else (None,)
for field in fields:
typ = type(field)
if field in empty:
s = ''
elif typ is str:
s = field
elif type is int:
s = str(field)
elif typ is float:
s = 'Infinite' if field == math.inf else str(field)
elif typ is bool:
s = '1' if field else '0'
elif typ is list:
# list of TagValue
s = ''.join(f'{v.tag}={v.value};' for v in field)
elif isinstance(field, Contract):
c = field
s = '\0'.join(str(f) for f in (
c.conId, c.symbol, c.secType,
c.lastTradeDateOrContractMonth, c.strike,
c.right, c.multiplier, c.exchange,
c.primaryExchange, c.currency,
c.localSymbol, c.tradingClass))
else:
s = str(field)
msg.write(s)
msg.write('\0')
self.sendMsg(msg.getvalue())
def sendMsg(self, msg: str):
loop = getLoop()
t = loop.time()
times = self._timeQ
msgs = self._msgQ
while times and t - times[0] > self.RequestsInterval:
times.popleft()
if msg:
msgs.append(msg)
while msgs and (len(times) < self.MaxRequests or not self.MaxRequests):
msg = msgs.popleft()
self.conn.sendMsg(self._prefix(msg.encode()))
times.append(t)
if self._logger.isEnabledFor(logging.DEBUG):
self._logger.debug('>>> %s', msg[:-1].replace('\0', ','))
if msgs:
if not self._isThrottling:
self._isThrottling = True
self.throttleStart.emit()
self._logger.debug('Started to throttle requests')
loop.call_at(
times[0] + self.RequestsInterval,
self.sendMsg, None)
else:
if self._isThrottling:
self._isThrottling = False
self.throttleEnd.emit()
self._logger.debug('Stopped to throttle requests')
def _prefix(self, msg):
# prefix a message with its length
return struct.pack('>I', len(msg)) + msg
def _onSocketHasData(self, data):
debug = self._logger.isEnabledFor(logging.DEBUG)
if self._tcpDataArrived:
self._tcpDataArrived()
self._data += data
self._numBytesRecv += len(data)
while True:
if len(self._data) <= 4:
break
# 4 byte prefix tells the message length
msgEnd = 4 + struct.unpack('>I', self._data[:4])[0]
if len(self._data) < msgEnd:
# insufficient data for now
break
msg = self._data[4:msgEnd].decode(errors='backslashreplace')
self._data = self._data[msgEnd:]
fields = msg.split('\0')
fields.pop() # pop off last empty element
self._numMsgRecv += 1
if debug:
self._logger.debug('<<< %s', ','.join(fields))
if not self._serverVersion and len(fields) == 2:
# this concludes the handshake
version, _connTime = fields
self._serverVersion = int(version)
if self._serverVersion < self.MinClientVersion:
self._onSocketDisconnected(
'TWS/gateway version must be >= 972')
return
self.decoder.serverVersion = self._serverVersion
self.connState = Client.CONNECTED
self.startApi()
self.wrapper.connectAck()
self._logger.info(
f'Logged on to server version {self._serverVersion}')
else:
if not self._apiReady:
# snoop for nextValidId and managedAccounts response,
# when both are in then the client is ready
msgId = int(fields[0])
if msgId == 9:
_, _, validId = fields
self.updateReqId(int(validId))
self._hasReqId = True
elif msgId == 15:
_, _, accts = fields
self._accounts = [a for a in accts.split(',') if a]
if self._hasReqId and self._accounts:
self._apiReady = True
self.apiStart.emit()
# decode and handle the message
self.decoder.interpret(fields)
if self._tcpDataProcessed:
self._tcpDataProcessed()
def _onSocketDisconnected(self, msg):
wasReady = self.isReady()
if not self.isConnected():
self._logger.info('Disconnected.')
elif not msg:
msg = 'Peer closed connection.'
if not wasReady:
msg += f' clientId {self.clientId} already in use?'
if msg:
self._logger.error(msg)
self.apiError.emit(msg)
self.wrapper.setEventsDone()
if wasReady:
self.wrapper.connectionClosed()
self.reset()
if wasReady:
self.apiEnd.emit()
# client request methods
# the message type id is sent first, often followed by a version number
def reqMktData(
self, reqId, contract, genericTickList, snapshot,
regulatorySnapshot, mktDataOptions):
fields = [1, 11, reqId, contract]
if contract.secType == 'BAG':
legs = contract.comboLegs or []
fields += [len(legs)]
for leg in legs:
fields += [leg.conId, leg.ratio, leg.action, leg.exchange]
dnc = contract.deltaNeutralContract
if dnc:
fields += [True, dnc.conId, dnc.delta, dnc.price]
else:
fields += [False]
fields += [
genericTickList, snapshot, regulatorySnapshot, mktDataOptions]
self.send(*fields)
def cancelMktData(self, reqId):
self.send(2, 2, reqId)
def placeOrder(self, orderId, contract, order):
version = self.serverVersion()
fields = [
3, orderId,
contract,
contract.secIdType,
contract.secId,
order.action,
order.totalQuantity,
order.orderType,
order.lmtPrice,
order.auxPrice,
order.tif,
order.ocaGroup,
order.account,
order.openClose,
order.origin,
order.orderRef,
order.transmit,
order.parentId,
order.blockOrder,
order.sweepToFill,
order.displaySize,
order.triggerMethod,
order.outsideRth,
order.hidden]
if contract.secType == 'BAG':
legs = contract.comboLegs or []
fields += [len(legs)]
for leg in legs:
fields += [
leg.conId,
leg.ratio,
leg.action,
leg.exchange,
leg.openClose,
leg.shortSaleSlot,
leg.designatedLocation,
leg.exemptCode]
legs = order.orderComboLegs or []
fields += [len(legs)]
for leg in legs:
fields += [leg.price]
params = order.smartComboRoutingParams or []
fields += [len(params)]
for param in params:
fields += [param.tag, param.value]
fields += [
'',
order.discretionaryAmt,
order.goodAfterTime,
order.goodTillDate,
order.faGroup,
order.faMethod,
order.faPercentage,
order.faProfile,
order.modelCode,
order.shortSaleSlot,
order.designatedLocation,
order.exemptCode,
order.ocaType,
order.rule80A,
order.settlingFirm,
order.allOrNone,
order.minQty,
order.percentOffset,
order.eTradeOnly,
order.firmQuoteOnly,
order.nbboPriceCap,
order.auctionStrategy,
order.startingPrice,
order.stockRefPrice,
order.delta,
order.stockRangeLower,
order.stockRangeUpper,
order.overridePercentageConstraints,
order.volatility,
order.volatilityType,
order.deltaNeutralOrderType,
order.deltaNeutralAuxPrice]
if order.deltaNeutralOrderType:
fields += [
order.deltaNeutralConId,
order.deltaNeutralSettlingFirm,
order.deltaNeutralClearingAccount,
order.deltaNeutralClearingIntent,
order.deltaNeutralOpenClose,
order.deltaNeutralShortSale,
order.deltaNeutralShortSaleSlot,
order.deltaNeutralDesignatedLocation]
fields += [
order.continuousUpdate,
order.referencePriceType,
order.trailStopPrice,
order.trailingPercent,
order.scaleInitLevelSize,
order.scaleSubsLevelSize,
order.scalePriceIncrement]
if (0 < order.scalePriceIncrement < UNSET_DOUBLE):
fields += [
order.scalePriceAdjustValue,
order.scalePriceAdjustInterval,
order.scaleProfitOffset,
order.scaleAutoReset,
order.scaleInitPosition,
order.scaleInitFillQty,
order.scaleRandomPercent]
fields += [
order.scaleTable,
order.activeStartTime,
order.activeStopTime,
order.hedgeType]
if order.hedgeType:
fields += [order.hedgeParam]
fields += [
order.optOutSmartRouting,
order.clearingAccount,
order.clearingIntent,
order.notHeld]
dnc = contract.deltaNeutralContract
if dnc:
fields += [True, dnc.conId, dnc.delta, dnc.price]
else:
fields += [False]
fields += [order.algoStrategy]
if order.algoStrategy:
params = order.algoParams or []
fields += [len(params)]
for param in params:
fields += [param.tag, param.value]
fields += [
order.algoId,
order.whatIf,
order.orderMiscOptions,
order.solicited,
order.randomizeSize,
order.randomizePrice]
if order.orderType == 'PEG BENCH':
fields += [
order.referenceContractId,
order.isPeggedChangeAmountDecrease,
order.peggedChangeAmount,
order.referenceChangeAmount,
order.referenceExchangeId]
fields += [len(order.conditions)]
if order.conditions:
for cond in order.conditions:
fields += dataclassAsTuple(cond)
fields += [
order.conditionsIgnoreRth,
order.conditionsCancelOrder]
fields += [
order.adjustedOrderType,
order.triggerPrice,
order.lmtPriceOffset,
order.adjustedStopPrice,
order.adjustedStopLimitPrice,
order.adjustedTrailingAmount,
order.adjustableTrailingUnit,
order.extOperator,
order.softDollarTier.name,
order.softDollarTier.val,
order.cashQty,
order.mifid2DecisionMaker,
order.mifid2DecisionAlgo,
order.mifid2ExecutionTrader,
order.mifid2ExecutionAlgo,
order.dontUseAutoPriceForHedge,
order.isOmsContainer,
order.discretionaryUpToLimitPrice,
order.usePriceMgmtAlgo]
if version >= 158:
fields += [order.duration]
if version >= 160:
fields += [order.postToAts]
if version >= 162:
fields += [order.autoCancelParent]
if version >= 166:
fields += [order.advancedErrorOverride]
if version >= 169:
fields += [order.manualOrderTime]
if version >= 170:
if contract.exchange == 'IBKRATS':
fields += [order.minTradeQty]
if order.orderType == 'PEG BEST':
fields += [
order.minCompeteSize,
order.competeAgainstBestOffset]
if order.competeAgainstBestOffset == math.inf:
fields += [order.midOffsetAtWhole, order.midOffsetAtHalf]
elif order.orderType == 'PEG MID':
fields += [order.midOffsetAtWhole, order.midOffsetAtHalf]
self.send(*fields)
def cancelOrder(self, orderId, manualCancelOrderTime=''):
fields = [4, 1, orderId]
if self.serverVersion() >= 169:
fields += [manualCancelOrderTime]
self.send(*fields)
def reqOpenOrders(self):
self.send(5, 1)
def reqAccountUpdates(self, subscribe, acctCode):
self.send(6, 2, subscribe, acctCode)
def reqExecutions(self, reqId, execFilter):
self.send(
7, 3, reqId,
execFilter.clientId,
execFilter.acctCode,
execFilter.time,
execFilter.symbol,
execFilter.secType,
execFilter.exchange,
execFilter.side)
def reqIds(self, numIds):
self.send(8, 1, numIds)
def reqContractDetails(self, reqId, contract):
fields = [
9, 8, reqId,
contract,
contract.includeExpired,
contract.secIdType,
contract.secId]
if self.serverVersion() >= 176:
fields += [contract.issuerId]
self.send(*fields)
def reqMktDepth(
self, reqId, contract, numRows, isSmartDepth, mktDepthOptions):
self.send(
10, 5, reqId,
contract.conId,
contract.symbol,
contract.secType,
contract.lastTradeDateOrContractMonth,
contract.strike,
contract.right,
contract.multiplier,
contract.exchange,
contract.primaryExchange,
contract.currency,
contract.localSymbol,
contract.tradingClass,
numRows,
isSmartDepth,
mktDepthOptions)
def cancelMktDepth(self, reqId, isSmartDepth):
self.send(11, 1, reqId, isSmartDepth)
def reqNewsBulletins(self, allMsgs):
self.send(12, 1, allMsgs)
def cancelNewsBulletins(self):
self.send(13, 1)
def setServerLogLevel(self, logLevel):
self.send(14, 1, logLevel)
def reqAutoOpenOrders(self, bAutoBind):
self.send(15, 1, bAutoBind)
def reqAllOpenOrders(self):
self.send(16, 1)
def reqManagedAccts(self):
self.send(17, 1)
def requestFA(self, faData):
self.send(18, 1, faData)
def replaceFA(self, reqId, faData, cxml):
self.send(19, 1, faData, cxml, reqId)
def reqHistoricalData(
self, reqId, contract, endDateTime, durationStr, barSizeSetting,
whatToShow, useRTH, formatDate, keepUpToDate, chartOptions):
fields = [
20, reqId, contract, contract.includeExpired,
endDateTime, barSizeSetting, durationStr, useRTH,
whatToShow, formatDate]
if contract.secType == 'BAG':
legs = contract.comboLegs or []
fields += [len(legs)]
for leg in legs:
fields += [leg.conId, leg.ratio, leg.action, leg.exchange]
fields += [keepUpToDate, chartOptions]
self.send(*fields)
def exerciseOptions(
self, reqId, contract, exerciseAction,
exerciseQuantity, account, override):
self.send(
21, 2, reqId,
contract.conId,
contract.symbol,
contract.secType,
contract.lastTradeDateOrContractMonth,
contract.strike,
contract.right,
contract.multiplier,
contract.exchange,
contract.currency,
contract.localSymbol,
contract.tradingClass,
exerciseAction, exerciseQuantity, account, override)
def reqScannerSubscription(
self, reqId, subscription, scannerSubscriptionOptions,
scannerSubscriptionFilterOptions):
sub = subscription
self.send(
22, reqId,
sub.numberOfRows,
sub.instrument,
sub.locationCode,
sub.scanCode,
sub.abovePrice,
sub.belowPrice,
sub.aboveVolume,
sub.marketCapAbove,
sub.marketCapBelow,
sub.moodyRatingAbove,
sub.moodyRatingBelow,
sub.spRatingAbove,
sub.spRatingBelow,
sub.maturityDateAbove,
sub.maturityDateBelow,
sub.couponRateAbove,
sub.couponRateBelow,
sub.excludeConvertible,
sub.averageOptionVolumeAbove,
sub.scannerSettingPairs,
sub.stockTypeFilter,
scannerSubscriptionFilterOptions,
scannerSubscriptionOptions)
def cancelScannerSubscription(self, reqId):
self.send(23, 1, reqId)
def reqScannerParameters(self):
self.send(24, 1)
def cancelHistoricalData(self, reqId):
self.send(25, 1, reqId)
def reqCurrentTime(self):
self.send(49, 1)
def reqRealTimeBars(
self, reqId, contract, barSize, whatToShow,
useRTH, realTimeBarsOptions):
self.send(
50, 3, reqId, contract, barSize, whatToShow,
useRTH, realTimeBarsOptions)
def cancelRealTimeBars(self, reqId):
self.send(51, 1, reqId)
def reqFundamentalData(
self, reqId, contract, reportType, fundamentalDataOptions):
options = fundamentalDataOptions or []
self.send(
52, 2, reqId,
contract.conId,
contract.symbol,
contract.secType,
contract.exchange,
contract.primaryExchange,
contract.currency,
contract.localSymbol,
reportType, len(options), options)
def cancelFundamentalData(self, reqId):
self.send(53, 1, reqId)
def calculateImpliedVolatility(
self, reqId, contract, optionPrice, underPrice, implVolOptions):
self.send(
54, 3, reqId, contract, optionPrice, underPrice,
len(implVolOptions), implVolOptions)
def calculateOptionPrice(
self, reqId, contract, volatility, underPrice, optPrcOptions):
self.send(
55, 3, reqId, contract, volatility, underPrice,
len(optPrcOptions), optPrcOptions)
def cancelCalculateImpliedVolatility(self, reqId):
self.send(56, 1, reqId)
def cancelCalculateOptionPrice(self, reqId):
self.send(57, 1, reqId)
def reqGlobalCancel(self):
self.send(58, 1)
def reqMarketDataType(self, marketDataType):
self.send(59, 1, marketDataType)
def reqPositions(self):
self.send(61, 1)
def reqAccountSummary(self, reqId, groupName, tags):
self.send(62, 1, reqId, groupName, tags)
def cancelAccountSummary(self, reqId):
self.send(63, 1, reqId)
def cancelPositions(self):
self.send(64, 1)
def verifyRequest(self, apiName, apiVersion):
self.send(65, 1, apiName, apiVersion)
def verifyMessage(self, apiData):
self.send(66, 1, apiData)
def queryDisplayGroups(self, reqId):
self.send(67, 1, reqId)
def subscribeToGroupEvents(self, reqId, groupId):
self.send(68, 1, reqId, groupId)
def updateDisplayGroup(self, reqId, contractInfo):
self.send(69, 1, reqId, contractInfo)
def unsubscribeFromGroupEvents(self, reqId):
self.send(70, 1, reqId)
def startApi(self):
self.send(71, 2, self.clientId, self.optCapab)
def verifyAndAuthRequest(self, apiName, apiVersion, opaqueIsvKey):
self.send(72, 1, apiName, apiVersion, opaqueIsvKey)
def verifyAndAuthMessage(self, apiData, xyzResponse):
self.send(73, 1, apiData, xyzResponse)
def reqPositionsMulti(self, reqId, account, modelCode):
self.send(74, 1, reqId, account, modelCode)
def cancelPositionsMulti(self, reqId):
self.send(75, 1, reqId)
def reqAccountUpdatesMulti(self, reqId, account, modelCode, ledgerAndNLV):
self.send(76, 1, reqId, account, modelCode, ledgerAndNLV)
def cancelAccountUpdatesMulti(self, reqId):
self.send(77, 1, reqId)
def reqSecDefOptParams(
self, reqId, underlyingSymbol, futFopExchange,
underlyingSecType, underlyingConId):
self.send(
78, reqId, underlyingSymbol, futFopExchange,
underlyingSecType, underlyingConId)
def reqSoftDollarTiers(self, reqId):
self.send(79, reqId)
def reqFamilyCodes(self):
self.send(80)
def reqMatchingSymbols(self, reqId, pattern):
self.send(81, reqId, pattern)
def reqMktDepthExchanges(self):
self.send(82)
def reqSmartComponents(self, reqId, bboExchange):
self.send(83, reqId, bboExchange)
def reqNewsArticle(
self, reqId, providerCode, articleId, newsArticleOptions):
self.send(84, reqId, providerCode, articleId, newsArticleOptions)
def reqNewsProviders(self):
self.send(85)
def reqHistoricalNews(
self, reqId, conId, providerCodes, startDateTime, endDateTime,
totalResults, historicalNewsOptions):
self.send(
86, reqId, conId, providerCodes, startDateTime, endDateTime,
totalResults, historicalNewsOptions)
def reqHeadTimeStamp(
self, reqId, contract, whatToShow, useRTH, formatDate):
self.send(
87, reqId, contract, contract.includeExpired,
useRTH, whatToShow, formatDate)
def reqHistogramData(self, tickerId, contract, useRTH, timePeriod):
self.send(
88, tickerId, contract, contract.includeExpired,
useRTH, timePeriod)
def cancelHistogramData(self, tickerId):
self.send(89, tickerId)
def cancelHeadTimeStamp(self, reqId):
self.send(90, reqId)
def reqMarketRule(self, marketRuleId):
self.send(91, marketRuleId)
def reqPnL(self, reqId, account, modelCode):
self.send(92, reqId, account, modelCode)
def cancelPnL(self, reqId):
self.send(93, reqId)
def reqPnLSingle(self, reqId, account, modelCode, conid):
self.send(94, reqId, account, modelCode, conid)
def cancelPnLSingle(self, reqId):
self.send(95, reqId)
def reqHistoricalTicks(
self, reqId, contract, startDateTime, endDateTime,
numberOfTicks, whatToShow, useRth, ignoreSize, miscOptions):
self.send(
96, reqId, contract, contract.includeExpired,
startDateTime, endDateTime, numberOfTicks, whatToShow,
useRth, ignoreSize, miscOptions)
def reqTickByTickData(
self, reqId, contract, tickType, numberOfTicks, ignoreSize):
self.send(97, reqId, contract, tickType, numberOfTicks, ignoreSize)
def cancelTickByTickData(self, reqId):
self.send(98, reqId)
def reqCompletedOrders(self, apiOnly):
self.send(99, apiOnly)
def reqWshMetaData(self, reqId):
self.send(100, reqId)
def cancelWshMetaData(self, reqId):
self.send(101, reqId)
def reqWshEventData(self, reqId, data: WshEventData):
fields = [102, reqId, data.conId]
if self.serverVersion() >= 171:
fields += [
data.filter,
data.fillWatchlist,
data.fillPortfolio,
data.fillCompetitors]
if self.serverVersion() >= 173:
fields += [
data.startDate,
data.endDate,
data.totalLimit]
self.send(*fields, makeEmpty=False)
def cancelWshEventData(self, reqId):
self.send(103, reqId)
def reqUserInfo(self, reqId):
self.send(104, reqId)

View File

@@ -0,0 +1,62 @@
"""Event-driven socket connection."""
import asyncio
from eventkit import Event
from ib_insync.util import getLoop
class Connection(asyncio.Protocol):
"""
Event-driven socket connection.
Events:
* ``hasData`` (data: bytes):
Emits the received socket data.
* ``disconnected`` (msg: str):
Is emitted on socket disconnect, with an error message in case
of error, or an empty string in case of a normal disconnect.
"""
def __init__(self):
self.hasData = Event('hasData')
self.disconnected = Event('disconnected')
self.reset()
def reset(self):
self.transport = None
self.numBytesSent = 0
self.numMsgSent = 0
async def connectAsync(self, host, port):
if self.transport:
# wait until a previous connection is finished closing
self.disconnect()
await self.disconnected
self.reset()
loop = getLoop()
self.transport, _ = await loop.create_connection(
lambda: self, host, port)
def disconnect(self):
if self.transport:
self.transport.write_eof()
self.transport.close()
def isConnected(self):
return self.transport is not None
def sendMsg(self, msg):
if self.transport:
self.transport.write(msg)
self.numBytesSent += len(msg)
self.numMsgSent += 1
def connection_lost(self, exc):
self.transport = None
msg = str(exc) if exc else ''
self.disconnected.emit(msg)
def data_received(self, data):
self.hasData.emit(data)

View File

@@ -0,0 +1,551 @@
"""Financial instrument types used by Interactive Brokers."""
import datetime as dt
from dataclasses import dataclass, field
from typing import List, NamedTuple, Optional
import ib_insync.util as util
@dataclass
class Contract:
"""
``Contract(**kwargs)`` can create any contract using keyword
arguments. To simplify working with contracts, there are also more
specialized contracts that take optional positional arguments.
Some examples::
Contract(conId=270639)
Stock('AMD', 'SMART', 'USD')
Stock('INTC', 'SMART', 'USD', primaryExchange='NASDAQ')
Forex('EURUSD')
CFD('IBUS30')
Future('ES', '20180921', 'GLOBEX')
Option('SPY', '20170721', 240, 'C', 'SMART')
Bond(secIdType='ISIN', secId='US03076KAA60')
Crypto('BTC', 'PAXOS', 'USD')
Args:
conId (int): The unique IB contract identifier.
symbol (str): The contract (or its underlying) symbol.
secType (str): The security type:
* 'STK' = Stock (or ETF)
* 'OPT' = Option
* 'FUT' = Future
* 'IND' = Index
* 'FOP' = Futures option
* 'CASH' = Forex pair
* 'CFD' = CFD
* 'BAG' = Combo
* 'WAR' = Warrant
* 'BOND' = Bond
* 'CMDTY' = Commodity
* 'NEWS' = News
* 'FUND' = Mutual fund
* 'CRYPTO' = Crypto currency
* 'EVENT' = Bet on an event
lastTradeDateOrContractMonth (str): The contract's last trading
day or contract month (for Options and Futures).
Strings with format YYYYMM will be interpreted as the
Contract Month whereas YYYYMMDD will be interpreted as
Last Trading Day.
strike (float): The option's strike price.
right (str): Put or Call.
Valid values are 'P', 'PUT', 'C', 'CALL', or '' for non-options.
multiplier (str): The instrument's multiplier (i.e. options, futures).
exchange (str): The destination exchange.
currency (str): The underlying's currency.
localSymbol (str): The contract's symbol within its primary exchange.
For options, this will be the OCC symbol.
primaryExchange (str): The contract's primary exchange.
For smart routed contracts, used to define contract in case
of ambiguity. Should be defined as native exchange of contract,
e.g. ISLAND for MSFT. For exchanges which contain a period in name,
will only be part of exchange name prior to period, i.e. ENEXT
for ENEXT.BE.
tradingClass (str): The trading class name for this contract.
Available in TWS contract description window as well.
For example, GBL Dec '13 future's trading class is "FGBL".
includeExpired (bool): If set to true, contract details requests
and historical data queries can be performed pertaining to
expired futures contracts. Expired options or other instrument
types are not available.
secIdType (str): Security identifier type. Examples for Apple:
* secIdType='ISIN', secId='US0378331005'
* secIdType='CUSIP', secId='037833100'
secId (str): Security identifier.
comboLegsDescription (str): Description of the combo legs.
comboLegs (List[ComboLeg]): The legs of a combined contract definition.
deltaNeutralContract (DeltaNeutralContract): Delta and underlying
price for Delta-Neutral combo orders.
"""
secType: str = ''
conId: int = 0
symbol: str = ''
lastTradeDateOrContractMonth: str = ''
strike: float = 0.0
right: str = ''
multiplier: str = ''
exchange: str = ''
primaryExchange: str = ''
currency: str = ''
localSymbol: str = ''
tradingClass: str = ''
includeExpired: bool = False
secIdType: str = ''
secId: str = ''
description: str = ''
issuerId: str = ''
comboLegsDescrip: str = ''
comboLegs: List['ComboLeg'] = field(default_factory=list)
deltaNeutralContract: Optional['DeltaNeutralContract'] = None
@staticmethod
def create(**kwargs) -> 'Contract':
"""
Create and a return a specialized contract based on the given secType,
or a general Contract if secType is not given.
"""
secType = kwargs.get('secType', '')
cls = {
'': Contract,
'STK': Stock,
'OPT': Option,
'FUT': Future,
'CONTFUT': ContFuture,
'CASH': Forex,
'IND': Index,
'CFD': CFD,
'BOND': Bond,
'CMDTY': Commodity,
'FOP': FuturesOption,
'FUND': MutualFund,
'WAR': Warrant,
'IOPT': Warrant,
'BAG': Bag,
'CRYPTO': Crypto,
'NEWS': Contract,
'EVENT': Contract,
}.get(secType, Contract)
if cls is not Contract:
kwargs.pop('secType', '')
return cls(**kwargs)
def isHashable(self) -> bool:
"""
See if this contract can be hashed by conId.
Note: Bag contracts always get conId=28812380, so they're not hashable.
"""
return bool(
self.conId and self.conId != 28812380
and self.secType != 'BAG')
def __eq__(self, other):
return (
isinstance(other, Contract)
and (
self.conId and self.conId == other.conId
or util.dataclassAsDict(self) == util.dataclassAsDict(other)))
def __hash__(self):
if not self.isHashable():
raise ValueError(f'Contract {self} can\'t be hashed')
if self.secType == 'CONTFUT':
# CONTFUT gets the same conId as the front contract, invert it here
h = -self.conId
else:
h = self.conId
return h
def __repr__(self):
attrs = util.dataclassNonDefaults(self)
if self.__class__ is not Contract:
attrs.pop('secType', '')
clsName = self.__class__.__qualname__
kwargs = ', '.join(f'{k}={v!r}' for k, v in attrs.items())
return f'{clsName}({kwargs})'
__str__ = __repr__
class Stock(Contract):
def __init__(
self, symbol: str = '', exchange: str = '', currency: str = '',
**kwargs):
"""
Stock contract.
Args:
symbol: Symbol name.
exchange: Destination exchange.
currency: Underlying currency.
"""
Contract.__init__(
self, secType='STK', symbol=symbol,
exchange=exchange, currency=currency, **kwargs)
class Option(Contract):
def __init__(
self, symbol: str = '', lastTradeDateOrContractMonth: str = '',
strike: float = 0.0, right: str = '', exchange: str = '',
multiplier: str = '', currency: str = '', **kwargs):
"""
Option contract.
Args:
symbol: Symbol name.
lastTradeDateOrContractMonth: The option's last trading day
or contract month.
* YYYYMM format: To specify last month
* YYYYMMDD format: To specify last trading day
strike: The option's strike price.
right: Put or call option.
Valid values are 'P', 'PUT', 'C' or 'CALL'.
exchange: Destination exchange.
multiplier: The contract multiplier.
currency: Underlying currency.
"""
Contract.__init__(
self, 'OPT', symbol=symbol,
lastTradeDateOrContractMonth=lastTradeDateOrContractMonth,
strike=strike, right=right, exchange=exchange,
multiplier=multiplier, currency=currency, **kwargs)
class Future(Contract):
def __init__(
self, symbol: str = '', lastTradeDateOrContractMonth: str = '',
exchange: str = '', localSymbol: str = '', multiplier: str = '',
currency: str = '', **kwargs):
"""
Future contract.
Args:
symbol: Symbol name.
lastTradeDateOrContractMonth: The option's last trading day
or contract month.
* YYYYMM format: To specify last month
* YYYYMMDD format: To specify last trading day
exchange: Destination exchange.
localSymbol: The contract's symbol within its primary exchange.
multiplier: The contract multiplier.
currency: Underlying currency.
"""
Contract.__init__(
self, 'FUT', symbol=symbol,
lastTradeDateOrContractMonth=lastTradeDateOrContractMonth,
exchange=exchange, localSymbol=localSymbol,
multiplier=multiplier, currency=currency, **kwargs)
class ContFuture(Contract):
def __init__(
self, symbol: str = '', exchange: str = '', localSymbol: str = '',
multiplier: str = '', currency: str = '', **kwargs):
"""
Continuous future contract.
Args:
symbol: Symbol name.
exchange: Destination exchange.
localSymbol: The contract's symbol within its primary exchange.
multiplier: The contract multiplier.
currency: Underlying currency.
"""
Contract.__init__(
self, 'CONTFUT', symbol=symbol,
exchange=exchange, localSymbol=localSymbol,
multiplier=multiplier, currency=currency, **kwargs)
class Forex(Contract):
def __init__(
self, pair: str = '', exchange: str = 'IDEALPRO',
symbol: str = '', currency: str = '', **kwargs):
"""
Foreign exchange currency pair.
Args:
pair: Shortcut for specifying symbol and currency, like 'EURUSD'.
exchange: Destination exchange.
symbol: Base currency.
currency: Quote currency.
"""
if pair:
assert len(pair) == 6
symbol = symbol or pair[:3]
currency = currency or pair[3:]
Contract.__init__(
self, 'CASH', symbol=symbol,
exchange=exchange, currency=currency, **kwargs)
def __repr__(self):
attrs = util.dataclassNonDefaults(self)
attrs.pop('secType')
s = 'Forex('
if 'symbol' in attrs and 'currency' in attrs:
pair = attrs.pop('symbol')
pair += attrs.pop('currency')
s += "'" + pair + "'" + (", " if attrs else "")
s += ', '.join(f'{k}={v!r}' for k, v in attrs.items())
s += ')'
return s
__str__ = __repr__
def pair(self) -> str:
"""Short name of pair."""
return self.symbol + self.currency
class Index(Contract):
def __init__(
self, symbol: str = '', exchange: str = '', currency: str = '',
**kwargs):
"""
Index.
Args:
symbol: Symbol name.
exchange: Destination exchange.
currency: Underlying currency.
"""
Contract.__init__(
self, 'IND', symbol=symbol,
exchange=exchange, currency=currency, **kwargs)
class CFD(Contract):
def __init__(
self, symbol: str = '', exchange: str = '', currency: str = '',
**kwargs):
"""
Contract For Difference.
Args:
symbol: Symbol name.
exchange: Destination exchange.
currency: Underlying currency.
"""
Contract.__init__(
self, 'CFD', symbol=symbol,
exchange=exchange, currency=currency, **kwargs)
class Commodity(Contract):
def __init__(
self, symbol: str = '', exchange: str = '', currency: str = '',
**kwargs):
"""
Commodity.
Args:
symbol: Symbol name.
exchange: Destination exchange.
currency: Underlying currency.
"""
Contract.__init__(
self, 'CMDTY', symbol=symbol,
exchange=exchange, currency=currency, **kwargs)
class Bond(Contract):
def __init__(self, **kwargs):
"""Bond."""
Contract.__init__(self, 'BOND', **kwargs)
class FuturesOption(Contract):
def __init__(
self, symbol: str = '', lastTradeDateOrContractMonth: str = '',
strike: float = 0.0, right: str = '', exchange: str = '',
multiplier: str = '', currency: str = '', **kwargs):
"""
Option on a futures contract.
Args:
symbol: Symbol name.
lastTradeDateOrContractMonth: The option's last trading day
or contract month.
* YYYYMM format: To specify last month
* YYYYMMDD format: To specify last trading day
strike: The option's strike price.
right: Put or call option.
Valid values are 'P', 'PUT', 'C' or 'CALL'.
exchange: Destination exchange.
multiplier: The contract multiplier.
currency: Underlying currency.
"""
Contract.__init__(
self, 'FOP', symbol=symbol,
lastTradeDateOrContractMonth=lastTradeDateOrContractMonth,
strike=strike, right=right, exchange=exchange,
multiplier=multiplier, currency=currency, **kwargs)
class MutualFund(Contract):
def __init__(self, **kwargs):
"""Mutual fund."""
Contract.__init__(self, 'FUND', **kwargs)
class Warrant(Contract):
def __init__(self, **kwargs):
"""Warrant option."""
Contract.__init__(self, 'WAR', **kwargs)
class Bag(Contract):
def __init__(self, **kwargs):
"""Bag contract."""
Contract.__init__(self, 'BAG', **kwargs)
class Crypto(Contract):
def __init__(
self, symbol: str = '', exchange: str = '', currency: str = '',
**kwargs):
"""
Crypto currency contract.
Args:
symbol: Symbol name.
exchange: Destination exchange.
currency: Underlying currency.
"""
Contract.__init__(
self, secType='CRYPTO', symbol=symbol,
exchange=exchange, currency=currency, **kwargs)
class TagValue(NamedTuple):
tag: str
value: str
@dataclass
class ComboLeg:
conId: int = 0
ratio: int = 0
action: str = ''
exchange: str = ''
openClose: int = 0
shortSaleSlot: int = 0
designatedLocation: str = ''
exemptCode: int = -1
@dataclass
class DeltaNeutralContract:
conId: int = 0
delta: float = 0.0
price: float = 0.0
class TradingSession(NamedTuple):
start: dt.datetime
end: dt.datetime
@dataclass
class ContractDetails:
contract: Optional[Contract] = None
marketName: str = ''
minTick: float = 0.0
orderTypes: str = ''
validExchanges: str = ''
priceMagnifier: int = 0
underConId: int = 0
longName: str = ''
contractMonth: str = ''
industry: str = ''
category: str = ''
subcategory: str = ''
timeZoneId: str = ''
tradingHours: str = ''
liquidHours: str = ''
evRule: str = ''
evMultiplier: int = 0
mdSizeMultiplier: int = 1 # obsolete
aggGroup: int = 0
underSymbol: str = ''
underSecType: str = ''
marketRuleIds: str = ''
secIdList: List[TagValue] = field(default_factory=list)
realExpirationDate: str = ''
lastTradeTime: str = ''
stockType: str = ''
minSize: float = 0.0
sizeIncrement: float = 0.0
suggestedSizeIncrement: float = 0.0
# minCashQtySize: float = 0.0
cusip: str = ''
ratings: str = ''
descAppend: str = ''
bondType: str = ''
couponType: str = ''
callable: bool = False
putable: bool = False
coupon: float = 0
convertible: bool = False
maturity: str = ''
issueDate: str = ''
nextOptionDate: str = ''
nextOptionType: str = ''
nextOptionPartial: bool = False
notes: str = ''
def tradingSessions(self) -> List[TradingSession]:
return self._parseSessions(self.tradingHours)
def liquidSessions(self) -> List[TradingSession]:
return self._parseSessions(self.liquidHours)
def _parseSessions(self, s: str) -> List[TradingSession]:
tz = util.ZoneInfo(self.timeZoneId)
sessions = []
for sess in s.split(';'):
if not sess or 'CLOSED' in sess:
continue
sessions.append(TradingSession(*[
dt.datetime.strptime(t, '%Y%m%d:%H%M').replace(tzinfo=tz)
for t in sess.split('-')]))
return sessions
@dataclass
class ContractDescription:
contract: Optional[Contract] = None
derivativeSecTypes: List[str] = field(default_factory=list)
@dataclass
class ScanData:
rank: int
contractDetails: ContractDetails
distance: str
benchmark: str
projection: str
legsStr: str

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,127 @@
"""Access to account statement webservice."""
import logging
import time
import xml.etree.ElementTree as et
from contextlib import suppress
from urllib.request import urlopen
from ib_insync import util
from ib_insync.objects import DynamicObject
_logger = logging.getLogger('ib_insync.flexreport')
class FlexError(Exception):
pass
class FlexReport:
"""
To obtain a token:
* Login to web portal
* Go to Settings
* Click on "Configure Flex Web Service"
* Generate token
"""
data: bytes
root: et.Element
def __init__(self, token=None, queryId=None, path=None):
"""
Download a report by giving a valid ``token`` and ``queryId``,
or load from file by giving a valid ``path``.
"""
if token and queryId:
self.download(token, queryId)
elif path:
self.load(path)
def topics(self):
"""Get the set of topics that can be extracted from this report."""
return set(node.tag for node in self.root.iter() if node.attrib)
def extract(self, topic: str, parseNumbers=True) -> list:
"""
Extract items of given topic and return as list of objects.
The topic is a string like TradeConfirm, ChangeInDividendAccrual,
Order, etc.
"""
cls = type(topic, (DynamicObject,), {})
results = [cls(**node.attrib) for node in self.root.iter(topic)]
if parseNumbers:
for obj in results:
d = obj.__dict__
for k, v in d.items():
with suppress(ValueError):
d[k] = float(v)
d[k] = int(v)
return results
def df(self, topic: str, parseNumbers=True):
"""Same as extract but return the result as a pandas DataFrame."""
return util.df(self.extract(topic, parseNumbers))
def download(self, token, queryId):
"""Download report for the given ``token`` and ``queryId``."""
url = (
'https://gdcdyn.interactivebrokers.com'
f'/Universal/servlet/FlexStatementService.SendRequest?'
f't={token}&q={queryId}&v=3')
resp = urlopen(url)
data = resp.read()
root = et.fromstring(data)
elem = root.find('Status')
if elem is not None and elem.text == 'Success':
elem = root.find('ReferenceCode')
assert elem is not None
code = elem.text
elem = root.find('Url')
assert elem is not None
baseUrl = elem.text
_logger.info('Statement is being prepared...')
else:
elem = root.find('ErrorCode')
errorCode = elem.text if elem is not None else ''
elem = root.find('ErrorMessage')
errorMsg = elem.text if elem is not None else ''
raise FlexError(f'{errorCode}: {errorMsg}')
while True:
time.sleep(1)
url = f'{baseUrl}?q={code}&t={token}'
resp = urlopen(url)
self.data = resp.read()
self.root = et.fromstring(self.data)
if self.root[0].tag == 'code':
msg = self.root[0].text
if msg and msg.startswith('Statement generation in progress'):
_logger.info('still working...')
continue
else:
raise FlexError(msg)
break
_logger.info('Statement retrieved.')
def load(self, path):
"""Load report from XML file."""
with open(path, 'rb') as f:
self.data = f.read()
self.root = et.fromstring(self.data)
def save(self, path):
"""Save report to XML file."""
with open(path, 'wb') as f:
f.write(self.data)
if __name__ == '__main__':
util.logToConsole()
report = FlexReport('945692423458902392892687', '272555')
print(report.topics())
trades = report.extract('Trade')
print(trades)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,358 @@
"""Programmatic control over the TWS/gateway client software."""
import asyncio
import logging
import sys
from contextlib import suppress
from dataclasses import dataclass
from typing import ClassVar
from eventkit import Event
import ib_insync.util as util
from ib_insync.contract import Contract, Forex
from ib_insync.ib import IB
@dataclass
class IBC:
r"""
Programmatic control over starting and stopping TWS/Gateway
using IBC (https://github.com/IbcAlpha/IBC).
Args:
twsVersion (int): (required) The major version number for
TWS or gateway.
gateway (bool):
* True = gateway
* False = TWS
tradingMode (str): 'live' or 'paper'.
userid (str): IB account username. It is recommended to set the real
username/password in a secured IBC config file.
password (str): IB account password.
twsPath (str): Path to the TWS installation folder.
Defaults:
* Linux: ~/Jts
* OS X: ~/Applications
* Windows: C:\\Jts
twsSettingsPath (str): Path to the TWS settings folder.
Defaults:
* Linux: ~/Jts
* OS X: ~/Jts
* Windows: Not available
ibcPath (str): Path to the IBC installation folder.
Defaults:
* Linux: /opt/ibc
* OS X: /opt/ibc
* Windows: C:\\IBC
ibcIni (str): Path to the IBC configuration file.
Defaults:
* Linux: ~/ibc/config.ini
* OS X: ~/ibc/config.ini
* Windows: %%HOMEPATH%%\\Documents\IBC\\config.ini
javaPath (str): Path to Java executable.
Default is to use the Java VM included with TWS/gateway.
fixuserid (str): FIX account user id (gateway only).
fixpassword (str): FIX account password (gateway only).
on2fatimeout (str): What to do if 2-factor authentication times
out; Can be 'restart' or 'exit'.
This is not intended to be run in a notebook.
To use IBC on Windows, the proactor (or quamash) event loop
must have been set:
.. code-block:: python
import asyncio
asyncio.set_event_loop(asyncio.ProactorEventLoop())
Example usage:
.. code-block:: python
ibc = IBC(976, gateway=True, tradingMode='live',
userid='edemo', password='demouser')
ibc.start()
IB.run()
"""
IbcLogLevel: ClassVar = logging.DEBUG
twsVersion: int = 0
gateway: bool = False
tradingMode: str = ''
twsPath: str = ''
twsSettingsPath: str = ''
ibcPath: str = ''
ibcIni: str = ''
javaPath: str = ''
userid: str = ''
password: str = ''
fixuserid: str = ''
fixpassword: str = ''
on2fatimeout: str = ''
def __post_init__(self):
self._isWindows = sys.platform == 'win32'
if not self.ibcPath:
self.ibcPath = '/opt/ibc' if not self._isWindows else 'C:\\IBC'
self._proc = None
self._monitor = None
self._logger = logging.getLogger('ib_insync.IBC')
def __enter__(self):
self.start()
return self
def __exit__(self, *_exc):
self.terminate()
def start(self):
"""Launch TWS/IBG."""
util.run(self.startAsync())
def terminate(self):
"""Terminate TWS/IBG."""
util.run(self.terminateAsync())
async def startAsync(self):
if self._proc:
return
self._logger.info('Starting')
# map from field names to cmd arguments; key=(UnixArg, WindowsArg)
args = dict(
twsVersion=('', ''),
gateway=('--gateway', '/Gateway'),
tradingMode=('--mode=', '/Mode:'),
twsPath=('--tws-path=', '/TwsPath:'),
twsSettingsPath=('--tws-settings-path=', ''),
ibcPath=('--ibc-path=', '/IbcPath:'),
ibcIni=('--ibc-ini=', '/Config:'),
javaPath=('--java-path=', '/JavaPath:'),
userid=('--user=', '/User:'),
password=('--pw=', '/PW:'),
fixuserid=('--fix-user=', '/FIXUser:'),
fixpassword=('--fix-pw=', '/FIXPW:'),
on2fatimeout=('--on2fatimeout=', '/On2FATimeout:'),
)
# create shell command
cmd = [
f'{self.ibcPath}\\scripts\\StartIBC.bat' if self._isWindows else
f'{self.ibcPath}/scripts/ibcstart.sh']
for k, v in util.dataclassAsDict(self).items():
arg = args[k][self._isWindows]
if v:
if arg.endswith('=') or arg.endswith(':'):
cmd.append(f'{arg}{v}')
elif arg:
cmd.append(arg)
else:
cmd.append(str(v))
# run shell command
self._proc = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE)
self._monitor = asyncio.ensure_future(self.monitorAsync())
async def terminateAsync(self):
if not self._proc:
return
self._logger.info('Terminating')
if self._monitor:
self._monitor.cancel()
self._monitor = None
if self._isWindows:
import subprocess
subprocess.call(
['taskkill', '/F', '/T', '/PID', str(self._proc.pid)])
else:
with suppress(ProcessLookupError):
self._proc.terminate()
await self._proc.wait()
self._proc = None
async def monitorAsync(self):
while self._proc:
line = await self._proc.stdout.readline()
if not line:
break
self._logger.log(IBC.IbcLogLevel, line.strip().decode())
@dataclass
class Watchdog:
"""
Start, connect and watch over the TWS or gateway app and try to keep it
up and running. It is intended to be used in an event-driven
application that properly initializes itself upon (re-)connect.
It is not intended to be used in a notebook or in imperative-style code.
Do not expect Watchdog to magically shield you from reality. Do not use
Watchdog unless you understand what it does and doesn't do.
Args:
controller (IBC): (required) IBC instance.
ib (IB): (required) IB instance to be used. Do not connect this
instance as Watchdog takes care of that.
host (str): Used for connecting IB instance.
port (int): Used for connecting IB instance.
clientId (int): Used for connecting IB instance.
connectTimeout (float): Used for connecting IB instance.
readonly (bool): Used for connecting IB instance.
appStartupTime (float): Time (in seconds) that the app is given
to start up. Make sure that it is given ample time.
appTimeout (float): Timeout (in seconds) for network traffic idle time.
retryDelay (float): Time (in seconds) to restart app after a
previous failure.
probeContract (Contract): Contract to use for historical data
probe requests (default is EURUSD).
probeTimeout (float); Timeout (in seconds) for the probe request.
The idea is to wait until there is no traffic coming from the app for
a certain amount of time (the ``appTimeout`` parameter). This triggers
a historical request to be placed just to see if the app is still alive
and well. If yes, then continue, if no then restart the whole app
and reconnect. Restarting will also occur directly on errors 1100 and 100.
Example usage:
.. code-block:: python
def onConnected():
print(ib.accountValues())
ibc = IBC(974, gateway=True, tradingMode='paper')
ib = IB()
ib.connectedEvent += onConnected
watchdog = Watchdog(ibc, ib, port=4002)
watchdog.start()
ib.run()
Events:
* ``startingEvent`` (watchdog: :class:`.Watchdog`)
* ``startedEvent`` (watchdog: :class:`.Watchdog`)
* ``stoppingEvent`` (watchdog: :class:`.Watchdog`)
* ``stoppedEvent`` (watchdog: :class:`.Watchdog`)
* ``softTimeoutEvent`` (watchdog: :class:`.Watchdog`)
* ``hardTimeoutEvent`` (watchdog: :class:`.Watchdog`)
"""
events = [
'startingEvent', 'startedEvent', 'stoppingEvent', 'stoppedEvent',
'softTimeoutEvent', 'hardTimeoutEvent']
controller: IBC
ib: IB
host: str = '127.0.0.1'
port: int = 7497
clientId: int = 1
connectTimeout: float = 2
appStartupTime: float = 30
appTimeout: float = 20
retryDelay: float = 2
readonly: bool = False
account: str = ''
probeContract: Contract = Forex('EURUSD')
probeTimeout: float = 4
def __post_init__(self):
self.startingEvent = Event('startingEvent')
self.startedEvent = Event('startedEvent')
self.stoppingEvent = Event('stoppingEvent')
self.stoppedEvent = Event('stoppedEvent')
self.softTimeoutEvent = Event('softTimeoutEvent')
self.hardTimeoutEvent = Event('hardTimeoutEvent')
if not self.controller:
raise ValueError('No controller supplied')
if not self.ib:
raise ValueError('No IB instance supplied')
if self.ib.isConnected():
raise ValueError('IB instance must not be connected')
self._runner = None
self._logger = logging.getLogger('ib_insync.Watchdog')
def start(self):
self._logger.info('Starting')
self.startingEvent.emit(self)
self._runner = asyncio.ensure_future(self.runAsync())
return self._runner
def stop(self):
self._logger.info('Stopping')
self.stoppingEvent.emit(self)
self.ib.disconnect()
self._runner = None
async def runAsync(self):
def onTimeout(idlePeriod):
if not waiter.done():
waiter.set_result(None)
def onError(reqId, errorCode, errorString, contract):
if errorCode in {100, 1100} and not waiter.done():
waiter.set_exception(Warning(f'Error {errorCode}'))
def onDisconnected():
if not waiter.done():
waiter.set_exception(Warning('Disconnected'))
while self._runner:
try:
await self.controller.startAsync()
await asyncio.sleep(self.appStartupTime)
await self.ib.connectAsync(
self.host, self.port, self.clientId, self.connectTimeout,
self.readonly, self.account)
self.startedEvent.emit(self)
self.ib.setTimeout(self.appTimeout)
self.ib.timeoutEvent += onTimeout
self.ib.errorEvent += onError
self.ib.disconnectedEvent += onDisconnected
while self._runner:
waiter: asyncio.Future = asyncio.Future()
await waiter
# soft timeout, probe the app with a historical request
self._logger.debug('Soft timeout')
self.softTimeoutEvent.emit(self)
probe = self.ib.reqHistoricalDataAsync(
self.probeContract, '', '30 S', '5 secs',
'MIDPOINT', False)
bars = None
with suppress(asyncio.TimeoutError):
bars = await asyncio.wait_for(probe, self.probeTimeout)
if not bars:
self.hardTimeoutEvent.emit(self)
raise Warning('Hard timeout')
self.ib.setTimeout(self.appTimeout)
except ConnectionRefusedError:
pass
except Warning as w:
self._logger.warning(w)
except Exception as e:
self._logger.exception(e)
finally:
self.ib.timeoutEvent -= onTimeout
self.ib.errorEvent -= onError
self.ib.disconnectedEvent -= onDisconnected
await self.controller.terminateAsync()
self.stoppedEvent.emit(self)
if self._runner:
await asyncio.sleep(self.retryDelay)
if __name__ == '__main__':
ibc = IBC(1012, gateway=True, tradingMode='paper')
ib = IB()
app = Watchdog(ibc, ib, appStartupTime=15)
app.start()
IB.run()

View File

@@ -0,0 +1,499 @@
"""Object hierarchy."""
from dataclasses import dataclass, field
from datetime import date as date_, datetime
from typing import List, NamedTuple, Optional, Union
from eventkit import Event
from .contract import Contract, ScanData, TagValue
from .util import EPOCH, UNSET_DOUBLE, UNSET_INTEGER
nan = float('nan')
@dataclass
class ScannerSubscription:
numberOfRows: int = -1
instrument: str = ''
locationCode: str = ''
scanCode: str = ''
abovePrice: float = UNSET_DOUBLE
belowPrice: float = UNSET_DOUBLE
aboveVolume: int = UNSET_INTEGER
marketCapAbove: float = UNSET_DOUBLE
marketCapBelow: float = UNSET_DOUBLE
moodyRatingAbove: str = ''
moodyRatingBelow: str = ''
spRatingAbove: str = ''
spRatingBelow: str = ''
maturityDateAbove: str = ''
maturityDateBelow: str = ''
couponRateAbove: float = UNSET_DOUBLE
couponRateBelow: float = UNSET_DOUBLE
excludeConvertible: bool = False
averageOptionVolumeAbove: int = UNSET_INTEGER
scannerSettingPairs: str = ''
stockTypeFilter: str = ''
@dataclass
class SoftDollarTier:
name: str = ''
val: str = ''
displayName: str = ''
def __bool__(self):
return bool(self.name or self.val or self.displayName)
@dataclass
class Execution:
execId: str = ''
time: datetime = field(default=EPOCH)
acctNumber: str = ''
exchange: str = ''
side: str = ''
shares: float = 0.0
price: float = 0.0
permId: int = 0
clientId: int = 0
orderId: int = 0
liquidation: int = 0
cumQty: float = 0.0
avgPrice: float = 0.0
orderRef: str = ''
evRule: str = ''
evMultiplier: float = 0.0
modelCode: str = ''
lastLiquidity: int = 0
@dataclass
class CommissionReport:
execId: str = ''
commission: float = 0.0
currency: str = ''
realizedPNL: float = 0.0
yield_: float = 0.0
yieldRedemptionDate: int = 0
@dataclass
class ExecutionFilter:
clientId: int = 0
acctCode: str = ''
time: str = ''
symbol: str = ''
secType: str = ''
exchange: str = ''
side: str = ''
@dataclass
class BarData:
date: Union[date_, datetime] = EPOCH
open: float = 0.0
high: float = 0.0
low: float = 0.0
close: float = 0.0
volume: float = 0
average: float = 0.0
barCount: int = 0
@dataclass
class RealTimeBar:
time: datetime = EPOCH
endTime: int = -1
open_: float = 0.0
high: float = 0.0
low: float = 0.0
close: float = 0.0
volume: float = 0.0
wap: float = 0.0
count: int = 0
@dataclass
class TickAttrib:
canAutoExecute: bool = False
pastLimit: bool = False
preOpen: bool = False
@dataclass
class TickAttribBidAsk:
bidPastLow: bool = False
askPastHigh: bool = False
@dataclass
class TickAttribLast:
pastLimit: bool = False
unreported: bool = False
@dataclass
class HistogramData:
price: float = 0.0
count: int = 0
@dataclass
class NewsProvider:
code: str = ''
name: str = ''
@dataclass
class DepthMktDataDescription:
exchange: str = ''
secType: str = ''
listingExch: str = ''
serviceDataType: str = ''
aggGroup: int = UNSET_INTEGER
@dataclass
class PnL:
account: str = ''
modelCode: str = ''
dailyPnL: float = nan
unrealizedPnL: float = nan
realizedPnL: float = nan
@dataclass
class TradeLogEntry:
time: datetime
status: str = ''
message: str = ''
errorCode: int = 0
@dataclass
class PnLSingle:
account: str = ''
modelCode: str = ''
conId: int = 0
dailyPnL: float = nan
unrealizedPnL: float = nan
realizedPnL: float = nan
position: int = 0
value: float = nan
@dataclass
class HistoricalSession:
startDateTime: str = ''
endDateTime: str = ''
refDate: str = ''
@dataclass
class HistoricalSchedule:
startDateTime: str = ''
endDateTime: str = ''
timeZone: str = ''
sessions: List[HistoricalSession] = field(default_factory=list)
@dataclass
class WshEventData:
conId: int = UNSET_INTEGER
filter: str = ''
fillWatchlist: bool = False
fillPortfolio: bool = False
fillCompetitors: bool = False
startDate: str = ''
endDate: str = ''
totalLimit: int = UNSET_INTEGER
class AccountValue(NamedTuple):
account: str
tag: str
value: str
currency: str
modelCode: str
class TickData(NamedTuple):
time: datetime
tickType: int
price: float
size: float
class HistoricalTick(NamedTuple):
time: datetime
price: float
size: float
class HistoricalTickBidAsk(NamedTuple):
time: datetime
tickAttribBidAsk: TickAttribBidAsk
priceBid: float
priceAsk: float
sizeBid: float
sizeAsk: float
class HistoricalTickLast(NamedTuple):
time: datetime
tickAttribLast: TickAttribLast
price: float
size: float
exchange: str
specialConditions: str
class TickByTickAllLast(NamedTuple):
tickType: int
time: datetime
price: float
size: float
tickAttribLast: TickAttribLast
exchange: str
specialConditions: str
class TickByTickBidAsk(NamedTuple):
time: datetime
bidPrice: float
askPrice: float
bidSize: float
askSize: float
tickAttribBidAsk: TickAttribBidAsk
class TickByTickMidPoint(NamedTuple):
time: datetime
midPoint: float
class MktDepthData(NamedTuple):
time: datetime
position: int
marketMaker: str
operation: int
side: int
price: float
size: float
class DOMLevel(NamedTuple):
price: float
size: float
marketMaker: str
class PriceIncrement(NamedTuple):
lowEdge: float
increment: float
class PortfolioItem(NamedTuple):
contract: Contract
position: float
marketPrice: float
marketValue: float
averageCost: float
unrealizedPNL: float
realizedPNL: float
account: str
class Position(NamedTuple):
account: str
contract: Contract
position: float
avgCost: float
class Fill(NamedTuple):
contract: Contract
execution: Execution
commissionReport: CommissionReport
time: datetime
class OptionComputation(NamedTuple):
tickAttrib: int
impliedVol: Optional[float]
delta: Optional[float]
optPrice: Optional[float]
pvDividend: Optional[float]
gamma: Optional[float]
vega: Optional[float]
theta: Optional[float]
undPrice: Optional[float]
class OptionChain(NamedTuple):
exchange: str
underlyingConId: int
tradingClass: str
multiplier: str
expirations: List[str]
strikes: List[float]
class Dividends(NamedTuple):
past12Months: Optional[float]
next12Months: Optional[float]
nextDate: Optional[date_]
nextAmount: Optional[float]
class NewsArticle(NamedTuple):
articleType: int
articleText: str
class HistoricalNews(NamedTuple):
time: datetime
providerCode: str
articleId: str
headline: str
class NewsTick(NamedTuple):
timeStamp: int
providerCode: str
articleId: str
headline: str
extraData: str
class NewsBulletin(NamedTuple):
msgId: int
msgType: int
message: str
origExchange: str
class FamilyCode(NamedTuple):
accountID: str
familyCodeStr: str
class SmartComponent(NamedTuple):
bitNumber: int
exchange: str
exchangeLetter: str
class ConnectionStats(NamedTuple):
startTime: float
duration: float
numBytesRecv: int
numBytesSent: int
numMsgRecv: int
numMsgSent: int
class BarDataList(List[BarData]):
"""
List of :class:`.BarData` that also stores all request parameters.
Events:
* ``updateEvent``
(bars: :class:`.BarDataList`, hasNewBar: bool)
"""
reqId: int
contract: Contract
endDateTime: Union[datetime, date_, str, None]
durationStr: str
barSizeSetting: str
whatToShow: str
useRTH: bool
formatDate: int
keepUpToDate: bool
chartOptions: List[TagValue]
def __init__(self, *args):
super().__init__(*args)
self.updateEvent = Event('updateEvent')
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
class RealTimeBarList(List[RealTimeBar]):
"""
List of :class:`.RealTimeBar` that also stores all request parameters.
Events:
* ``updateEvent``
(bars: :class:`.RealTimeBarList`, hasNewBar: bool)
"""
reqId: int
contract: Contract
barSize: int
whatToShow: str
useRTH: bool
realTimeBarsOptions: List[TagValue]
def __init__(self, *args):
super().__init__(*args)
self.updateEvent = Event('updateEvent')
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
class ScanDataList(List[ScanData]):
"""
List of :class:`.ScanData` that also stores all request parameters.
Events:
* ``updateEvent`` (:class:`.ScanDataList`)
"""
reqId: int
subscription: ScannerSubscription
scannerSubscriptionOptions: List[TagValue]
scannerSubscriptionFilterOptions: List[TagValue]
def __init__(self, *args):
super().__init__(*args)
self.updateEvent = Event('updateEvent')
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
class DynamicObject:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __repr__(self):
clsName = self.__class__.__name__
kwargs = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items())
return f'{clsName}({kwargs})'
class FundamentalRatios(DynamicObject):
"""
See:
https://web.archive.org/web/20200725010343/https://interactivebrokers.github.io/tws-api/fundamental_ratios_tags.html
"""
pass

View File

@@ -0,0 +1,412 @@
"""Order types used by Interactive Brokers."""
from dataclasses import dataclass, field
from typing import ClassVar, List, NamedTuple, Set
from eventkit import Event
from .contract import Contract, TagValue
from .objects import Fill, SoftDollarTier, TradeLogEntry
from .util import UNSET_DOUBLE, UNSET_INTEGER, dataclassNonDefaults
@dataclass
class Order:
"""
Order for trading contracts.
https://interactivebrokers.github.io/tws-api/available_orders.html
"""
orderId: int = 0
clientId: int = 0
permId: int = 0
action: str = ''
totalQuantity: float = 0.0
orderType: str = ''
lmtPrice: float = UNSET_DOUBLE
auxPrice: float = UNSET_DOUBLE
tif: str = ''
activeStartTime: str = ''
activeStopTime: str = ''
ocaGroup: str = ''
ocaType: int = 0
orderRef: str = ''
transmit: bool = True
parentId: int = 0
blockOrder: bool = False
sweepToFill: bool = False
displaySize: int = 0
triggerMethod: int = 0
outsideRth: bool = False
hidden: bool = False
goodAfterTime: str = ''
goodTillDate: str = ''
rule80A: str = ''
allOrNone: bool = False
minQty: int = UNSET_INTEGER
percentOffset: float = UNSET_DOUBLE
overridePercentageConstraints: bool = False
trailStopPrice: float = UNSET_DOUBLE
trailingPercent: float = UNSET_DOUBLE
faGroup: str = ''
faProfile: str = ''
faMethod: str = ''
faPercentage: str = ''
designatedLocation: str = ''
openClose: str = "O"
origin: int = 0
shortSaleSlot: int = 0
exemptCode: int = -1
discretionaryAmt: float = 0.0
eTradeOnly: bool = False
firmQuoteOnly: bool = False
nbboPriceCap: float = UNSET_DOUBLE
optOutSmartRouting: bool = False
auctionStrategy: int = 0
startingPrice: float = UNSET_DOUBLE
stockRefPrice: float = UNSET_DOUBLE
delta: float = UNSET_DOUBLE
stockRangeLower: float = UNSET_DOUBLE
stockRangeUpper: float = UNSET_DOUBLE
randomizePrice: bool = False
randomizeSize: bool = False
volatility: float = UNSET_DOUBLE
volatilityType: int = UNSET_INTEGER
deltaNeutralOrderType: str = ''
deltaNeutralAuxPrice: float = UNSET_DOUBLE
deltaNeutralConId: int = 0
deltaNeutralSettlingFirm: str = ''
deltaNeutralClearingAccount: str = ''
deltaNeutralClearingIntent: str = ''
deltaNeutralOpenClose: str = ''
deltaNeutralShortSale: bool = False
deltaNeutralShortSaleSlot: int = 0
deltaNeutralDesignatedLocation: str = ''
continuousUpdate: bool = False
referencePriceType: int = UNSET_INTEGER
basisPoints: float = UNSET_DOUBLE
basisPointsType: int = UNSET_INTEGER
scaleInitLevelSize: int = UNSET_INTEGER
scaleSubsLevelSize: int = UNSET_INTEGER
scalePriceIncrement: float = UNSET_DOUBLE
scalePriceAdjustValue: float = UNSET_DOUBLE
scalePriceAdjustInterval: int = UNSET_INTEGER
scaleProfitOffset: float = UNSET_DOUBLE
scaleAutoReset: bool = False
scaleInitPosition: int = UNSET_INTEGER
scaleInitFillQty: int = UNSET_INTEGER
scaleRandomPercent: bool = False
scaleTable: str = ''
hedgeType: str = ''
hedgeParam: str = ''
account: str = ''
settlingFirm: str = ''
clearingAccount: str = ''
clearingIntent: str = ''
algoStrategy: str = ''
algoParams: List[TagValue] = field(default_factory=list)
smartComboRoutingParams: List[TagValue] = field(default_factory=list)
algoId: str = ''
whatIf: bool = False
notHeld: bool = False
solicited: bool = False
modelCode: str = ''
orderComboLegs: List['OrderComboLeg'] = field(default_factory=list)
orderMiscOptions: List[TagValue] = field(default_factory=list)
referenceContractId: int = 0
peggedChangeAmount: float = 0.0
isPeggedChangeAmountDecrease: bool = False
referenceChangeAmount: float = 0.0
referenceExchangeId: str = ''
adjustedOrderType: str = ''
triggerPrice: float = UNSET_DOUBLE
adjustedStopPrice: float = UNSET_DOUBLE
adjustedStopLimitPrice: float = UNSET_DOUBLE
adjustedTrailingAmount: float = UNSET_DOUBLE
adjustableTrailingUnit: int = 0
lmtPriceOffset: float = UNSET_DOUBLE
conditions: List['OrderCondition'] = field(default_factory=list)
conditionsCancelOrder: bool = False
conditionsIgnoreRth: bool = False
extOperator: str = ''
softDollarTier: SoftDollarTier = field(default_factory=SoftDollarTier)
cashQty: float = UNSET_DOUBLE
mifid2DecisionMaker: str = ''
mifid2DecisionAlgo: str = ''
mifid2ExecutionTrader: str = ''
mifid2ExecutionAlgo: str = ''
dontUseAutoPriceForHedge: bool = False
isOmsContainer: bool = False
discretionaryUpToLimitPrice: bool = False
autoCancelDate: str = ''
filledQuantity: float = UNSET_DOUBLE
refFuturesConId: int = 0
autoCancelParent: bool = False
shareholder: str = ''
imbalanceOnly: bool = False
routeMarketableToBbo: bool = False
parentPermId: int = 0
usePriceMgmtAlgo: bool = False
duration: int = UNSET_INTEGER
postToAts: int = UNSET_INTEGER
advancedErrorOverride: str = ''
manualOrderTime: str = ''
minTradeQty: int = UNSET_INTEGER
minCompeteSize: int = UNSET_INTEGER
competeAgainstBestOffset: float = UNSET_DOUBLE
midOffsetAtWhole: float = UNSET_DOUBLE
midOffsetAtHalf: float = UNSET_DOUBLE
def __repr__(self):
attrs = dataclassNonDefaults(self)
if self.__class__ is not Order:
attrs.pop('orderType', None)
if not self.softDollarTier:
attrs.pop('softDollarTier')
clsName = self.__class__.__qualname__
kwargs = ', '.join(
f'{k}={v!r}' for k, v in attrs.items())
return f'{clsName}({kwargs})'
__str__ = __repr__
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
class LimitOrder(Order):
def __init__(self, action: str, totalQuantity: float, lmtPrice: float,
**kwargs):
Order.__init__(
self, orderType='LMT', action=action,
totalQuantity=totalQuantity, lmtPrice=lmtPrice, **kwargs)
class MarketOrder(Order):
def __init__(self, action: str, totalQuantity: float, **kwargs):
Order.__init__(
self, orderType='MKT', action=action,
totalQuantity=totalQuantity, **kwargs)
class StopOrder(Order):
def __init__(self, action: str, totalQuantity: float, stopPrice: float,
**kwargs):
Order.__init__(
self, orderType='STP', action=action,
totalQuantity=totalQuantity, auxPrice=stopPrice, **kwargs)
class StopLimitOrder(Order):
def __init__(self, action: str, totalQuantity: float, lmtPrice: float,
stopPrice: float, **kwargs):
Order.__init__(
self, orderType='STP LMT', action=action,
totalQuantity=totalQuantity, lmtPrice=lmtPrice,
auxPrice=stopPrice, **kwargs)
@dataclass
class OrderStatus:
orderId: int = 0
status: str = ''
filled: float = 0.0
remaining: float = 0.0
avgFillPrice: float = 0.0
permId: int = 0
parentId: int = 0
lastFillPrice: float = 0.0
clientId: int = 0
whyHeld: str = ''
mktCapPrice: float = 0.0
PendingSubmit: ClassVar[str] = 'PendingSubmit'
PendingCancel: ClassVar[str] = 'PendingCancel'
PreSubmitted: ClassVar[str] = 'PreSubmitted'
Submitted: ClassVar[str] = 'Submitted'
ApiPending: ClassVar[str] = 'ApiPending'
ApiCancelled: ClassVar[str] = 'ApiCancelled'
Cancelled: ClassVar[str] = 'Cancelled'
Filled: ClassVar[str] = 'Filled'
Inactive: ClassVar[str] = 'Inactive'
DoneStates: ClassVar[Set[str]] = {'Filled', 'Cancelled', 'ApiCancelled'}
ActiveStates: ClassVar[Set[str]] = {
'PendingSubmit', 'ApiPending', 'PreSubmitted', 'Submitted'}
@dataclass
class OrderState:
status: str = ''
initMarginBefore: str = ''
maintMarginBefore: str = ''
equityWithLoanBefore: str = ''
initMarginChange: str = ''
maintMarginChange: str = ''
equityWithLoanChange: str = ''
initMarginAfter: str = ''
maintMarginAfter: str = ''
equityWithLoanAfter: str = ''
commission: float = UNSET_DOUBLE
minCommission: float = UNSET_DOUBLE
maxCommission: float = UNSET_DOUBLE
commissionCurrency: str = ''
warningText: str = ''
completedTime: str = ''
completedStatus: str = ''
@dataclass
class OrderComboLeg:
price: float = UNSET_DOUBLE
@dataclass
class Trade:
"""
Trade keeps track of an order, its status and all its fills.
Events:
* ``statusEvent`` (trade: :class:`.Trade`)
* ``modifyEvent`` (trade: :class:`.Trade`)
* ``fillEvent`` (trade: :class:`.Trade`, fill: :class:`.Fill`)
* ``commissionReportEvent`` (trade: :class:`.Trade`,
fill: :class:`.Fill`, commissionReport: :class:`.CommissionReport`)
* ``filledEvent`` (trade: :class:`.Trade`)
* ``cancelEvent`` (trade: :class:`.Trade`)
* ``cancelledEvent`` (trade: :class:`.Trade`)
"""
events: ClassVar = (
'statusEvent', 'modifyEvent', 'fillEvent',
'commissionReportEvent', 'filledEvent',
'cancelEvent', 'cancelledEvent')
contract: Contract = field(default_factory=Contract)
order: Order = field(default_factory=Order)
orderStatus: 'OrderStatus' = field(default_factory=OrderStatus)
fills: List[Fill] = field(default_factory=list)
log: List[TradeLogEntry] = field(default_factory=list)
advancedError: str = ''
def __post_init__(self):
self.statusEvent = Event('statusEvent')
self.modifyEvent = Event('modifyEvent')
self.fillEvent = Event('fillEvent')
self.commissionReportEvent = Event('commissionReportEvent')
self.filledEvent = Event('filledEvent')
self.cancelEvent = Event('cancelEvent')
self.cancelledEvent = Event('cancelledEvent')
def isActive(self):
"""True if eligible for execution, false otherwise."""
return self.orderStatus.status in OrderStatus.ActiveStates
def isDone(self):
"""True if completely filled or cancelled, false otherwise."""
return self.orderStatus.status in OrderStatus.DoneStates
def filled(self):
"""Number of shares filled."""
fills = self.fills
if self.contract.secType == 'BAG':
# don't count fills for the leg contracts
fills = [f for f in fills if f.contract.secType == 'BAG']
return sum(f.execution.shares for f in fills)
def remaining(self):
"""Number of shares remaining to be filled."""
return self.order.totalQuantity - self.filled()
class BracketOrder(NamedTuple):
parent: Order
takeProfit: Order
stopLoss: Order
@dataclass
class OrderCondition:
@staticmethod
def createClass(condType):
d = {
1: PriceCondition,
3: TimeCondition,
4: MarginCondition,
5: ExecutionCondition,
6: VolumeCondition,
7: PercentChangeCondition}
return d[condType]
def And(self):
self.conjunction = 'a'
return self
def Or(self):
self.conjunction = 'o'
return self
@dataclass
class PriceCondition(OrderCondition):
condType: int = 1
conjunction: str = 'a'
isMore: bool = True
price: float = 0.0
conId: int = 0
exch: str = ''
triggerMethod: int = 0
@dataclass
class TimeCondition(OrderCondition):
condType: int = 3
conjunction: str = 'a'
isMore: bool = True
time: str = ''
@dataclass
class MarginCondition(OrderCondition):
condType: int = 4
conjunction: str = 'a'
isMore: bool = True
percent: int = 0
@dataclass
class ExecutionCondition(OrderCondition):
condType: int = 5
conjunction: str = 'a'
secType: str = ''
exch: str = ''
symbol: str = ''
@dataclass
class VolumeCondition(OrderCondition):
condType: int = 6
conjunction: str = 'a'
isMore: bool = True
volume: int = 0
conId: int = 0
exch: str = ''
@dataclass
class PercentChangeCondition(OrderCondition):
condType: int = 7
conjunction: str = 'a'
isMore: bool = True
changePercent: float = 0.0
conId: int = 0
exch: str = ''

View File

@@ -0,0 +1,362 @@
"""Access to realtime market information."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import ClassVar, List, Optional, Union
from eventkit import Event, Op
from ib_insync.contract import Contract
from ib_insync.objects import (
DOMLevel, Dividends, FundamentalRatios, MktDepthData,
OptionComputation, TickByTickAllLast, TickByTickBidAsk, TickByTickMidPoint,
TickData)
from ib_insync.util import dataclassRepr, isNan
nan = float('nan')
@dataclass
class Ticker:
"""
Current market data such as bid, ask, last price, etc. for a contract.
Streaming level-1 ticks of type :class:`.TickData` are stored in
the ``ticks`` list.
Streaming level-2 ticks of type :class:`.MktDepthData` are stored in the
``domTicks`` list. The order book (DOM) is available as lists of
:class:`.DOMLevel` in ``domBids`` and ``domAsks``.
Streaming tick-by-tick ticks are stored in ``tickByTicks``.
For options the :class:`.OptionComputation` values for the bid, ask, resp.
last price are stored in the ``bidGreeks``, ``askGreeks`` resp.
``lastGreeks`` attributes. There is also ``modelGreeks`` that conveys
the greeks as calculated by Interactive Brokers' option model.
Events:
* ``updateEvent`` (ticker: :class:`.Ticker`)
"""
events: ClassVar = ('updateEvent',)
contract: Optional[Contract] = None
time: Optional[datetime] = None
marketDataType: int = 1
minTick: float = nan
bid: float = nan
bidSize: float = nan
bidExchange: str = ''
ask: float = nan
askSize: float = nan
askExchange: str = ''
last: float = nan
lastSize: float = nan
lastExchange: str = ''
prevBid: float = nan
prevBidSize: float = nan
prevAsk: float = nan
prevAskSize: float = nan
prevLast: float = nan
prevLastSize: float = nan
volume: float = nan
open: float = nan
high: float = nan
low: float = nan
close: float = nan
vwap: float = nan
low13week: float = nan
high13week: float = nan
low26week: float = nan
high26week: float = nan
low52week: float = nan
high52week: float = nan
bidYield: float = nan
askYield: float = nan
lastYield: float = nan
markPrice: float = nan
halted: float = nan
rtHistVolatility: float = nan
rtVolume: float = nan
rtTradeVolume: float = nan
rtTime: Optional[datetime] = None
avVolume: float = nan
tradeCount: float = nan
tradeRate: float = nan
volumeRate: float = nan
shortableShares: float = nan
indexFuturePremium: float = nan
futuresOpenInterest: float = nan
putOpenInterest: float = nan
callOpenInterest: float = nan
putVolume: float = nan
callVolume: float = nan
avOptionVolume: float = nan
histVolatility: float = nan
impliedVolatility: float = nan
dividends: Optional[Dividends] = None
fundamentalRatios: Optional[FundamentalRatios] = None
ticks: List[TickData] = field(default_factory=list)
tickByTicks: List[Union[
TickByTickAllLast, TickByTickBidAsk, TickByTickMidPoint]] = \
field(default_factory=list)
domBids: List[DOMLevel] = field(default_factory=list)
domAsks: List[DOMLevel] = field(default_factory=list)
domTicks: List[MktDepthData] = field(default_factory=list)
bidGreeks: Optional[OptionComputation] = None
askGreeks: Optional[OptionComputation] = None
lastGreeks: Optional[OptionComputation] = None
modelGreeks: Optional[OptionComputation] = None
auctionVolume: float = nan
auctionPrice: float = nan
auctionImbalance: float = nan
regulatoryImbalance: float = nan
bboExchange: str = ''
snapshotPermissions: int = 0
def __post_init__(self):
self.updateEvent = TickerUpdateEvent('updateEvent')
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
__repr__ = dataclassRepr
__str__ = dataclassRepr
def hasBidAsk(self) -> bool:
"""See if this ticker has a valid bid and ask."""
return (
self.bid != -1 and not isNan(self.bid) and self.bidSize > 0
and self.ask != -1 and not isNan(self.ask) and self.askSize > 0)
def midpoint(self) -> float:
"""
Return average of bid and ask, or NaN if no valid bid and ask
are available.
"""
return (self.bid + self.ask) * 0.5 if self.hasBidAsk() else nan
def marketPrice(self) -> float:
"""
Return the first available one of
* last price if within current bid/ask or no bid/ask available;
* average of bid and ask (midpoint).
"""
if self.hasBidAsk():
if self.bid <= self.last <= self.ask:
price = self.last
else:
price = self.midpoint()
else:
price = self.last
return price
class TickerUpdateEvent(Event):
__slots__ = ()
def trades(self) -> "Tickfilter":
"""Emit trade ticks."""
return Tickfilter((4, 5, 48, 68, 71), self)
def bids(self) -> "Tickfilter":
"""Emit bid ticks."""
return Tickfilter((0, 1, 66, 69), self)
def asks(self) -> "Tickfilter":
"""Emit ask ticks."""
return Tickfilter((2, 3, 67, 70), self)
def bidasks(self) -> "Tickfilter":
"""Emit bid and ask ticks."""
return Tickfilter((0, 1, 66, 69, 2, 3, 67, 70), self)
def midpoints(self) -> "Tickfilter":
"""Emit midpoint ticks."""
return Midpoints((), self)
class Tickfilter(Op):
"""Tick filtering event operators that ``emit(time, price, size)``."""
__slots__ = ('_tickTypes',)
def __init__(self, tickTypes, source=None):
Op.__init__(self, source)
self._tickTypes = set(tickTypes)
def on_source(self, ticker):
for t in ticker.ticks:
if t.tickType in self._tickTypes:
self.emit(t.time, t.price, t.size)
def timebars(self, timer: Event) -> "TimeBars":
"""
Aggregate ticks into time bars, where the timing of new bars
is derived from a timer event.
Emits a completed :class:`Bar`.
This event stores a :class:`BarList` of all created bars in the
``bars`` property.
Args:
timer: Event for timing when a new bar starts.
"""
return TimeBars(timer, self)
def tickbars(self, count: int) -> "TickBars":
"""
Aggregate ticks into bars that have the same number of ticks.
Emits a completed :class:`Bar`.
This event stores a :class:`BarList` of all created bars in the
``bars`` property.
Args:
count: Number of ticks to use to form one bar.
"""
return TickBars(count, self)
def volumebars(self, volume: int) -> "VolumeBars":
"""
Aggregate ticks into bars that have the same volume.
Emits a completed :class:`Bar`.
This event stores a :class:`BarList` of all created bars in the
``bars`` property.
Args:
count: Number of ticks to use to form one bar.
"""
return VolumeBars(volume, self)
class Midpoints(Tickfilter):
__slots__ = ()
def on_source(self, ticker):
if ticker.ticks:
self.emit(ticker.time, ticker.midpoint(), 0)
@dataclass
class Bar:
time: Optional[datetime]
open: float = nan
high: float = nan
low: float = nan
close: float = nan
volume: int = 0
count: int = 0
class BarList(List[Bar]):
def __init__(self, *args):
super().__init__(*args)
self.updateEvent = Event('updateEvent')
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
class TimeBars(Op):
__slots__ = ('_timer', 'bars',)
__doc__ = Tickfilter.timebars.__doc__
bars: BarList
def __init__(self, timer, source=None):
Op.__init__(self, source)
self._timer = timer
self._timer.connect(self._on_timer, None, self._on_timer_done)
self.bars = BarList()
def on_source(self, time, price, size):
if not self.bars:
return
bar = self.bars[-1]
if isNan(bar.open):
bar.open = bar.high = bar.low = price
bar.high = max(bar.high, price)
bar.low = min(bar.low, price)
bar.close = price
bar.volume += size
bar.count += 1
self.bars.updateEvent.emit(self.bars, False)
def _on_timer(self, time):
if self.bars:
bar = self.bars[-1]
if isNan(bar.close) and len(self.bars) > 1:
bar.open = bar.high = bar.low = bar.close = \
self.bars[-2].close
self.bars.updateEvent.emit(self.bars, True)
self.emit(bar)
self.bars.append(Bar(time))
def _on_timer_done(self, timer):
self._timer = None
self.set_done()
class TickBars(Op):
__slots__ = ('_count', 'bars')
__doc__ = Tickfilter.tickbars.__doc__
bars: BarList
def __init__(self, count, source=None):
Op.__init__(self, source)
self._count = count
self.bars = BarList()
def on_source(self, time, price, size):
if not self.bars or self.bars[-1].count == self._count:
bar = Bar(time, price, price, price, price, size, 1)
self.bars.append(bar)
else:
bar = self.bars[-1]
bar.high = max(bar.high, price)
bar.low = min(bar.low, price)
bar.close = price
bar.volume += size
bar.count += 1
if bar.count == self._count:
self.bars.updateEvent.emit(self.bars, True)
self.emit(self.bars)
class VolumeBars(Op):
__slots__ = ('_volume', 'bars')
__doc__ = Tickfilter.volumebars.__doc__
bars: BarList
def __init__(self, volume, source=None):
Op.__init__(self, source)
self._volume = volume
self.bars = BarList()
def on_source(self, time, price, size):
if not self.bars or self.bars[-1].volume >= self._volume:
bar = Bar(time, price, price, price, price, size, 1)
self.bars.append(bar)
else:
bar = self.bars[-1]
bar.high = max(bar.high, price)
bar.low = min(bar.low, price)
bar.close = price
bar.volume += size
bar.count += 1
if bar.volume >= self._volume:
self.bars.updateEvent.emit(self.bars, True)
self.emit(self.bars)

View File

@@ -0,0 +1,550 @@
"""Utilities."""
import asyncio
import datetime as dt
import logging
import math
import signal
import sys
import time
from dataclasses import fields, is_dataclass
from typing import (
AsyncIterator, Awaitable, Callable, Iterator, List, Optional, Union)
import eventkit as ev
try:
from zoneinfo import ZoneInfo
except ImportError:
from backports.zoneinfo import ZoneInfo # type: ignore
globalErrorEvent = ev.Event()
"""
Event to emit global exceptions.
"""
EPOCH = dt.datetime(1970, 1, 1, tzinfo=dt.timezone.utc)
UNSET_INTEGER = 2 ** 31 - 1
UNSET_DOUBLE = sys.float_info.max
Time_t = Union[dt.time, dt.datetime]
def df(objs, labels: Optional[List[str]] = None):
"""
Create pandas DataFrame from the sequence of same-type objects.
Args:
labels: If supplied, retain only the given labels and drop the rest.
"""
import pandas as pd
from .objects import DynamicObject
if objs:
objs = list(objs)
obj = objs[0]
if is_dataclass(obj):
df = pd.DataFrame.from_records(dataclassAsTuple(o) for o in objs)
df.columns = [field.name for field in fields(obj)]
elif isinstance(obj, DynamicObject):
df = pd.DataFrame.from_records(o.__dict__ for o in objs)
else:
df = pd.DataFrame.from_records(objs)
if isinstance(obj, tuple):
_fields = getattr(obj, '_fields', None)
if _fields:
# assume it's a namedtuple
df.columns = _fields
else:
df = None
if labels:
exclude = [label for label in df if label not in labels]
df = df.drop(exclude, axis=1)
return df
def dataclassAsDict(obj) -> dict:
"""
Return dataclass values as ``dict``.
This is a non-recursive variant of ``dataclasses.asdict``.
"""
if not is_dataclass(obj):
raise TypeError(f'Object {obj} is not a dataclass')
return {field.name: getattr(obj, field.name) for field in fields(obj)}
def dataclassAsTuple(obj) -> tuple:
"""
Return dataclass values as ``tuple``.
This is a non-recursive variant of ``dataclasses.astuple``.
"""
if not is_dataclass(obj):
raise TypeError(f'Object {obj} is not a dataclass')
return tuple(getattr(obj, field.name) for field in fields(obj))
def dataclassNonDefaults(obj) -> dict:
"""
For a ``dataclass`` instance get the fields that are different from the
default values and return as ``dict``.
"""
if not is_dataclass(obj):
raise TypeError(f'Object {obj} is not a dataclass')
values = [getattr(obj, field.name) for field in fields(obj)]
return {
field.name: value for field, value in zip(fields(obj), values)
if value != field.default
and value == value
and not (isinstance(value, list) and value == [])}
def dataclassUpdate(obj, *srcObjs, **kwargs) -> object:
"""
Update fields of the given ``dataclass`` object from zero or more
``dataclass`` source objects and/or from keyword arguments.
"""
if not is_dataclass(obj):
raise TypeError(f'Object {obj} is not a dataclass')
for srcObj in srcObjs:
obj.__dict__.update(dataclassAsDict(srcObj))
obj.__dict__.update(**kwargs)
return obj
def dataclassRepr(obj) -> str:
"""
Provide a culled representation of the given ``dataclass`` instance,
showing only the fields with a non-default value.
"""
attrs = dataclassNonDefaults(obj)
clsName = obj.__class__.__qualname__
kwargs = ', '.join(f'{k}={v!r}' for k, v in attrs.items())
return f'{clsName}({kwargs})'
def isnamedtupleinstance(x):
"""From https://stackoverflow.com/a/2166841/6067848"""
t = type(x)
b = t.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(t, '_fields', None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
def tree(obj):
"""
Convert object to a tree of lists, dicts and simple values.
The result can be serialized to JSON.
"""
if isinstance(obj, (bool, int, float, str, bytes)):
return obj
elif isinstance(obj, (dt.date, dt.time)):
return obj.isoformat()
elif isinstance(obj, dict):
return {k: tree(v) for k, v in obj.items()}
elif isnamedtupleinstance(obj):
return {f: tree(getattr(obj, f)) for f in obj._fields}
elif isinstance(obj, (list, tuple, set)):
return [tree(i) for i in obj]
elif is_dataclass(obj):
return {obj.__class__.__qualname__: tree(dataclassNonDefaults(obj))}
else:
return str(obj)
def barplot(bars, title='', upColor='blue', downColor='red'):
"""
Create candlestick plot for the given bars. The bars can be given as
a DataFrame or as a list of bar objects.
"""
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
if isinstance(bars, pd.DataFrame):
ohlcTups = [
tuple(v) for v in bars[['open', 'high', 'low', 'close']].values]
elif bars and hasattr(bars[0], 'open_'):
ohlcTups = [(b.open_, b.high, b.low, b.close) for b in bars]
else:
ohlcTups = [(b.open, b.high, b.low, b.close) for b in bars]
fig, ax = plt.subplots()
ax.set_title(title)
ax.grid(True)
fig.set_size_inches(10, 6)
for n, (open_, high, low, close) in enumerate(ohlcTups):
if close >= open_:
color = upColor
bodyHi, bodyLo = close, open_
else:
color = downColor
bodyHi, bodyLo = open_, close
line = Line2D(
xdata=(n, n),
ydata=(low, bodyLo),
color=color,
linewidth=1)
ax.add_line(line)
line = Line2D(
xdata=(n, n),
ydata=(high, bodyHi),
color=color,
linewidth=1)
ax.add_line(line)
rect = Rectangle(
xy=(n - 0.3, bodyLo),
width=0.6,
height=bodyHi - bodyLo,
edgecolor=color,
facecolor=color,
alpha=0.4,
antialiased=True
)
ax.add_patch(rect)
ax.autoscale_view()
return fig
def allowCtrlC():
"""Allow Control-C to end program."""
signal.signal(signal.SIGINT, signal.SIG_DFL)
def logToFile(path, level=logging.INFO):
"""Create a log handler that logs to the given file."""
logger = logging.getLogger()
if logger.handlers:
logging.getLogger('ib_insync').setLevel(level)
else:
logger.setLevel(level)
formatter = logging.Formatter(
'%(asctime)s %(name)s %(levelname)s %(message)s')
handler = logging.FileHandler(path)
handler.setFormatter(formatter)
logger.addHandler(handler)
def logToConsole(level=logging.INFO):
"""Create a log handler that logs to the console."""
logger = logging.getLogger()
stdHandlers = [
h for h in logger.handlers
if type(h) is logging.StreamHandler and h.stream is sys.stderr]
if stdHandlers:
# if a standard stream handler already exists, use it and
# set the log level for the ib_insync namespace only
logging.getLogger('ib_insync').setLevel(level)
else:
# else create a new handler
logger.setLevel(level)
formatter = logging.Formatter(
'%(asctime)s %(name)s %(levelname)s %(message)s')
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
def isNan(x: float) -> bool:
"""Not a number test."""
return x != x
def formatSI(n: float) -> str:
"""Format the integer or float n to 3 significant digits + SI prefix."""
s = ''
if n < 0:
n = -n
s += '-'
if type(n) is int and n < 1000:
s = str(n) + ' '
elif n < 1e-22:
s = '0.00 '
else:
assert n < 9.99e26
log = int(math.floor(math.log10(n)))
i, j = divmod(log, 3)
for _try in range(2):
templ = '%.{}f'.format(2 - j)
val = templ % (n * 10 ** (-3 * i))
if val != '1000':
break
i += 1
j = 0
s += val + ' '
if i != 0:
s += 'yzafpnum kMGTPEZY'[i + 8]
return s
class timeit:
"""Context manager for timing."""
def __init__(self, title='Run'):
self.title = title
def __enter__(self):
self.t0 = time.time()
def __exit__(self, *_args):
print(self.title + ' took ' + formatSI(time.time() - self.t0) + 's')
def run(*awaitables: Awaitable, timeout: Optional[float] = None):
"""
By default run the event loop forever.
When awaitables (like Tasks, Futures or coroutines) are given then
run the event loop until each has completed and return their results.
An optional timeout (in seconds) can be given that will raise
asyncio.TimeoutError if the awaitables are not ready within the
timeout period.
"""
loop = getLoop()
if not awaitables:
if loop.is_running():
return
loop.run_forever()
result = None
if sys.version_info >= (3, 7):
all_tasks = asyncio.all_tasks(loop) # type: ignore
else:
all_tasks = asyncio.Task.all_tasks() # type: ignore
if all_tasks:
# cancel pending tasks
f = asyncio.gather(*all_tasks)
f.cancel()
try:
loop.run_until_complete(f)
except asyncio.CancelledError:
pass
else:
if len(awaitables) == 1:
future = awaitables[0]
else:
future = asyncio.gather(*awaitables)
if timeout:
future = asyncio.wait_for(future, timeout)
task = asyncio.ensure_future(future)
def onError(_):
task.cancel()
globalErrorEvent.connect(onError)
try:
result = loop.run_until_complete(task)
except asyncio.CancelledError as e:
raise globalErrorEvent.value() or e
finally:
globalErrorEvent.disconnect(onError)
return result
def _fillDate(time: Time_t) -> dt.datetime:
# use today if date is absent
if isinstance(time, dt.time):
t = dt.datetime.combine(dt.date.today(), time)
else:
t = time
return t
def schedule(time: Time_t, callback: Callable, *args):
"""
Schedule the callback to be run at the given time with
the given arguments.
This will return the Event Handle.
Args:
time: Time to run callback. If given as :py:class:`datetime.time`
then use today as date.
callback: Callable scheduled to run.
args: Arguments for to call callback with.
"""
t = _fillDate(time)
now = dt.datetime.now(t.tzinfo)
delay = (t - now).total_seconds()
loop = getLoop()
return loop.call_later(delay, callback, *args)
def sleep(secs: float = 0.02) -> bool:
"""
Wait for the given amount of seconds while everything still keeps
processing in the background. Never use time.sleep().
Args:
secs (float): Time in seconds to wait.
"""
run(asyncio.sleep(secs))
return True
def timeRange(start: Time_t, end: Time_t, step: float) \
-> Iterator[dt.datetime]:
"""
Iterator that waits periodically until certain time points are
reached while yielding those time points.
Args:
start: Start time, can be specified as datetime.datetime,
or as datetime.time in which case today is used as the date
end: End time, can be specified as datetime.datetime,
or as datetime.time in which case today is used as the date
step (float): The number of seconds of each period
"""
assert step > 0
delta = dt.timedelta(seconds=step)
t = _fillDate(start)
tz = dt.timezone.utc if t.tzinfo else None
now = dt.datetime.now(tz)
while t < now:
t += delta
while t <= _fillDate(end):
waitUntil(t)
yield t
t += delta
def waitUntil(t: Time_t) -> bool:
"""
Wait until the given time t is reached.
Args:
t: The time t can be specified as datetime.datetime,
or as datetime.time in which case today is used as the date.
"""
now = dt.datetime.now(t.tzinfo)
secs = (_fillDate(t) - now).total_seconds()
run(asyncio.sleep(secs))
return True
async def timeRangeAsync(
start: Time_t,
end: Time_t,
step: float) -> AsyncIterator[dt.datetime]:
"""Async version of :meth:`timeRange`."""
assert step > 0
delta = dt.timedelta(seconds=step)
t = _fillDate(start)
tz = dt.timezone.utc if t.tzinfo else None
now = dt.datetime.now(tz)
while t < now:
t += delta
while t <= _fillDate(end):
await waitUntilAsync(t)
yield t
t += delta
async def waitUntilAsync(t: Time_t) -> bool:
"""Async version of :meth:`waitUntil`."""
now = dt.datetime.now(t.tzinfo)
secs = (_fillDate(t) - now).total_seconds()
await asyncio.sleep(secs)
return True
def patchAsyncio():
"""Patch asyncio to allow nested event loops."""
import nest_asyncio
nest_asyncio.apply()
def getLoop():
"""Get the asyncio event loop for the current thread."""
return asyncio.get_event_loop_policy().get_event_loop()
def startLoop():
"""Use nested asyncio event loop for Jupyter notebooks."""
patchAsyncio()
def useQt(qtLib: str = 'PyQt5', period: float = 0.01):
"""
Run combined Qt5/asyncio event loop.
Args:
qtLib: Name of Qt library to use:
* PyQt5
* PyQt6
* PySide2
* PySide6
period: Period in seconds to poll Qt.
"""
def qt_step():
loop.call_later(period, qt_step)
if not stack:
qloop = qc.QEventLoop()
timer = qc.QTimer()
timer.timeout.connect(qloop.quit)
stack.append((qloop, timer))
qloop, timer = stack.pop()
timer.start(0)
qloop.exec() if qtLib == 'PyQt6' else qloop.exec_()
timer.stop()
stack.append((qloop, timer))
qApp.processEvents() # type: ignore
if qtLib not in ('PyQt5', 'PyQt6', 'PySide2', 'PySide6'):
raise RuntimeError(f'Unknown Qt library: {qtLib}')
from importlib import import_module
qc = import_module(qtLib + '.QtCore')
qw = import_module(qtLib + '.QtWidgets')
global qApp
qApp = (qw.QApplication.instance() # type: ignore
or qw.QApplication(sys.argv)) # type: ignore
loop = getLoop()
stack: list = []
qt_step()
def formatIBDatetime(t: Union[dt.date, dt.datetime, str, None]) -> str:
"""Format date or datetime to string that IB uses."""
if not t:
s = ''
elif isinstance(t, dt.datetime):
# convert to UTC timezone
t = t.astimezone(tz=dt.timezone.utc)
s = t.strftime('%Y%m%d %H:%M:%S UTC')
elif isinstance(t, dt.date):
t = dt.datetime(
t.year, t.month, t.day, 23, 59, 59).astimezone(tz=dt.timezone.utc)
s = t.strftime('%Y%m%d %H:%M:%S UTC')
else:
s = t
return s
def parseIBDatetime(s: str) -> Union[dt.date, dt.datetime]:
"""Parse string in IB date or datetime format to datetime."""
if len(s) == 8:
# YYYYmmdd
y = int(s[0:4])
m = int(s[4:6])
d = int(s[6:8])
t = dt.date(y, m, d)
elif s.isdigit():
t = dt.datetime.fromtimestamp(int(s), dt.timezone.utc)
elif s.count(' ') >= 2 and ' ' not in s:
# 20221125 10:00:00 Europe/Amsterdam
s0, s1, s2 = s.split(' ', 2)
t = dt.datetime.strptime(s0 + s1, '%Y%m%d%H:%M:%S')
t = t.replace(tzinfo=ZoneInfo(s2))
else:
# YYYYmmdd HH:MM:SS
# or
# YYYY-mm-dd HH:MM:SS.0
ss = s.replace(' ', '').replace('-', '')[:16]
t = dt.datetime.strptime(ss, '%Y%m%d%H:%M:%S')
return t

View File

@@ -0,0 +1,4 @@
"""Version info."""
__version_info__ = (0, 9, 86)
__version__ = '.'.join(str(v) for v in __version_info__)

Some files were not shown because too many files have changed in this diff Show More