updating for download
This commit is contained in:
@@ -1,48 +0,0 @@
|
||||
2025-03-26 03:01:52,506 - INFO - ===== Resource Statistics =====
|
||||
2025-03-26 03:01:52,506 - INFO - Physical CPU Cores: 28
|
||||
2025-03-26 03:01:52,506 - INFO - Logical CPU Cores: 56
|
||||
2025-03-26 03:01:52,507 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
|
||||
2025-03-26 03:01:52,507 - INFO - No GPUs detected.
|
||||
2025-03-26 03:01:52,507 - INFO - =================================
|
||||
2025-03-26 03:01:52,507 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-03-26 03:01:52,508 - INFO - Loading data from: data/MES2023Z.csv
|
||||
2025-03-26 03:01:52,513 - ERROR - Unexpected error: Missing column provided to 'parse_dates': 'time'
|
||||
2025-03-26 03:04:50,616 - INFO - ===== Resource Statistics =====
|
||||
2025-03-26 03:04:50,616 - INFO - Physical CPU Cores: 28
|
||||
2025-03-26 03:04:50,616 - INFO - Logical CPU Cores: 56
|
||||
2025-03-26 03:04:50,616 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]%
|
||||
2025-03-26 03:04:50,617 - INFO - No GPUs detected.
|
||||
2025-03-26 03:04:50,617 - INFO - =================================
|
||||
2025-03-26 03:04:50,617 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-03-26 03:04:50,618 - INFO - Loading data from: data/MES2023Z.csv
|
||||
2025-03-26 03:04:50,621 - ERROR - Unexpected error: Missing column provided to 'parse_dates': 'time'
|
||||
2025-03-26 03:08:02,316 - INFO - ===== Resource Statistics =====
|
||||
2025-03-26 03:08:02,316 - INFO - Physical CPU Cores: 28
|
||||
2025-03-26 03:08:02,316 - INFO - Logical CPU Cores: 56
|
||||
2025-03-26 03:08:02,317 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
|
||||
2025-03-26 03:08:02,317 - INFO - No GPUs detected.
|
||||
2025-03-26 03:08:02,317 - INFO - =================================
|
||||
2025-03-26 03:08:02,317 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-03-26 03:08:02,318 - INFO - Loading data from: data/MES2023Z.csv
|
||||
2025-03-26 03:08:02,355 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||
2025-03-26 03:08:02,383 - INFO - Data loaded and sorted successfully.
|
||||
2025-03-26 03:08:02,383 - INFO - Calculating technical indicators...
|
||||
2025-03-26 03:08:02,448 - INFO - Technical indicators calculated successfully.
|
||||
2025-03-26 03:08:02,464 - INFO - Starting parallel feature engineering with 54 workers...
|
||||
2025-03-26 03:08:03,331 - INFO - Parallel feature engineering completed.
|
||||
2025-03-26 03:08:03,341 - INFO - Training sequences shape: (676, 15, 17)
|
||||
2025-03-26 03:08:03,342 - INFO - Validation sequences shape: (144, 15, 17)
|
||||
2025-03-26 03:08:03,342 - INFO - Testing sequences shape: (146, 15, 17)
|
||||
2025-03-26 03:08:03,342 - INFO - Starting LSTM hyperparameter optimization with Optuna using 54 parallel trials...
|
||||
2025-03-26 03:22:04,033 - INFO - Best LSTM Hyperparameters: {'num_lstm_layers': 2, 'lstm_units': 64, 'dropout_rate': 0.13619292923712067, 'learning_rate': 0.0030545284525912166, 'optimizer': 'Nadam', 'decay': 9.615099767236892e-05}
|
||||
2025-03-26 03:22:04,553 - INFO - Training best LSTM model with optimized hyperparameters...
|
||||
2025-03-26 03:24:28,296 - INFO - Evaluating final LSTM model...
|
||||
2025-03-26 03:24:29,722 - INFO - Test MSE: 0.3437
|
||||
2025-03-26 03:24:29,722 - INFO - Test RMSE: 0.5862
|
||||
2025-03-26 03:24:29,722 - INFO - Test MAE: 0.4561
|
||||
2025-03-26 03:24:29,722 - INFO - Test R2 Score: 0.8620
|
||||
2025-03-26 03:24:29,722 - INFO - Directional Accuracy: 0.2759
|
||||
2025-03-26 03:24:30,013 - WARNING - You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`.
|
||||
2025-03-26 03:24:30,121 - INFO - Saved best LSTM model and scaler objects.
|
||||
2025-03-26 03:24:30,150 - INFO - Starting PPO training...
|
||||
2025-03-26 05:47:15,571 - INFO - PPO training completed and model saved.
|
||||
Binary file not shown.
@@ -33,6 +33,8 @@ from multiprocessing import Pool, cpu_count
|
||||
import threading
|
||||
import time
|
||||
|
||||
VERSION = 1
|
||||
|
||||
# Suppress TensorFlow logs beyond errors
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
|
||||
748
src/MidasAgent/PPO/LSTMPPOattempt2.py
Normal file
748
src/MidasAgent/PPO/LSTMPPOattempt2.py
Normal file
@@ -0,0 +1,748 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from tabulate import tabulate
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import psutil
|
||||
import GPUtil
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Sequential, load_model
|
||||
from tensorflow.keras.layers import LSTM, Dense, Dropout, Bidirectional
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
||||
from tensorflow.keras.losses import Huber
|
||||
from tensorflow.keras.regularizers import l2
|
||||
from tensorflow.keras.optimizers import Adam, Nadam
|
||||
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
import joblib
|
||||
import optuna
|
||||
from optuna.integration import KerasPruningCallback
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import threading
|
||||
import time
|
||||
|
||||
# Suppress TensorFlow logs beyond errors
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
# =============================================================================
|
||||
# Resource Detection Functions
|
||||
# =============================================================================
|
||||
def get_cpu_info():
|
||||
cpu_count_physical = psutil.cpu_count(logical=False) # Physical cores
|
||||
cpu_count_logical = psutil.cpu_count(logical=True) # Logical cores
|
||||
cpu_percent = psutil.cpu_percent(interval=1, percpu=True)
|
||||
return {
|
||||
'physical_cores': cpu_count_physical,
|
||||
'logical_cores': cpu_count_logical,
|
||||
'cpu_percent': cpu_percent
|
||||
}
|
||||
|
||||
def get_gpu_info():
|
||||
gpus = GPUtil.getGPUs()
|
||||
gpu_info = []
|
||||
for gpu in gpus:
|
||||
gpu_info.append({
|
||||
'id': gpu.id,
|
||||
'name': gpu.name,
|
||||
'load': gpu.load * 100, # Convert to percentage
|
||||
'memory_total': gpu.memoryTotal,
|
||||
'memory_used': gpu.memoryUsed,
|
||||
'memory_free': gpu.memoryFree,
|
||||
'temperature': gpu.temperature
|
||||
})
|
||||
return gpu_info
|
||||
|
||||
def configure_tensorflow(cpu_stats, gpu_stats):
|
||||
logical_cores = cpu_stats['logical_cores']
|
||||
os.environ["OMP_NUM_THREADS"] = str(logical_cores)
|
||||
os.environ["TF_NUM_INTRAOP_THREADS"] = str(logical_cores)
|
||||
os.environ["TF_NUM_INTEROP_THREADS"] = str(logical_cores)
|
||||
|
||||
if gpu_stats:
|
||||
gpus = tf.config.list_physical_devices('GPU')
|
||||
if gpus:
|
||||
try:
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logging.info(f"Enabled memory growth for {len(gpus)} GPU(s).")
|
||||
except RuntimeError as e:
|
||||
logging.error(f"TensorFlow GPU configuration error: {e}")
|
||||
else:
|
||||
tf.config.threading.set_intra_op_parallelism_threads(logical_cores)
|
||||
tf.config.threading.set_inter_op_parallelism_threads(logical_cores)
|
||||
logging.info("Configured TensorFlow to use CPU with optimized thread settings.")
|
||||
|
||||
def monitor_resources(interval=60):
|
||||
while True:
|
||||
cpu = psutil.cpu_percent(interval=1, percpu=True)
|
||||
gpu = get_gpu_info()
|
||||
logging.info(f"CPU Usage per Core: {cpu}%")
|
||||
if gpu:
|
||||
for gpu_stat in gpu:
|
||||
logging.info(f"GPU {gpu_stat['id']} - {gpu_stat['name']}: Load: {gpu_stat['load']}%, "
|
||||
f"Memory Used: {gpu_stat['memory_used']}MB / {gpu_stat['memory_total']}MB, "
|
||||
f"Temperature: {gpu_stat['temperature']}°C")
|
||||
else:
|
||||
logging.info("No GPUs detected.")
|
||||
logging.info("-" * 50)
|
||||
time.sleep(interval)
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading & Technical Indicators
|
||||
# =============================================================================
|
||||
def load_data(file_path):
|
||||
logging.info(f"Loading data from: {file_path}")
|
||||
try:
|
||||
df = pd.read_csv(file_path, parse_dates=['time'])
|
||||
except FileNotFoundError:
|
||||
logging.error(f"File not found: {file_path}")
|
||||
sys.exit(1)
|
||||
except pd.errors.ParserError as e:
|
||||
logging.error(f"Error parsing CSV file: {e}")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logging.error(f"Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
rename_mapping = {
|
||||
'time': 'Date',
|
||||
'open': 'Open',
|
||||
'high': 'High',
|
||||
'low': 'Low',
|
||||
'close': 'Close'
|
||||
}
|
||||
df.rename(columns=rename_mapping, inplace=True)
|
||||
logging.info(f"Data columns after renaming: {df.columns.tolist()}")
|
||||
df.sort_values('Date', inplace=True)
|
||||
df.reset_index(drop=True, inplace=True)
|
||||
logging.info("Data loaded and sorted successfully.")
|
||||
return df
|
||||
|
||||
def compute_rsi(series, window=14):
|
||||
delta = series.diff()
|
||||
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
|
||||
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
|
||||
RS = gain / (loss + 1e-9)
|
||||
return 100 - (100 / (1 + RS))
|
||||
|
||||
def compute_macd(series, span_short=12, span_long=26, span_signal=9):
|
||||
ema_short = series.ewm(span=span_short, adjust=False).mean()
|
||||
ema_long = series.ewm(span=span_long, adjust=False).mean()
|
||||
macd_line = ema_short - ema_long
|
||||
signal_line = macd_line.ewm(span=span_signal, adjust=False).mean()
|
||||
return macd_line - signal_line # histogram
|
||||
|
||||
def compute_obv(df):
|
||||
signed_volume = (np.sign(df['Close'].diff()) * df['Volume']).fillna(0)
|
||||
return signed_volume.cumsum()
|
||||
|
||||
def compute_adx(df, window=14):
|
||||
df['H-L'] = df['High'] - df['Low']
|
||||
df['H-Cp'] = (df['High'] - df['Close'].shift(1)).abs()
|
||||
df['L-Cp'] = (df['Low'] - df['Close'].shift(1)).abs()
|
||||
tr = df[['H-L','H-Cp','L-Cp']].max(axis=1)
|
||||
tr_rolling = tr.rolling(window=window).mean()
|
||||
adx_placeholder = tr_rolling / (df['Close'] + 1e-9)
|
||||
df.drop(['H-L','H-Cp','L-Cp'], axis=1, inplace=True)
|
||||
return adx_placeholder
|
||||
|
||||
def compute_bollinger_bands(series, window=20, num_std=2):
|
||||
sma = series.rolling(window=window).mean()
|
||||
std = series.rolling(window=window).std()
|
||||
upper = sma + num_std * std
|
||||
lower = sma - num_std * std
|
||||
bandwidth = (upper - lower) / (sma + 1e-9)
|
||||
return upper, lower, bandwidth
|
||||
|
||||
def compute_mfi(df, window=14):
|
||||
typical_price = (df['High'] + df['Low'] + df['Close']) / 3
|
||||
money_flow = typical_price * df['Volume']
|
||||
prev_tp = typical_price.shift(1)
|
||||
flow_pos = money_flow.where(typical_price > prev_tp, 0)
|
||||
flow_neg = money_flow.where(typical_price < prev_tp, 0)
|
||||
pos_sum = flow_pos.rolling(window=window).sum()
|
||||
neg_sum = flow_neg.rolling(window=window).sum()
|
||||
mfi = 100 - (100 / (1 + pos_sum / (neg_sum + 1e-9)))
|
||||
return mfi
|
||||
|
||||
def calculate_technical_indicators(df):
|
||||
logging.info("Calculating technical indicators...")
|
||||
df['RSI'] = compute_rsi(df['Close'], 14)
|
||||
df['MACD'] = compute_macd(df['Close'])
|
||||
df['OBV'] = compute_obv(df)
|
||||
df['ADX'] = compute_adx(df)
|
||||
|
||||
up, lo, bw = compute_bollinger_bands(df['Close'], 20, 2)
|
||||
df['BB_Upper'] = up
|
||||
df['BB_Lower'] = lo
|
||||
df['BB_Width'] = bw
|
||||
|
||||
df['MFI'] = compute_mfi(df, 14)
|
||||
df['SMA_5'] = df['Close'].rolling(5).mean()
|
||||
df['SMA_10'] = df['Close'].rolling(10).mean()
|
||||
df['EMA_5'] = df['Close'].ewm(span=5, adjust=False).mean()
|
||||
df['EMA_10'] = df['Close'].ewm(span=10, adjust=False).mean()
|
||||
df['STDDEV_5'] = df['Close'].rolling(5).std()
|
||||
|
||||
df.dropna(inplace=True)
|
||||
logging.info("Technical indicators calculated successfully.")
|
||||
return df
|
||||
|
||||
# =============================================================================
|
||||
# Argument Parsing
|
||||
# =============================================================================
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Futures Trading with LSTM Forecasting and PPO.')
|
||||
parser.add_argument('csv_path', type=str,
|
||||
help='Path to CSV data with columns [time, open, high, low, close, volume].')
|
||||
parser.add_argument('--lstm_window_size', type=int, default=15,
|
||||
help='Sequence window size for LSTM forecasting. Default=15.')
|
||||
parser.add_argument('--ppo_total_timesteps', type=int, default=100000,
|
||||
help='Total timesteps to train the PPO model. Default=100000.')
|
||||
parser.add_argument('--n_trials_lstm', type=int, default=30,
|
||||
help='Number of Optuna trials for LSTM hyperparameter tuning. Default=30.')
|
||||
parser.add_argument('--preprocess_workers', type=int, default=None,
|
||||
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
|
||||
parser.add_argument('--monitor_resources', action='store_true',
|
||||
help='Enable real-time resource monitoring.')
|
||||
parser.add_argument('--output_dir', type=str, default='output',
|
||||
help='Directory where all output files will be saved.')
|
||||
parser.add_argument('--action_mode', type=str, choices=['discrete', 'continuous'], default='discrete',
|
||||
help='Select action space type: discrete (e.g., -5 to +5) or continuous (Box). Default=discrete.')
|
||||
parser.add_argument('--max_contracts', type=int, default=5,
|
||||
help='Maximum number of contracts to trade per action. Default=5.')
|
||||
return parser.parse_args()
|
||||
|
||||
# =============================================================================
|
||||
# LSTM Price Predictor (renamed from LSTM part)
|
||||
# =============================================================================
|
||||
def build_lstm(input_shape, hyperparams):
|
||||
model = Sequential()
|
||||
num_layers = hyperparams['num_lstm_layers']
|
||||
units = hyperparams['lstm_units']
|
||||
drop = hyperparams['dropout_rate']
|
||||
for i in range(num_layers):
|
||||
return_seqs = (i < num_layers - 1)
|
||||
if i == 0:
|
||||
model.add(Bidirectional(LSTM(units, return_sequences=return_seqs, kernel_regularizer=l2(1e-4)),
|
||||
input_shape=input_shape))
|
||||
else:
|
||||
model.add(Bidirectional(LSTM(units, return_sequences=return_seqs, kernel_regularizer=l2(1e-4))))
|
||||
model.add(Dropout(drop))
|
||||
model.add(Dense(1, activation='linear'))
|
||||
|
||||
opt_name = hyperparams['optimizer']
|
||||
lr = hyperparams['learning_rate']
|
||||
decay = hyperparams['decay']
|
||||
if opt_name == 'Adam':
|
||||
opt = Adam(learning_rate=lr, decay=decay)
|
||||
elif opt_name == 'Nadam':
|
||||
opt = Nadam(learning_rate=lr)
|
||||
else:
|
||||
opt = Adam(learning_rate=lr)
|
||||
|
||||
model.compile(loss=Huber(), optimizer=opt, metrics=['mae'])
|
||||
return model
|
||||
|
||||
# =============================================================================
|
||||
# Custom Gym Environment for Futures Trading with LSTM Forecasting
|
||||
# =============================================================================
|
||||
class FuturesTradingEnv(gym.Env):
|
||||
"""
|
||||
A custom OpenAI Gym environment for futures trading.
|
||||
It integrates an LSTM price predictor for forecasting.
|
||||
The environment tracks positions as contracts_held (can be negative for shorts).
|
||||
Reward is defined as the change in mark-to-market profit (unrealized PnL)
|
||||
minus transaction costs and includes a bonus based on the LSTM forecast.
|
||||
The action space can be either discrete (e.g., -max_contracts ... +max_contracts)
|
||||
or continuous (Box space) which is then rounded to an integer.
|
||||
"""
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
def __init__(self, df, feature_columns, lstm_model, scaler_features, scaler_target,
|
||||
window_size=15, transaction_cost=0.001, action_mode='discrete', max_contracts=5):
|
||||
super(FuturesTradingEnv, self).__init__()
|
||||
self.df = df.reset_index(drop=True)
|
||||
self.feature_columns = feature_columns
|
||||
self.lstm_model = lstm_model # Frozen LSTM model for forecasting
|
||||
self.scaler_features = scaler_features
|
||||
self.scaler_target = scaler_target
|
||||
self.window_size = window_size
|
||||
self.transaction_cost = transaction_cost
|
||||
self.action_mode = action_mode
|
||||
self.max_contracts = max_contracts
|
||||
|
||||
self.max_steps = len(df)
|
||||
self.current_step = 0
|
||||
|
||||
# Futures position variables
|
||||
self.contracts_held = 0 # positive for long, negative for short
|
||||
self.entry_price = None # weighted average entry price
|
||||
|
||||
# Pre-calculate normalized features for observations
|
||||
self.raw_features = df[feature_columns].values
|
||||
|
||||
# Define action space
|
||||
if self.action_mode == 'discrete':
|
||||
# Actions: integer orders from -max_contracts to +max_contracts
|
||||
self.action_space = spaces.Discrete(2 * self.max_contracts + 1)
|
||||
else:
|
||||
# Continuous action: a real number in [-max_contracts, max_contracts]
|
||||
self.action_space = spaces.Box(low=-self.max_contracts, high=self.max_contracts, shape=(1,), dtype=np.float32)
|
||||
|
||||
# Observation space: technical indicators + [normalized contracts_held, normalized unrealized PnL] + LSTM forecast
|
||||
obs_len = len(feature_columns) + 2 + 1
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_len,), dtype=np.float32)
|
||||
|
||||
# Lock for LSTM prediction (if used in multi-threaded settings)
|
||||
self.lstm_lock = threading.Lock()
|
||||
|
||||
# Cache for the LSTM forecast for reward shaping
|
||||
self.last_forecast = 0.0
|
||||
|
||||
def reset(self):
|
||||
self.current_step = 0
|
||||
self.contracts_held = 0
|
||||
self.entry_price = None
|
||||
self.last_forecast = 0.0
|
||||
return self._get_obs()
|
||||
|
||||
def _get_obs(self):
|
||||
# Normalize the raw features row by row
|
||||
row = self.raw_features[self.current_step]
|
||||
row_max = np.max(np.abs(row)) if np.max(np.abs(row)) != 0 else 1.0
|
||||
row_norm = row / row_max
|
||||
|
||||
# Additional account info:
|
||||
# Normalize contracts_held by max_contracts.
|
||||
# Unrealized PnL: if entry_price exists, (current_price - entry_price)*contracts_held; else 0.
|
||||
current_price = self.df.loc[self.current_step, 'Close']
|
||||
pnl = (current_price - self.entry_price) * self.contracts_held if self.entry_price is not None else 0.0
|
||||
|
||||
additional = np.array([
|
||||
self.contracts_held / self.max_contracts,
|
||||
pnl # Consider normalizing pnl as needed
|
||||
], dtype=np.float32)
|
||||
|
||||
# LSTM price forecast: predict next price if possible
|
||||
if self.current_step < self.window_size:
|
||||
forecast = 0.0
|
||||
else:
|
||||
seq = self.raw_features[self.current_step - self.window_size: self.current_step]
|
||||
seq_scaled = self.scaler_features.transform(seq)
|
||||
seq_scaled = np.expand_dims(seq_scaled, axis=0) # shape: (1, window_size, num_features)
|
||||
with self.lstm_lock:
|
||||
pred_scaled = self.lstm_model.predict(seq_scaled, verbose=0).flatten()[0]
|
||||
pred_scaled = np.clip(pred_scaled, 0, 1)
|
||||
unscaled = self.scaler_target.inverse_transform([[pred_scaled]])[0, 0]
|
||||
# Forecast as relative difference from the current price
|
||||
forecast = (unscaled - current_price) / (current_price + 1e-9)
|
||||
|
||||
# Cache the forecast for use in the reward function
|
||||
self.last_forecast = forecast
|
||||
|
||||
obs = np.concatenate([row_norm, additional, [forecast]]).astype(np.float32)
|
||||
return obs
|
||||
|
||||
def step(self, action):
|
||||
prev_price = self.df.loc[self.current_step, 'Close']
|
||||
prev_position = self.contracts_held
|
||||
|
||||
# Convert action to an integer number of contracts
|
||||
if self.action_mode == 'discrete':
|
||||
# Discrete action space: 0 corresponds to -max_contracts, last corresponds to +max_contracts.
|
||||
action_int = action - self.max_contracts
|
||||
else:
|
||||
# For continuous, round to nearest integer
|
||||
action_int = int(np.round(action[0]))
|
||||
action_int = np.clip(action_int, -self.max_contracts, self.max_contracts)
|
||||
|
||||
current_price = self.df.loc[self.current_step, 'Close']
|
||||
fee = self.transaction_cost * abs(action_int) * current_price
|
||||
|
||||
# Update position logic
|
||||
if action_int != 0:
|
||||
# If no current position, set new position and record entry price.
|
||||
if self.contracts_held == 0:
|
||||
self.contracts_held = action_int
|
||||
self.entry_price = current_price
|
||||
# If same sign, update weighted average entry price.
|
||||
elif np.sign(self.contracts_held) == np.sign(action_int):
|
||||
total_contracts = self.contracts_held + action_int
|
||||
self.entry_price = (self.entry_price * self.contracts_held + current_price * action_int) / total_contracts
|
||||
self.contracts_held = total_contracts
|
||||
# If opposite sign, reduce/flip position.
|
||||
else:
|
||||
if abs(action_int) >= abs(self.contracts_held):
|
||||
self.contracts_held = self.contracts_held + action_int # may flip sign
|
||||
self.entry_price = current_price if self.contracts_held != 0 else None
|
||||
else:
|
||||
self.contracts_held = self.contracts_held + action_int
|
||||
|
||||
# Mark-to-market PnL: change from previous price * previous position
|
||||
pnl_change = (current_price - prev_price) * prev_position
|
||||
reward = pnl_change - fee
|
||||
|
||||
# Bonus reward based on forecast and the chosen action.
|
||||
bonus_factor = 10.0 # Scale factor; tune as needed.
|
||||
bonus = bonus_factor * (action_int * self.last_forecast)
|
||||
reward += bonus
|
||||
|
||||
self.current_step += 1
|
||||
done = (self.current_step >= self.max_steps - 1)
|
||||
obs = self._get_obs()
|
||||
return obs, reward, done, {}
|
||||
|
||||
def render(self, mode='human'):
|
||||
current_price = self.df.loc[self.current_step, 'Close']
|
||||
pnl = (current_price - self.entry_price) * self.contracts_held if self.entry_price is not None else 0.0
|
||||
print(f"Step: {self.current_step}, Contracts Held: {self.contracts_held}, "
|
||||
f"Entry Price: {self.entry_price}, Current Price: {current_price:.2f}, PnL: {pnl:.2f}")
|
||||
|
||||
# =============================================================================
|
||||
# Placeholders for Live Deployment Functions
|
||||
# =============================================================================
|
||||
def get_live_data():
|
||||
"""
|
||||
Placeholder: Connect to a live data feed and return the latest market data.
|
||||
"""
|
||||
return None
|
||||
|
||||
def execute_order(action):
|
||||
"""
|
||||
Placeholder: Execute the trading order in a live environment.
|
||||
"""
|
||||
logging.info(f"Executing order: {action}")
|
||||
|
||||
def live_trading_loop(model, env, polling_interval=5):
|
||||
obs = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
live_data = get_live_data()
|
||||
if live_data is not None:
|
||||
pass # Update environment with live data as needed
|
||||
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
execute_order(action)
|
||||
obs, reward, done, _ = env.step(action)
|
||||
env.render()
|
||||
time.sleep(polling_interval)
|
||||
|
||||
# =============================================================================
|
||||
# Data Preprocessing with Parallelization
|
||||
# =============================================================================
|
||||
def parallel_feature_engineering(row):
|
||||
"""
|
||||
Placeholder function for additional feature engineering.
|
||||
"""
|
||||
return row
|
||||
|
||||
def feature_engineering_parallel(df, num_workers):
|
||||
logging.info(f"Starting parallel feature engineering with {num_workers} workers...")
|
||||
with Pool(processes=num_workers) as pool:
|
||||
processed_rows = pool.map(parallel_feature_engineering, [row for _, row in df.iterrows()])
|
||||
df_processed = pd.DataFrame(processed_rows)
|
||||
logging.info("Parallel feature engineering completed.")
|
||||
return df_processed
|
||||
|
||||
# =============================================================================
|
||||
# MAIN FUNCTION: LSTM Training + PPO for Futures Trading
|
||||
# =============================================================================
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
csv_path = args.csv_path
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
lstm_window_size = args.lstm_window_size
|
||||
ppo_total_timesteps = args.ppo_total_timesteps
|
||||
n_trials_lstm = args.n_trials_lstm
|
||||
preprocess_workers = args.preprocess_workers
|
||||
enable_resource_monitor = args.monitor_resources
|
||||
action_mode = args.action_mode
|
||||
max_contracts = args.max_contracts
|
||||
|
||||
# -----------------------------
|
||||
# Setup Logging
|
||||
# -----------------------------
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join(output_dir, "FuturesPPO.log")),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
])
|
||||
|
||||
# -----------------------------
|
||||
# Resource Detection & Logging
|
||||
# -----------------------------
|
||||
cpu_stats = get_cpu_info()
|
||||
gpu_stats = get_gpu_info()
|
||||
logging.info("===== Resource Statistics =====")
|
||||
logging.info(f"Physical CPU Cores: {cpu_stats['physical_cores']}")
|
||||
logging.info(f"Logical CPU Cores: {cpu_stats['logical_cores']}")
|
||||
logging.info(f"CPU Usage per Core: {cpu_stats['cpu_percent']}%")
|
||||
if gpu_stats:
|
||||
for gpu in gpu_stats:
|
||||
logging.info(f"GPU {gpu['id']} - {gpu['name']}: Load: {gpu['load']}%, Memory Used: {gpu['memory_used']}MB/{gpu['memory_total']}MB, Temperature: {gpu['temperature']}°C")
|
||||
else:
|
||||
logging.info("No GPUs detected.")
|
||||
logging.info("=================================")
|
||||
|
||||
# -----------------------------
|
||||
# Configure TensorFlow
|
||||
# -----------------------------
|
||||
configure_tensorflow(cpu_stats, gpu_stats)
|
||||
|
||||
# -----------------------------
|
||||
# Start Resource Monitoring (Optional)
|
||||
# -----------------------------
|
||||
if enable_resource_monitor:
|
||||
logging.info("Starting real-time resource monitoring...")
|
||||
resource_monitor_thread = threading.Thread(target=monitor_resources, args=(60,), daemon=True)
|
||||
resource_monitor_thread.start()
|
||||
|
||||
##########################################
|
||||
# A) LSTM PART: LOAD, PREPROCESS & TUNE
|
||||
##########################################
|
||||
df = load_data(csv_path)
|
||||
df = calculate_technical_indicators(df)
|
||||
|
||||
feature_columns = [
|
||||
'SMA_5','SMA_10','EMA_5','EMA_10','STDDEV_5',
|
||||
'RSI','MACD','ADX','OBV','Volume','Open','High','Low',
|
||||
'BB_Upper','BB_Lower','BB_Width','MFI'
|
||||
]
|
||||
target_column = 'Close'
|
||||
df = df[['Date'] + feature_columns + [target_column]].dropna()
|
||||
|
||||
# 2) Controlled Parallel Data Preprocessing
|
||||
if preprocess_workers is None:
|
||||
preprocess_workers = max(1, cpu_stats['logical_cores'] - 2)
|
||||
else:
|
||||
preprocess_workers = min(preprocess_workers, cpu_stats['logical_cores'])
|
||||
df = feature_engineering_parallel(df, num_workers=preprocess_workers)
|
||||
|
||||
scaler_features = MinMaxScaler()
|
||||
scaler_target = MinMaxScaler()
|
||||
|
||||
X_all = df[feature_columns].values
|
||||
y_all = df[[target_column]].values
|
||||
|
||||
X_scaled = scaler_features.fit_transform(X_all)
|
||||
y_scaled = scaler_target.fit_transform(y_all).flatten()
|
||||
|
||||
# 3) Create sequences for LSTM forecasting
|
||||
def create_sequences(features, target, window_size):
|
||||
X_seq, y_seq = [], []
|
||||
for i in range(len(features) - window_size):
|
||||
X_seq.append(features[i:i+window_size])
|
||||
y_seq.append(target[i+window_size])
|
||||
return np.array(X_seq), np.array(y_seq)
|
||||
|
||||
X, y = create_sequences(X_scaled, y_scaled, lstm_window_size)
|
||||
|
||||
# 4) Split into train/val/test
|
||||
train_size = int(len(X) * 0.7)
|
||||
val_size = int(len(X) * 0.15)
|
||||
test_size = len(X) - train_size - val_size
|
||||
|
||||
X_train, y_train = X[:train_size], y[:train_size]
|
||||
X_val, y_val = X[train_size: train_size + val_size], y[train_size: train_size + val_size]
|
||||
X_test, y_test = X[train_size + val_size:], y[train_size + val_size:]
|
||||
|
||||
logging.info(f"Training sequences shape: {X_train.shape}")
|
||||
logging.info(f"Validation sequences shape: {X_val.shape}")
|
||||
logging.info(f"Testing sequences shape: {X_test.shape}")
|
||||
|
||||
# 5) Define LSTM objective for hyperparameter tuning using Optuna
|
||||
def lstm_objective(trial):
|
||||
num_lstm_layers = trial.suggest_int('num_lstm_layers', 1, 3)
|
||||
lstm_units = trial.suggest_categorical('lstm_units', [32, 64, 96, 128])
|
||||
dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
|
||||
learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True)
|
||||
optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'Nadam'])
|
||||
decay = trial.suggest_float('decay', 0.0, 1e-4)
|
||||
|
||||
hyperparams = {
|
||||
'num_lstm_layers': num_lstm_layers,
|
||||
'lstm_units': lstm_units,
|
||||
'dropout_rate': dropout_rate,
|
||||
'learning_rate': learning_rate,
|
||||
'optimizer': optimizer_name,
|
||||
'decay': decay
|
||||
}
|
||||
|
||||
model_ = build_lstm((X_train.shape[1], X_train.shape[2]), hyperparams)
|
||||
early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
|
||||
lr_reduce = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
|
||||
cb_prune = KerasPruningCallback(trial, 'val_loss')
|
||||
|
||||
history = model_.fit(
|
||||
X_train, y_train,
|
||||
epochs=100,
|
||||
batch_size=16,
|
||||
validation_data=(X_val, y_val),
|
||||
callbacks=[early_stop, lr_reduce, cb_prune],
|
||||
verbose=0
|
||||
)
|
||||
val_mae = min(history.history['val_mae'])
|
||||
return val_mae
|
||||
|
||||
logging.info(f"Starting LSTM hyperparameter optimization with {cpu_stats['logical_cores']-2} parallel trials...")
|
||||
study_lstm = optuna.create_study(direction='minimize')
|
||||
study_lstm.optimize(lstm_objective, n_trials=n_trials_lstm, n_jobs=cpu_stats['logical_cores']-2)
|
||||
best_lstm_params = study_lstm.best_params
|
||||
logging.info(f"Best LSTM Hyperparameters: {best_lstm_params}")
|
||||
|
||||
# 6) Train final LSTM (PricePredictorLSTM) with best hyperparameters
|
||||
final_lstm = build_lstm((X_train.shape[1], X_train.shape[2]), best_lstm_params)
|
||||
early_stop_final = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)
|
||||
lr_reduce_final = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-6)
|
||||
logging.info("Training best LSTM model with optimized hyperparameters...")
|
||||
final_lstm.fit(
|
||||
X_train, y_train,
|
||||
epochs=300,
|
||||
batch_size=16,
|
||||
validation_data=(X_val, y_val),
|
||||
callbacks=[early_stop_final, lr_reduce_final],
|
||||
verbose=1
|
||||
)
|
||||
|
||||
# 7) Evaluate final LSTM
|
||||
def evaluate_final_lstm(model, X_test, y_test):
|
||||
logging.info("Evaluating final LSTM model...")
|
||||
y_pred_scaled = model.predict(X_test).flatten()
|
||||
y_pred_scaled = np.clip(y_pred_scaled, 0, 1)
|
||||
y_pred = scaler_target.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
y_test_actual = scaler_target.inverse_transform(y_test.reshape(-1, 1)).flatten()
|
||||
|
||||
mse_ = mean_squared_error(y_test_actual, y_pred)
|
||||
rmse_ = np.sqrt(mse_)
|
||||
mae_ = mean_absolute_error(y_test_actual, y_pred)
|
||||
r2_ = r2_score(y_test_actual, y_pred)
|
||||
|
||||
direction_actual = np.sign(np.diff(y_test_actual))
|
||||
direction_pred = np.sign(np.diff(y_pred))
|
||||
directional_accuracy = np.mean(direction_actual == direction_pred)
|
||||
|
||||
logging.info(f"Test MSE: {mse_:.4f}")
|
||||
logging.info(f"Test RMSE: {rmse_:.4f}")
|
||||
logging.info(f"Test MAE: {mae_:.4f}")
|
||||
logging.info(f"Test R2 Score: {r2_:.4f}")
|
||||
logging.info(f"Directional Accuracy: {directional_accuracy:.4f}")
|
||||
|
||||
plt.figure(figsize=(14, 7))
|
||||
plt.plot(y_test_actual, label='Actual Price')
|
||||
plt.plot(y_pred, label='Predicted Price')
|
||||
plt.title('LSTM: Actual vs Predicted Closing Prices')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig(os.path.join(output_dir, 'lstm_actual_vs_pred.png'))
|
||||
plt.close()
|
||||
|
||||
table = []
|
||||
limit = min(40, len(y_test_actual))
|
||||
for i in range(limit):
|
||||
table.append([i, round(y_test_actual[i], 2), round(y_pred[i], 2)])
|
||||
headers = ["Index", "Actual Price", "Predicted Price"]
|
||||
print("\nFirst 40 Actual vs. Predicted Prices:")
|
||||
print(tabulate(table, headers=headers, tablefmt="pretty"))
|
||||
return r2_, directional_accuracy
|
||||
|
||||
_r2, _diracc = evaluate_final_lstm(final_lstm, X_test, y_test)
|
||||
|
||||
# 8) Save final LSTM model and scalers
|
||||
final_lstm.save(os.path.join(output_dir, 'best_lstm_model.h5'))
|
||||
joblib.dump(scaler_features, os.path.join(output_dir, 'scaler_features.pkl'))
|
||||
joblib.dump(scaler_target, os.path.join(output_dir, 'scaler_target.pkl'))
|
||||
logging.info("Saved best LSTM model and scaler objects.")
|
||||
|
||||
##########################################
|
||||
# B) PPO PART: SET UP FUTURES TRADING ENVIRONMENT
|
||||
##########################################
|
||||
env_params = {
|
||||
'df': df,
|
||||
'feature_columns': feature_columns,
|
||||
'lstm_model': final_lstm, # Frozen LSTM for forecasting
|
||||
'scaler_features': scaler_features,
|
||||
'scaler_target': scaler_target,
|
||||
'window_size': lstm_window_size,
|
||||
'transaction_cost': 0.001,
|
||||
'action_mode': action_mode,
|
||||
'max_contracts': max_contracts
|
||||
}
|
||||
|
||||
# Create the FuturesTradingEnv and wrap it for PPO training
|
||||
env = FuturesTradingEnv(**env_params)
|
||||
vec_env = DummyVecEnv([lambda: env])
|
||||
|
||||
# PPO hyperparameters (customize as needed)
|
||||
ppo_hyperparams = {
|
||||
'n_steps': 2048,
|
||||
'batch_size': 64,
|
||||
'gae_lambda': 0.95,
|
||||
'gamma': 0.99,
|
||||
'learning_rate': 3e-4,
|
||||
'ent_coef': 0.0,
|
||||
'verbose': 1
|
||||
}
|
||||
|
||||
# -----------------------------
|
||||
# Train PPO Model
|
||||
# -----------------------------
|
||||
logging.info("Starting PPO training...")
|
||||
ppo_model = PPO('MlpPolicy', vec_env, **ppo_hyperparams)
|
||||
ppo_model.learn(total_timesteps=ppo_total_timesteps)
|
||||
ppo_model.save(os.path.join(output_dir, "best_ppo_model.zip"))
|
||||
logging.info("PPO training completed and model saved.")
|
||||
|
||||
##########################################
|
||||
# C) FINAL INFERENCE & (Optional) LIVE TRADING EXAMPLE
|
||||
##########################################
|
||||
# Evaluate the trained PPO model in the environment
|
||||
obs = env.reset()
|
||||
done = False
|
||||
total_reward = 0.0
|
||||
step_data = []
|
||||
step_count = 0
|
||||
|
||||
while not done:
|
||||
step_count += 1
|
||||
action, _ = ppo_model.predict(obs, deterministic=True)
|
||||
obs, reward, done, _ = env.step(action)
|
||||
total_reward += reward
|
||||
step_data.append({
|
||||
"Step": step_count,
|
||||
"Action": int(action) if action_mode=='discrete' else int(np.round(action[0])),
|
||||
"Reward": reward,
|
||||
"Contracts": env.contracts_held
|
||||
})
|
||||
|
||||
final_pnl = (env.df.loc[env.current_step, 'Close'] - (env.entry_price if env.entry_price is not None else 0)) * env.contracts_held
|
||||
print("\n=== Final PPO Inference ===")
|
||||
print(f"Total Steps: {step_count}")
|
||||
print(f"Final Contracts Held: {env.contracts_held}")
|
||||
print(f"Final Estimated PnL: {final_pnl:.2f}")
|
||||
print(f"Total Reward Sum: {total_reward:.2f}")
|
||||
|
||||
print("\nLast 15 Steps:")
|
||||
last_n = step_data[-15:] if len(step_data) > 15 else step_data
|
||||
print(tabulate(last_n, headers="keys", tablefmt="pretty"))
|
||||
|
||||
# OPTIONAL: Uncomment to run a live trading loop (requires implementation of live data feed and order execution)
|
||||
# live_trading_loop(ppo_model, env)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
721
src/MidasAgent/PPO/MidasPPO.py
Normal file
721
src/MidasAgent/PPO/MidasPPO.py
Normal file
@@ -0,0 +1,721 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from tabulate import tabulate
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import psutil
|
||||
import GPUtil
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import Sequential, load_model
|
||||
from tensorflow.keras.layers import LSTM, Dense, Dropout, Bidirectional
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
||||
from tensorflow.keras.losses import Huber
|
||||
from tensorflow.keras.regularizers import l2
|
||||
from tensorflow.keras.optimizers import Adam, Nadam
|
||||
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
import joblib
|
||||
import optuna
|
||||
from optuna.integration import KerasPruningCallback
|
||||
|
||||
import gym
|
||||
from gym import spaces
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
|
||||
from multiprocessing import Pool, cpu_count
|
||||
import threading
|
||||
import time
|
||||
from dateutil import parser # For custom date parsing
|
||||
|
||||
# Suppress TensorFlow logs beyond errors
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
||||
|
||||
# =============================================================================
|
||||
# Custom Date Parser: handles EDT/EST then drops timezone info
|
||||
# =============================================================================
|
||||
def custom_date_parser(date_str):
|
||||
tzinfos = {"EDT": -14400, "EST": -18000}
|
||||
dt = parser.parse(date_str, tzinfos=tzinfos)
|
||||
return dt.replace(tzinfo=None)
|
||||
|
||||
# =============================================================================
|
||||
# Resource Detection Functions
|
||||
# =============================================================================
|
||||
def get_cpu_info():
|
||||
cpu_count_physical = psutil.cpu_count(logical=False)
|
||||
cpu_count_logical = psutil.cpu_count(logical=True)
|
||||
cpu_percent = psutil.cpu_percent(interval=1, percpu=True)
|
||||
return {
|
||||
'physical_cores': cpu_count_physical,
|
||||
'logical_cores': cpu_count_logical,
|
||||
'cpu_percent': cpu_percent
|
||||
}
|
||||
|
||||
def get_gpu_info():
|
||||
gpus = GPUtil.getGPUs()
|
||||
gpu_info = []
|
||||
for gpu in gpus:
|
||||
gpu_info.append({
|
||||
'id': gpu.id,
|
||||
'name': gpu.name,
|
||||
'load': gpu.load * 100,
|
||||
'memory_total': gpu.memoryTotal,
|
||||
'memory_used': gpu.memoryUsed,
|
||||
'memory_free': gpu.memoryFree,
|
||||
'temperature': gpu.temperature
|
||||
})
|
||||
return gpu_info
|
||||
|
||||
def configure_tensorflow(cpu_stats, gpu_stats):
|
||||
logical_cores = cpu_stats['logical_cores']
|
||||
os.environ["OMP_NUM_THREADS"] = str(logical_cores)
|
||||
os.environ["TF_NUM_INTRAOP_THREADS"] = str(logical_cores)
|
||||
os.environ["TF_NUM_INTEROP_THREADS"] = str(logical_cores)
|
||||
|
||||
if gpu_stats:
|
||||
gpus = tf.config.list_physical_devices('GPU')
|
||||
if gpus:
|
||||
try:
|
||||
for gpu in gpus:
|
||||
tf.config.experimental.set_memory_growth(gpu, True)
|
||||
logging.info(f"Enabled memory growth for {len(gpus)} GPU(s).")
|
||||
except RuntimeError as e:
|
||||
logging.error(f"TensorFlow GPU configuration error: {e}")
|
||||
else:
|
||||
tf.config.threading.set_intra_op_parallelism_threads(logical_cores)
|
||||
tf.config.threading.set_inter_op_parallelism_threads(logical_cores)
|
||||
logging.info("Configured TensorFlow to use CPU with optimized thread settings.")
|
||||
|
||||
def monitor_resources(interval=60):
|
||||
while True:
|
||||
cpu = psutil.cpu_percent(interval=1, percpu=True)
|
||||
gpu = get_gpu_info()
|
||||
logging.info(f"CPU Usage per Core: {cpu}%")
|
||||
if gpu:
|
||||
for gpu_stat in gpu:
|
||||
logging.info(f"GPU {gpu_stat['id']} - {gpu_stat['name']}: Load: {gpu_stat['load']}%, "
|
||||
f"Memory Used: {gpu_stat['memory_used']}MB/{gpu_stat['memory_total']}MB, "
|
||||
f"Temperature: {gpu_stat['temperature']}°C")
|
||||
else:
|
||||
logging.info("No GPUs detected.")
|
||||
logging.info("-" * 50)
|
||||
time.sleep(interval)
|
||||
|
||||
# =============================================================================
|
||||
# Data Loading & Technical Indicators
|
||||
# =============================================================================
|
||||
def load_data(file_path):
|
||||
"""
|
||||
Loads data from a CSV or JSON file.
|
||||
Expects a time column plus: open, high, low, close, volume.
|
||||
Uses a custom date parser and flexible column mapping.
|
||||
"""
|
||||
logging.info(f"Loading data from: {file_path}")
|
||||
try:
|
||||
if file_path.lower().endswith('.csv'):
|
||||
df = pd.read_csv(file_path, parse_dates=['time'], date_parser=custom_date_parser)
|
||||
elif file_path.lower().endswith('.json'):
|
||||
df = pd.read_json(file_path, convert_dates=False)
|
||||
# Try to find a time/date column
|
||||
time_col = None
|
||||
for col in df.columns:
|
||||
if col.strip().lower() in ['time', 'date']:
|
||||
time_col = col
|
||||
break
|
||||
if time_col:
|
||||
df[time_col] = df[time_col].apply(custom_date_parser)
|
||||
else:
|
||||
logging.error("No time column found in JSON data.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
logging.error("Unsupported file format. Please provide CSV or JSON.")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading file: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Clean and rename columns
|
||||
df.columns = [col.strip() for col in df.columns]
|
||||
lower_cols = [col.lower() for col in df.columns]
|
||||
|
||||
required_fields = {
|
||||
'time': ['time', 'date'],
|
||||
'open': ['open'],
|
||||
'high': ['high'],
|
||||
'low': ['low'],
|
||||
'close': ['close'],
|
||||
'volume': ['volume', 'vol']
|
||||
}
|
||||
|
||||
rename_mapping = {}
|
||||
for canonical, alternatives in required_fields.items():
|
||||
found = False
|
||||
for alt in alternatives:
|
||||
if alt in lower_cols:
|
||||
orig_col = df.columns[lower_cols.index(alt)]
|
||||
if canonical == 'time':
|
||||
rename_mapping[orig_col] = 'Date'
|
||||
else:
|
||||
rename_mapping[orig_col] = canonical.capitalize()
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
logging.error(f"Required column for '{canonical}' not found. Alternatives: {alternatives}")
|
||||
sys.exit(1)
|
||||
|
||||
df.rename(columns=rename_mapping, inplace=True)
|
||||
logging.info(f"Columns after renaming: {df.columns.tolist()}")
|
||||
try:
|
||||
df.sort_values('Date', inplace=True)
|
||||
except KeyError:
|
||||
logging.error("Column 'Date' not found after renaming. Check data format.")
|
||||
sys.exit(1)
|
||||
df.reset_index(drop=True, inplace=True)
|
||||
logging.info("Data loaded and sorted successfully.")
|
||||
return df
|
||||
|
||||
def compute_rsi(series, window=14):
|
||||
delta = series.diff()
|
||||
gain = delta.where(delta > 0, 0).rolling(window=window).mean()
|
||||
loss = -delta.where(delta < 0, 0).rolling(window=window).mean()
|
||||
RS = gain / (loss + 1e-9)
|
||||
return 100 - (100 / (1 + RS))
|
||||
|
||||
def compute_macd(series, span_short=12, span_long=26, span_signal=9):
|
||||
ema_short = series.ewm(span=span_short, adjust=False).mean()
|
||||
ema_long = series.ewm(span=span_long, adjust=False).mean()
|
||||
macd_line = ema_short - ema_long
|
||||
signal_line = macd_line.ewm(span=span_signal, adjust=False).mean()
|
||||
return macd_line - signal_line
|
||||
|
||||
def compute_obv(df):
|
||||
signed_volume = (np.sign(df['Close'].diff()) * df['Volume']).fillna(0)
|
||||
return signed_volume.cumsum()
|
||||
|
||||
def compute_adx(df, window=14):
|
||||
df['H-L'] = df['High'] - df['Low']
|
||||
df['H-Cp'] = (df['High'] - df['Close'].shift(1)).abs()
|
||||
df['L-Cp'] = (df['Low'] - df['Close'].shift(1)).abs()
|
||||
tr = df[['H-L','H-Cp','L-Cp']].max(axis=1)
|
||||
tr_rolling = tr.rolling(window=window).mean()
|
||||
adx_placeholder = tr_rolling / (df['Close'] + 1e-9)
|
||||
df.drop(['H-L','H-Cp','L-Cp'], axis=1, inplace=True)
|
||||
return adx_placeholder
|
||||
|
||||
def compute_bollinger_bands(series, window=20, num_std=2):
|
||||
sma = series.rolling(window=window).mean()
|
||||
std = series.rolling(window=window).std()
|
||||
upper = sma + num_std * std
|
||||
lower = sma - num_std * std
|
||||
bandwidth = (upper - lower) / (sma + 1e-9)
|
||||
return upper, lower, bandwidth
|
||||
|
||||
def compute_mfi(df, window=14):
|
||||
typical_price = (df['High'] + df['Low'] + df['Close']) / 3
|
||||
money_flow = typical_price * df['Volume']
|
||||
prev_tp = typical_price.shift(1)
|
||||
flow_pos = money_flow.where(typical_price > prev_tp, 0)
|
||||
flow_neg = money_flow.where(typical_price < prev_tp, 0)
|
||||
pos_sum = flow_pos.rolling(window=window).sum()
|
||||
neg_sum = flow_neg.rolling(window=window).sum()
|
||||
mfi = 100 - (100 / (1 + pos_sum / (neg_sum + 1e-9)))
|
||||
return mfi
|
||||
|
||||
def calculate_technical_indicators(df):
|
||||
logging.info("Calculating technical indicators...")
|
||||
df['RSI'] = compute_rsi(df['Close'], 14)
|
||||
df['MACD'] = compute_macd(df['Close'])
|
||||
df['OBV'] = compute_obv(df)
|
||||
df['ADX'] = compute_adx(df)
|
||||
up, lo, bw = compute_bollinger_bands(df['Close'], 20, 2)
|
||||
df['BB_Upper'] = up
|
||||
df['BB_Lower'] = lo
|
||||
df['BB_Width'] = bw
|
||||
df['MFI'] = compute_mfi(df, 14)
|
||||
df['SMA_5'] = df['Close'].rolling(5).mean()
|
||||
df['SMA_10'] = df['Close'].rolling(10).mean()
|
||||
df['EMA_5'] = df['Close'].ewm(span=5, adjust=False).mean()
|
||||
df['EMA_10'] = df['Close'].ewm(span=10, adjust=False).mean()
|
||||
df['STDDEV_5'] = df['Close'].rolling(5).std()
|
||||
df.dropna(inplace=True)
|
||||
logging.info("Technical indicators calculated successfully.")
|
||||
return df
|
||||
|
||||
# =============================================================================
|
||||
# Argument Parsing
|
||||
# =============================================================================
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='Futures Trading with LSTM and PPO')
|
||||
parser.add_argument('data_path', type=str,
|
||||
help='Path to CSV or JSON file with columns [time, open, high, low, close, volume].')
|
||||
parser.add_argument('--lstm_window_size', type=int, default=15,
|
||||
help='Sequence window size for LSTM forecasting. Default=15.')
|
||||
parser.add_argument('--ppo_total_timesteps', type=int, default=100000,
|
||||
help='Total timesteps to train the PPO model. Default=100000.')
|
||||
parser.add_argument('--n_trials_lstm', type=int, default=30,
|
||||
help='Number of Optuna trials for LSTM hyperparameter tuning. Default=30.')
|
||||
parser.add_argument('--preprocess_workers', type=int, default=None,
|
||||
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
|
||||
parser.add_argument('--monitor_resources', action='store_true',
|
||||
help='Enable real-time resource monitoring.')
|
||||
parser.add_argument('--output_dir', type=str, default='output',
|
||||
help='Directory where output files will be saved.')
|
||||
parser.add_argument('--action_mode', type=str, choices=['discrete', 'continuous'], default='discrete',
|
||||
help='Action space type. Default=discrete.')
|
||||
parser.add_argument('--max_contracts', type=int, default=5,
|
||||
help='Maximum number of contracts per action. Default=5.')
|
||||
parser.add_argument('--initial_balance', type=float, default=10000,
|
||||
help='Initial account balance. Default=10000.')
|
||||
parser.add_argument('--stop_loss_points', type=float, default=20,
|
||||
help='Stop loss in points. Default=20.')
|
||||
parser.add_argument('--trailing_stop_points', type=float, default=5,
|
||||
help='Trailing stop in points. Default=5.')
|
||||
parser.add_argument('--min_risk_reward', type=float, default=2.0,
|
||||
help='Minimum risk-reward ratio for trade entry. Default=2.0.')
|
||||
return parser.parse_args()
|
||||
|
||||
# =============================================================================
|
||||
# LSTM Price Predictor
|
||||
# =============================================================================
|
||||
def build_lstm(input_shape, hyperparams):
|
||||
model = Sequential()
|
||||
num_layers = hyperparams['num_lstm_layers']
|
||||
units = hyperparams['lstm_units']
|
||||
dropout = hyperparams['dropout_rate']
|
||||
for i in range(num_layers):
|
||||
return_sequences = (i < num_layers - 1)
|
||||
if i == 0:
|
||||
model.add(Bidirectional(LSTM(units, return_sequences=return_sequences, kernel_regularizer=l2(1e-4)),
|
||||
input_shape=input_shape))
|
||||
else:
|
||||
model.add(Bidirectional(LSTM(units, return_sequences=return_sequences, kernel_regularizer=l2(1e-4))))
|
||||
model.add(Dropout(dropout))
|
||||
model.add(Dense(1, activation='linear'))
|
||||
opt_name = hyperparams['optimizer']
|
||||
lr = hyperparams['learning_rate']
|
||||
decay = hyperparams['decay']
|
||||
if opt_name == 'Adam':
|
||||
opt = Adam(learning_rate=lr, decay=decay)
|
||||
elif opt_name == 'Nadam':
|
||||
opt = Nadam(learning_rate=lr)
|
||||
else:
|
||||
opt = Adam(learning_rate=lr)
|
||||
model.compile(loss=Huber(), optimizer=opt, metrics=['mae'])
|
||||
return model
|
||||
|
||||
# =============================================================================
|
||||
# Custom Gym Environment for Futures Trading with LSTM Forecasting and Risk Management
|
||||
# =============================================================================
|
||||
class FuturesTradingEnv(gym.Env):
|
||||
"""
|
||||
A Gym environment that incorporates:
|
||||
- LSTM forecasting (with forecast cached for reward shaping)
|
||||
- Futures trading with risk management (stop loss, trailing stop, position sizing)
|
||||
- Bonus reward when the action aligns with the LSTM forecast
|
||||
It supports both discrete and continuous action spaces.
|
||||
"""
|
||||
metadata = {'render.modes': ['human']}
|
||||
|
||||
def __init__(self, df, feature_columns, lstm_model, scaler_features, scaler_target,
|
||||
window_size=15, transaction_cost=0.60, action_mode='discrete', max_contracts=5,
|
||||
initial_balance=10000, max_risk_per_trade=0.02, stop_loss_points=20,
|
||||
trailing_stop_points=5, min_risk_reward=2.0, max_daily_loss=0.05):
|
||||
super(FuturesTradingEnv, self).__init__()
|
||||
self.df = df.reset_index(drop=True)
|
||||
self.feature_columns = feature_columns
|
||||
self.lstm_model = lstm_model
|
||||
self.scaler_features = scaler_features
|
||||
self.scaler_target = scaler_target
|
||||
self.window_size = window_size
|
||||
self.transaction_cost = transaction_cost
|
||||
self.action_mode = action_mode
|
||||
self.max_contracts = max_contracts
|
||||
|
||||
# Futures-specific parameters
|
||||
self.multiplier = 5 # e.g. contract multiplier
|
||||
self.tick_size = 0.25
|
||||
|
||||
# Risk management and account parameters
|
||||
self.initial_balance = initial_balance
|
||||
self.balance = initial_balance
|
||||
self.max_risk_per_trade = max_risk_per_trade
|
||||
self.stop_loss_points = stop_loss_points
|
||||
self.trailing_stop_points = trailing_stop_points
|
||||
self.min_risk_reward = min_risk_reward
|
||||
self.max_daily_loss = max_daily_loss
|
||||
self.daily_loss = 0.0
|
||||
|
||||
self.max_steps = len(df)
|
||||
self.current_step = 0
|
||||
self.contracts_held = 0
|
||||
self.entry_price = None
|
||||
self.peak_price = None
|
||||
self.trade_history = []
|
||||
self.daily_balances = [initial_balance]
|
||||
self.current_day = 0
|
||||
|
||||
self.raw_features = df[feature_columns].values
|
||||
|
||||
# Define action space
|
||||
if self.action_mode == 'discrete':
|
||||
self.action_space = spaces.Discrete(2 * self.max_contracts + 1)
|
||||
else:
|
||||
self.action_space = spaces.Box(low=-self.max_contracts, high=self.max_contracts, shape=(1,), dtype=np.float32)
|
||||
|
||||
# Observation: normalized features, account info, and cached LSTM forecast
|
||||
obs_len = len(feature_columns) + 4 # account info: balance ratio, normalized position, pnl, peak diff
|
||||
self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(obs_len + 1,), dtype=np.float32)
|
||||
|
||||
self.lstm_lock = threading.Lock()
|
||||
# Cache for the forecast (used for reward shaping)
|
||||
self.last_forecast = 0.0
|
||||
|
||||
def reset(self):
|
||||
self.current_step = 0
|
||||
self.contracts_held = 0
|
||||
self.entry_price = None
|
||||
self.peak_price = None
|
||||
self.balance = self.initial_balance
|
||||
self.daily_loss = 0.0
|
||||
self.trade_history = []
|
||||
self.daily_balances = [self.initial_balance]
|
||||
self.current_day = 0
|
||||
self.last_forecast = 0.0
|
||||
return self._get_obs()
|
||||
|
||||
def _get_obs(self):
|
||||
# Normalize current feature row
|
||||
row = self.raw_features[self.current_step]
|
||||
row_max = np.max(np.abs(row)) if np.max(np.abs(row)) != 0 else 1.0
|
||||
row_norm = row / row_max
|
||||
|
||||
# Account info: balance ratio, normalized position, unrealized pnl, peak difference
|
||||
current_price = self.df.loc[self.current_step, 'Close']
|
||||
pnl = (current_price - self.entry_price) * self.contracts_held * self.multiplier if self.entry_price is not None else 0.0
|
||||
peak_diff = (current_price - self.peak_price) if self.peak_price is not None else 0.0
|
||||
additional = np.array([
|
||||
self.balance / self.initial_balance,
|
||||
self.contracts_held / self.max_contracts,
|
||||
pnl,
|
||||
peak_diff
|
||||
], dtype=np.float32)
|
||||
|
||||
# LSTM forecast: predict next price difference (cached for reward shaping)
|
||||
if self.current_step < self.window_size:
|
||||
forecast = 0.0
|
||||
else:
|
||||
seq = self.raw_features[self.current_step - self.window_size:self.current_step]
|
||||
seq_scaled = self.scaler_features.transform(seq)
|
||||
seq_scaled = np.expand_dims(seq_scaled, axis=0)
|
||||
with self.lstm_lock:
|
||||
pred_scaled = self.lstm_model.predict(seq_scaled, verbose=0).flatten()[0]
|
||||
# Inverse scale to obtain predicted price and subtract current price
|
||||
predicted_price = self.scaler_target.inverse_transform([[pred_scaled]])[0, 0]
|
||||
forecast = predicted_price - current_price
|
||||
self.last_forecast = forecast
|
||||
|
||||
return np.concatenate([row_norm, additional, [forecast]]).astype(np.float32)
|
||||
|
||||
def _calculate_position_size(self, action_int):
|
||||
# Calculate the maximum number of contracts you can trade based on risk
|
||||
risk_amount = self.initial_balance * self.max_risk_per_trade
|
||||
risk_per_contract = self.stop_loss_points * self.multiplier
|
||||
max_contracts_risk = int(risk_amount / risk_per_contract) if risk_per_contract > 0 else 1
|
||||
return min(abs(action_int), max(1, max_contracts_risk), self.max_contracts) * np.sign(action_int)
|
||||
|
||||
def step(self, action):
|
||||
start_balance = self.balance
|
||||
current_price = self.df.loc[self.current_step, 'Close']
|
||||
prev_position = self.contracts_held
|
||||
|
||||
# Convert action to integer (discrete mapping)
|
||||
if self.action_mode == 'discrete':
|
||||
action_int = action - self.max_contracts
|
||||
else:
|
||||
action_int = int(np.round(action[0]))
|
||||
action_int = np.clip(action_int, -self.max_contracts, self.max_contracts)
|
||||
|
||||
# Use the cached forecast for risk/reward decision
|
||||
forecast = self.last_forecast
|
||||
risk = self.stop_loss_points
|
||||
min_required_reward = self.min_risk_reward * risk
|
||||
|
||||
# If forecast magnitude is insufficient, do not trade
|
||||
if abs(forecast) < min_required_reward:
|
||||
action_int = 0
|
||||
logging.debug("No trade: forecast below risk-reward threshold.")
|
||||
else:
|
||||
# Only allow trades in the direction of the forecast
|
||||
if np.sign(action_int) != np.sign(forecast):
|
||||
action_int = 0
|
||||
logging.debug("No trade: action direction does not match forecast.")
|
||||
else:
|
||||
# Scale action based on forecast strength and then calculate position size
|
||||
max_forecast = risk * 2
|
||||
scaling = min(1, abs(forecast) / max_forecast)
|
||||
action_int = int(np.sign(action_int) * max(1, int(abs(action_int) * scaling)))
|
||||
action_int = self._calculate_position_size(action_int)
|
||||
logging.debug(f"Trade approved: adjusted action to {action_int} based on forecast {forecast:.2f}.")
|
||||
|
||||
# Check for stop loss or trailing stop if already in a position
|
||||
if self.contracts_held != 0 and self.entry_price is not None:
|
||||
price_change = current_price - self.entry_price
|
||||
if self.contracts_held > 0:
|
||||
self.peak_price = max(self.peak_price or current_price, current_price)
|
||||
trail_diff = self.peak_price - current_price
|
||||
if price_change <= -risk or trail_diff >= self.trailing_stop_points:
|
||||
logging.debug("Stop loss / trailing stop triggered for long position.")
|
||||
action_int = -self.contracts_held
|
||||
else:
|
||||
self.peak_price = min(self.peak_price or current_price, current_price)
|
||||
trail_diff = current_price - self.peak_price
|
||||
if price_change >= risk or trail_diff >= self.trailing_stop_points:
|
||||
logging.debug("Stop loss / trailing stop triggered for short position.")
|
||||
action_int = -self.contracts_held
|
||||
|
||||
fee = self.transaction_cost * abs(action_int)
|
||||
# Update position logic
|
||||
if action_int != 0:
|
||||
if self.contracts_held == 0:
|
||||
self.contracts_held = action_int
|
||||
self.entry_price = current_price
|
||||
self.peak_price = current_price
|
||||
elif np.sign(self.contracts_held) == np.sign(action_int):
|
||||
total_contracts = self.contracts_held + action_int
|
||||
self.entry_price = (self.entry_price * self.contracts_held + current_price * action_int) / total_contracts
|
||||
self.contracts_held = total_contracts
|
||||
else:
|
||||
if abs(action_int) >= abs(self.contracts_held):
|
||||
self.contracts_held += action_int
|
||||
self.entry_price = current_price if self.contracts_held != 0 else None
|
||||
self.peak_price = current_price if self.contracts_held != 0 else None
|
||||
else:
|
||||
self.contracts_held += action_int
|
||||
|
||||
# Calculate mark-to-market PnL change from previous step (using previous step close)
|
||||
prev_close = self.df.loc[self.current_step-1, 'Close'] if self.current_step > 0 else current_price
|
||||
pnl_change = (current_price - prev_close) * prev_position * self.multiplier
|
||||
# Base reward: pnl change minus fee
|
||||
reward = pnl_change - fee
|
||||
# Add bonus reward if action aligns with forecast (scaled by bonus factor)
|
||||
bonus_factor = 10.0
|
||||
reward += bonus_factor * (action_int * forecast)
|
||||
|
||||
self.balance += reward
|
||||
self.daily_loss -= min(0, reward)
|
||||
|
||||
# Record trade if a position was closed
|
||||
if self.contracts_held == 0 and prev_position != 0:
|
||||
profit = self.balance - start_balance
|
||||
day = self.current_step // (60 * 24)
|
||||
self.trade_history.append((start_balance, profit, day))
|
||||
|
||||
# Update daily balance if new day begins
|
||||
current_day = self.current_step // (60 * 24)
|
||||
if current_day > self.current_day:
|
||||
self.daily_balances.append(self.balance)
|
||||
self.daily_loss = 0.0
|
||||
self.current_day = current_day
|
||||
|
||||
self.current_step += 1
|
||||
done = (self.current_step >= self.max_steps - 1) or (self.daily_loss >= self.initial_balance * self.max_daily_loss)
|
||||
return self._get_obs(), reward, done, {}
|
||||
|
||||
def render(self, mode='human'):
|
||||
current_price = self.df.loc[self.current_step, 'Close']
|
||||
pnl = (current_price - self.entry_price) * self.contracts_held * self.multiplier if self.entry_price is not None else 0.0
|
||||
print(f"Step: {self.current_step}, Balance: {self.balance:.2f}, Contracts: {self.contracts_held}, PnL: {pnl:.2f}, Daily Loss: {self.daily_loss:.2f}")
|
||||
|
||||
def calculate_metrics(self):
|
||||
if not self.trade_history or len(self.daily_balances) < 2:
|
||||
return {
|
||||
'avg_daily_return': 0.0,
|
||||
'avg_return_per_trade': 0.0,
|
||||
'win_rate': 0.0,
|
||||
'avg_dollar_return_daily': 0.0,
|
||||
'avg_dollar_return_per_trade': 0.0,
|
||||
'sharpe_ratio': 0.0
|
||||
}
|
||||
daily_returns = [(self.daily_balances[i] - self.daily_balances[i-1]) / self.daily_balances[i-1]
|
||||
for i in range(1, len(self.daily_balances))]
|
||||
avg_daily_return = np.mean(daily_returns) * 100
|
||||
avg_dollar_return_daily = np.mean([self.daily_balances[i] - self.daily_balances[i-1]
|
||||
for i in range(1, len(self.daily_balances))])
|
||||
trade_returns = [profit / start for start, profit, _ in self.trade_history]
|
||||
avg_return_per_trade = np.mean(trade_returns) * 100
|
||||
avg_dollar_return_per_trade = np.mean([profit for _, profit, _ in self.trade_history])
|
||||
win_rate = len([r for r in trade_returns if r > 0]) / len(trade_returns) * 100
|
||||
daily_returns_np = np.array(daily_returns)
|
||||
sharpe_ratio = (np.mean(daily_returns_np) / np.std(daily_returns_np)) * np.sqrt(252) if np.std(daily_returns_np) != 0 else 0.0
|
||||
|
||||
return {
|
||||
'avg_daily_return': avg_daily_return,
|
||||
'avg_return_per_trade': avg_return_per_trade,
|
||||
'win_rate': win_rate,
|
||||
'avg_dollar_return_daily': avg_dollar_return_daily,
|
||||
'avg_dollar_return_per_trade': avg_dollar_return_per_trade,
|
||||
'sharpe_ratio': sharpe_ratio
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Placeholders for Live Deployment Functions
|
||||
# =============================================================================
|
||||
def get_live_data():
|
||||
return None
|
||||
|
||||
def execute_order(action):
|
||||
logging.info(f"Executing order: {action}")
|
||||
|
||||
def live_trading_loop(model, env, polling_interval=5):
|
||||
obs = env.reset()
|
||||
done = False
|
||||
while not done:
|
||||
live_data = get_live_data()
|
||||
if live_data is not None:
|
||||
pass
|
||||
action, _ = model.predict(obs, deterministic=True)
|
||||
execute_order(action)
|
||||
obs, reward, done, _ = env.step(action)
|
||||
env.render()
|
||||
time.sleep(polling_interval)
|
||||
|
||||
# =============================================================================
|
||||
# Data Preprocessing with Parallelization
|
||||
# =============================================================================
|
||||
def parallel_feature_engineering(row):
|
||||
return row
|
||||
|
||||
def feature_engineering_parallel(df, num_workers):
|
||||
logging.info(f"Starting parallel feature engineering with {num_workers} workers...")
|
||||
with Pool(processes=num_workers) as pool:
|
||||
processed_rows = pool.map(parallel_feature_engineering, [row for _, row in df.iterrows()])
|
||||
df_processed = pd.DataFrame(processed_rows)
|
||||
logging.info("Parallel feature engineering completed.")
|
||||
return df_processed
|
||||
|
||||
# =============================================================================
|
||||
# MAIN FUNCTION: LSTM Training + PPO for Futures Trading
|
||||
# =============================================================================
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
data_path = args.data_path
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
lstm_window_size = args.lstm_window_size
|
||||
ppo_total_timesteps = args.ppo_total_timesteps
|
||||
n_trials_lstm = args.n_trials_lstm
|
||||
preprocess_workers = args.preprocess_workers
|
||||
enable_resource_monitor = args.monitor_resources
|
||||
action_mode = args.action_mode
|
||||
max_contracts = args.max_contracts
|
||||
|
||||
# Load and process data (CSV or JSON)
|
||||
df = load_data(data_path)
|
||||
df = calculate_technical_indicators(df)
|
||||
# Use numeric columns (exclude Date)
|
||||
feature_columns = [col for col in df.columns if col != 'Date' and np.issubdtype(df[col].dtype, np.number)]
|
||||
|
||||
# Fit scalers on the numeric data to avoid issues with column names
|
||||
scaler_features = MinMaxScaler().fit(df[feature_columns].values)
|
||||
scaler_target = MinMaxScaler().fit(df[['Close']].values)
|
||||
|
||||
# Build a basic LSTM model (further tuning could be applied using Optuna)
|
||||
final_lstm = build_lstm((lstm_window_size, len(feature_columns)), {
|
||||
'num_lstm_layers': 2,
|
||||
'lstm_units': 50,
|
||||
'dropout_rate': 0.2,
|
||||
'optimizer': 'Adam',
|
||||
'learning_rate': 0.001,
|
||||
'decay': 1e-6
|
||||
})
|
||||
|
||||
# (Optional) You can add LSTM hyperparameter tuning here using Optuna
|
||||
|
||||
# Set up the trading environment
|
||||
env_params = {
|
||||
'df': df,
|
||||
'feature_columns': feature_columns,
|
||||
'lstm_model': final_lstm,
|
||||
'scaler_features': scaler_features,
|
||||
'scaler_target': scaler_target,
|
||||
'window_size': lstm_window_size,
|
||||
'transaction_cost': 0.60,
|
||||
'action_mode': action_mode,
|
||||
'max_contracts': max_contracts,
|
||||
'initial_balance': args.initial_balance,
|
||||
'stop_loss_points': args.stop_loss_points,
|
||||
'trailing_stop_points': args.trailing_stop_points,
|
||||
'min_risk_reward': args.min_risk_reward,
|
||||
'max_daily_loss': 0.05
|
||||
}
|
||||
env = FuturesTradingEnv(**env_params)
|
||||
vec_env = DummyVecEnv([lambda: env])
|
||||
|
||||
# PPO hyperparameters
|
||||
ppo_hyperparams = {
|
||||
'n_steps': 2048,
|
||||
'batch_size': 64,
|
||||
'gae_lambda': 0.95,
|
||||
'gamma': 0.99,
|
||||
'learning_rate': 3e-4,
|
||||
'ent_coef': 0.0,
|
||||
'verbose': 1
|
||||
}
|
||||
|
||||
if enable_resource_monitor:
|
||||
resource_monitor_thread = threading.Thread(target=monitor_resources, args=(60,), daemon=True)
|
||||
resource_monitor_thread.start()
|
||||
|
||||
logging.info("Starting PPO training...")
|
||||
ppo_model = PPO('MlpPolicy', vec_env, **ppo_hyperparams)
|
||||
ppo_model.learn(total_timesteps=ppo_total_timesteps)
|
||||
ppo_model.save(os.path.join(output_dir, "best_ppo_model.zip"))
|
||||
logging.info("PPO training completed and model saved.")
|
||||
|
||||
# Inference on the environment
|
||||
obs = env.reset()
|
||||
done = False
|
||||
total_reward = 0.0
|
||||
step_data = []
|
||||
|
||||
while not done:
|
||||
step_count = env.current_step + 1
|
||||
action, _ = ppo_model.predict(obs, deterministic=True)
|
||||
obs, reward, done, _ = env.step(action)
|
||||
total_reward += reward
|
||||
step_data.append({
|
||||
"Step": step_count,
|
||||
"Action": int(action) if action_mode == 'discrete' else int(np.round(action[0])),
|
||||
"Reward": reward,
|
||||
"Contracts": env.contracts_held
|
||||
})
|
||||
|
||||
final_pnl = (env.df.loc[env.current_step, 'Close'] - (env.entry_price if env.entry_price else 0)) * env.contracts_held * env.multiplier
|
||||
print("\n=== Final PPO Inference ===")
|
||||
print(f"Total Steps: {step_count}")
|
||||
print(f"Final Contracts Held: {env.contracts_held}")
|
||||
print(f"Final Estimated PnL: {final_pnl:.2f}")
|
||||
print(f"Total Reward Sum: {total_reward:.2f}")
|
||||
print("\nLast 15 Steps:")
|
||||
last_n = step_data[-15:] if len(step_data) > 15 else step_data
|
||||
print(tabulate(last_n, headers="keys", tablefmt="pretty"))
|
||||
|
||||
metrics = env.calculate_metrics()
|
||||
print("\n=== Trading Metrics ===")
|
||||
print(f"Average Daily Return: {metrics['avg_daily_return']:.2f}%")
|
||||
print(f"Average Return per Trade: {metrics['avg_return_per_trade']:.2f}%")
|
||||
print(f"Win Rate: {metrics['win_rate']:.2f}%")
|
||||
print(f"Average Dollar Return Daily: ${metrics['avg_dollar_return_daily']:.2f}")
|
||||
print(f"Average Dollar Return per Trade: ${metrics['avg_dollar_return_per_trade']:.2f}")
|
||||
print(f"Sharpe Ratio: {metrics['sharpe_ratio']:.2f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
257
src/MidasAgent/PPO/MidasWrapper.py
Normal file
257
src/MidasAgent/PPO/MidasWrapper.py
Normal file
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MidasWrapper.py
|
||||
|
||||
This script connects to IBKR via IB_insync, attempts to select a MES futures contract
|
||||
using several variants, and then, based on the AI signal, places a market order and attaches
|
||||
a stop-loss order.
|
||||
|
||||
Usage:
|
||||
- Ensure that IBKR TWS or Gateway is running in paper trading mode (port 4002 used here).
|
||||
- Make sure your IBKR account is subscribed to MES futures and the requested contract month is active.
|
||||
- Example: py MidasWrapper.py
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import time
|
||||
import logging
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from ib_insync import IB, Future, MarketOrder, Order, util
|
||||
|
||||
# --- Logger Setup ---
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Global Time Zone ---
|
||||
TZ = ZoneInfo("US/Eastern")
|
||||
|
||||
# --- Utility Functions ---
|
||||
def get_third_friday(year, month):
|
||||
"""Return the third Friday of the given year/month as a datetime with TZ."""
|
||||
fridays = []
|
||||
for day in range(1, 32):
|
||||
try:
|
||||
d = datetime.date(year, month, day)
|
||||
except ValueError:
|
||||
break
|
||||
if d.weekday() == 4:
|
||||
fridays.append(d)
|
||||
if len(fridays) >= 3:
|
||||
dt = datetime.datetime.combine(fridays[2], datetime.time(16, 0))
|
||||
elif fridays:
|
||||
dt = datetime.datetime.combine(fridays[-1], datetime.time(16, 0))
|
||||
else:
|
||||
dt = datetime.datetime(year, month, 1, 16, 0)
|
||||
return dt.replace(tzinfo=TZ)
|
||||
|
||||
def getMESContract(ib_conn, cm, contract_expiration):
|
||||
"""
|
||||
Try several MES contract definitions for the given contract month (cm)
|
||||
and expiration date. Returns a tuple (contract, variant_desc) if found,
|
||||
or (None, None) if no valid contract is returned.
|
||||
"""
|
||||
expiration_str = contract_expiration.strftime("%Y%m%d")
|
||||
variants = []
|
||||
|
||||
# Variant 1: Full expiration date, exchange GLOBEX.
|
||||
contract1 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=expiration_str,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract1.includeExpired = True
|
||||
variants.append(("Variant 1: full expiration, GLOBEX", contract1))
|
||||
|
||||
# Variant 2: Full expiration date, exchange CME.
|
||||
contract2 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=expiration_str,
|
||||
exchange='CME',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract2.includeExpired = True
|
||||
variants.append(("Variant 2: full expiration, CME", contract2))
|
||||
|
||||
# Variant 3: Contract month with tradingClass on GLOBEX.
|
||||
contract3 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5,
|
||||
tradingClass='MES'
|
||||
)
|
||||
contract3.includeExpired = True
|
||||
variants.append(("Variant 3: contract month, GLOBEX, tradingClass", contract3))
|
||||
|
||||
# Variant 4: Contract month with tradingClass on CME.
|
||||
contract4 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
exchange='CME',
|
||||
currency='USD',
|
||||
multiplier=5,
|
||||
tradingClass='MES'
|
||||
)
|
||||
contract4.includeExpired = True
|
||||
variants.append(("Variant 4: contract month, CME, tradingClass", contract4))
|
||||
|
||||
# Variant 5: Contract month using computed localSymbol on GLOBEX.
|
||||
month_codes = {1: 'F', 2: 'G', 3: 'H', 4: 'J', 5: 'K', 6: 'M', 7: 'N', 8: 'Q', 9: 'U', 10: 'V', 11: 'X', 12: 'Z'}
|
||||
year_num = int(cm[:4])
|
||||
month_num = int(cm[4:])
|
||||
local_symbol = f"MES{month_codes.get(month_num, '')}{str(year_num)[-1]}"
|
||||
contract5 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
localSymbol=local_symbol,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract5.includeExpired = True
|
||||
variants.append(("Variant 5: contract month, GLOBEX, localSymbol", contract5))
|
||||
|
||||
# Variant 6: Basic contract definition on GLOBEX using only contract month.
|
||||
contract6 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
exchange='GLOBEX',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract6.includeExpired = True
|
||||
variants.append(("Variant 6: basic contract, GLOBEX", contract6))
|
||||
|
||||
# Variant 7: Basic contract definition on CME using only contract month.
|
||||
contract7 = Future(
|
||||
symbol='MES',
|
||||
lastTradeDateOrContractMonth=cm,
|
||||
exchange='CME',
|
||||
currency='USD',
|
||||
multiplier=5
|
||||
)
|
||||
contract7.includeExpired = True
|
||||
variants.append(("Variant 7: basic contract, CME", contract7))
|
||||
|
||||
for variant_desc, contract in variants:
|
||||
logger.info(f"Trying {variant_desc} for {cm} (expiration: {expiration_str})...")
|
||||
details = ib_conn.reqContractDetails(contract)
|
||||
if details:
|
||||
logger.info(f"Success with {variant_desc}: {details[0].contract.localSymbol}")
|
||||
return details[0].contract, variant_desc
|
||||
else:
|
||||
logger.info(f"{variant_desc} did not return any details.")
|
||||
return None, None
|
||||
|
||||
# --- Main Order Execution Function ---
|
||||
def execute_future_with_stop(action_signal, quantity=1, stop_loss_points=20):
|
||||
"""
|
||||
Based on an AI signal (action_signal), this function:
|
||||
- Determines order side (BUY if signal > 0, SELL if signal < 0).
|
||||
- Selects an appropriate MES futures contract.
|
||||
- Submits a market order and waits for its fill.
|
||||
- Then submits a stop-loss order at a price offset by stop_loss_points.
|
||||
"""
|
||||
if action_signal == 0:
|
||||
logger.info("Action signal is neutral. No trade will be executed.")
|
||||
return
|
||||
|
||||
side = "BUY" if action_signal > 0 else "SELL"
|
||||
stop_side = "SELL" if side == "BUY" else "BUY"
|
||||
logger.info(f"Action signal received: {action_signal} -> {side} {quantity} contract(s)")
|
||||
|
||||
# Connect to IBKR (using a dedicated clientId and port for orders)
|
||||
ib = IB()
|
||||
try:
|
||||
logger.info("Connecting to IBKR on 127.0.0.1:4002 with clientId 1...")
|
||||
ib.connect('127.0.0.1', 4002, clientId=1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error connecting to IBKR: {e}")
|
||||
return
|
||||
|
||||
# Determine contract month – here we use the current month.
|
||||
now = datetime.datetime.now(TZ)
|
||||
cm = now.strftime("%Y%m")
|
||||
contract_expiration = get_third_friday(now.year, now.month)
|
||||
contract, variant_used = getMESContract(ib, cm, contract_expiration)
|
||||
if not contract:
|
||||
logger.error("No valid MES contract found. Aborting order execution.")
|
||||
logger.error("Please verify that your IBKR account is subscribed to MES futures and that the contract month is active.")
|
||||
ib.disconnect()
|
||||
return
|
||||
|
||||
logger.info(f"Selected contract: {contract.localSymbol} using {variant_used}")
|
||||
|
||||
# Subscribe to market data to get the current price.
|
||||
ticker = ib.reqMktData(contract)
|
||||
t0 = time.time()
|
||||
while ticker.last is None and time.time() - t0 < 10:
|
||||
ib.sleep(0.5)
|
||||
if ticker.last is None:
|
||||
logger.error("Market data not available. Aborting order execution.")
|
||||
ib.disconnect()
|
||||
return
|
||||
|
||||
current_price = float(ticker.last)
|
||||
logger.info(f"Current market price for {contract.localSymbol}: {current_price}")
|
||||
|
||||
# Place the market order (parent order)
|
||||
parent_order = MarketOrder(side, quantity)
|
||||
parent_order.transmit = True # Transmit immediately.
|
||||
parent_trade = ib.placeOrder(contract, parent_order)
|
||||
logger.info(f"Placed market order: {side} {quantity} contract(s) for {contract.localSymbol}")
|
||||
|
||||
# Wait for the market order to fill.
|
||||
while parent_trade.orderStatus.status not in ("Filled", "Cancelled"):
|
||||
ib.sleep(0.5)
|
||||
if parent_trade.orderStatus.status != "Filled":
|
||||
logger.error(f"Market order not filled. Status: {parent_trade.orderStatus.status}")
|
||||
ib.disconnect()
|
||||
return
|
||||
|
||||
fill_price = parent_trade.orderStatus.avgFillPrice
|
||||
logger.info(f"Market order filled at {fill_price}")
|
||||
|
||||
# Calculate stop-loss price.
|
||||
# For BUY orders, the stop price is below fill price; for SELL orders, above.
|
||||
if side == "BUY":
|
||||
stop_price = fill_price - stop_loss_points
|
||||
else:
|
||||
stop_price = fill_price + stop_loss_points
|
||||
|
||||
logger.info(f"Placing stop loss order at {stop_price} for {quantity} contract(s)")
|
||||
|
||||
# Create and place the stop-loss order.
|
||||
stop_order = Order()
|
||||
stop_order.action = stop_side
|
||||
stop_order.orderType = "STP"
|
||||
stop_order.totalQuantity = quantity
|
||||
stop_order.auxPrice = stop_price
|
||||
stop_order.transmit = True
|
||||
|
||||
stop_trade = ib.placeOrder(contract, stop_order)
|
||||
logger.info(f"Stop loss order placed: {stop_side} {quantity} contract(s) at {stop_price}")
|
||||
|
||||
# Optionally, wait a moment to see if the stop order is accepted.
|
||||
ib.sleep(1)
|
||||
ib.disconnect()
|
||||
return parent_trade, stop_trade
|
||||
|
||||
# --- Main Execution ---
|
||||
if __name__ == "__main__":
|
||||
# Example usage:
|
||||
# In practice, the action_signal would come from your AI/PPO system.
|
||||
# A positive value means buy, negative means sell, zero means no action.
|
||||
simulated_ai_signal = 1 # Change to -1 for a sell signal; 0 for no action.
|
||||
quantity = 1 # Number of contracts to trade.
|
||||
stop_loss_points = 20 # Stop-loss offset (price units).
|
||||
|
||||
execute_future_with_stop(simulated_ai_signal, quantity, stop_loss_points)
|
||||
|
||||
106
src/MidasAgent/PPO/output/FuturesPPO.log
Normal file
106
src/MidasAgent/PPO/output/FuturesPPO.log
Normal file
@@ -0,0 +1,106 @@
|
||||
2025-03-26 03:01:52,506 - INFO - ===== Resource Statistics =====
|
||||
2025-03-26 03:01:52,506 - INFO - Physical CPU Cores: 28
|
||||
2025-03-26 03:01:52,506 - INFO - Logical CPU Cores: 56
|
||||
2025-03-26 03:01:52,507 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
|
||||
2025-03-26 03:01:52,507 - INFO - No GPUs detected.
|
||||
2025-03-26 03:01:52,507 - INFO - =================================
|
||||
2025-03-26 03:01:52,507 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-03-26 03:01:52,508 - INFO - Loading data from: data/MES2023Z.csv
|
||||
2025-03-26 03:01:52,513 - ERROR - Unexpected error: Missing column provided to 'parse_dates': 'time'
|
||||
2025-03-26 03:04:50,616 - INFO - ===== Resource Statistics =====
|
||||
2025-03-26 03:04:50,616 - INFO - Physical CPU Cores: 28
|
||||
2025-03-26 03:04:50,616 - INFO - Logical CPU Cores: 56
|
||||
2025-03-26 03:04:50,616 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]%
|
||||
2025-03-26 03:04:50,617 - INFO - No GPUs detected.
|
||||
2025-03-26 03:04:50,617 - INFO - =================================
|
||||
2025-03-26 03:04:50,617 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-03-26 03:04:50,618 - INFO - Loading data from: data/MES2023Z.csv
|
||||
2025-03-26 03:04:50,621 - ERROR - Unexpected error: Missing column provided to 'parse_dates': 'time'
|
||||
2025-03-26 03:08:02,316 - INFO - ===== Resource Statistics =====
|
||||
2025-03-26 03:08:02,316 - INFO - Physical CPU Cores: 28
|
||||
2025-03-26 03:08:02,316 - INFO - Logical CPU Cores: 56
|
||||
2025-03-26 03:08:02,317 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
|
||||
2025-03-26 03:08:02,317 - INFO - No GPUs detected.
|
||||
2025-03-26 03:08:02,317 - INFO - =================================
|
||||
2025-03-26 03:08:02,317 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-03-26 03:08:02,318 - INFO - Loading data from: data/MES2023Z.csv
|
||||
2025-03-26 03:08:02,355 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||
2025-03-26 03:08:02,383 - INFO - Data loaded and sorted successfully.
|
||||
2025-03-26 03:08:02,383 - INFO - Calculating technical indicators...
|
||||
2025-03-26 03:08:02,448 - INFO - Technical indicators calculated successfully.
|
||||
2025-03-26 03:08:02,464 - INFO - Starting parallel feature engineering with 54 workers...
|
||||
2025-03-26 03:08:03,331 - INFO - Parallel feature engineering completed.
|
||||
2025-03-26 03:08:03,341 - INFO - Training sequences shape: (676, 15, 17)
|
||||
2025-03-26 03:08:03,342 - INFO - Validation sequences shape: (144, 15, 17)
|
||||
2025-03-26 03:08:03,342 - INFO - Testing sequences shape: (146, 15, 17)
|
||||
2025-03-26 03:08:03,342 - INFO - Starting LSTM hyperparameter optimization with Optuna using 54 parallel trials...
|
||||
2025-03-26 03:22:04,033 - INFO - Best LSTM Hyperparameters: {'num_lstm_layers': 2, 'lstm_units': 64, 'dropout_rate': 0.13619292923712067, 'learning_rate': 0.0030545284525912166, 'optimizer': 'Nadam', 'decay': 9.615099767236892e-05}
|
||||
2025-03-26 03:22:04,553 - INFO - Training best LSTM model with optimized hyperparameters...
|
||||
2025-03-26 03:24:28,296 - INFO - Evaluating final LSTM model...
|
||||
2025-03-26 03:24:29,722 - INFO - Test MSE: 0.3437
|
||||
2025-03-26 03:24:29,722 - INFO - Test RMSE: 0.5862
|
||||
2025-03-26 03:24:29,722 - INFO - Test MAE: 0.4561
|
||||
2025-03-26 03:24:29,722 - INFO - Test R2 Score: 0.8620
|
||||
2025-03-26 03:24:29,722 - INFO - Directional Accuracy: 0.2759
|
||||
2025-03-26 03:24:30,013 - WARNING - You are saving your model as an HDF5 file via `model.save()` or `keras.saving.save_model(model)`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')` or `keras.saving.save_model(model, 'my_model.keras')`.
|
||||
2025-03-26 03:24:30,121 - INFO - Saved best LSTM model and scaler objects.
|
||||
2025-03-26 03:24:30,150 - INFO - Starting PPO training...
|
||||
2025-03-26 05:47:15,571 - INFO - PPO training completed and model saved.
|
||||
2025-04-11 22:15:50,927 - INFO - ===== Resource Statistics =====
|
||||
2025-04-11 22:15:50,929 - INFO - Physical CPU Cores: 28
|
||||
2025-04-11 22:15:50,929 - INFO - Logical CPU Cores: 56
|
||||
2025-04-11 22:15:50,930 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
|
||||
2025-04-11 22:15:50,930 - INFO - No GPUs detected.
|
||||
2025-04-11 22:15:50,930 - INFO - =================================
|
||||
2025-04-11 22:15:50,932 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-04-11 22:15:50,932 - INFO - Loading data from: data/cleaned_MES_data.csv
|
||||
2025-04-11 22:15:50,933 - ERROR - File not found: data/cleaned_MES_data.csv
|
||||
2025-04-11 22:17:40,253 - INFO - ===== Resource Statistics =====
|
||||
2025-04-11 22:17:40,253 - INFO - Physical CPU Cores: 28
|
||||
2025-04-11 22:17:40,253 - INFO - Logical CPU Cores: 56
|
||||
2025-04-11 22:17:40,253 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
|
||||
2025-04-11 22:17:40,253 - INFO - No GPUs detected.
|
||||
2025-04-11 22:17:40,254 - INFO - =================================
|
||||
2025-04-11 22:17:40,254 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-04-11 22:17:40,254 - INFO - Loading data from: ../data/cleaned_MES_data.csv
|
||||
2025-04-11 22:17:40,373 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||
2025-04-11 22:17:40,392 - INFO - Data loaded and sorted successfully.
|
||||
2025-04-11 22:17:40,392 - INFO - Calculating technical indicators...
|
||||
2025-04-11 22:17:40,517 - INFO - Technical indicators calculated successfully.
|
||||
2025-04-11 22:17:40,549 - INFO - Starting parallel feature engineering with 54 workers...
|
||||
2025-04-11 22:18:25,291 - INFO - ===== Resource Statistics =====
|
||||
2025-04-11 22:18:25,291 - INFO - Physical CPU Cores: 28
|
||||
2025-04-11 22:18:25,291 - INFO - Logical CPU Cores: 56
|
||||
2025-04-11 22:18:25,291 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
|
||||
2025-04-11 22:18:25,291 - INFO - No GPUs detected.
|
||||
2025-04-11 22:18:25,291 - INFO - =================================
|
||||
2025-04-11 22:18:25,292 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-04-11 22:18:25,292 - INFO - Loading data from: ../data/cleaned_MES_data.csv
|
||||
2025-04-11 22:18:25,373 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||
2025-04-11 22:18:25,378 - INFO - Data loaded and sorted successfully.
|
||||
2025-04-11 22:18:25,378 - INFO - Calculating technical indicators...
|
||||
2025-04-11 22:18:25,440 - INFO - Technical indicators calculated successfully.
|
||||
2025-04-11 22:18:25,456 - INFO - Starting parallel feature engineering with 54 workers...
|
||||
2025-04-11 22:18:53,636 - INFO - Parallel feature engineering completed.
|
||||
2025-04-11 22:18:53,925 - INFO - Training sequences shape: (45647, 15, 17)
|
||||
2025-04-11 22:18:53,925 - INFO - Validation sequences shape: (9781, 15, 17)
|
||||
2025-04-11 22:18:53,925 - INFO - Testing sequences shape: (9782, 15, 17)
|
||||
2025-04-11 22:18:53,925 - INFO - Starting LSTM hyperparameter optimization with Optuna using 54 parallel trials...
|
||||
2025-04-12 15:24:28,932 - INFO - ===== Resource Statistics =====
|
||||
2025-04-12 15:24:28,932 - INFO - Physical CPU Cores: 28
|
||||
2025-04-12 15:24:28,932 - INFO - Logical CPU Cores: 56
|
||||
2025-04-12 15:24:28,932 - INFO - CPU Usage per Core: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]%
|
||||
2025-04-12 15:24:28,932 - INFO - No GPUs detected.
|
||||
2025-04-12 15:24:28,932 - INFO - =================================
|
||||
2025-04-12 15:24:28,933 - INFO - Configured TensorFlow to use CPU with optimized thread settings.
|
||||
2025-04-12 15:24:28,933 - INFO - Loading data from: ../data/cleaned_MES_data.csv
|
||||
2025-04-12 15:24:29,013 - INFO - Data columns after renaming: ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
|
||||
2025-04-12 15:24:29,018 - INFO - Data loaded and sorted successfully.
|
||||
2025-04-12 15:24:29,018 - INFO - Calculating technical indicators...
|
||||
2025-04-12 15:24:29,080 - INFO - Technical indicators calculated successfully.
|
||||
2025-04-12 15:24:29,096 - INFO - Starting parallel feature engineering with 54 workers...
|
||||
2025-04-12 15:24:56,873 - INFO - Parallel feature engineering completed.
|
||||
2025-04-12 15:24:57,163 - INFO - Training sequences shape: (45647, 15, 17)
|
||||
2025-04-12 15:24:57,163 - INFO - Validation sequences shape: (9781, 15, 17)
|
||||
2025-04-12 15:24:57,163 - INFO - Testing sequences shape: (9782, 15, 17)
|
||||
2025-04-12 15:24:57,163 - INFO - Starting LSTM hyperparameter optimization with Optuna using 54 parallel trials...
|
||||
BIN
src/MidasAgent/PPO/output/best_ppo_model.zip
Normal file
BIN
src/MidasAgent/PPO/output/best_ppo_model.zip
Normal file
Binary file not shown.
|
Before Width: | Height: | Size: 80 KiB After Width: | Height: | Size: 80 KiB |
247
src/MidasAgent/PPO/wrappervenv/bin/Activate.ps1
Normal file
247
src/MidasAgent/PPO/wrappervenv/bin/Activate.ps1
Normal file
@@ -0,0 +1,247 @@
|
||||
<#
|
||||
.Synopsis
|
||||
Activate a Python virtual environment for the current PowerShell session.
|
||||
|
||||
.Description
|
||||
Pushes the python executable for a virtual environment to the front of the
|
||||
$Env:PATH environment variable and sets the prompt to signify that you are
|
||||
in a Python virtual environment. Makes use of the command line switches as
|
||||
well as the `pyvenv.cfg` file values present in the virtual environment.
|
||||
|
||||
.Parameter VenvDir
|
||||
Path to the directory that contains the virtual environment to activate. The
|
||||
default value for this is the parent of the directory that the Activate.ps1
|
||||
script is located within.
|
||||
|
||||
.Parameter Prompt
|
||||
The prompt prefix to display when this virtual environment is activated. By
|
||||
default, this prompt is the name of the virtual environment folder (VenvDir)
|
||||
surrounded by parentheses and followed by a single space (ie. '(.venv) ').
|
||||
|
||||
.Example
|
||||
Activate.ps1
|
||||
Activates the Python virtual environment that contains the Activate.ps1 script.
|
||||
|
||||
.Example
|
||||
Activate.ps1 -Verbose
|
||||
Activates the Python virtual environment that contains the Activate.ps1 script,
|
||||
and shows extra information about the activation as it executes.
|
||||
|
||||
.Example
|
||||
Activate.ps1 -VenvDir C:\Users\MyUser\Common\.venv
|
||||
Activates the Python virtual environment located in the specified location.
|
||||
|
||||
.Example
|
||||
Activate.ps1 -Prompt "MyPython"
|
||||
Activates the Python virtual environment that contains the Activate.ps1 script,
|
||||
and prefixes the current prompt with the specified string (surrounded in
|
||||
parentheses) while the virtual environment is active.
|
||||
|
||||
.Notes
|
||||
On Windows, it may be required to enable this Activate.ps1 script by setting the
|
||||
execution policy for the user. You can do this by issuing the following PowerShell
|
||||
command:
|
||||
|
||||
PS C:\> Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser
|
||||
|
||||
For more information on Execution Policies:
|
||||
https://go.microsoft.com/fwlink/?LinkID=135170
|
||||
|
||||
#>
|
||||
Param(
|
||||
[Parameter(Mandatory = $false)]
|
||||
[String]
|
||||
$VenvDir,
|
||||
[Parameter(Mandatory = $false)]
|
||||
[String]
|
||||
$Prompt
|
||||
)
|
||||
|
||||
<# Function declarations --------------------------------------------------- #>
|
||||
|
||||
<#
|
||||
.Synopsis
|
||||
Remove all shell session elements added by the Activate script, including the
|
||||
addition of the virtual environment's Python executable from the beginning of
|
||||
the PATH variable.
|
||||
|
||||
.Parameter NonDestructive
|
||||
If present, do not remove this function from the global namespace for the
|
||||
session.
|
||||
|
||||
#>
|
||||
function global:deactivate ([switch]$NonDestructive) {
|
||||
# Revert to original values
|
||||
|
||||
# The prior prompt:
|
||||
if (Test-Path -Path Function:_OLD_VIRTUAL_PROMPT) {
|
||||
Copy-Item -Path Function:_OLD_VIRTUAL_PROMPT -Destination Function:prompt
|
||||
Remove-Item -Path Function:_OLD_VIRTUAL_PROMPT
|
||||
}
|
||||
|
||||
# The prior PYTHONHOME:
|
||||
if (Test-Path -Path Env:_OLD_VIRTUAL_PYTHONHOME) {
|
||||
Copy-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME -Destination Env:PYTHONHOME
|
||||
Remove-Item -Path Env:_OLD_VIRTUAL_PYTHONHOME
|
||||
}
|
||||
|
||||
# The prior PATH:
|
||||
if (Test-Path -Path Env:_OLD_VIRTUAL_PATH) {
|
||||
Copy-Item -Path Env:_OLD_VIRTUAL_PATH -Destination Env:PATH
|
||||
Remove-Item -Path Env:_OLD_VIRTUAL_PATH
|
||||
}
|
||||
|
||||
# Just remove the VIRTUAL_ENV altogether:
|
||||
if (Test-Path -Path Env:VIRTUAL_ENV) {
|
||||
Remove-Item -Path env:VIRTUAL_ENV
|
||||
}
|
||||
|
||||
# Just remove VIRTUAL_ENV_PROMPT altogether.
|
||||
if (Test-Path -Path Env:VIRTUAL_ENV_PROMPT) {
|
||||
Remove-Item -Path env:VIRTUAL_ENV_PROMPT
|
||||
}
|
||||
|
||||
# Just remove the _PYTHON_VENV_PROMPT_PREFIX altogether:
|
||||
if (Get-Variable -Name "_PYTHON_VENV_PROMPT_PREFIX" -ErrorAction SilentlyContinue) {
|
||||
Remove-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Scope Global -Force
|
||||
}
|
||||
|
||||
# Leave deactivate function in the global namespace if requested:
|
||||
if (-not $NonDestructive) {
|
||||
Remove-Item -Path function:deactivate
|
||||
}
|
||||
}
|
||||
|
||||
<#
|
||||
.Description
|
||||
Get-PyVenvConfig parses the values from the pyvenv.cfg file located in the
|
||||
given folder, and returns them in a map.
|
||||
|
||||
For each line in the pyvenv.cfg file, if that line can be parsed into exactly
|
||||
two strings separated by `=` (with any amount of whitespace surrounding the =)
|
||||
then it is considered a `key = value` line. The left hand string is the key,
|
||||
the right hand is the value.
|
||||
|
||||
If the value starts with a `'` or a `"` then the first and last character is
|
||||
stripped from the value before being captured.
|
||||
|
||||
.Parameter ConfigDir
|
||||
Path to the directory that contains the `pyvenv.cfg` file.
|
||||
#>
|
||||
function Get-PyVenvConfig(
|
||||
[String]
|
||||
$ConfigDir
|
||||
) {
|
||||
Write-Verbose "Given ConfigDir=$ConfigDir, obtain values in pyvenv.cfg"
|
||||
|
||||
# Ensure the file exists, and issue a warning if it doesn't (but still allow the function to continue).
|
||||
$pyvenvConfigPath = Join-Path -Resolve -Path $ConfigDir -ChildPath 'pyvenv.cfg' -ErrorAction Continue
|
||||
|
||||
# An empty map will be returned if no config file is found.
|
||||
$pyvenvConfig = @{ }
|
||||
|
||||
if ($pyvenvConfigPath) {
|
||||
|
||||
Write-Verbose "File exists, parse `key = value` lines"
|
||||
$pyvenvConfigContent = Get-Content -Path $pyvenvConfigPath
|
||||
|
||||
$pyvenvConfigContent | ForEach-Object {
|
||||
$keyval = $PSItem -split "\s*=\s*", 2
|
||||
if ($keyval[0] -and $keyval[1]) {
|
||||
$val = $keyval[1]
|
||||
|
||||
# Remove extraneous quotations around a string value.
|
||||
if ("'""".Contains($val.Substring(0, 1))) {
|
||||
$val = $val.Substring(1, $val.Length - 2)
|
||||
}
|
||||
|
||||
$pyvenvConfig[$keyval[0]] = $val
|
||||
Write-Verbose "Adding Key: '$($keyval[0])'='$val'"
|
||||
}
|
||||
}
|
||||
}
|
||||
return $pyvenvConfig
|
||||
}
|
||||
|
||||
|
||||
<# Begin Activate script --------------------------------------------------- #>
|
||||
|
||||
# Determine the containing directory of this script
|
||||
$VenvExecPath = Split-Path -Parent $MyInvocation.MyCommand.Definition
|
||||
$VenvExecDir = Get-Item -Path $VenvExecPath
|
||||
|
||||
Write-Verbose "Activation script is located in path: '$VenvExecPath'"
|
||||
Write-Verbose "VenvExecDir Fullname: '$($VenvExecDir.FullName)"
|
||||
Write-Verbose "VenvExecDir Name: '$($VenvExecDir.Name)"
|
||||
|
||||
# Set values required in priority: CmdLine, ConfigFile, Default
|
||||
# First, get the location of the virtual environment, it might not be
|
||||
# VenvExecDir if specified on the command line.
|
||||
if ($VenvDir) {
|
||||
Write-Verbose "VenvDir given as parameter, using '$VenvDir' to determine values"
|
||||
}
|
||||
else {
|
||||
Write-Verbose "VenvDir not given as a parameter, using parent directory name as VenvDir."
|
||||
$VenvDir = $VenvExecDir.Parent.FullName.TrimEnd("\\/")
|
||||
Write-Verbose "VenvDir=$VenvDir"
|
||||
}
|
||||
|
||||
# Next, read the `pyvenv.cfg` file to determine any required value such
|
||||
# as `prompt`.
|
||||
$pyvenvCfg = Get-PyVenvConfig -ConfigDir $VenvDir
|
||||
|
||||
# Next, set the prompt from the command line, or the config file, or
|
||||
# just use the name of the virtual environment folder.
|
||||
if ($Prompt) {
|
||||
Write-Verbose "Prompt specified as argument, using '$Prompt'"
|
||||
}
|
||||
else {
|
||||
Write-Verbose "Prompt not specified as argument to script, checking pyvenv.cfg value"
|
||||
if ($pyvenvCfg -and $pyvenvCfg['prompt']) {
|
||||
Write-Verbose " Setting based on value in pyvenv.cfg='$($pyvenvCfg['prompt'])'"
|
||||
$Prompt = $pyvenvCfg['prompt'];
|
||||
}
|
||||
else {
|
||||
Write-Verbose " Setting prompt based on parent's directory's name. (Is the directory name passed to venv module when creating the virtual environment)"
|
||||
Write-Verbose " Got leaf-name of $VenvDir='$(Split-Path -Path $venvDir -Leaf)'"
|
||||
$Prompt = Split-Path -Path $venvDir -Leaf
|
||||
}
|
||||
}
|
||||
|
||||
Write-Verbose "Prompt = '$Prompt'"
|
||||
Write-Verbose "VenvDir='$VenvDir'"
|
||||
|
||||
# Deactivate any currently active virtual environment, but leave the
|
||||
# deactivate function in place.
|
||||
deactivate -nondestructive
|
||||
|
||||
# Now set the environment variable VIRTUAL_ENV, used by many tools to determine
|
||||
# that there is an activated venv.
|
||||
$env:VIRTUAL_ENV = $VenvDir
|
||||
|
||||
if (-not $Env:VIRTUAL_ENV_DISABLE_PROMPT) {
|
||||
|
||||
Write-Verbose "Setting prompt to '$Prompt'"
|
||||
|
||||
# Set the prompt to include the env name
|
||||
# Make sure _OLD_VIRTUAL_PROMPT is global
|
||||
function global:_OLD_VIRTUAL_PROMPT { "" }
|
||||
Copy-Item -Path function:prompt -Destination function:_OLD_VIRTUAL_PROMPT
|
||||
New-Variable -Name _PYTHON_VENV_PROMPT_PREFIX -Description "Python virtual environment prompt prefix" -Scope Global -Option ReadOnly -Visibility Public -Value $Prompt
|
||||
|
||||
function global:prompt {
|
||||
Write-Host -NoNewline -ForegroundColor Green "($_PYTHON_VENV_PROMPT_PREFIX) "
|
||||
_OLD_VIRTUAL_PROMPT
|
||||
}
|
||||
$env:VIRTUAL_ENV_PROMPT = $Prompt
|
||||
}
|
||||
|
||||
# Clear PYTHONHOME
|
||||
if (Test-Path -Path Env:PYTHONHOME) {
|
||||
Copy-Item -Path Env:PYTHONHOME -Destination Env:_OLD_VIRTUAL_PYTHONHOME
|
||||
Remove-Item -Path Env:PYTHONHOME
|
||||
}
|
||||
|
||||
# Add the venv to the PATH
|
||||
Copy-Item -Path Env:PATH -Destination Env:_OLD_VIRTUAL_PATH
|
||||
$Env:PATH = "$VenvExecDir$([System.IO.Path]::PathSeparator)$Env:PATH"
|
||||
69
src/MidasAgent/PPO/wrappervenv/bin/activate
Normal file
69
src/MidasAgent/PPO/wrappervenv/bin/activate
Normal file
@@ -0,0 +1,69 @@
|
||||
# This file must be used with "source bin/activate" *from bash*
|
||||
# you cannot run it directly
|
||||
|
||||
deactivate () {
|
||||
# reset old environment variables
|
||||
if [ -n "${_OLD_VIRTUAL_PATH:-}" ] ; then
|
||||
PATH="${_OLD_VIRTUAL_PATH:-}"
|
||||
export PATH
|
||||
unset _OLD_VIRTUAL_PATH
|
||||
fi
|
||||
if [ -n "${_OLD_VIRTUAL_PYTHONHOME:-}" ] ; then
|
||||
PYTHONHOME="${_OLD_VIRTUAL_PYTHONHOME:-}"
|
||||
export PYTHONHOME
|
||||
unset _OLD_VIRTUAL_PYTHONHOME
|
||||
fi
|
||||
|
||||
# This should detect bash and zsh, which have a hash command that must
|
||||
# be called to get it to forget past commands. Without forgetting
|
||||
# past commands the $PATH changes we made may not be respected
|
||||
if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
|
||||
hash -r 2> /dev/null
|
||||
fi
|
||||
|
||||
if [ -n "${_OLD_VIRTUAL_PS1:-}" ] ; then
|
||||
PS1="${_OLD_VIRTUAL_PS1:-}"
|
||||
export PS1
|
||||
unset _OLD_VIRTUAL_PS1
|
||||
fi
|
||||
|
||||
unset VIRTUAL_ENV
|
||||
unset VIRTUAL_ENV_PROMPT
|
||||
if [ ! "${1:-}" = "nondestructive" ] ; then
|
||||
# Self destruct!
|
||||
unset -f deactivate
|
||||
fi
|
||||
}
|
||||
|
||||
# unset irrelevant variables
|
||||
deactivate nondestructive
|
||||
|
||||
VIRTUAL_ENV="/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv"
|
||||
export VIRTUAL_ENV
|
||||
|
||||
_OLD_VIRTUAL_PATH="$PATH"
|
||||
PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
export PATH
|
||||
|
||||
# unset PYTHONHOME if set
|
||||
# this will fail if PYTHONHOME is set to the empty string (which is bad anyway)
|
||||
# could use `if (set -u; : $PYTHONHOME) ;` in bash
|
||||
if [ -n "${PYTHONHOME:-}" ] ; then
|
||||
_OLD_VIRTUAL_PYTHONHOME="${PYTHONHOME:-}"
|
||||
unset PYTHONHOME
|
||||
fi
|
||||
|
||||
if [ -z "${VIRTUAL_ENV_DISABLE_PROMPT:-}" ] ; then
|
||||
_OLD_VIRTUAL_PS1="${PS1:-}"
|
||||
PS1="(wrappervenv) ${PS1:-}"
|
||||
export PS1
|
||||
VIRTUAL_ENV_PROMPT="(wrappervenv) "
|
||||
export VIRTUAL_ENV_PROMPT
|
||||
fi
|
||||
|
||||
# This should detect bash and zsh, which have a hash command that must
|
||||
# be called to get it to forget past commands. Without forgetting
|
||||
# past commands the $PATH changes we made may not be respected
|
||||
if [ -n "${BASH:-}" -o -n "${ZSH_VERSION:-}" ] ; then
|
||||
hash -r 2> /dev/null
|
||||
fi
|
||||
26
src/MidasAgent/PPO/wrappervenv/bin/activate.csh
Normal file
26
src/MidasAgent/PPO/wrappervenv/bin/activate.csh
Normal file
@@ -0,0 +1,26 @@
|
||||
# This file must be used with "source bin/activate.csh" *from csh*.
|
||||
# You cannot run it directly.
|
||||
# Created by Davide Di Blasi <davidedb@gmail.com>.
|
||||
# Ported to Python 3.3 venv by Andrew Svetlov <andrew.svetlov@gmail.com>
|
||||
|
||||
alias deactivate 'test $?_OLD_VIRTUAL_PATH != 0 && setenv PATH "$_OLD_VIRTUAL_PATH" && unset _OLD_VIRTUAL_PATH; rehash; test $?_OLD_VIRTUAL_PROMPT != 0 && set prompt="$_OLD_VIRTUAL_PROMPT" && unset _OLD_VIRTUAL_PROMPT; unsetenv VIRTUAL_ENV; unsetenv VIRTUAL_ENV_PROMPT; test "\!:*" != "nondestructive" && unalias deactivate'
|
||||
|
||||
# Unset irrelevant variables.
|
||||
deactivate nondestructive
|
||||
|
||||
setenv VIRTUAL_ENV "/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv"
|
||||
|
||||
set _OLD_VIRTUAL_PATH="$PATH"
|
||||
setenv PATH "$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
|
||||
set _OLD_VIRTUAL_PROMPT="$prompt"
|
||||
|
||||
if (! "$?VIRTUAL_ENV_DISABLE_PROMPT") then
|
||||
set prompt = "(wrappervenv) $prompt"
|
||||
setenv VIRTUAL_ENV_PROMPT "(wrappervenv) "
|
||||
endif
|
||||
|
||||
alias pydoc python -m pydoc
|
||||
|
||||
rehash
|
||||
69
src/MidasAgent/PPO/wrappervenv/bin/activate.fish
Normal file
69
src/MidasAgent/PPO/wrappervenv/bin/activate.fish
Normal file
@@ -0,0 +1,69 @@
|
||||
# This file must be used with "source <venv>/bin/activate.fish" *from fish*
|
||||
# (https://fishshell.com/); you cannot run it directly.
|
||||
|
||||
function deactivate -d "Exit virtual environment and return to normal shell environment"
|
||||
# reset old environment variables
|
||||
if test -n "$_OLD_VIRTUAL_PATH"
|
||||
set -gx PATH $_OLD_VIRTUAL_PATH
|
||||
set -e _OLD_VIRTUAL_PATH
|
||||
end
|
||||
if test -n "$_OLD_VIRTUAL_PYTHONHOME"
|
||||
set -gx PYTHONHOME $_OLD_VIRTUAL_PYTHONHOME
|
||||
set -e _OLD_VIRTUAL_PYTHONHOME
|
||||
end
|
||||
|
||||
if test -n "$_OLD_FISH_PROMPT_OVERRIDE"
|
||||
set -e _OLD_FISH_PROMPT_OVERRIDE
|
||||
# prevents error when using nested fish instances (Issue #93858)
|
||||
if functions -q _old_fish_prompt
|
||||
functions -e fish_prompt
|
||||
functions -c _old_fish_prompt fish_prompt
|
||||
functions -e _old_fish_prompt
|
||||
end
|
||||
end
|
||||
|
||||
set -e VIRTUAL_ENV
|
||||
set -e VIRTUAL_ENV_PROMPT
|
||||
if test "$argv[1]" != "nondestructive"
|
||||
# Self-destruct!
|
||||
functions -e deactivate
|
||||
end
|
||||
end
|
||||
|
||||
# Unset irrelevant variables.
|
||||
deactivate nondestructive
|
||||
|
||||
set -gx VIRTUAL_ENV "/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv"
|
||||
|
||||
set -gx _OLD_VIRTUAL_PATH $PATH
|
||||
set -gx PATH "$VIRTUAL_ENV/bin" $PATH
|
||||
|
||||
# Unset PYTHONHOME if set.
|
||||
if set -q PYTHONHOME
|
||||
set -gx _OLD_VIRTUAL_PYTHONHOME $PYTHONHOME
|
||||
set -e PYTHONHOME
|
||||
end
|
||||
|
||||
if test -z "$VIRTUAL_ENV_DISABLE_PROMPT"
|
||||
# fish uses a function instead of an env var to generate the prompt.
|
||||
|
||||
# Save the current fish_prompt function as the function _old_fish_prompt.
|
||||
functions -c fish_prompt _old_fish_prompt
|
||||
|
||||
# With the original prompt function renamed, we can override with our own.
|
||||
function fish_prompt
|
||||
# Save the return status of the last command.
|
||||
set -l old_status $status
|
||||
|
||||
# Output the venv prompt; color taken from the blue of the Python logo.
|
||||
printf "%s%s%s" (set_color 4B8BBE) "(wrappervenv) " (set_color normal)
|
||||
|
||||
# Restore the return status of the previous command.
|
||||
echo "exit $old_status" | .
|
||||
# Output the original/"old" prompt.
|
||||
_old_fish_prompt
|
||||
end
|
||||
|
||||
set -gx _OLD_FISH_PROMPT_OVERRIDE "$VIRTUAL_ENV"
|
||||
set -gx VIRTUAL_ENV_PROMPT "(wrappervenv) "
|
||||
end
|
||||
8
src/MidasAgent/PPO/wrappervenv/bin/f2py
Executable file
8
src/MidasAgent/PPO/wrappervenv/bin/f2py
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import sys
|
||||
from numpy.f2py.f2py2e import main
|
||||
if __name__ == '__main__':
|
||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
||||
sys.exit(main())
|
||||
8
src/MidasAgent/PPO/wrappervenv/bin/numpy-config
Executable file
8
src/MidasAgent/PPO/wrappervenv/bin/numpy-config
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import sys
|
||||
from numpy._configtool import main
|
||||
if __name__ == '__main__':
|
||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
||||
sys.exit(main())
|
||||
8
src/MidasAgent/PPO/wrappervenv/bin/pip
Executable file
8
src/MidasAgent/PPO/wrappervenv/bin/pip
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import sys
|
||||
from pip._internal.cli.main import main
|
||||
if __name__ == '__main__':
|
||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
||||
sys.exit(main())
|
||||
8
src/MidasAgent/PPO/wrappervenv/bin/pip3
Executable file
8
src/MidasAgent/PPO/wrappervenv/bin/pip3
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import sys
|
||||
from pip._internal.cli.main import main
|
||||
if __name__ == '__main__':
|
||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
||||
sys.exit(main())
|
||||
8
src/MidasAgent/PPO/wrappervenv/bin/pip3.11
Executable file
8
src/MidasAgent/PPO/wrappervenv/bin/pip3.11
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/home/midas/codeWS/Projects/MidasTechnologiesINC/MidasEngine/src/MidasAgent/PPO/wrappervenv/bin/python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import re
|
||||
import sys
|
||||
from pip._internal.cli.main import main
|
||||
if __name__ == '__main__':
|
||||
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
|
||||
sys.exit(main())
|
||||
1
src/MidasAgent/PPO/wrappervenv/bin/python
Symbolic link
1
src/MidasAgent/PPO/wrappervenv/bin/python
Symbolic link
@@ -0,0 +1 @@
|
||||
python3
|
||||
1
src/MidasAgent/PPO/wrappervenv/bin/python3
Symbolic link
1
src/MidasAgent/PPO/wrappervenv/bin/python3
Symbolic link
@@ -0,0 +1 @@
|
||||
/home/midas/.pyenv/versions/3.11.4/bin/python3
|
||||
1
src/MidasAgent/PPO/wrappervenv/bin/python3.11
Symbolic link
1
src/MidasAgent/PPO/wrappervenv/bin/python3.11
Symbolic link
@@ -0,0 +1 @@
|
||||
python3
|
||||
Binary file not shown.
@@ -0,0 +1,222 @@
|
||||
# don't import any costly modules
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
is_pypy = '__pypy__' in sys.builtin_module_names
|
||||
|
||||
|
||||
def warn_distutils_present():
|
||||
if 'distutils' not in sys.modules:
|
||||
return
|
||||
if is_pypy and sys.version_info < (3, 7):
|
||||
# PyPy for 3.6 unconditionally imports distutils, so bypass the warning
|
||||
# https://foss.heptapod.net/pypy/pypy/-/blob/be829135bc0d758997b3566062999ee8b23872b4/lib-python/3/site.py#L250
|
||||
return
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Distutils was imported before Setuptools, but importing Setuptools "
|
||||
"also replaces the `distutils` module in `sys.modules`. This may lead "
|
||||
"to undesirable behaviors or errors. To avoid these issues, avoid "
|
||||
"using distutils directly, ensure that setuptools is installed in the "
|
||||
"traditional way (e.g. not an editable install), and/or make sure "
|
||||
"that setuptools is always imported before distutils."
|
||||
)
|
||||
|
||||
|
||||
def clear_distutils():
|
||||
if 'distutils' not in sys.modules:
|
||||
return
|
||||
import warnings
|
||||
|
||||
warnings.warn("Setuptools is replacing distutils.")
|
||||
mods = [
|
||||
name
|
||||
for name in sys.modules
|
||||
if name == "distutils" or name.startswith("distutils.")
|
||||
]
|
||||
for name in mods:
|
||||
del sys.modules[name]
|
||||
|
||||
|
||||
def enabled():
|
||||
"""
|
||||
Allow selection of distutils by environment variable.
|
||||
"""
|
||||
which = os.environ.get('SETUPTOOLS_USE_DISTUTILS', 'local')
|
||||
return which == 'local'
|
||||
|
||||
|
||||
def ensure_local_distutils():
|
||||
import importlib
|
||||
|
||||
clear_distutils()
|
||||
|
||||
# With the DistutilsMetaFinder in place,
|
||||
# perform an import to cause distutils to be
|
||||
# loaded from setuptools._distutils. Ref #2906.
|
||||
with shim():
|
||||
importlib.import_module('distutils')
|
||||
|
||||
# check that submodules load as expected
|
||||
core = importlib.import_module('distutils.core')
|
||||
assert '_distutils' in core.__file__, core.__file__
|
||||
assert 'setuptools._distutils.log' not in sys.modules
|
||||
|
||||
|
||||
def do_override():
|
||||
"""
|
||||
Ensure that the local copy of distutils is preferred over stdlib.
|
||||
|
||||
See https://github.com/pypa/setuptools/issues/417#issuecomment-392298401
|
||||
for more motivation.
|
||||
"""
|
||||
if enabled():
|
||||
warn_distutils_present()
|
||||
ensure_local_distutils()
|
||||
|
||||
|
||||
class _TrivialRe:
|
||||
def __init__(self, *patterns):
|
||||
self._patterns = patterns
|
||||
|
||||
def match(self, string):
|
||||
return all(pat in string for pat in self._patterns)
|
||||
|
||||
|
||||
class DistutilsMetaFinder:
|
||||
def find_spec(self, fullname, path, target=None):
|
||||
# optimization: only consider top level modules and those
|
||||
# found in the CPython test suite.
|
||||
if path is not None and not fullname.startswith('test.'):
|
||||
return
|
||||
|
||||
method_name = 'spec_for_{fullname}'.format(**locals())
|
||||
method = getattr(self, method_name, lambda: None)
|
||||
return method()
|
||||
|
||||
def spec_for_distutils(self):
|
||||
if self.is_cpython():
|
||||
return
|
||||
|
||||
import importlib
|
||||
import importlib.abc
|
||||
import importlib.util
|
||||
|
||||
try:
|
||||
mod = importlib.import_module('setuptools._distutils')
|
||||
except Exception:
|
||||
# There are a couple of cases where setuptools._distutils
|
||||
# may not be present:
|
||||
# - An older Setuptools without a local distutils is
|
||||
# taking precedence. Ref #2957.
|
||||
# - Path manipulation during sitecustomize removes
|
||||
# setuptools from the path but only after the hook
|
||||
# has been loaded. Ref #2980.
|
||||
# In either case, fall back to stdlib behavior.
|
||||
return
|
||||
|
||||
class DistutilsLoader(importlib.abc.Loader):
|
||||
def create_module(self, spec):
|
||||
mod.__name__ = 'distutils'
|
||||
return mod
|
||||
|
||||
def exec_module(self, module):
|
||||
pass
|
||||
|
||||
return importlib.util.spec_from_loader(
|
||||
'distutils', DistutilsLoader(), origin=mod.__file__
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_cpython():
|
||||
"""
|
||||
Suppress supplying distutils for CPython (build and tests).
|
||||
Ref #2965 and #3007.
|
||||
"""
|
||||
return os.path.isfile('pybuilddir.txt')
|
||||
|
||||
def spec_for_pip(self):
|
||||
"""
|
||||
Ensure stdlib distutils when running under pip.
|
||||
See pypa/pip#8761 for rationale.
|
||||
"""
|
||||
if self.pip_imported_during_build():
|
||||
return
|
||||
clear_distutils()
|
||||
self.spec_for_distutils = lambda: None
|
||||
|
||||
@classmethod
|
||||
def pip_imported_during_build(cls):
|
||||
"""
|
||||
Detect if pip is being imported in a build script. Ref #2355.
|
||||
"""
|
||||
import traceback
|
||||
|
||||
return any(
|
||||
cls.frame_file_is_setup(frame) for frame, line in traceback.walk_stack(None)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def frame_file_is_setup(frame):
|
||||
"""
|
||||
Return True if the indicated frame suggests a setup.py file.
|
||||
"""
|
||||
# some frames may not have __file__ (#2940)
|
||||
return frame.f_globals.get('__file__', '').endswith('setup.py')
|
||||
|
||||
def spec_for_sensitive_tests(self):
|
||||
"""
|
||||
Ensure stdlib distutils when running select tests under CPython.
|
||||
|
||||
python/cpython#91169
|
||||
"""
|
||||
clear_distutils()
|
||||
self.spec_for_distutils = lambda: None
|
||||
|
||||
sensitive_tests = (
|
||||
[
|
||||
'test.test_distutils',
|
||||
'test.test_peg_generator',
|
||||
'test.test_importlib',
|
||||
]
|
||||
if sys.version_info < (3, 10)
|
||||
else [
|
||||
'test.test_distutils',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
for name in DistutilsMetaFinder.sensitive_tests:
|
||||
setattr(
|
||||
DistutilsMetaFinder,
|
||||
f'spec_for_{name}',
|
||||
DistutilsMetaFinder.spec_for_sensitive_tests,
|
||||
)
|
||||
|
||||
|
||||
DISTUTILS_FINDER = DistutilsMetaFinder()
|
||||
|
||||
|
||||
def add_shim():
|
||||
DISTUTILS_FINDER in sys.meta_path or insert_shim()
|
||||
|
||||
|
||||
class shim:
|
||||
def __enter__(self):
|
||||
insert_shim()
|
||||
|
||||
def __exit__(self, exc, value, tb):
|
||||
remove_shim()
|
||||
|
||||
|
||||
def insert_shim():
|
||||
sys.meta_path.insert(0, DISTUTILS_FINDER)
|
||||
|
||||
|
||||
def remove_shim():
|
||||
try:
|
||||
sys.meta_path.remove(DISTUTILS_FINDER)
|
||||
except ValueError:
|
||||
pass
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
|
||||
__import__('_distutils_hack').do_override()
|
||||
@@ -0,0 +1 @@
|
||||
import os; var = 'SETUPTOOLS_USE_DISTUTILS'; enabled = os.environ.get(var, 'local') == 'local'; enabled and __import__('_distutils_hack').add_shim();
|
||||
@@ -0,0 +1 @@
|
||||
pip
|
||||
@@ -0,0 +1,25 @@
|
||||
BSD 2-Clause License
|
||||
|
||||
Copyright (c) 2023, Ewald de Wit
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
@@ -0,0 +1,210 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: eventkit
|
||||
Version: 1.0.3
|
||||
Summary: Event-driven data pipelines
|
||||
Home-page: https://github.com/erdewit/eventkit
|
||||
Author: Ewald R. de Wit
|
||||
Author-email: ewald.de.wit@gmail.com
|
||||
License: BSD
|
||||
Keywords: python asyncio event driven data pipelines
|
||||
Classifier: Development Status :: 4 - Beta
|
||||
Classifier: Intended Audience :: Science/Research
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: License :: OSI Approved :: BSD License
|
||||
Classifier: Programming Language :: Python :: 3.6
|
||||
Classifier: Programming Language :: Python :: 3.7
|
||||
Classifier: Programming Language :: Python :: 3.8
|
||||
Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3.12
|
||||
Classifier: Programming Language :: Python :: 3 :: Only
|
||||
License-File: LICENSE
|
||||
Requires-Dist: numpy
|
||||
|
||||
|Build| |PyVersion| |Status| |PyPiVersion| |License| |Docs|
|
||||
|
||||
Introduction
|
||||
------------
|
||||
|
||||
The primary use cases of eventkit are
|
||||
|
||||
* to send events between loosely coupled components;
|
||||
* to compose all kinds of event-driven data pipelines.
|
||||
|
||||
The interface is kept as Pythonic as possible,
|
||||
with familiar names from Python and its libraries where possible.
|
||||
For scheduling asyncio is used and there is seamless integration with it.
|
||||
|
||||
See the examples and the
|
||||
`introduction notebook <https://github.com/erdewit/eventkit/tree/master/notebooks/eventkit_introduction.ipynb>`_
|
||||
to get a true feel for the possibilities.
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
::
|
||||
|
||||
pip3 install eventkit
|
||||
|
||||
Python_ version 3.6 or higher is required.
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
**Create an event and connect two listeners**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import eventkit as ev
|
||||
|
||||
def f(a, b):
|
||||
print(a * b)
|
||||
|
||||
def g(a, b):
|
||||
print(a / b)
|
||||
|
||||
event = ev.Event()
|
||||
event += f
|
||||
event += g
|
||||
event.emit(10, 5)
|
||||
|
||||
**Create a simple pipeline**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import eventkit as ev
|
||||
|
||||
event = (
|
||||
ev.Sequence('abcde')
|
||||
.map(str.upper)
|
||||
.enumerate()
|
||||
)
|
||||
|
||||
print(event.run()) # in Jupyter: await event.list()
|
||||
|
||||
Output::
|
||||
|
||||
[(0, 'A'), (1, 'B'), (2, 'C'), (3, 'D'), (4, 'E')]
|
||||
|
||||
**Create a pipeline to get a running average and standard deviation**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import random
|
||||
import eventkit as ev
|
||||
|
||||
source = ev.Range(1000).map(lambda i: random.gauss(0, 1))
|
||||
|
||||
event = source.array(500)[ev.ArrayMean, ev.ArrayStd].zip()
|
||||
|
||||
print(event.last().run()) # in Jupyter: await event.last()
|
||||
|
||||
Output::
|
||||
|
||||
[(0.00790957852672618, 1.0345673260655333)]
|
||||
|
||||
**Combine async iterators together**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
import eventkit as ev
|
||||
|
||||
async def ait(r):
|
||||
for i in r:
|
||||
await asyncio.sleep(0.1)
|
||||
yield i
|
||||
|
||||
async def main():
|
||||
async for t in ev.Zip(ait('XYZ'), ait('123')):
|
||||
print(t)
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(main()) # in Jupyter: await main()
|
||||
|
||||
Output::
|
||||
|
||||
('X', '1')
|
||||
('Y', '2')
|
||||
('Z', '3')
|
||||
|
||||
**Real-time video analysis pipeline**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
self.video = VideoStream(conf.CAM_ID)
|
||||
scene = self.video | FaceTracker | SceneAnalyzer
|
||||
lastScene = scene.aiter(skip_to_last=True)
|
||||
async for frame, persons in lastScene:
|
||||
...
|
||||
|
||||
`Full source code <https://github.com/erdewit/heartwave/blob/100e1a89d18756e141f9dcfbb73c55a1009debf4/heartwave/app.py#L88>`_
|
||||
|
||||
Distributed computing
|
||||
---------------------
|
||||
|
||||
The `distex <https://github.com/erdewit/distex>`_ library provides a
|
||||
``poolmap`` extension method to put multiple cores or machines to use:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from distex import Pool
|
||||
import eventkit as ev
|
||||
import bz2
|
||||
|
||||
pool = Pool()
|
||||
# await pool # un-comment in Jupyter
|
||||
data = [b'A' * 1000000] * 1000
|
||||
|
||||
pipe = ev.Sequence(data).poolmap(pool, bz2.compress).map(len).mean().last()
|
||||
|
||||
print(pipe.run()) # in Jupyter: print(await pipe)
|
||||
pool.shutdown()
|
||||
|
||||
|
||||
Inspired by:
|
||||
------------
|
||||
|
||||
* `Qt Signals & Slots <https://doc.qt.io/qt-5/signalsandslots.html>`_
|
||||
* `itertools <https://docs.python.org/3/library/itertools.html>`_
|
||||
* `aiostream <https://github.com/vxgmichel/aiostream>`_
|
||||
* `Bacon <https://baconjs.github.io/index.html>`_
|
||||
* `aioreactive <https://github.com/dbrattli/aioreactive>`_
|
||||
* `Reactive extensions <http://reactivex.io/documentation/operators.html>`_
|
||||
* `underscore.js <https://underscorejs.org>`_
|
||||
* `.NET Events <https://docs.microsoft.com/en-us/dotnet/standard/events>`_
|
||||
|
||||
Documentation
|
||||
-------------
|
||||
|
||||
The complete `API documentation <https://eventkit.readthedocs.io/en/latest/api.html>`_.
|
||||
|
||||
|
||||
|
||||
.. _Python: http://www.python.org
|
||||
.. _`Interactive Brokers Python API`: http://interactivebrokers.github.io
|
||||
|
||||
.. |Build| image:: https://github.com/erdewit/eventkit/actions/workflows/test.yml/badge.svg?branch=master
|
||||
:alt: Build
|
||||
:target: https://github.com/erdewit/eventkit/actions
|
||||
|
||||
.. |PyPiVersion| image:: https://img.shields.io/pypi/v/eventkit.svg
|
||||
:alt: PyPi
|
||||
:target: https://pypi.python.org/pypi/eventkit
|
||||
|
||||
|
||||
.. |PyVersion| image:: https://img.shields.io/badge/python-3.6+-blue.svg
|
||||
:alt:
|
||||
|
||||
.. |Status| image:: https://img.shields.io/badge/status-stable-green.svg
|
||||
:alt:
|
||||
|
||||
.. |License| image:: https://img.shields.io/badge/license-BSD-blue.svg
|
||||
:alt:
|
||||
|
||||
.. |Docs| image:: https://readthedocs.org/projects/eventkit/badge/?version=latest
|
||||
:alt: Documentation
|
||||
:target: https://eventkit.readthedocs.io
|
||||
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
eventkit-1.0.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
eventkit-1.0.3.dist-info/LICENSE,sha256=AQjFGoH_Hjo3QMlS4NcbQdSPkhQtqRRJAX5exgWZZsc,1317
|
||||
eventkit-1.0.3.dist-info/METADATA,sha256=5sTsyrTbHL4M_iikr626sWVt36trzZwbh1dV4WLWm1g,5441
|
||||
eventkit-1.0.3.dist-info/RECORD,,
|
||||
eventkit-1.0.3.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
||||
eventkit-1.0.3.dist-info/top_level.txt,sha256=F6uIbfpq0pwfdya0QaNm3KncSabwqY3zCfX_Sy5K8NQ,15
|
||||
eventkit/__init__.py,sha256=huK8A1TNx_ofW4m7mbQayY7vPclhO9NIZgWKSAqaCZw,998
|
||||
eventkit/__pycache__/__init__.cpython-311.pyc,,
|
||||
eventkit/__pycache__/event.cpython-311.pyc,,
|
||||
eventkit/__pycache__/util.cpython-311.pyc,,
|
||||
eventkit/__pycache__/version.cpython-311.pyc,,
|
||||
eventkit/event.py,sha256=cFkqeOqe9Bzu7bUfW-NVVsSmIi46c4xC1MH_mEPkS2Q,41870
|
||||
eventkit/ops/__init__.py,sha256=IepLb1S6t8pGGzDt9bpYG_CFfE4uBxPcFtuXlhdvGLE,23
|
||||
eventkit/ops/__pycache__/__init__.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/aggregate.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/array.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/combine.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/create.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/misc.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/op.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/select.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/timing.cpython-311.pyc,,
|
||||
eventkit/ops/__pycache__/transform.cpython-311.pyc,,
|
||||
eventkit/ops/aggregate.py,sha256=6pWaPeG9xxaYvLfa3QVrTyRBCwpjPRCTTFdSUhUdJnw,3875
|
||||
eventkit/ops/array.py,sha256=kiHny-DpR_68_fGur7UtLT2zs2UXj2zRQkWVgKa84RY,2424
|
||||
eventkit/ops/combine.py,sha256=NOGLZqsoePxDvt90L26JpbUYkoWIgwQm_2eePh6I8KE,9242
|
||||
eventkit/ops/create.py,sha256=VRVQrA_GM8HD8mIukUQE7IHLPyfGar_ghAbfRozBTwY,3302
|
||||
eventkit/ops/misc.py,sha256=yQ_FmyVXYUpFQsyvHdgmI11T75UoHQzou-L9Vr-fkOk,587
|
||||
eventkit/ops/op.py,sha256=uwjmTZY-cb0FpoVK3Ey31eSh2neTbbALGQAy8Vbo2k0,1711
|
||||
eventkit/ops/select.py,sha256=i5AEzx_xDoGgSmhHeJbdUQWom28mGnCVIYCpuTyVslk,3461
|
||||
eventkit/ops/timing.py,sha256=Cpc-jlzqLKub3AUkdKNi8s6Jry7OBzENVTGmlNMjuVY,6159
|
||||
eventkit/ops/transform.py,sha256=GCJm0UkFjE0TRNFl4IlFUjOQ1DLS5veGdzDMibERRhk,8994
|
||||
eventkit/util.py,sha256=COHtEwRtXm_H51UAbB4CxtRPaTylaDzWWLqAbKfF0sI,2134
|
||||
eventkit/version.py,sha256=S8gNTliBVeYLnV8E5UQ8c_2kz3gycGdX9tqY_Ec8H2s,86
|
||||
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
tests/__pycache__/__init__.cpython-311.pyc,,
|
||||
tests/__pycache__/aggregate_test.cpython-311.pyc,,
|
||||
tests/__pycache__/combine_test.cpython-311.pyc,,
|
||||
tests/__pycache__/create_test.cpython-311.pyc,,
|
||||
tests/__pycache__/event_test.cpython-311.pyc,,
|
||||
tests/__pycache__/select_test.cpython-311.pyc,,
|
||||
tests/__pycache__/timing_test.cpython-311.pyc,,
|
||||
tests/__pycache__/transform_test.cpython-311.pyc,,
|
||||
tests/aggregate_test.py,sha256=31sQbR0EKbYoYVwFLWS5Aj6Cb0FUEvdsMfcXkq2XTGM,1654
|
||||
tests/combine_test.py,sha256=c3_tB8c10AyuoSJE2MnkUt-VwwiEzGQrEVKM28ah0Tg,1904
|
||||
tests/create_test.py,sha256=eduKfDW3b5hLnpqHRofs2MZkHMdTtVi3R0Fxr_MirpI,800
|
||||
tests/event_test.py,sha256=DbuLAOXsDgeykb2A5l0__a_yLEYzAWgwr0Mxeo6uQf8,4059
|
||||
tests/select_test.py,sha256=ZmfAoJdU3hccKzks3KjeIxrIMfzpU9mQs2tSZZGP-RU,1276
|
||||
tests/timing_test.py,sha256=YPGkdLP85FvtH6kdib_mynrqDBbGrQbd3uwCRp0_24M,1609
|
||||
tests/transform_test.py,sha256=89_-WWJTORiiGr90mN5L4m9eOwb7m29RhkwQKJFwLtM,5375
|
||||
@@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.40.0)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
eventkit
|
||||
tests
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Event-driven data pipelines."""
|
||||
|
||||
from .event import Event
|
||||
from .ops.aggregate import (
|
||||
All, Any, Count, Deque, Ema, List, Max, Mean, Min, Pairwise, Product,
|
||||
Reduce, Sum)
|
||||
from .ops.array import (
|
||||
Array, ArrayAll, ArrayAny, ArrayMax, ArrayMean, ArrayMin, ArrayStd,
|
||||
ArraySum)
|
||||
from .ops.combine import (
|
||||
AddableJoinOp, Chain, Concat, Fork, Merge, Switch, Zip, Ziplatest)
|
||||
from .ops.create import (
|
||||
Aiterate, Marble, Range, Repeat, Sequence, Timer, Timerange, Wait)
|
||||
from .ops.misc import EndOnError, Errors
|
||||
from .ops.op import Op
|
||||
from .ops.select import (
|
||||
Changes, DropWhile, Filter, Last, Skip, Take, TakeUntil, TakeWhile, Unique)
|
||||
from .ops.timing import (Debounce, Delay, Sample, Throttle, Timeout)
|
||||
from .ops.transform import (
|
||||
Chainmap, Chunk, ChunkWith, Concatmap, Constant, Copy, Deepcopy, Emap,
|
||||
Enumerate, Iterate, Map, Mergemap, Pack, Partial, PartialRight, Pluck,
|
||||
Previous, Star, Switchmap, Timestamp)
|
||||
from .version import __version__, __version_info__
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
"""Event operators."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,159 @@
|
||||
import itertools
|
||||
import operator
|
||||
from collections import deque
|
||||
|
||||
from .op import Op
|
||||
from .transform import Iterate
|
||||
from ..util import NO_VALUE
|
||||
|
||||
|
||||
class Count(Iterate):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, start=0, step=1, source=None):
|
||||
it = itertools.count(start, step)
|
||||
Iterate.__init__(self, it, source)
|
||||
|
||||
|
||||
class Reduce(Op):
|
||||
__slots__ = ('_func', '_initializer', '_prev')
|
||||
|
||||
def __init__(self, func, initializer=NO_VALUE, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._func = func
|
||||
self._initializer = initializer
|
||||
self._prev = NO_VALUE
|
||||
|
||||
def on_source(self, arg):
|
||||
if self._prev is NO_VALUE:
|
||||
if self._initializer is NO_VALUE:
|
||||
self._prev = arg
|
||||
else:
|
||||
self._prev = self._func(self._initializer, arg)
|
||||
self.emit(self._prev)
|
||||
else:
|
||||
self._prev = self._func(self._prev, arg)
|
||||
self.emit(self._prev)
|
||||
|
||||
|
||||
class Min(Reduce):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, source=None):
|
||||
Reduce.__init__(self, min, float('inf'), source)
|
||||
|
||||
|
||||
class Max(Reduce):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, source=None):
|
||||
Reduce.__init__(self, max, -float('inf'), source)
|
||||
|
||||
|
||||
class Sum(Reduce):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, start=0, source=None):
|
||||
Reduce.__init__(self, operator.add, start, source)
|
||||
|
||||
|
||||
class Product(Reduce):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, start=1, source=None):
|
||||
Reduce.__init__(self, operator.mul, start, source)
|
||||
|
||||
|
||||
class Mean(Op):
|
||||
__slots__ = ('_count', '_sum')
|
||||
|
||||
def __init__(self, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._count = 0
|
||||
self._sum = 0
|
||||
|
||||
def on_source(self, arg):
|
||||
self._count += 1
|
||||
self._sum += arg
|
||||
self.emit(self._sum / self._count)
|
||||
|
||||
|
||||
class Any(Reduce):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, source=None):
|
||||
Reduce.__init__(self, lambda prev, v: prev or bool(v), False, source)
|
||||
|
||||
|
||||
class All(Reduce):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, source=None):
|
||||
Reduce.__init__(self, lambda prev, v: prev and bool(v), True, source)
|
||||
|
||||
|
||||
class Ema(Op):
|
||||
__slots__ = ('_f1', '_f2', '_prev')
|
||||
|
||||
def __init__(self, n=None, weight=None, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._f1 = weight or 2.0 / (n + 1)
|
||||
self._f2 = 1 - self._f1
|
||||
self._prev = NO_VALUE
|
||||
|
||||
def on_source(self, *args):
|
||||
if self._prev is NO_VALUE:
|
||||
value = args
|
||||
else:
|
||||
value = [
|
||||
self._f2 * p + self._f1 * a for p, a in zip(self._prev, args)]
|
||||
self._prev = value
|
||||
self.emit(*value)
|
||||
|
||||
|
||||
class Pairwise(Op):
|
||||
__slots__ = ('_prev', '_has_prev')
|
||||
|
||||
def __init__(self, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._has_prev = False
|
||||
|
||||
def on_source(self, *args):
|
||||
value = args[0] if len(args) == 1 else args if args else NO_VALUE
|
||||
if self._has_prev:
|
||||
self.emit(self._prev, value)
|
||||
else:
|
||||
self._has_prev = True
|
||||
self._prev = value
|
||||
|
||||
|
||||
class List(Op):
|
||||
__slots__ = ('_values')
|
||||
|
||||
def __init__(self, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._values = []
|
||||
|
||||
def on_source(self, *args):
|
||||
self._values.append(
|
||||
args[0] if len(args) == 1 else args if args else NO_VALUE)
|
||||
|
||||
def on_source_done(self, source):
|
||||
self.emit(self._values)
|
||||
Op.on_source_done(self, source)
|
||||
|
||||
|
||||
class Deque(Op):
|
||||
__slots__ = ('_count', '_q')
|
||||
|
||||
def __init__(self, count, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._count = count
|
||||
self._q = deque()
|
||||
|
||||
def on_source(self, *args):
|
||||
self._q.append(
|
||||
args[0] if len(args) == 1 else args if args else NO_VALUE)
|
||||
if self._count and len(self._q) > self._count:
|
||||
self._q.popleft()
|
||||
self.emit(self._q)
|
||||
@@ -0,0 +1,126 @@
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .op import Op
|
||||
from ..util import NO_VALUE
|
||||
|
||||
|
||||
class Array(Op):
|
||||
__slots__ = ('_count', '_q')
|
||||
|
||||
def __init__(self, count, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._count = count
|
||||
self._q = deque()
|
||||
|
||||
def on_source(self, *args):
|
||||
self._q.append(
|
||||
args[0] if len(args) == 1 else args if args else NO_VALUE)
|
||||
if self._count and len(self._q) > self._count:
|
||||
self._q.popleft()
|
||||
self.emit(np.asarray(self._q))
|
||||
|
||||
def min(self) -> "ArrayMin": # type: ignore
|
||||
"""
|
||||
Minimum value.
|
||||
"""
|
||||
return ArrayMin(self)
|
||||
|
||||
def max(self) -> "ArrayMax": # type: ignore
|
||||
"""
|
||||
Maximum value.
|
||||
"""
|
||||
return ArrayMax(self)
|
||||
|
||||
def sum(self) -> "ArraySum": # type: ignore
|
||||
"""
|
||||
Summation.
|
||||
"""
|
||||
return ArraySum(self)
|
||||
|
||||
def prod(self) -> "ArrayProd":
|
||||
"""
|
||||
Product.
|
||||
"""
|
||||
return ArrayProd(self)
|
||||
|
||||
def mean(self) -> "ArrayMean": # type: ignore
|
||||
"""
|
||||
Mean value.
|
||||
"""
|
||||
return ArrayMean(self)
|
||||
|
||||
def std(self) -> "ArrayStd": # type: ignore
|
||||
"""
|
||||
Sample standard deviation.
|
||||
"""
|
||||
return ArrayStd(self)
|
||||
|
||||
def any(self) -> "ArrayAny": # type: ignore
|
||||
"""
|
||||
Test if any array value is true.
|
||||
"""
|
||||
return ArrayAny(self)
|
||||
|
||||
def all(self) -> "ArrayAll": # type: ignore
|
||||
"""
|
||||
Test if all array values are true.
|
||||
"""
|
||||
return ArrayAll(self)
|
||||
|
||||
|
||||
class ArrayMin(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(arg.min())
|
||||
|
||||
|
||||
class ArrayMax(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(arg.max())
|
||||
|
||||
|
||||
class ArraySum(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(arg.sum())
|
||||
|
||||
|
||||
class ArrayProd(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(arg.prod())
|
||||
|
||||
|
||||
class ArrayMean(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(arg.mean())
|
||||
|
||||
|
||||
class ArrayStd(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(arg.std(ddof=1) if len(arg) > 1 else np.nan)
|
||||
|
||||
|
||||
class ArrayAny(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(arg.any())
|
||||
|
||||
|
||||
class ArrayAll(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(arg.all())
|
||||
@@ -0,0 +1,302 @@
|
||||
import functools
|
||||
from collections import defaultdict, deque
|
||||
from typing import Deque, Optional
|
||||
|
||||
from .op import Op
|
||||
from ..event import Event
|
||||
from ..util import NO_VALUE
|
||||
|
||||
|
||||
class Fork(list):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self):
|
||||
list.__init__(self)
|
||||
|
||||
def join(self, joiner: "JoinOp"):
|
||||
joiner._set_sources(*self)
|
||||
self.clear()
|
||||
return joiner
|
||||
|
||||
def concat(self) -> "Concat":
|
||||
return self.join(Concat())
|
||||
|
||||
def merge(self) -> "Merge":
|
||||
return self.join(Merge())
|
||||
|
||||
def switch(self) -> "Switch":
|
||||
return self.join(Switch())
|
||||
|
||||
def zip(self) -> "Zip":
|
||||
return self.join(Zip())
|
||||
|
||||
def ziplatest(self) -> "Ziplatest":
|
||||
return self.join(Ziplatest())
|
||||
|
||||
def chain(self) -> "Chain":
|
||||
return self.join(Chain())
|
||||
|
||||
|
||||
class JoinOp(Op):
|
||||
"""
|
||||
Base class for join operators that combine the emits
|
||||
from multiple source events.
|
||||
"""
|
||||
|
||||
__slots__ = ('_sources',)
|
||||
|
||||
_sources: Deque[Event]
|
||||
|
||||
def _set_sources(self, sources):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AddableJoinOp(JoinOp):
|
||||
"""
|
||||
Base class for join operators where new sources, produced by a
|
||||
parent higher-order event, can be added dynamically.
|
||||
"""
|
||||
|
||||
__slots__ = ('_parent',)
|
||||
|
||||
_parent: Optional[Event]
|
||||
|
||||
def __init__(self, *sources: Event):
|
||||
JoinOp.__init__(self)
|
||||
self._sources = deque()
|
||||
self._parent = None
|
||||
self._set_sources(*sources)
|
||||
|
||||
def _set_sources(self, *sources):
|
||||
for source in sources:
|
||||
source = Event.create(source)
|
||||
self.add_source(source)
|
||||
|
||||
def add_source(self, source):
|
||||
# note: the same source can be added multiple times
|
||||
raise NotImplementedError
|
||||
|
||||
def set_parent(self, parent: Event):
|
||||
self._parent = parent
|
||||
if parent.done_event:
|
||||
parent.done_event += self._on_parent_done
|
||||
|
||||
def on_source_done(self, source):
|
||||
self._disconnect_from(source)
|
||||
self._sources.remove(source)
|
||||
if not self._sources and self._parent is None:
|
||||
self.set_done()
|
||||
|
||||
def _on_parent_done(self, parent):
|
||||
parent -= self._on_parent_done
|
||||
self._parent = None
|
||||
if not self._sources:
|
||||
self.set_done()
|
||||
|
||||
|
||||
class Merge(AddableJoinOp):
|
||||
__slots__ = ()
|
||||
|
||||
def add_source(self, source):
|
||||
self._sources.append(source)
|
||||
self._connect_from(source)
|
||||
|
||||
|
||||
class Switch(AddableJoinOp):
|
||||
__slots__ = ('_source2cb', '_active_source')
|
||||
|
||||
def __init__(self, *sources):
|
||||
AddableJoinOp.__init__(self)
|
||||
self._source2cb = {} # map from source to callback
|
||||
self._active_source = None
|
||||
self._set_sources(*sources)
|
||||
|
||||
def add_source(self, source):
|
||||
self._sources.append(source)
|
||||
cb = self._source2cb.get(source)
|
||||
if not cb:
|
||||
cb = functools.partial(self.on_source_s, source)
|
||||
self._source2cb[source] = cb
|
||||
source.connect(cb, done=self.on_source_done)
|
||||
|
||||
def _remove_source(self, source):
|
||||
if source in self._sources:
|
||||
self._sources.remove(source)
|
||||
cb = self._source2cb.pop(source, None)
|
||||
if cb:
|
||||
source -= cb
|
||||
|
||||
def on_source_s(self, source, *args):
|
||||
if source is not self._active_source:
|
||||
self._remove_source(self._active_source)
|
||||
self._active_source = source
|
||||
self.emit(*args)
|
||||
|
||||
def on_source_done(self, source):
|
||||
self._remove_source(source)
|
||||
if not self._sources and self._parent is None:
|
||||
self._active_source = None
|
||||
self.set_done()
|
||||
|
||||
|
||||
class Concat(AddableJoinOp):
|
||||
__slots__ = ('_source2cb',)
|
||||
|
||||
def __init__(self, *sources):
|
||||
AddableJoinOp.__init__(self)
|
||||
self._source2cb = {} # map from source to callback
|
||||
self._set_sources(*sources)
|
||||
|
||||
def add_source(self, source):
|
||||
if source in self._sources:
|
||||
return
|
||||
self._sources.append(source)
|
||||
cb = self._source2cb.get(source)
|
||||
if not cb:
|
||||
cb = functools.partial(self._on_source_s, source)
|
||||
self._source2cb[source] = cb
|
||||
source.connect(cb, done=self.on_source_done)
|
||||
|
||||
def _on_source_s(self, source, *args):
|
||||
while self._sources and self._sources[0] is not source:
|
||||
s = self._sources.popleft()
|
||||
cb = self._source2cb.pop(s, None)
|
||||
if cb:
|
||||
s.disconnect(cb, done=self.on_source_done)
|
||||
self.emit(*args)
|
||||
|
||||
def on_source_done(self, source):
|
||||
cb = self._source2cb.pop(source)
|
||||
source.disconnect(cb, done=self.on_source_done)
|
||||
while source in self._sources:
|
||||
self._sources.remove(source)
|
||||
if not self._sources and self._parent is None:
|
||||
self.set_done()
|
||||
|
||||
|
||||
class Chain(AddableJoinOp):
|
||||
__slots__ = ('_qq', '_source2cbs')
|
||||
|
||||
def __init__(self, *sources):
|
||||
AddableJoinOp.__init__(self)
|
||||
self._qq = deque()
|
||||
self._source2cbs = defaultdict(list) # map from source to callbacks
|
||||
self._set_sources(*sources)
|
||||
|
||||
def add_source(self, source):
|
||||
if not self._sources:
|
||||
self._connect_from(source)
|
||||
else:
|
||||
def cb(*args):
|
||||
q.append(args)
|
||||
q = deque()
|
||||
self._qq.append(q)
|
||||
source += cb
|
||||
self._source2cbs[source].append(cb)
|
||||
self._sources.append(source)
|
||||
|
||||
def on_source_done(self, source):
|
||||
if source is not self._sources[0]:
|
||||
return
|
||||
self._disconnect_from(source)
|
||||
self._sources.popleft()
|
||||
while self._sources:
|
||||
source = self._sources[0]
|
||||
q = self._qq.popleft()
|
||||
for args in q:
|
||||
self.emit(*args)
|
||||
for cb in self._source2cbs.pop(source, []):
|
||||
source -= cb
|
||||
if source.done():
|
||||
self._sources.popleft()
|
||||
continue
|
||||
self._connect_from(source)
|
||||
return
|
||||
if not self._sources and self._parent is None:
|
||||
self.set_done()
|
||||
|
||||
|
||||
class Zip(JoinOp):
|
||||
__slots__ = ('_results', '_source2cbs', '_num_ready')
|
||||
|
||||
def __init__(self, *sources):
|
||||
JoinOp.__init__(self)
|
||||
self._num_ready = 0 # number of sources with a pending result
|
||||
self._source2cbs = defaultdict(list) # map from source to callbacks
|
||||
if sources:
|
||||
self._set_sources(*sources)
|
||||
|
||||
def _set_sources(self, *sources):
|
||||
self._sources = deque(Event.create(s) for s in sources)
|
||||
if any(s.done() for s in self._sources):
|
||||
self.set_done()
|
||||
return
|
||||
self._results = [deque() for _ in self._sources]
|
||||
for i, source in enumerate(self._sources):
|
||||
cb = functools.partial(self._on_source_i, i)
|
||||
source.connect(cb, self.on_source_error, self.on_source_done)
|
||||
self._source2cbs[source].append(cb)
|
||||
|
||||
def _on_source_i(self, i, *args):
|
||||
q = self._results[i]
|
||||
if not q:
|
||||
self._num_ready += 1
|
||||
ready = self._num_ready == len(self._results)
|
||||
else:
|
||||
ready = False
|
||||
q.append(args[0] if len(args) == 1 else args if args else NO_VALUE)
|
||||
if ready:
|
||||
tup = tuple(q.popleft() for q in self._results)
|
||||
self._num_ready = sum(bool(q) for q in self._results)
|
||||
self.emit(*tup)
|
||||
|
||||
def on_source_done(self, source):
|
||||
self._sources.remove(source)
|
||||
if not self._sources:
|
||||
for source, cbs in self._source2cbs.items():
|
||||
for cb in cbs:
|
||||
source.disconnect(
|
||||
cb, self.on_source_error, self.on_source_done)
|
||||
self._source2cbs = None
|
||||
self.set_done()
|
||||
|
||||
|
||||
class Ziplatest(JoinOp):
|
||||
__slots__ = ('_values', '_is_primed', '_source2cbs')
|
||||
|
||||
def __init__(self, *sources, partial=True):
|
||||
JoinOp.__init__(self)
|
||||
self._is_primed = partial
|
||||
self._source2cbs = defaultdict(list) # map from source to callbacks
|
||||
if sources:
|
||||
self._set_sources(*sources)
|
||||
|
||||
def _set_sources(self, *sources):
|
||||
sources = [Event.create(s) for s in sources]
|
||||
self._sources = deque(s for s in sources if not s.done())
|
||||
if not self._sources:
|
||||
self.set_done()
|
||||
return
|
||||
self._values = [s.value() for s in sources]
|
||||
for i, source in enumerate(self._sources):
|
||||
cb = functools.partial(self._on_source_i, i)
|
||||
source.connect(cb, self.on_source_error, self.on_source_done)
|
||||
self._source2cbs[source].append(cb)
|
||||
|
||||
def _on_source_i(self, i, *args):
|
||||
self._values[i] = \
|
||||
args[0] if len(args) == 1 else args if args else NO_VALUE
|
||||
if not self._is_primed:
|
||||
self._is_primed = not any(r is NO_VALUE for r in self._values)
|
||||
if self._is_primed:
|
||||
self.emit(*self._values)
|
||||
|
||||
def on_source_done(self, source):
|
||||
self._sources.remove(source)
|
||||
if not self._sources:
|
||||
for source, cbs in self._source2cbs.items():
|
||||
for cb in cbs:
|
||||
source.disconnect(
|
||||
cb, self.on_source_error, self.on_source_done)
|
||||
self._source2cbs = None
|
||||
self.set_done()
|
||||
@@ -0,0 +1,123 @@
|
||||
import asyncio
|
||||
import itertools
|
||||
import time
|
||||
|
||||
from .op import Op
|
||||
from ..event import Event
|
||||
from ..util import NO_VALUE, get_event_loop, timerange
|
||||
|
||||
|
||||
class Wait(Event):
|
||||
__slots__ = ('_task',)
|
||||
|
||||
def __init__(self, future, name='wait'):
|
||||
Event.__init__(self, name)
|
||||
if future.done():
|
||||
self._task = None
|
||||
self.set_done()
|
||||
else:
|
||||
loop = get_event_loop()
|
||||
self._task = asyncio.ensure_future(future, loop=loop)
|
||||
future.add_done_callback(self._on_task_done)
|
||||
|
||||
def _on_task_done(self, task):
|
||||
try:
|
||||
result = task.result()
|
||||
except Exception as error:
|
||||
result = NO_VALUE
|
||||
self.error_event.emit(self, error)
|
||||
self.emit(result)
|
||||
self._task = None
|
||||
self.set_done()
|
||||
|
||||
def __del__(self):
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
|
||||
|
||||
class Aiterate(Event):
|
||||
__slots__ = ('_task',)
|
||||
|
||||
def __init__(self, ait):
|
||||
Event.__init__(self, ait.__qualname__)
|
||||
loop = get_event_loop()
|
||||
self._task = asyncio.ensure_future(self._looper(ait), loop=loop)
|
||||
|
||||
async def _looper(self, ait):
|
||||
try:
|
||||
async for args in ait:
|
||||
self.emit(args)
|
||||
except Exception as error:
|
||||
self.error_event.emit(self, error)
|
||||
self._task = None
|
||||
self.set_done()
|
||||
|
||||
def __del__(self):
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
|
||||
|
||||
class Sequence(Aiterate):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, values, interval=0, times=None):
|
||||
async def sequence():
|
||||
t0 = time.time()
|
||||
if times is not None:
|
||||
for t, value in zip(times, values):
|
||||
delay = max(0, time.time() + t - t0)
|
||||
await asyncio.sleep(delay)
|
||||
yield value
|
||||
else:
|
||||
for i, value in enumerate(values):
|
||||
delay = max(0, i * interval + t0 - time.time())
|
||||
await asyncio.sleep(delay)
|
||||
yield value
|
||||
Aiterate.__init__(self, sequence())
|
||||
|
||||
|
||||
class Repeat(Sequence):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, value, count, interval=0, times=None):
|
||||
Sequence.__init__(self, itertools.repeat(count), interval, times)
|
||||
|
||||
|
||||
class Range(Sequence):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, *args, interval=0, times=None):
|
||||
Sequence.__init__(self, range(*args), interval, times)
|
||||
|
||||
|
||||
class Timerange(Aiterate):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, start=0, end=None, step=1):
|
||||
Aiterate.__init__(self, timerange(start, end, step))
|
||||
|
||||
|
||||
class Timer(Aiterate):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, interval, count=None):
|
||||
async def timer():
|
||||
t0 = time.time()
|
||||
i = 0
|
||||
while count is None or i < count:
|
||||
i += 1
|
||||
delay = i * interval + t0 - time.time()
|
||||
await asyncio.sleep(delay)
|
||||
yield i * interval
|
||||
Aiterate.__init__(self, timer())
|
||||
|
||||
|
||||
class Marble(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, s, interval=0, times=None):
|
||||
s = s.replace('_', '')
|
||||
source = Event.sequence(s, interval, times) \
|
||||
.filter(lambda c: c not in '- ') \
|
||||
.takewhile(lambda c: c != '|')
|
||||
Op.__init__(self, source)
|
||||
@@ -0,0 +1,26 @@
|
||||
from .op import Op
|
||||
from ..event import Event
|
||||
|
||||
|
||||
class Errors(Event):
|
||||
__slots__ = ('_source',)
|
||||
|
||||
def __init__(self, source=None):
|
||||
Event.__init__(self)
|
||||
self._source = source
|
||||
if source is not None and source.done():
|
||||
self.set_done()
|
||||
else:
|
||||
source.error_event += self.emit
|
||||
|
||||
|
||||
class EndOnError(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, source=None):
|
||||
Op.__init__(self, source)
|
||||
|
||||
def on_source_error(self, error):
|
||||
self.disconnect_from(self._source)
|
||||
self.error_event.emit(error)
|
||||
self.set_done()
|
||||
@@ -0,0 +1,63 @@
|
||||
from typing import Union
|
||||
|
||||
from ..event import Event
|
||||
|
||||
|
||||
class Op(Event):
|
||||
"""
|
||||
Base functionality for operators.
|
||||
|
||||
The Observer pattern is implemented by the following three methods::
|
||||
|
||||
on_source(self, *args)
|
||||
on_source_error(self, source, error)
|
||||
on_source_done(self, source)
|
||||
|
||||
The default handlers will pass along source emits, errors and done events.
|
||||
This makes ``Op`` also suitable as an identity operator.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, source: Union[Event, None] = None):
|
||||
Event.__init__(self)
|
||||
if source is not None:
|
||||
self.set_source(source)
|
||||
|
||||
on_source = Event.emit
|
||||
|
||||
def on_source_error(self, source, error):
|
||||
if len(self.error_event):
|
||||
self.error_event.emit(source, error)
|
||||
else:
|
||||
Event.logger.exception(error)
|
||||
|
||||
def on_source_done(self, _source):
|
||||
if self._source is not None:
|
||||
self._disconnect_from(self._source)
|
||||
self._source = None
|
||||
self.set_done()
|
||||
|
||||
def set_source(self, source):
|
||||
source = Event.create(source)
|
||||
if self._source is None:
|
||||
self._source = source
|
||||
self._connect_from(source)
|
||||
else:
|
||||
self._source.set_source(source)
|
||||
|
||||
def _connect_from(self, source: Event):
|
||||
if source.done():
|
||||
self.set_done()
|
||||
else:
|
||||
source.connect(
|
||||
self.on_source,
|
||||
self.on_source_error,
|
||||
self.on_source_done,
|
||||
keep_ref=True)
|
||||
|
||||
def _disconnect_from(self, source: Event):
|
||||
source.disconnect(
|
||||
self.on_source,
|
||||
self.on_source_error,
|
||||
self.on_source_done)
|
||||
@@ -0,0 +1,145 @@
|
||||
from .op import Op
|
||||
from ..util import NO_VALUE
|
||||
|
||||
|
||||
class Filter(Op):
|
||||
__slots__ = ('_predicate',)
|
||||
|
||||
def __init__(self, predicate=bool, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._predicate = predicate
|
||||
|
||||
def on_source(self, *args):
|
||||
if self._predicate(*args):
|
||||
self.emit(*args)
|
||||
|
||||
|
||||
class Skip(Op):
|
||||
__slots__ = ('_count', '_n')
|
||||
|
||||
def __init__(self, count=1, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._count = count
|
||||
self._n = 0
|
||||
|
||||
def on_source(self, *args):
|
||||
self._n += 1
|
||||
if self._n == self._count:
|
||||
self._source -= self.on_source
|
||||
self._source += self.emit
|
||||
|
||||
|
||||
class Take(Op):
|
||||
__slots__ = ('_count', '_n')
|
||||
|
||||
def __init__(self, count=1, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._count = count
|
||||
self._n = 0
|
||||
|
||||
def on_source(self, *args):
|
||||
self._n += 1
|
||||
if self._n <= self._count:
|
||||
self.emit(*args)
|
||||
if self._n == self._count:
|
||||
self._disconnect_from(self._source)
|
||||
self.set_done()
|
||||
|
||||
|
||||
class TakeWhile(Op):
|
||||
__slots__ = ('_predicate',)
|
||||
|
||||
def __init__(self, predicate=bool, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._predicate = predicate
|
||||
|
||||
def on_source(self, *args):
|
||||
if self._predicate(*args):
|
||||
self.emit(*args)
|
||||
else:
|
||||
self.set_done()
|
||||
self._disconnect_from(self._source)
|
||||
|
||||
|
||||
class DropWhile(Op):
|
||||
__slots__ = ('_predicate', '_drop')
|
||||
|
||||
def __init__(self, predicate=lambda x: not x, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._predicate = predicate
|
||||
self._drop = True
|
||||
|
||||
def on_source(self, *args):
|
||||
if self._drop:
|
||||
self._drop = self._predicate(*args)
|
||||
if not self._drop:
|
||||
self.emit(*args)
|
||||
|
||||
|
||||
class TakeUntil(Op):
|
||||
__slots__ = ('_notifier',)
|
||||
|
||||
def __init__(self, notifier, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._notifier = notifier
|
||||
notifier.connect(
|
||||
self._on_notifier,
|
||||
self.on_source_error,
|
||||
self.on_source_done)
|
||||
|
||||
def _on_notifier(self, *args):
|
||||
self.on_source_done(self._source)
|
||||
|
||||
def on_source_done(self, source):
|
||||
Op.on_source_done(self, self._source)
|
||||
self._notifier.disconnect(
|
||||
self._on_notifier,
|
||||
self.on_source_error,
|
||||
self.on_source_done)
|
||||
self._notifier = None
|
||||
|
||||
|
||||
class Changes(Op):
|
||||
__slots__ = ('_prev',)
|
||||
|
||||
def __init__(self, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._prev = NO_VALUE
|
||||
|
||||
def on_source(self, *args):
|
||||
if args != self._prev:
|
||||
self.emit(*args)
|
||||
self._prev = args
|
||||
|
||||
|
||||
class Unique(Op):
|
||||
__slots__ = ('_key', '_seen')
|
||||
|
||||
def __init__(self, key, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._key = key
|
||||
self._seen = set()
|
||||
|
||||
def on_source(self, *args):
|
||||
if self._key is None:
|
||||
new = args not in self._seen
|
||||
else:
|
||||
new = self._key(*args) not in self._seen
|
||||
self._seen.add(args)
|
||||
if new:
|
||||
self.emit(*args)
|
||||
|
||||
|
||||
class Last(Op):
|
||||
__slots__ = ('_last',)
|
||||
|
||||
def __init__(self, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._last = NO_VALUE
|
||||
|
||||
def on_source(self, *args):
|
||||
self._last = args
|
||||
|
||||
def on_source_done(self, source):
|
||||
self.emit(*self._last)
|
||||
Op.on_source_done(self, source)
|
||||
@@ -0,0 +1,211 @@
|
||||
from collections import deque
|
||||
|
||||
from .op import Op
|
||||
from ..event import Event
|
||||
from ..util import NO_VALUE, get_event_loop
|
||||
|
||||
|
||||
class Delay(Op):
|
||||
__slots__ = ('_delay',)
|
||||
|
||||
def __init__(self, delay, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._delay = delay
|
||||
|
||||
def on_source(self, *args):
|
||||
loop = get_event_loop()
|
||||
loop.call_later(self._delay, self.emit, *args)
|
||||
|
||||
def on_source_error(self, error):
|
||||
loop = get_event_loop()
|
||||
loop.call_later(self._delay, self.error_event.emit, error)
|
||||
|
||||
def on_source_done(self, source):
|
||||
if self._source is not None:
|
||||
self._disconnect_from(self._source)
|
||||
self._source = None
|
||||
loop = get_event_loop()
|
||||
loop.call_later(self._delay, self.set_done)
|
||||
|
||||
|
||||
class Timeout(Op):
|
||||
__slots__ = ('_timeout', '_handle', '_last_time')
|
||||
|
||||
def __init__(self, timeout, source=None):
|
||||
Op.__init__(self, source)
|
||||
if source is not None and source.done():
|
||||
return
|
||||
self._timeout = timeout
|
||||
loop = get_event_loop()
|
||||
self._last_time = loop.time()
|
||||
self._handle = None
|
||||
self._schedule()
|
||||
|
||||
def on_source(self, *args):
|
||||
loop = get_event_loop()
|
||||
self._last_time = loop.time()
|
||||
|
||||
def on_source_done(self, source):
|
||||
self._handle.cancel()
|
||||
del self._handle
|
||||
Op.on_source_done(self, source)
|
||||
|
||||
def _schedule(self):
|
||||
loop = get_event_loop()
|
||||
self._handle = loop.call_at(
|
||||
self._last_time + self._timeout, self._on_timer)
|
||||
|
||||
def _on_timer(self):
|
||||
loop = get_event_loop()
|
||||
if loop.time() - self._last_time > self._timeout:
|
||||
self.emit()
|
||||
self.set_done()
|
||||
else:
|
||||
self._schedule()
|
||||
|
||||
|
||||
class Debounce(Op):
|
||||
__slots__ = ('_interval', '_on_first', '_handle', '_last_time')
|
||||
|
||||
def __init__(self, interval, on_first=False, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._interval = interval
|
||||
self._on_first = on_first
|
||||
self._last_time = -float('inf')
|
||||
self._handle = None
|
||||
|
||||
def on_source(self, *args):
|
||||
loop = get_event_loop()
|
||||
time = loop.time()
|
||||
delta = time - self._last_time
|
||||
self._last_time = time
|
||||
if self._on_first:
|
||||
if delta >= self._interval:
|
||||
self.emit(*args)
|
||||
else:
|
||||
if self._handle:
|
||||
self._handle.cancel()
|
||||
self._handle = loop.call_at(
|
||||
time + self._interval, self._delayed_emit, *args)
|
||||
|
||||
def _delayed_emit(self, *args):
|
||||
self._handle = None
|
||||
self.emit(*args)
|
||||
if self._source is None:
|
||||
self.set_done()
|
||||
|
||||
def on_source_done(self, source):
|
||||
self._disconnect_from(source)
|
||||
self._source = None
|
||||
if not self._handle:
|
||||
self.set_done()
|
||||
|
||||
|
||||
class Throttle(Op):
|
||||
__slots__ = (
|
||||
'status_event', '_maximum', '_interval', '_cost_func',
|
||||
'_q', '_time_q', '_cost_q', '_is_throttling')
|
||||
|
||||
def __init__(self, maximum, interval, cost_func=None, source=None):
|
||||
Op.__init__(self, source)
|
||||
self.status_event = Event('throttle_status')
|
||||
"""
|
||||
Sub event that emits ``True`` when throttling starts and ``False``
|
||||
when throttling ends.
|
||||
"""
|
||||
self._maximum = maximum
|
||||
self._interval = interval
|
||||
self._cost_func = cost_func
|
||||
self._q = deque() # deque of (args, cost) tuples
|
||||
self._time_q = deque() # deque of previous emit times
|
||||
self._cost_q = deque() # deque of costs of previous emits
|
||||
self._is_throttling = False
|
||||
|
||||
def set_limit(self, maximum, interval):
|
||||
"""
|
||||
Dynamically update the ``maximum`` per ``interval`` limit.
|
||||
"""
|
||||
self._maximum = maximum
|
||||
self._interval = interval
|
||||
|
||||
def on_source(self, *args):
|
||||
cost = self._cost_func
|
||||
if cost is not None:
|
||||
cost = cost(*args)
|
||||
self._q.append((args, cost))
|
||||
self._try_emit()
|
||||
|
||||
def on_source_done(self, source):
|
||||
self._disconnect_from(source)
|
||||
self._source = None
|
||||
if not self._q:
|
||||
self.set_done()
|
||||
self.status_event.set_done()
|
||||
|
||||
def _try_emit(self):
|
||||
loop = get_event_loop()
|
||||
t = loop.time()
|
||||
q = self._q
|
||||
times = self._time_q
|
||||
costs = self._cost_q
|
||||
|
||||
# forget old emit times
|
||||
while times and t - times[0] > self._interval:
|
||||
times.popleft()
|
||||
costs.popleft()
|
||||
|
||||
# emit values while not exceeding the limit
|
||||
while q:
|
||||
args, cost = q[0]
|
||||
if self._cost_func:
|
||||
cost = self._cost_func(*args)
|
||||
total_cost = cost + sum(costs)
|
||||
else:
|
||||
cost = None
|
||||
total_cost = 1 + len(costs)
|
||||
if self._maximum and total_cost >= self._maximum:
|
||||
break
|
||||
args, cost = q.popleft()
|
||||
times.append(t)
|
||||
costs.append(cost)
|
||||
self.emit(*args)
|
||||
|
||||
# update status and schedule new emits
|
||||
if q:
|
||||
if not self._is_throttling:
|
||||
self.status_event.emit(True)
|
||||
loop.call_at(times[0] + self._interval, self._try_emit)
|
||||
elif self._is_throttling:
|
||||
self.status_event.emit(False)
|
||||
self._is_throttling = bool(q)
|
||||
|
||||
if not q and self._source is None:
|
||||
self.set_done()
|
||||
self.status_event.set_done()
|
||||
|
||||
|
||||
class Sample(Op):
|
||||
__slots__ = ('_timer',)
|
||||
|
||||
def __init__(self, timer, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._timer = timer
|
||||
timer.connect(
|
||||
self._on_timer,
|
||||
self.on_source_error,
|
||||
self.on_source_done)
|
||||
|
||||
def on_source(self, *args):
|
||||
self._value = args
|
||||
|
||||
def _on_timer(self, *args):
|
||||
if self._value is not NO_VALUE:
|
||||
self.emit(*self._value)
|
||||
|
||||
def on_source_done(self, source):
|
||||
Op.on_source_done(self, self._source)
|
||||
self._timer.disconnect(
|
||||
self._on_timer,
|
||||
self.on_source_error,
|
||||
self.on_source_done)
|
||||
self._timer = None
|
||||
@@ -0,0 +1,346 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
from .combine import Chain, Concat, Merge, Switch
|
||||
from .op import Op
|
||||
from ..util import NO_VALUE, get_event_loop
|
||||
|
||||
|
||||
class Constant(Op):
|
||||
__slots__ = ('_constant',)
|
||||
|
||||
def __init__(self, constant, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._constant = constant
|
||||
|
||||
def on_source(self, *args):
|
||||
self.emit(self._constant)
|
||||
|
||||
|
||||
class Iterate(Op):
|
||||
__slots__ = ('_it',)
|
||||
|
||||
def __init__(self, it, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._it = iter(it)
|
||||
|
||||
def on_source(self, *args):
|
||||
try:
|
||||
value = next(self._it)
|
||||
self.emit(value)
|
||||
except StopIteration:
|
||||
self._disconnect_from(self._source)
|
||||
self.set_done()
|
||||
|
||||
|
||||
class Enumerate(Op):
|
||||
__slots__ = ('_step', '_i')
|
||||
|
||||
def __init__(self, start=0, step=1, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._i = start
|
||||
self._step = step
|
||||
|
||||
def on_source(self, *args):
|
||||
self.emit(
|
||||
self._i,
|
||||
args[0] if len(args) == 1 else args if args else NO_VALUE)
|
||||
self._i += self._step
|
||||
|
||||
|
||||
class Timestamp(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, *args):
|
||||
self.emit(
|
||||
time.time(),
|
||||
args[0] if len(args) == 1 else args if args else NO_VALUE)
|
||||
|
||||
|
||||
class Partial(Op):
|
||||
__slots__ = ('_left_args',)
|
||||
|
||||
def __init__(self, *left_args, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._left_args = left_args
|
||||
|
||||
def on_source(self, *args):
|
||||
self.emit(*(self._left_args + args))
|
||||
|
||||
|
||||
class PartialRight(Op):
|
||||
__slots__ = ('_right_args',)
|
||||
|
||||
def __init__(self, *right_args, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._right_args = right_args
|
||||
|
||||
def on_source(self, *args):
|
||||
self.emit(*(args + self._right_args))
|
||||
|
||||
|
||||
class Star(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, arg):
|
||||
self.emit(*arg)
|
||||
|
||||
|
||||
class Pack(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, *args):
|
||||
self.emit(args)
|
||||
|
||||
|
||||
class Pluck(Op):
|
||||
__slots__ = ('_selections',)
|
||||
|
||||
def __init__(self, *selections, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._selections = [] # list of [arg-index, *sub-attributes]
|
||||
for sel in selections:
|
||||
if type(sel) is int:
|
||||
s = [sel]
|
||||
else:
|
||||
s = sel.split('.')
|
||||
if s[0].isdigit():
|
||||
s[0] = int(s[0])
|
||||
elif s[0] == '':
|
||||
s[0] = 0
|
||||
else:
|
||||
s.insert(0, 0)
|
||||
self._selections.append(s)
|
||||
|
||||
def on_source(self, *args):
|
||||
values = []
|
||||
for s in self._selections:
|
||||
try:
|
||||
value = args[s[0]]
|
||||
for attr in s[1:]:
|
||||
value = getattr(value, attr)
|
||||
except Exception:
|
||||
value = NO_VALUE
|
||||
values.append(value)
|
||||
self.emit(*values)
|
||||
|
||||
|
||||
class Previous(Op):
|
||||
__slots__ = ('_count', '_q')
|
||||
|
||||
def __init__(self, count=1, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._count = count
|
||||
self._q = deque()
|
||||
|
||||
def on_source(self, *args):
|
||||
self._q.append(args)
|
||||
if len(self._q) > self._count:
|
||||
self.emit(*self._q.popleft())
|
||||
|
||||
|
||||
class Copy(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, *args):
|
||||
self.emit(*(copy.copy(a) for a in args))
|
||||
|
||||
|
||||
class Deepcopy(Op):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, *args):
|
||||
self.emit(*copy.deepcopy(args))
|
||||
|
||||
|
||||
class Chunk(Op):
|
||||
__slots__ = ('_size', '_list')
|
||||
|
||||
def __init__(self, size, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._size = size
|
||||
self._list = []
|
||||
|
||||
def on_source(self, *args):
|
||||
self._list.append(
|
||||
args[0] if len(args) == 1 else args if args else NO_VALUE)
|
||||
if len(self._list) == self._size:
|
||||
self.emit(self._list)
|
||||
self._list = []
|
||||
|
||||
def on_source_done(self, source):
|
||||
if self._list:
|
||||
self.emit(self._list)
|
||||
Op.on_source_done(self, self._source)
|
||||
|
||||
|
||||
class ChunkWith(Op):
|
||||
__slots__ = ('_timer', '_list', '_emit_empty')
|
||||
|
||||
def __init__(self, timer, emit_empty, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._timer = timer
|
||||
self._emit_empty = emit_empty
|
||||
self._list = []
|
||||
timer.connect(
|
||||
self._on_timer,
|
||||
self.on_source_error,
|
||||
self.on_source_done)
|
||||
|
||||
def on_source(self, *args):
|
||||
self._list.append(
|
||||
args[0] if len(args) == 1 else args if args else NO_VALUE)
|
||||
|
||||
def _on_timer(self, *args):
|
||||
if self._list or self._emit_empty:
|
||||
self.emit(self._list)
|
||||
self._list = []
|
||||
|
||||
def on_source_done(self, source):
|
||||
if self._list:
|
||||
self.emit(self._list)
|
||||
self._list = None
|
||||
if self._timer is not None:
|
||||
self._timer.disconnect(
|
||||
self._on_timer,
|
||||
self.on_source_error,
|
||||
self.on_source_done)
|
||||
self._timer = None
|
||||
Op.on_source_done(self, self._source)
|
||||
|
||||
|
||||
class Map(Op):
|
||||
__slots__ = (
|
||||
'_func', '_timeout', '_ordered', '_task_limit', '_coro_q', '_tasks')
|
||||
|
||||
def __init__(
|
||||
self, func, timeout=0, ordered=True, task_limit=None, source=None):
|
||||
Op.__init__(self, source)
|
||||
if source is not None and source.done():
|
||||
return
|
||||
self._func = func
|
||||
self._timeout = timeout
|
||||
self._ordered = ordered
|
||||
self._task_limit = task_limit
|
||||
self._coro_q = deque()
|
||||
self._tasks = deque()
|
||||
|
||||
def on_source(self, *args):
|
||||
obj = self._func(*args)
|
||||
if hasattr(obj, '__await__'):
|
||||
# function returns an awaitable
|
||||
if not self._task_limit or len(self._tasks) < self._task_limit:
|
||||
# schedule right away
|
||||
self._create_task(obj)
|
||||
else:
|
||||
# queue for later
|
||||
self._coro_q.append(obj)
|
||||
else:
|
||||
# regular function returns the result directly
|
||||
self.emit(obj)
|
||||
|
||||
def on_source_done(self, source):
|
||||
if not self._tasks:
|
||||
# only end when no tasks are pending
|
||||
Op.on_source_done(self, self._source)
|
||||
self._source = None
|
||||
|
||||
def _create_task(self, coro):
|
||||
# schedule a task to be run
|
||||
if self._timeout:
|
||||
coro = asyncio.wait_for(coro, self._timeout)
|
||||
loop = get_event_loop()
|
||||
task = asyncio.ensure_future(coro, loop=loop)
|
||||
task.add_done_callback(self._on_task_done)
|
||||
self._tasks.append(task)
|
||||
|
||||
def _on_task_done(self, task):
|
||||
# handle task result
|
||||
tasks = self._tasks
|
||||
if self._ordered:
|
||||
while tasks and tasks[0].done():
|
||||
# remove task after emitting result
|
||||
task = tasks[0]
|
||||
self._emit_task(task)
|
||||
task = tasks.popleft()
|
||||
else:
|
||||
# remove task after emitting result
|
||||
self._emit_task(task)
|
||||
tasks.remove(task)
|
||||
|
||||
# schedule pending awaitables from the queue
|
||||
while self._coro_q and (
|
||||
not self._task_limit or len(tasks) < self._task_limit):
|
||||
self._create_task(self._coro_q.popleft())
|
||||
|
||||
# end when source has ended with no pending tasks
|
||||
if not tasks and self._source is None:
|
||||
Op.on_source_done(self, self._source)
|
||||
|
||||
def _emit_task(self, task):
|
||||
try:
|
||||
result = task.result()
|
||||
except Exception as error:
|
||||
result = NO_VALUE
|
||||
self.error_event.emit(error)
|
||||
self.emit(result)
|
||||
|
||||
|
||||
class Emap(Op):
|
||||
__slots__ = ('_constr', '_joiner',)
|
||||
|
||||
def __init__(self, constr, joiner, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._constr = constr
|
||||
self._joiner = joiner
|
||||
joiner.set_parent(source)
|
||||
joiner.connect(
|
||||
self.emit,
|
||||
self.error_event.emit,
|
||||
self._on_joiner_done)
|
||||
|
||||
def on_source(self, *args):
|
||||
obj = self._constr(*args)
|
||||
event = self.create(obj)
|
||||
self._joiner.add_source(event)
|
||||
|
||||
def on_source_done(self, source):
|
||||
pass
|
||||
|
||||
def _on_joiner_done(self, joiner):
|
||||
joiner.disconnect(
|
||||
self.emit,
|
||||
self.error_event.emit,
|
||||
self._on_joiner_done)
|
||||
self._joiner = None
|
||||
self.set_done()
|
||||
|
||||
|
||||
class Mergemap(Emap):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, constr, source=None):
|
||||
Emap.__init__(self, constr, Merge(), source)
|
||||
|
||||
|
||||
class Chainmap(Emap):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, constr, source=None):
|
||||
Emap.__init__(self, constr, Chain(), source)
|
||||
|
||||
|
||||
class Concatmap(Emap):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, constr, source=None):
|
||||
Emap.__init__(self, constr, Concat(), source)
|
||||
|
||||
|
||||
class Switchmap(Emap):
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, constr, source=None):
|
||||
Emap.__init__(self, constr, Switch(), source)
|
||||
@@ -0,0 +1,78 @@
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
from typing import AsyncIterator
|
||||
|
||||
|
||||
class _NoValue:
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return '<NoValue>'
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
NO_VALUE = _NoValue()
|
||||
|
||||
|
||||
def get_event_loop():
|
||||
"""Get asyncio event loop, running or not."""
|
||||
return asyncio.get_event_loop_policy().get_event_loop()
|
||||
|
||||
|
||||
main_event_loop = get_event_loop()
|
||||
|
||||
|
||||
async def timerange(start=0, end=None, step: float = 1) \
|
||||
-> AsyncIterator[dt.datetime]:
|
||||
"""
|
||||
Iterator that waits periodically until certain time points are
|
||||
reached while yielding those time points.
|
||||
|
||||
Args:
|
||||
start: Start time, can be specified as:
|
||||
|
||||
* ``datetime.datetime``.
|
||||
* ``datetime.time``: Today is used as date.
|
||||
* ``int`` or ``float``: Number of seconds relative to now.
|
||||
Values will be quantized to the given step.
|
||||
end: End time, can be specified as:
|
||||
|
||||
* ``datetime.datetime``.
|
||||
* ``datetime.time``: Today is used as date.
|
||||
* ``None``: No end limit.
|
||||
step: Number of seconds, or ``datetime.timedelta``,
|
||||
to space between values.
|
||||
"""
|
||||
tz = getattr(start, 'tzinfo', None)
|
||||
now = dt.datetime.now(tz)
|
||||
if isinstance(step, dt.timedelta):
|
||||
delta = step
|
||||
step = delta.total_seconds()
|
||||
else:
|
||||
delta = dt.timedelta(seconds=step)
|
||||
t = start
|
||||
if t == 0 or isinstance(t, (int, float)):
|
||||
t = now + dt.timedelta(seconds=t)
|
||||
# quantize to step
|
||||
t = dt.datetime.fromtimestamp(
|
||||
step * int(t.timestamp() / step))
|
||||
elif isinstance(t, dt.time):
|
||||
t = dt.datetime.combine(now.today(), t)
|
||||
|
||||
if t < now:
|
||||
# t += delta
|
||||
t -= ((t - now) // delta) * delta
|
||||
|
||||
if isinstance(end, dt.time):
|
||||
end = dt.datetime.combine(now.today(), end)
|
||||
elif isinstance(end, (int, float)):
|
||||
end = now + dt.timedelta(seconds=end)
|
||||
|
||||
while end is None or t <= end:
|
||||
now = dt.datetime.now(tz)
|
||||
secs = (t - now).total_seconds()
|
||||
await asyncio.sleep(secs)
|
||||
yield t
|
||||
t += delta
|
||||
@@ -0,0 +1,2 @@
|
||||
__version_info__ = (1, 0, 3)
|
||||
__version__ = '.'.join(str(v) for v in __version_info__)
|
||||
@@ -0,0 +1 @@
|
||||
pip
|
||||
@@ -0,0 +1,25 @@
|
||||
BSD 2-Clause License
|
||||
|
||||
Copyright (c) 2019, Ewald de Wit
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
@@ -0,0 +1,162 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: ib-insync
|
||||
Version: 0.9.86
|
||||
Summary: Python sync/async framework for Interactive Brokers API
|
||||
Home-page: https://github.com/erdewit/ib_insync
|
||||
Author: Ewald R. de Wit
|
||||
Author-email: ewald.de.wit@gmail.com
|
||||
License: BSD
|
||||
Keywords: ibapi tws asyncio jupyter interactive brokers async
|
||||
Classifier: Development Status :: 4 - Beta
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Topic :: Office/Business :: Financial :: Investment
|
||||
Classifier: License :: OSI Approved :: BSD License
|
||||
Classifier: Programming Language :: Python :: 3.6
|
||||
Classifier: Programming Language :: Python :: 3.7
|
||||
Classifier: Programming Language :: Python :: 3.8
|
||||
Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3 :: Only
|
||||
Requires-Python: >=3.6
|
||||
License-File: LICENSE
|
||||
Requires-Dist: eventkit
|
||||
Requires-Dist: nest-asyncio
|
||||
Requires-Dist: dataclasses ; python_version < "3.7"
|
||||
Requires-Dist: backports.zoneinfo ; python_version < "3.9"
|
||||
|
||||
|Build| |Group| |PyVersion| |Status| |PyPiVersion| |CondaVersion| |License| |Downloads| |Docs|
|
||||
|
||||
Introduction
|
||||
============
|
||||
|
||||
The goal of the IB-insync library is to make working with the
|
||||
`Trader Workstation API <http://interactivebrokers.github.io/tws-api/>`_
|
||||
from Interactive Brokers as easy as possible.
|
||||
|
||||
The main features are:
|
||||
|
||||
* An easy to use linear style of programming;
|
||||
* An `IB component <https://ib-insync.readthedocs.io/api.html#module-ib_insync.ib>`_
|
||||
that automatically keeps in sync with the TWS or IB Gateway application;
|
||||
* A fully asynchonous framework based on
|
||||
`asyncio <https://docs.python.org/3/library/asyncio.html>`_
|
||||
and
|
||||
`eventkit <https://github.com/erdewit/eventkit>`_
|
||||
for advanced users;
|
||||
* Interactive operation with live data in Jupyter notebooks.
|
||||
|
||||
Be sure to take a look at the
|
||||
`notebooks <https://ib-insync.readthedocs.io/notebooks.html>`_,
|
||||
the `recipes <https://ib-insync.readthedocs.io/recipes.html>`_
|
||||
and the `API docs <https://ib-insync.readthedocs.io/api.html>`_.
|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
::
|
||||
|
||||
pip install ib_insync
|
||||
|
||||
Requirements:
|
||||
|
||||
* Python 3.6 or higher;
|
||||
* A running TWS or IB Gateway application (version 1023 or higher).
|
||||
Make sure the
|
||||
`API port is enabled <https://interactivebrokers.github.io/tws-api/initial_setup.html>`_
|
||||
and 'Download open orders on connection' is checked.
|
||||
|
||||
The ibapi package from IB is not needed.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
This is a complete script to download historical data:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ib_insync import *
|
||||
# util.startLoop() # uncomment this line when in a notebook
|
||||
|
||||
ib = IB()
|
||||
ib.connect('127.0.0.1', 7497, clientId=1)
|
||||
|
||||
contract = Forex('EURUSD')
|
||||
bars = ib.reqHistoricalData(
|
||||
contract, endDateTime='', durationStr='30 D',
|
||||
barSizeSetting='1 hour', whatToShow='MIDPOINT', useRTH=True)
|
||||
|
||||
# convert to pandas dataframe (pandas needs to be installed):
|
||||
df = util.df(bars)
|
||||
print(df)
|
||||
|
||||
Output::
|
||||
|
||||
date open high low close volume \
|
||||
0 2019-11-19 23:15:00 1.107875 1.108050 1.107725 1.107825 -1
|
||||
1 2019-11-20 00:00:00 1.107825 1.107925 1.107675 1.107825 -1
|
||||
2 2019-11-20 01:00:00 1.107825 1.107975 1.107675 1.107875 -1
|
||||
3 2019-11-20 02:00:00 1.107875 1.107975 1.107025 1.107225 -1
|
||||
4 2019-11-20 03:00:00 1.107225 1.107725 1.107025 1.107525 -1
|
||||
.. ... ... ... ... ... ...
|
||||
705 2020-01-02 14:00:00 1.119325 1.119675 1.119075 1.119225 -1
|
||||
|
||||
|
||||
Documentation
|
||||
-------------
|
||||
|
||||
The complete `API documentation <https://ib-insync.readthedocs.io/api.html>`_.
|
||||
|
||||
`Changelog <https://ib-insync.readthedocs.io/changelog.html>`_.
|
||||
|
||||
Discussion
|
||||
----------
|
||||
|
||||
The `insync user group <https://groups.io/g/insync>`_ is the place to discuss
|
||||
IB-insync and anything related to it.
|
||||
|
||||
Disclaimer
|
||||
----------
|
||||
|
||||
The software is provided on the conditions of the simplified BSD license.
|
||||
|
||||
This project is not affiliated with Interactive Brokers Group, Inc.'s.
|
||||
|
||||
Good luck and enjoy,
|
||||
|
||||
:author: Ewald de Wit <ewald.de.wit@gmail.com>
|
||||
|
||||
.. _`Interactive Brokers Python API`: http://interactivebrokers.github.io
|
||||
|
||||
.. |Group| image:: https://img.shields.io/badge/groups.io-insync-green.svg
|
||||
:alt: Join the user group
|
||||
:target: https://groups.io/g/insync
|
||||
|
||||
.. |PyPiVersion| image:: https://img.shields.io/pypi/v/ib_insync.svg
|
||||
:alt: PyPi
|
||||
:target: https://pypi.python.org/pypi/ib_insync
|
||||
|
||||
.. |CondaVersion| image:: https://img.shields.io/conda/vn/conda-forge/ib-insync.svg
|
||||
:alt: Conda
|
||||
:target: https://anaconda.org/conda-forge/ib-insync
|
||||
|
||||
.. |PyVersion| image:: https://img.shields.io/badge/python-3.6+-blue.svg
|
||||
:alt:
|
||||
|
||||
.. |Status| image:: https://img.shields.io/badge/status-beta-green.svg
|
||||
:alt:
|
||||
|
||||
.. |License| image:: https://img.shields.io/badge/license-BSD-blue.svg
|
||||
:alt:
|
||||
|
||||
.. |Docs| image:: https://img.shields.io/badge/Documentation-green.svg
|
||||
:alt: Documentation
|
||||
:target: https://ib-insync.readthedocs.io/api.html
|
||||
|
||||
.. |Downloads| image:: https://pepy.tech/badge/ib-insync
|
||||
:alt: Number of downloads
|
||||
:target: https://pepy.tech/project/ib-insync
|
||||
|
||||
.. |Build| image:: https://github.com/erdewit/ib_insync/actions/workflows/test.yml/badge.svg?branch=master
|
||||
:target: https://github.com/erdewit/ib_insync/actions
|
||||
@@ -0,0 +1,36 @@
|
||||
ib_insync-0.9.86.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
||||
ib_insync-0.9.86.dist-info/LICENSE,sha256=wVXmemzM4v_VlNCzedi7qFodW6NAhMiDQaSAy10WxtU,1317
|
||||
ib_insync-0.9.86.dist-info/METADATA,sha256=zgqOVqK0KlHbwOxnc_0XlLCC6V1pwlxkzknbpK5_Akg,5369
|
||||
ib_insync-0.9.86.dist-info/RECORD,,
|
||||
ib_insync-0.9.86.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
ib_insync-0.9.86.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92
|
||||
ib_insync-0.9.86.dist-info/top_level.txt,sha256=b9ruIFE0O0uehMGcAzo9k5WBt1ZeHmukMQrF_Qaq0J4,10
|
||||
ib_insync/__init__.py,sha256=bkuPG1U_EIGJ8WjiY2vxdDVpDisj7yr6aI3IsAgrcQ0,3611
|
||||
ib_insync/__pycache__/__init__.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/client.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/connection.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/contract.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/decoder.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/flexreport.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/ib.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/ibcontroller.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/objects.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/order.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/ticker.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/util.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/version.cpython-311.pyc,,
|
||||
ib_insync/__pycache__/wrapper.cpython-311.pyc,,
|
||||
ib_insync/client.py,sha256=HYaYnWgViB6ohXbo473PNP5A0dp3jf7mhhnugOU6KCI,33172
|
||||
ib_insync/connection.py,sha256=gXxcI_1xPw4XHtmAKpGN6grxHsEeoKiTSz6H6JteCAk,1654
|
||||
ib_insync/contract.py,sha256=o9jSUcTwmiM8nP1asKq8mRXesUzbh42kOlbFk6KU87A,17102
|
||||
ib_insync/decoder.py,sha256=sGy-JOyjfjOG0KshPetD5JMXHoy7fCYGDqSTb5f3m0w,40909
|
||||
ib_insync/flexreport.py,sha256=91rYpgKMLMshpyW7_S4M-sBgA1c7S1hvXw3_i0tTwQQ,4009
|
||||
ib_insync/ib.py,sha256=tfYCIMlBozXEfLsCV6NC4xIrHwfMn4CaNqDnybS6hPw,85631
|
||||
ib_insync/ibcontroller.py,sha256=Mod0rBATMFeZEjad89LGqfJPkinIMJDXQpyrxMOots4,12510
|
||||
ib_insync/objects.py,sha256=ktzQWpRSH6tUTcneEtHD5kunfoqEkllC6cWj2hK6J0Y,9817
|
||||
ib_insync/order.py,sha256=8BSwY7iZ_Z7BzUvUz5zINN4O-4nAib3rZzwJs5ITHKc,12368
|
||||
ib_insync/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
||||
ib_insync/ticker.py,sha256=7pEKbvPGy9axCyZwOLm_6Syh1IPkGytpCJoQSQyVSd0,10642
|
||||
ib_insync/util.py,sha256=hsclDyRUVyGk_SMOdHsNAeZssfikrRLhUsenDUs32WQ,16122
|
||||
ib_insync/version.py,sha256=Ny-bTPSSnRRmaGV_vw1sp_cJarnJFQJJJELra6-cO98,108
|
||||
ib_insync/wrapper.py,sha256=TtcpkQaXnAslGGWNTT7zcf5aObJUWGJUZQ7cNSF8sME,46713
|
||||
@@ -0,0 +1,5 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: bdist_wheel (0.38.4)
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
ib_insync
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Python sync/async framework for Interactive Brokers API"""
|
||||
|
||||
import dataclasses
|
||||
import sys
|
||||
|
||||
from eventkit import Event
|
||||
|
||||
from . import util
|
||||
from .client import Client
|
||||
from .contract import (
|
||||
Bag, Bond, CFD, ComboLeg, Commodity, ContFuture, Contract,
|
||||
ContractDescription, ContractDetails, Crypto, DeltaNeutralContract,
|
||||
Forex, Future, FuturesOption, Index, MutualFund, Option, ScanData, Stock,
|
||||
TagValue, Warrant)
|
||||
from .flexreport import FlexError, FlexReport
|
||||
from .ib import IB
|
||||
from .ibcontroller import IBC, Watchdog
|
||||
from .objects import (
|
||||
AccountValue, BarData, BarDataList, CommissionReport, ConnectionStats,
|
||||
DOMLevel, DepthMktDataDescription, Dividends, Execution, ExecutionFilter,
|
||||
FamilyCode, Fill, FundamentalRatios, HistogramData, HistoricalNews,
|
||||
HistoricalSchedule, HistoricalSession, HistoricalTick,
|
||||
HistoricalTickBidAsk, HistoricalTickLast, MktDepthData, NewsArticle,
|
||||
NewsBulletin, NewsProvider, NewsTick, OptionChain, OptionComputation,
|
||||
PnL, PnLSingle, PortfolioItem,
|
||||
Position, PriceIncrement, RealTimeBar, RealTimeBarList, ScanDataList,
|
||||
ScannerSubscription, SmartComponent, SoftDollarTier, TickAttrib,
|
||||
TickAttribBidAsk, TickAttribLast, TickByTickAllLast, TickByTickBidAsk,
|
||||
TickByTickMidPoint, TickData, TradeLogEntry, WshEventData)
|
||||
from .order import (
|
||||
BracketOrder, ExecutionCondition, LimitOrder, MarginCondition, MarketOrder,
|
||||
Order, OrderComboLeg, OrderCondition, OrderState, OrderStatus,
|
||||
PercentChangeCondition, PriceCondition, StopLimitOrder, StopOrder,
|
||||
TimeCondition, Trade, VolumeCondition)
|
||||
from .ticker import Ticker
|
||||
from .version import __version__, __version_info__
|
||||
from .wrapper import RequestError, Wrapper
|
||||
|
||||
__all__ = [
|
||||
'Event', 'util', 'Client',
|
||||
'Bag', 'Bond', 'CFD', 'ComboLeg', 'Commodity', 'ContFuture', 'Contract',
|
||||
'ContractDescription', 'ContractDetails', 'Crypto', 'DeltaNeutralContract',
|
||||
'Forex', 'Future', 'FuturesOption', 'Index', 'MutualFund', 'Option',
|
||||
'ScanData', 'Stock', 'TagValue', 'Warrant', 'FlexError', 'FlexReport',
|
||||
'IB', 'IBC', 'Watchdog',
|
||||
'AccountValue', 'BarData', 'BarDataList', 'CommissionReport',
|
||||
'ConnectionStats', 'DOMLevel', 'DepthMktDataDescription', 'Dividends',
|
||||
'Execution', 'ExecutionFilter', 'FamilyCode', 'Fill', 'FundamentalRatios',
|
||||
'HistogramData', 'HistoricalNews', 'HistoricalTick',
|
||||
'HistoricalTickBidAsk', 'HistoricalTickLast',
|
||||
'HistoricalSchedule', 'HistoricalSession', 'MktDepthData',
|
||||
'NewsArticle', 'NewsBulletin', 'NewsProvider', 'NewsTick', 'OptionChain',
|
||||
'OptionComputation', 'PnL', 'PnLSingle', 'PortfolioItem', 'Position',
|
||||
'PriceIncrement', 'RealTimeBar', 'RealTimeBarList', 'ScanDataList',
|
||||
'ScannerSubscription', 'SmartComponent', 'SoftDollarTier', 'TickAttrib',
|
||||
'TickAttribBidAsk', 'TickAttribLast', 'TickByTickAllLast', 'WshEventData',
|
||||
'TickByTickBidAsk', 'TickByTickMidPoint', 'TickData', 'TradeLogEntry',
|
||||
'BracketOrder', 'ExecutionCondition', 'LimitOrder', 'MarginCondition',
|
||||
'MarketOrder', 'Order', 'OrderComboLeg', 'OrderCondition', 'OrderState',
|
||||
'OrderStatus', 'PercentChangeCondition', 'PriceCondition',
|
||||
'StopLimitOrder', 'StopOrder', 'TimeCondition', 'Trade', 'VolumeCondition',
|
||||
'Ticker', '__version__', '__version_info__', 'RequestError', 'Wrapper'
|
||||
]
|
||||
|
||||
|
||||
# compatibility with old Object
|
||||
for obj in locals().copy().values():
|
||||
if dataclasses.is_dataclass(obj):
|
||||
obj.dict = util.dataclassAsDict
|
||||
obj.tuple = util.dataclassAsTuple
|
||||
obj.update = util.dataclassUpdate
|
||||
obj.nonDefaults = util.dataclassNonDefaults
|
||||
|
||||
del sys
|
||||
del dataclasses
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,996 @@
|
||||
"""Socket client for communicating with Interactive Brokers."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import struct
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Deque, List, Optional
|
||||
|
||||
from eventkit import Event
|
||||
|
||||
from .connection import Connection
|
||||
from .contract import Contract
|
||||
from .decoder import Decoder
|
||||
from .objects import ConnectionStats, WshEventData
|
||||
from .util import UNSET_DOUBLE, UNSET_INTEGER, dataclassAsTuple, getLoop, run
|
||||
|
||||
|
||||
class Client:
|
||||
"""
|
||||
Replacement for ``ibapi.client.EClient`` that uses asyncio.
|
||||
|
||||
The client is fully asynchronous and has its own
|
||||
event-driven networking code that replaces the
|
||||
networking code of the standard EClient.
|
||||
It also replaces the infinite loop of ``EClient.run()``
|
||||
with the asyncio event loop. It can be used as a drop-in
|
||||
replacement for the standard EClient as provided by IBAPI.
|
||||
|
||||
Compared to the standard EClient this client has the following
|
||||
additional features:
|
||||
|
||||
* ``client.connect()`` will block until the client is ready to
|
||||
serve requests; It is not necessary to wait for ``nextValidId``
|
||||
to start requests as the client has already done that.
|
||||
The reqId is directly available with :py:meth:`.getReqId()`.
|
||||
|
||||
* ``client.connectAsync()`` is a coroutine for connecting asynchronously.
|
||||
|
||||
* When blocking, ``client.connect()`` can be made to time out with
|
||||
the timeout parameter (default 2 seconds).
|
||||
|
||||
* Optional ``wrapper.priceSizeTick(reqId, tickType, price, size)`` that
|
||||
combines price and size instead of the two wrapper methods
|
||||
priceTick and sizeTick.
|
||||
|
||||
* Automatic request throttling.
|
||||
|
||||
* Optional ``wrapper.tcpDataArrived()`` method;
|
||||
If the wrapper has this method it is invoked directly after
|
||||
a network packet has arrived.
|
||||
A possible use is to timestamp all data in the packet with
|
||||
the exact same time.
|
||||
|
||||
* Optional ``wrapper.tcpDataProcessed()`` method;
|
||||
If the wrapper has this method it is invoked after the
|
||||
network packet's data has been handled.
|
||||
A possible use is to write or evaluate the newly arrived data in
|
||||
one batch instead of item by item.
|
||||
|
||||
Parameters:
|
||||
MaxRequests (int):
|
||||
Throttle the number of requests to ``MaxRequests`` per
|
||||
``RequestsInterval`` seconds. Set to 0 to disable throttling.
|
||||
RequestsInterval (float):
|
||||
Time interval (in seconds) for request throttling.
|
||||
MinClientVersion (int):
|
||||
Client protocol version.
|
||||
MaxClientVersion (int):
|
||||
Client protocol version.
|
||||
|
||||
Events:
|
||||
* ``apiStart`` ()
|
||||
* ``apiEnd`` ()
|
||||
* ``apiError`` (errorMsg: str)
|
||||
* ``throttleStart`` ()
|
||||
* ``throttleEnd`` ()
|
||||
"""
|
||||
|
||||
events = ('apiStart', 'apiEnd', 'apiError', 'throttleStart', 'throttleEnd')
|
||||
|
||||
MaxRequests = 45
|
||||
RequestsInterval = 1
|
||||
|
||||
MinClientVersion = 157
|
||||
MaxClientVersion = 176
|
||||
|
||||
(DISCONNECTED, CONNECTING, CONNECTED) = range(3)
|
||||
|
||||
def __init__(self, wrapper):
|
||||
self.wrapper = wrapper
|
||||
self.decoder = Decoder(wrapper, 0)
|
||||
self.apiStart = Event('apiStart')
|
||||
self.apiEnd = Event('apiEnd')
|
||||
self.apiError = Event('apiError')
|
||||
self.throttleStart = Event('throttleStart')
|
||||
self.throttleEnd = Event('throttleEnd')
|
||||
self._logger = logging.getLogger('ib_insync.client')
|
||||
|
||||
self.conn = Connection()
|
||||
self.conn.hasData += self._onSocketHasData
|
||||
self.conn.disconnected += self._onSocketDisconnected
|
||||
|
||||
# extra optional wrapper methods
|
||||
self._priceSizeTick = getattr(wrapper, 'priceSizeTick', None)
|
||||
self._tcpDataArrived = getattr(wrapper, 'tcpDataArrived', None)
|
||||
self._tcpDataProcessed = getattr(wrapper, 'tcpDataProcessed', None)
|
||||
|
||||
self.host = ''
|
||||
self.port = -1
|
||||
self.clientId = -1
|
||||
self.optCapab = ''
|
||||
self.connectOptions = b''
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.connState = Client.DISCONNECTED
|
||||
self._apiReady = False
|
||||
self._serverVersion = 0
|
||||
self._data = b''
|
||||
self._hasReqId = False
|
||||
self._reqIdSeq = 0
|
||||
self._accounts = []
|
||||
self._startTime = time.time()
|
||||
self._numBytesRecv = 0
|
||||
self._numMsgRecv = 0
|
||||
self._isThrottling = False
|
||||
self._msgQ: Deque[str] = deque()
|
||||
self._timeQ: Deque[float] = deque()
|
||||
|
||||
def serverVersion(self) -> int:
|
||||
return self._serverVersion
|
||||
|
||||
def run(self):
|
||||
loop = getLoop()
|
||||
loop.run_forever()
|
||||
|
||||
def isConnected(self):
|
||||
return self.connState == Client.CONNECTED
|
||||
|
||||
def isReady(self) -> bool:
|
||||
"""Is the API connection up and running?"""
|
||||
return self._apiReady
|
||||
|
||||
def connectionStats(self) -> ConnectionStats:
|
||||
"""Get statistics about the connection."""
|
||||
if not self.isReady():
|
||||
raise ConnectionError('Not connected')
|
||||
return ConnectionStats(
|
||||
self._startTime,
|
||||
time.time() - self._startTime,
|
||||
self._numBytesRecv, self.conn.numBytesSent,
|
||||
self._numMsgRecv, self.conn.numMsgSent)
|
||||
|
||||
def getReqId(self) -> int:
|
||||
"""Get new request ID."""
|
||||
if not self.isReady():
|
||||
raise ConnectionError('Not connected')
|
||||
newId = self._reqIdSeq
|
||||
self._reqIdSeq += 1
|
||||
return newId
|
||||
|
||||
def updateReqId(self, minReqId):
|
||||
"""Update the next reqId to be at least ``minReqId``."""
|
||||
self._reqIdSeq = max(self._reqIdSeq, minReqId)
|
||||
|
||||
def getAccounts(self) -> List[str]:
|
||||
"""Get the list of account names that are under management."""
|
||||
if not self.isReady():
|
||||
raise ConnectionError('Not connected')
|
||||
return self._accounts
|
||||
|
||||
def setConnectOptions(self, connectOptions: str):
|
||||
"""
|
||||
Set additional connect options.
|
||||
|
||||
Args:
|
||||
connectOptions: Use "+PACEAPI" to use request-pacing built
|
||||
into TWS/gateway 974+ (obsolete).
|
||||
"""
|
||||
self.connectOptions = connectOptions.encode()
|
||||
|
||||
def connect(
|
||||
self, host: str, port: int, clientId: int,
|
||||
timeout: Optional[float] = 2.0):
|
||||
"""
|
||||
Connect to a running TWS or IB gateway application.
|
||||
|
||||
Args:
|
||||
host: Host name or IP address.
|
||||
port: Port number.
|
||||
clientId: ID number to use for this client; must be unique per
|
||||
connection.
|
||||
timeout: If establishing the connection takes longer than
|
||||
``timeout`` seconds then the ``asyncio.TimeoutError`` exception
|
||||
is raised. Set to 0 to disable timeout.
|
||||
"""
|
||||
run(self.connectAsync(host, port, clientId, timeout))
|
||||
|
||||
async def connectAsync(self, host, port, clientId, timeout=2.0):
|
||||
try:
|
||||
self._logger.info(
|
||||
f'Connecting to {host}:{port} with clientId {clientId}...')
|
||||
self.host = host
|
||||
self.port = int(port)
|
||||
self.clientId = int(clientId)
|
||||
self.connState = Client.CONNECTING
|
||||
timeout = timeout or None
|
||||
await asyncio.wait_for(self.conn.connectAsync(host, port), timeout)
|
||||
self._logger.info('Connected')
|
||||
msg = b'API\0' + self._prefix(b'v%d..%d%s' % (
|
||||
self.MinClientVersion, self.MaxClientVersion,
|
||||
b' ' + self.connectOptions if self.connectOptions else b''))
|
||||
self.conn.sendMsg(msg)
|
||||
await asyncio.wait_for(self.apiStart, timeout)
|
||||
self._logger.info('API connection ready')
|
||||
except BaseException as e:
|
||||
self.disconnect()
|
||||
msg = f'API connection failed: {e!r}'
|
||||
self._logger.error(msg)
|
||||
self.apiError.emit(msg)
|
||||
if isinstance(e, ConnectionRefusedError):
|
||||
self._logger.error('Make sure API port on TWS/IBG is open')
|
||||
raise
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect from IB connection."""
|
||||
self._logger.info('Disconnecting')
|
||||
self.connState = Client.DISCONNECTED
|
||||
self.conn.disconnect()
|
||||
self.reset()
|
||||
|
||||
def send(self, *fields, makeEmpty=True):
|
||||
"""Serialize and send the given fields using the IB socket protocol."""
|
||||
if not self.isConnected():
|
||||
raise ConnectionError('Not connected')
|
||||
|
||||
msg = io.StringIO()
|
||||
empty = (None, UNSET_INTEGER, UNSET_DOUBLE) if makeEmpty else (None,)
|
||||
for field in fields:
|
||||
typ = type(field)
|
||||
if field in empty:
|
||||
s = ''
|
||||
elif typ is str:
|
||||
s = field
|
||||
elif type is int:
|
||||
s = str(field)
|
||||
elif typ is float:
|
||||
s = 'Infinite' if field == math.inf else str(field)
|
||||
elif typ is bool:
|
||||
s = '1' if field else '0'
|
||||
elif typ is list:
|
||||
# list of TagValue
|
||||
s = ''.join(f'{v.tag}={v.value};' for v in field)
|
||||
elif isinstance(field, Contract):
|
||||
c = field
|
||||
s = '\0'.join(str(f) for f in (
|
||||
c.conId, c.symbol, c.secType,
|
||||
c.lastTradeDateOrContractMonth, c.strike,
|
||||
c.right, c.multiplier, c.exchange,
|
||||
c.primaryExchange, c.currency,
|
||||
c.localSymbol, c.tradingClass))
|
||||
else:
|
||||
s = str(field)
|
||||
msg.write(s)
|
||||
msg.write('\0')
|
||||
self.sendMsg(msg.getvalue())
|
||||
|
||||
def sendMsg(self, msg: str):
|
||||
loop = getLoop()
|
||||
t = loop.time()
|
||||
times = self._timeQ
|
||||
msgs = self._msgQ
|
||||
while times and t - times[0] > self.RequestsInterval:
|
||||
times.popleft()
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
while msgs and (len(times) < self.MaxRequests or not self.MaxRequests):
|
||||
msg = msgs.popleft()
|
||||
self.conn.sendMsg(self._prefix(msg.encode()))
|
||||
times.append(t)
|
||||
if self._logger.isEnabledFor(logging.DEBUG):
|
||||
self._logger.debug('>>> %s', msg[:-1].replace('\0', ','))
|
||||
if msgs:
|
||||
if not self._isThrottling:
|
||||
self._isThrottling = True
|
||||
self.throttleStart.emit()
|
||||
self._logger.debug('Started to throttle requests')
|
||||
loop.call_at(
|
||||
times[0] + self.RequestsInterval,
|
||||
self.sendMsg, None)
|
||||
else:
|
||||
if self._isThrottling:
|
||||
self._isThrottling = False
|
||||
self.throttleEnd.emit()
|
||||
self._logger.debug('Stopped to throttle requests')
|
||||
|
||||
def _prefix(self, msg):
|
||||
# prefix a message with its length
|
||||
return struct.pack('>I', len(msg)) + msg
|
||||
|
||||
def _onSocketHasData(self, data):
|
||||
debug = self._logger.isEnabledFor(logging.DEBUG)
|
||||
if self._tcpDataArrived:
|
||||
self._tcpDataArrived()
|
||||
|
||||
self._data += data
|
||||
self._numBytesRecv += len(data)
|
||||
|
||||
while True:
|
||||
if len(self._data) <= 4:
|
||||
break
|
||||
# 4 byte prefix tells the message length
|
||||
msgEnd = 4 + struct.unpack('>I', self._data[:4])[0]
|
||||
if len(self._data) < msgEnd:
|
||||
# insufficient data for now
|
||||
break
|
||||
msg = self._data[4:msgEnd].decode(errors='backslashreplace')
|
||||
self._data = self._data[msgEnd:]
|
||||
fields = msg.split('\0')
|
||||
fields.pop() # pop off last empty element
|
||||
self._numMsgRecv += 1
|
||||
|
||||
if debug:
|
||||
self._logger.debug('<<< %s', ','.join(fields))
|
||||
|
||||
if not self._serverVersion and len(fields) == 2:
|
||||
# this concludes the handshake
|
||||
version, _connTime = fields
|
||||
self._serverVersion = int(version)
|
||||
if self._serverVersion < self.MinClientVersion:
|
||||
self._onSocketDisconnected(
|
||||
'TWS/gateway version must be >= 972')
|
||||
return
|
||||
self.decoder.serverVersion = self._serverVersion
|
||||
self.connState = Client.CONNECTED
|
||||
self.startApi()
|
||||
self.wrapper.connectAck()
|
||||
self._logger.info(
|
||||
f'Logged on to server version {self._serverVersion}')
|
||||
else:
|
||||
if not self._apiReady:
|
||||
# snoop for nextValidId and managedAccounts response,
|
||||
# when both are in then the client is ready
|
||||
msgId = int(fields[0])
|
||||
if msgId == 9:
|
||||
_, _, validId = fields
|
||||
self.updateReqId(int(validId))
|
||||
self._hasReqId = True
|
||||
elif msgId == 15:
|
||||
_, _, accts = fields
|
||||
self._accounts = [a for a in accts.split(',') if a]
|
||||
if self._hasReqId and self._accounts:
|
||||
self._apiReady = True
|
||||
self.apiStart.emit()
|
||||
|
||||
# decode and handle the message
|
||||
self.decoder.interpret(fields)
|
||||
|
||||
if self._tcpDataProcessed:
|
||||
self._tcpDataProcessed()
|
||||
|
||||
def _onSocketDisconnected(self, msg):
|
||||
wasReady = self.isReady()
|
||||
if not self.isConnected():
|
||||
self._logger.info('Disconnected.')
|
||||
elif not msg:
|
||||
msg = 'Peer closed connection.'
|
||||
if not wasReady:
|
||||
msg += f' clientId {self.clientId} already in use?'
|
||||
if msg:
|
||||
self._logger.error(msg)
|
||||
self.apiError.emit(msg)
|
||||
self.wrapper.setEventsDone()
|
||||
if wasReady:
|
||||
self.wrapper.connectionClosed()
|
||||
self.reset()
|
||||
if wasReady:
|
||||
self.apiEnd.emit()
|
||||
|
||||
# client request methods
|
||||
# the message type id is sent first, often followed by a version number
|
||||
|
||||
def reqMktData(
|
||||
self, reqId, contract, genericTickList, snapshot,
|
||||
regulatorySnapshot, mktDataOptions):
|
||||
fields = [1, 11, reqId, contract]
|
||||
|
||||
if contract.secType == 'BAG':
|
||||
legs = contract.comboLegs or []
|
||||
fields += [len(legs)]
|
||||
for leg in legs:
|
||||
fields += [leg.conId, leg.ratio, leg.action, leg.exchange]
|
||||
|
||||
dnc = contract.deltaNeutralContract
|
||||
if dnc:
|
||||
fields += [True, dnc.conId, dnc.delta, dnc.price]
|
||||
else:
|
||||
fields += [False]
|
||||
|
||||
fields += [
|
||||
genericTickList, snapshot, regulatorySnapshot, mktDataOptions]
|
||||
self.send(*fields)
|
||||
|
||||
def cancelMktData(self, reqId):
|
||||
self.send(2, 2, reqId)
|
||||
|
||||
def placeOrder(self, orderId, contract, order):
|
||||
version = self.serverVersion()
|
||||
fields = [
|
||||
3, orderId,
|
||||
contract,
|
||||
contract.secIdType,
|
||||
contract.secId,
|
||||
order.action,
|
||||
order.totalQuantity,
|
||||
order.orderType,
|
||||
order.lmtPrice,
|
||||
order.auxPrice,
|
||||
order.tif,
|
||||
order.ocaGroup,
|
||||
order.account,
|
||||
order.openClose,
|
||||
order.origin,
|
||||
order.orderRef,
|
||||
order.transmit,
|
||||
order.parentId,
|
||||
order.blockOrder,
|
||||
order.sweepToFill,
|
||||
order.displaySize,
|
||||
order.triggerMethod,
|
||||
order.outsideRth,
|
||||
order.hidden]
|
||||
|
||||
if contract.secType == 'BAG':
|
||||
legs = contract.comboLegs or []
|
||||
fields += [len(legs)]
|
||||
for leg in legs:
|
||||
fields += [
|
||||
leg.conId,
|
||||
leg.ratio,
|
||||
leg.action,
|
||||
leg.exchange,
|
||||
leg.openClose,
|
||||
leg.shortSaleSlot,
|
||||
leg.designatedLocation,
|
||||
leg.exemptCode]
|
||||
|
||||
legs = order.orderComboLegs or []
|
||||
fields += [len(legs)]
|
||||
for leg in legs:
|
||||
fields += [leg.price]
|
||||
|
||||
params = order.smartComboRoutingParams or []
|
||||
fields += [len(params)]
|
||||
for param in params:
|
||||
fields += [param.tag, param.value]
|
||||
|
||||
fields += [
|
||||
'',
|
||||
order.discretionaryAmt,
|
||||
order.goodAfterTime,
|
||||
order.goodTillDate,
|
||||
order.faGroup,
|
||||
order.faMethod,
|
||||
order.faPercentage,
|
||||
order.faProfile,
|
||||
order.modelCode,
|
||||
order.shortSaleSlot,
|
||||
order.designatedLocation,
|
||||
order.exemptCode,
|
||||
order.ocaType,
|
||||
order.rule80A,
|
||||
order.settlingFirm,
|
||||
order.allOrNone,
|
||||
order.minQty,
|
||||
order.percentOffset,
|
||||
order.eTradeOnly,
|
||||
order.firmQuoteOnly,
|
||||
order.nbboPriceCap,
|
||||
order.auctionStrategy,
|
||||
order.startingPrice,
|
||||
order.stockRefPrice,
|
||||
order.delta,
|
||||
order.stockRangeLower,
|
||||
order.stockRangeUpper,
|
||||
order.overridePercentageConstraints,
|
||||
order.volatility,
|
||||
order.volatilityType,
|
||||
order.deltaNeutralOrderType,
|
||||
order.deltaNeutralAuxPrice]
|
||||
|
||||
if order.deltaNeutralOrderType:
|
||||
fields += [
|
||||
order.deltaNeutralConId,
|
||||
order.deltaNeutralSettlingFirm,
|
||||
order.deltaNeutralClearingAccount,
|
||||
order.deltaNeutralClearingIntent,
|
||||
order.deltaNeutralOpenClose,
|
||||
order.deltaNeutralShortSale,
|
||||
order.deltaNeutralShortSaleSlot,
|
||||
order.deltaNeutralDesignatedLocation]
|
||||
|
||||
fields += [
|
||||
order.continuousUpdate,
|
||||
order.referencePriceType,
|
||||
order.trailStopPrice,
|
||||
order.trailingPercent,
|
||||
order.scaleInitLevelSize,
|
||||
order.scaleSubsLevelSize,
|
||||
order.scalePriceIncrement]
|
||||
|
||||
if (0 < order.scalePriceIncrement < UNSET_DOUBLE):
|
||||
fields += [
|
||||
order.scalePriceAdjustValue,
|
||||
order.scalePriceAdjustInterval,
|
||||
order.scaleProfitOffset,
|
||||
order.scaleAutoReset,
|
||||
order.scaleInitPosition,
|
||||
order.scaleInitFillQty,
|
||||
order.scaleRandomPercent]
|
||||
|
||||
fields += [
|
||||
order.scaleTable,
|
||||
order.activeStartTime,
|
||||
order.activeStopTime,
|
||||
order.hedgeType]
|
||||
|
||||
if order.hedgeType:
|
||||
fields += [order.hedgeParam]
|
||||
|
||||
fields += [
|
||||
order.optOutSmartRouting,
|
||||
order.clearingAccount,
|
||||
order.clearingIntent,
|
||||
order.notHeld]
|
||||
|
||||
dnc = contract.deltaNeutralContract
|
||||
if dnc:
|
||||
fields += [True, dnc.conId, dnc.delta, dnc.price]
|
||||
else:
|
||||
fields += [False]
|
||||
|
||||
fields += [order.algoStrategy]
|
||||
if order.algoStrategy:
|
||||
params = order.algoParams or []
|
||||
fields += [len(params)]
|
||||
for param in params:
|
||||
fields += [param.tag, param.value]
|
||||
|
||||
fields += [
|
||||
order.algoId,
|
||||
order.whatIf,
|
||||
order.orderMiscOptions,
|
||||
order.solicited,
|
||||
order.randomizeSize,
|
||||
order.randomizePrice]
|
||||
|
||||
if order.orderType == 'PEG BENCH':
|
||||
fields += [
|
||||
order.referenceContractId,
|
||||
order.isPeggedChangeAmountDecrease,
|
||||
order.peggedChangeAmount,
|
||||
order.referenceChangeAmount,
|
||||
order.referenceExchangeId]
|
||||
|
||||
fields += [len(order.conditions)]
|
||||
if order.conditions:
|
||||
for cond in order.conditions:
|
||||
fields += dataclassAsTuple(cond)
|
||||
fields += [
|
||||
order.conditionsIgnoreRth,
|
||||
order.conditionsCancelOrder]
|
||||
|
||||
fields += [
|
||||
order.adjustedOrderType,
|
||||
order.triggerPrice,
|
||||
order.lmtPriceOffset,
|
||||
order.adjustedStopPrice,
|
||||
order.adjustedStopLimitPrice,
|
||||
order.adjustedTrailingAmount,
|
||||
order.adjustableTrailingUnit,
|
||||
order.extOperator,
|
||||
order.softDollarTier.name,
|
||||
order.softDollarTier.val,
|
||||
order.cashQty,
|
||||
order.mifid2DecisionMaker,
|
||||
order.mifid2DecisionAlgo,
|
||||
order.mifid2ExecutionTrader,
|
||||
order.mifid2ExecutionAlgo,
|
||||
order.dontUseAutoPriceForHedge,
|
||||
order.isOmsContainer,
|
||||
order.discretionaryUpToLimitPrice,
|
||||
order.usePriceMgmtAlgo]
|
||||
|
||||
if version >= 158:
|
||||
fields += [order.duration]
|
||||
if version >= 160:
|
||||
fields += [order.postToAts]
|
||||
if version >= 162:
|
||||
fields += [order.autoCancelParent]
|
||||
if version >= 166:
|
||||
fields += [order.advancedErrorOverride]
|
||||
if version >= 169:
|
||||
fields += [order.manualOrderTime]
|
||||
if version >= 170:
|
||||
if contract.exchange == 'IBKRATS':
|
||||
fields += [order.minTradeQty]
|
||||
if order.orderType == 'PEG BEST':
|
||||
fields += [
|
||||
order.minCompeteSize,
|
||||
order.competeAgainstBestOffset]
|
||||
if order.competeAgainstBestOffset == math.inf:
|
||||
fields += [order.midOffsetAtWhole, order.midOffsetAtHalf]
|
||||
elif order.orderType == 'PEG MID':
|
||||
fields += [order.midOffsetAtWhole, order.midOffsetAtHalf]
|
||||
|
||||
self.send(*fields)
|
||||
|
||||
def cancelOrder(self, orderId, manualCancelOrderTime=''):
|
||||
fields = [4, 1, orderId]
|
||||
if self.serverVersion() >= 169:
|
||||
fields += [manualCancelOrderTime]
|
||||
self.send(*fields)
|
||||
|
||||
def reqOpenOrders(self):
|
||||
self.send(5, 1)
|
||||
|
||||
def reqAccountUpdates(self, subscribe, acctCode):
|
||||
self.send(6, 2, subscribe, acctCode)
|
||||
|
||||
def reqExecutions(self, reqId, execFilter):
|
||||
self.send(
|
||||
7, 3, reqId,
|
||||
execFilter.clientId,
|
||||
execFilter.acctCode,
|
||||
execFilter.time,
|
||||
execFilter.symbol,
|
||||
execFilter.secType,
|
||||
execFilter.exchange,
|
||||
execFilter.side)
|
||||
|
||||
def reqIds(self, numIds):
|
||||
self.send(8, 1, numIds)
|
||||
|
||||
def reqContractDetails(self, reqId, contract):
|
||||
fields = [
|
||||
9, 8, reqId,
|
||||
contract,
|
||||
contract.includeExpired,
|
||||
contract.secIdType,
|
||||
contract.secId]
|
||||
if self.serverVersion() >= 176:
|
||||
fields += [contract.issuerId]
|
||||
self.send(*fields)
|
||||
|
||||
def reqMktDepth(
|
||||
self, reqId, contract, numRows, isSmartDepth, mktDepthOptions):
|
||||
self.send(
|
||||
10, 5, reqId,
|
||||
contract.conId,
|
||||
contract.symbol,
|
||||
contract.secType,
|
||||
contract.lastTradeDateOrContractMonth,
|
||||
contract.strike,
|
||||
contract.right,
|
||||
contract.multiplier,
|
||||
contract.exchange,
|
||||
contract.primaryExchange,
|
||||
contract.currency,
|
||||
contract.localSymbol,
|
||||
contract.tradingClass,
|
||||
numRows,
|
||||
isSmartDepth,
|
||||
mktDepthOptions)
|
||||
|
||||
def cancelMktDepth(self, reqId, isSmartDepth):
|
||||
self.send(11, 1, reqId, isSmartDepth)
|
||||
|
||||
def reqNewsBulletins(self, allMsgs):
|
||||
self.send(12, 1, allMsgs)
|
||||
|
||||
def cancelNewsBulletins(self):
|
||||
self.send(13, 1)
|
||||
|
||||
def setServerLogLevel(self, logLevel):
|
||||
self.send(14, 1, logLevel)
|
||||
|
||||
def reqAutoOpenOrders(self, bAutoBind):
|
||||
self.send(15, 1, bAutoBind)
|
||||
|
||||
def reqAllOpenOrders(self):
|
||||
self.send(16, 1)
|
||||
|
||||
def reqManagedAccts(self):
|
||||
self.send(17, 1)
|
||||
|
||||
def requestFA(self, faData):
|
||||
self.send(18, 1, faData)
|
||||
|
||||
def replaceFA(self, reqId, faData, cxml):
|
||||
self.send(19, 1, faData, cxml, reqId)
|
||||
|
||||
def reqHistoricalData(
|
||||
self, reqId, contract, endDateTime, durationStr, barSizeSetting,
|
||||
whatToShow, useRTH, formatDate, keepUpToDate, chartOptions):
|
||||
fields = [
|
||||
20, reqId, contract, contract.includeExpired,
|
||||
endDateTime, barSizeSetting, durationStr, useRTH,
|
||||
whatToShow, formatDate]
|
||||
|
||||
if contract.secType == 'BAG':
|
||||
legs = contract.comboLegs or []
|
||||
fields += [len(legs)]
|
||||
for leg in legs:
|
||||
fields += [leg.conId, leg.ratio, leg.action, leg.exchange]
|
||||
|
||||
fields += [keepUpToDate, chartOptions]
|
||||
self.send(*fields)
|
||||
|
||||
def exerciseOptions(
|
||||
self, reqId, contract, exerciseAction,
|
||||
exerciseQuantity, account, override):
|
||||
self.send(
|
||||
21, 2, reqId,
|
||||
contract.conId,
|
||||
contract.symbol,
|
||||
contract.secType,
|
||||
contract.lastTradeDateOrContractMonth,
|
||||
contract.strike,
|
||||
contract.right,
|
||||
contract.multiplier,
|
||||
contract.exchange,
|
||||
contract.currency,
|
||||
contract.localSymbol,
|
||||
contract.tradingClass,
|
||||
exerciseAction, exerciseQuantity, account, override)
|
||||
|
||||
def reqScannerSubscription(
|
||||
self, reqId, subscription, scannerSubscriptionOptions,
|
||||
scannerSubscriptionFilterOptions):
|
||||
sub = subscription
|
||||
self.send(
|
||||
22, reqId,
|
||||
sub.numberOfRows,
|
||||
sub.instrument,
|
||||
sub.locationCode,
|
||||
sub.scanCode,
|
||||
sub.abovePrice,
|
||||
sub.belowPrice,
|
||||
sub.aboveVolume,
|
||||
sub.marketCapAbove,
|
||||
sub.marketCapBelow,
|
||||
sub.moodyRatingAbove,
|
||||
sub.moodyRatingBelow,
|
||||
sub.spRatingAbove,
|
||||
sub.spRatingBelow,
|
||||
sub.maturityDateAbove,
|
||||
sub.maturityDateBelow,
|
||||
sub.couponRateAbove,
|
||||
sub.couponRateBelow,
|
||||
sub.excludeConvertible,
|
||||
sub.averageOptionVolumeAbove,
|
||||
sub.scannerSettingPairs,
|
||||
sub.stockTypeFilter,
|
||||
scannerSubscriptionFilterOptions,
|
||||
scannerSubscriptionOptions)
|
||||
|
||||
def cancelScannerSubscription(self, reqId):
|
||||
self.send(23, 1, reqId)
|
||||
|
||||
def reqScannerParameters(self):
|
||||
self.send(24, 1)
|
||||
|
||||
def cancelHistoricalData(self, reqId):
|
||||
self.send(25, 1, reqId)
|
||||
|
||||
def reqCurrentTime(self):
|
||||
self.send(49, 1)
|
||||
|
||||
def reqRealTimeBars(
|
||||
self, reqId, contract, barSize, whatToShow,
|
||||
useRTH, realTimeBarsOptions):
|
||||
self.send(
|
||||
50, 3, reqId, contract, barSize, whatToShow,
|
||||
useRTH, realTimeBarsOptions)
|
||||
|
||||
def cancelRealTimeBars(self, reqId):
|
||||
self.send(51, 1, reqId)
|
||||
|
||||
def reqFundamentalData(
|
||||
self, reqId, contract, reportType, fundamentalDataOptions):
|
||||
options = fundamentalDataOptions or []
|
||||
self.send(
|
||||
52, 2, reqId,
|
||||
contract.conId,
|
||||
contract.symbol,
|
||||
contract.secType,
|
||||
contract.exchange,
|
||||
contract.primaryExchange,
|
||||
contract.currency,
|
||||
contract.localSymbol,
|
||||
reportType, len(options), options)
|
||||
|
||||
def cancelFundamentalData(self, reqId):
|
||||
self.send(53, 1, reqId)
|
||||
|
||||
def calculateImpliedVolatility(
|
||||
self, reqId, contract, optionPrice, underPrice, implVolOptions):
|
||||
self.send(
|
||||
54, 3, reqId, contract, optionPrice, underPrice,
|
||||
len(implVolOptions), implVolOptions)
|
||||
|
||||
def calculateOptionPrice(
|
||||
self, reqId, contract, volatility, underPrice, optPrcOptions):
|
||||
self.send(
|
||||
55, 3, reqId, contract, volatility, underPrice,
|
||||
len(optPrcOptions), optPrcOptions)
|
||||
|
||||
def cancelCalculateImpliedVolatility(self, reqId):
|
||||
self.send(56, 1, reqId)
|
||||
|
||||
def cancelCalculateOptionPrice(self, reqId):
|
||||
self.send(57, 1, reqId)
|
||||
|
||||
def reqGlobalCancel(self):
|
||||
self.send(58, 1)
|
||||
|
||||
def reqMarketDataType(self, marketDataType):
|
||||
self.send(59, 1, marketDataType)
|
||||
|
||||
def reqPositions(self):
|
||||
self.send(61, 1)
|
||||
|
||||
def reqAccountSummary(self, reqId, groupName, tags):
|
||||
self.send(62, 1, reqId, groupName, tags)
|
||||
|
||||
def cancelAccountSummary(self, reqId):
|
||||
self.send(63, 1, reqId)
|
||||
|
||||
def cancelPositions(self):
|
||||
self.send(64, 1)
|
||||
|
||||
def verifyRequest(self, apiName, apiVersion):
|
||||
self.send(65, 1, apiName, apiVersion)
|
||||
|
||||
def verifyMessage(self, apiData):
|
||||
self.send(66, 1, apiData)
|
||||
|
||||
def queryDisplayGroups(self, reqId):
|
||||
self.send(67, 1, reqId)
|
||||
|
||||
def subscribeToGroupEvents(self, reqId, groupId):
|
||||
self.send(68, 1, reqId, groupId)
|
||||
|
||||
def updateDisplayGroup(self, reqId, contractInfo):
|
||||
self.send(69, 1, reqId, contractInfo)
|
||||
|
||||
def unsubscribeFromGroupEvents(self, reqId):
|
||||
self.send(70, 1, reqId)
|
||||
|
||||
def startApi(self):
|
||||
self.send(71, 2, self.clientId, self.optCapab)
|
||||
|
||||
def verifyAndAuthRequest(self, apiName, apiVersion, opaqueIsvKey):
|
||||
self.send(72, 1, apiName, apiVersion, opaqueIsvKey)
|
||||
|
||||
def verifyAndAuthMessage(self, apiData, xyzResponse):
|
||||
self.send(73, 1, apiData, xyzResponse)
|
||||
|
||||
def reqPositionsMulti(self, reqId, account, modelCode):
|
||||
self.send(74, 1, reqId, account, modelCode)
|
||||
|
||||
def cancelPositionsMulti(self, reqId):
|
||||
self.send(75, 1, reqId)
|
||||
|
||||
def reqAccountUpdatesMulti(self, reqId, account, modelCode, ledgerAndNLV):
|
||||
self.send(76, 1, reqId, account, modelCode, ledgerAndNLV)
|
||||
|
||||
def cancelAccountUpdatesMulti(self, reqId):
|
||||
self.send(77, 1, reqId)
|
||||
|
||||
def reqSecDefOptParams(
|
||||
self, reqId, underlyingSymbol, futFopExchange,
|
||||
underlyingSecType, underlyingConId):
|
||||
self.send(
|
||||
78, reqId, underlyingSymbol, futFopExchange,
|
||||
underlyingSecType, underlyingConId)
|
||||
|
||||
def reqSoftDollarTiers(self, reqId):
|
||||
self.send(79, reqId)
|
||||
|
||||
def reqFamilyCodes(self):
|
||||
self.send(80)
|
||||
|
||||
def reqMatchingSymbols(self, reqId, pattern):
|
||||
self.send(81, reqId, pattern)
|
||||
|
||||
def reqMktDepthExchanges(self):
|
||||
self.send(82)
|
||||
|
||||
def reqSmartComponents(self, reqId, bboExchange):
|
||||
self.send(83, reqId, bboExchange)
|
||||
|
||||
def reqNewsArticle(
|
||||
self, reqId, providerCode, articleId, newsArticleOptions):
|
||||
self.send(84, reqId, providerCode, articleId, newsArticleOptions)
|
||||
|
||||
def reqNewsProviders(self):
|
||||
self.send(85)
|
||||
|
||||
def reqHistoricalNews(
|
||||
self, reqId, conId, providerCodes, startDateTime, endDateTime,
|
||||
totalResults, historicalNewsOptions):
|
||||
self.send(
|
||||
86, reqId, conId, providerCodes, startDateTime, endDateTime,
|
||||
totalResults, historicalNewsOptions)
|
||||
|
||||
def reqHeadTimeStamp(
|
||||
self, reqId, contract, whatToShow, useRTH, formatDate):
|
||||
self.send(
|
||||
87, reqId, contract, contract.includeExpired,
|
||||
useRTH, whatToShow, formatDate)
|
||||
|
||||
def reqHistogramData(self, tickerId, contract, useRTH, timePeriod):
|
||||
self.send(
|
||||
88, tickerId, contract, contract.includeExpired,
|
||||
useRTH, timePeriod)
|
||||
|
||||
def cancelHistogramData(self, tickerId):
|
||||
self.send(89, tickerId)
|
||||
|
||||
def cancelHeadTimeStamp(self, reqId):
|
||||
self.send(90, reqId)
|
||||
|
||||
def reqMarketRule(self, marketRuleId):
|
||||
self.send(91, marketRuleId)
|
||||
|
||||
def reqPnL(self, reqId, account, modelCode):
|
||||
self.send(92, reqId, account, modelCode)
|
||||
|
||||
def cancelPnL(self, reqId):
|
||||
self.send(93, reqId)
|
||||
|
||||
def reqPnLSingle(self, reqId, account, modelCode, conid):
|
||||
self.send(94, reqId, account, modelCode, conid)
|
||||
|
||||
def cancelPnLSingle(self, reqId):
|
||||
self.send(95, reqId)
|
||||
|
||||
def reqHistoricalTicks(
|
||||
self, reqId, contract, startDateTime, endDateTime,
|
||||
numberOfTicks, whatToShow, useRth, ignoreSize, miscOptions):
|
||||
self.send(
|
||||
96, reqId, contract, contract.includeExpired,
|
||||
startDateTime, endDateTime, numberOfTicks, whatToShow,
|
||||
useRth, ignoreSize, miscOptions)
|
||||
|
||||
def reqTickByTickData(
|
||||
self, reqId, contract, tickType, numberOfTicks, ignoreSize):
|
||||
self.send(97, reqId, contract, tickType, numberOfTicks, ignoreSize)
|
||||
|
||||
def cancelTickByTickData(self, reqId):
|
||||
self.send(98, reqId)
|
||||
|
||||
def reqCompletedOrders(self, apiOnly):
|
||||
self.send(99, apiOnly)
|
||||
|
||||
def reqWshMetaData(self, reqId):
|
||||
self.send(100, reqId)
|
||||
|
||||
def cancelWshMetaData(self, reqId):
|
||||
self.send(101, reqId)
|
||||
|
||||
def reqWshEventData(self, reqId, data: WshEventData):
|
||||
fields = [102, reqId, data.conId]
|
||||
if self.serverVersion() >= 171:
|
||||
fields += [
|
||||
data.filter,
|
||||
data.fillWatchlist,
|
||||
data.fillPortfolio,
|
||||
data.fillCompetitors]
|
||||
if self.serverVersion() >= 173:
|
||||
fields += [
|
||||
data.startDate,
|
||||
data.endDate,
|
||||
data.totalLimit]
|
||||
self.send(*fields, makeEmpty=False)
|
||||
|
||||
def cancelWshEventData(self, reqId):
|
||||
self.send(103, reqId)
|
||||
|
||||
def reqUserInfo(self, reqId):
|
||||
self.send(104, reqId)
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Event-driven socket connection."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from eventkit import Event
|
||||
|
||||
from ib_insync.util import getLoop
|
||||
|
||||
|
||||
class Connection(asyncio.Protocol):
|
||||
"""
|
||||
Event-driven socket connection.
|
||||
|
||||
Events:
|
||||
* ``hasData`` (data: bytes):
|
||||
Emits the received socket data.
|
||||
* ``disconnected`` (msg: str):
|
||||
Is emitted on socket disconnect, with an error message in case
|
||||
of error, or an empty string in case of a normal disconnect.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.hasData = Event('hasData')
|
||||
self.disconnected = Event('disconnected')
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.transport = None
|
||||
self.numBytesSent = 0
|
||||
self.numMsgSent = 0
|
||||
|
||||
async def connectAsync(self, host, port):
|
||||
if self.transport:
|
||||
# wait until a previous connection is finished closing
|
||||
self.disconnect()
|
||||
await self.disconnected
|
||||
self.reset()
|
||||
loop = getLoop()
|
||||
self.transport, _ = await loop.create_connection(
|
||||
lambda: self, host, port)
|
||||
|
||||
def disconnect(self):
|
||||
if self.transport:
|
||||
self.transport.write_eof()
|
||||
self.transport.close()
|
||||
|
||||
def isConnected(self):
|
||||
return self.transport is not None
|
||||
|
||||
def sendMsg(self, msg):
|
||||
if self.transport:
|
||||
self.transport.write(msg)
|
||||
self.numBytesSent += len(msg)
|
||||
self.numMsgSent += 1
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.transport = None
|
||||
msg = str(exc) if exc else ''
|
||||
self.disconnected.emit(msg)
|
||||
|
||||
def data_received(self, data):
|
||||
self.hasData.emit(data)
|
||||
@@ -0,0 +1,551 @@
|
||||
"""Financial instrument types used by Interactive Brokers."""
|
||||
|
||||
import datetime as dt
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, NamedTuple, Optional
|
||||
|
||||
import ib_insync.util as util
|
||||
|
||||
|
||||
@dataclass
|
||||
class Contract:
|
||||
"""
|
||||
``Contract(**kwargs)`` can create any contract using keyword
|
||||
arguments. To simplify working with contracts, there are also more
|
||||
specialized contracts that take optional positional arguments.
|
||||
Some examples::
|
||||
|
||||
Contract(conId=270639)
|
||||
Stock('AMD', 'SMART', 'USD')
|
||||
Stock('INTC', 'SMART', 'USD', primaryExchange='NASDAQ')
|
||||
Forex('EURUSD')
|
||||
CFD('IBUS30')
|
||||
Future('ES', '20180921', 'GLOBEX')
|
||||
Option('SPY', '20170721', 240, 'C', 'SMART')
|
||||
Bond(secIdType='ISIN', secId='US03076KAA60')
|
||||
Crypto('BTC', 'PAXOS', 'USD')
|
||||
|
||||
Args:
|
||||
conId (int): The unique IB contract identifier.
|
||||
symbol (str): The contract (or its underlying) symbol.
|
||||
secType (str): The security type:
|
||||
|
||||
* 'STK' = Stock (or ETF)
|
||||
* 'OPT' = Option
|
||||
* 'FUT' = Future
|
||||
* 'IND' = Index
|
||||
* 'FOP' = Futures option
|
||||
* 'CASH' = Forex pair
|
||||
* 'CFD' = CFD
|
||||
* 'BAG' = Combo
|
||||
* 'WAR' = Warrant
|
||||
* 'BOND' = Bond
|
||||
* 'CMDTY' = Commodity
|
||||
* 'NEWS' = News
|
||||
* 'FUND' = Mutual fund
|
||||
* 'CRYPTO' = Crypto currency
|
||||
* 'EVENT' = Bet on an event
|
||||
lastTradeDateOrContractMonth (str): The contract's last trading
|
||||
day or contract month (for Options and Futures).
|
||||
Strings with format YYYYMM will be interpreted as the
|
||||
Contract Month whereas YYYYMMDD will be interpreted as
|
||||
Last Trading Day.
|
||||
strike (float): The option's strike price.
|
||||
right (str): Put or Call.
|
||||
Valid values are 'P', 'PUT', 'C', 'CALL', or '' for non-options.
|
||||
multiplier (str): The instrument's multiplier (i.e. options, futures).
|
||||
exchange (str): The destination exchange.
|
||||
currency (str): The underlying's currency.
|
||||
localSymbol (str): The contract's symbol within its primary exchange.
|
||||
For options, this will be the OCC symbol.
|
||||
primaryExchange (str): The contract's primary exchange.
|
||||
For smart routed contracts, used to define contract in case
|
||||
of ambiguity. Should be defined as native exchange of contract,
|
||||
e.g. ISLAND for MSFT. For exchanges which contain a period in name,
|
||||
will only be part of exchange name prior to period, i.e. ENEXT
|
||||
for ENEXT.BE.
|
||||
tradingClass (str): The trading class name for this contract.
|
||||
Available in TWS contract description window as well.
|
||||
For example, GBL Dec '13 future's trading class is "FGBL".
|
||||
includeExpired (bool): If set to true, contract details requests
|
||||
and historical data queries can be performed pertaining to
|
||||
expired futures contracts. Expired options or other instrument
|
||||
types are not available.
|
||||
secIdType (str): Security identifier type. Examples for Apple:
|
||||
|
||||
* secIdType='ISIN', secId='US0378331005'
|
||||
* secIdType='CUSIP', secId='037833100'
|
||||
secId (str): Security identifier.
|
||||
comboLegsDescription (str): Description of the combo legs.
|
||||
comboLegs (List[ComboLeg]): The legs of a combined contract definition.
|
||||
deltaNeutralContract (DeltaNeutralContract): Delta and underlying
|
||||
price for Delta-Neutral combo orders.
|
||||
"""
|
||||
|
||||
secType: str = ''
|
||||
conId: int = 0
|
||||
symbol: str = ''
|
||||
lastTradeDateOrContractMonth: str = ''
|
||||
strike: float = 0.0
|
||||
right: str = ''
|
||||
multiplier: str = ''
|
||||
exchange: str = ''
|
||||
primaryExchange: str = ''
|
||||
currency: str = ''
|
||||
localSymbol: str = ''
|
||||
tradingClass: str = ''
|
||||
includeExpired: bool = False
|
||||
secIdType: str = ''
|
||||
secId: str = ''
|
||||
description: str = ''
|
||||
issuerId: str = ''
|
||||
comboLegsDescrip: str = ''
|
||||
comboLegs: List['ComboLeg'] = field(default_factory=list)
|
||||
deltaNeutralContract: Optional['DeltaNeutralContract'] = None
|
||||
|
||||
@staticmethod
|
||||
def create(**kwargs) -> 'Contract':
|
||||
"""
|
||||
Create and a return a specialized contract based on the given secType,
|
||||
or a general Contract if secType is not given.
|
||||
"""
|
||||
secType = kwargs.get('secType', '')
|
||||
cls = {
|
||||
'': Contract,
|
||||
'STK': Stock,
|
||||
'OPT': Option,
|
||||
'FUT': Future,
|
||||
'CONTFUT': ContFuture,
|
||||
'CASH': Forex,
|
||||
'IND': Index,
|
||||
'CFD': CFD,
|
||||
'BOND': Bond,
|
||||
'CMDTY': Commodity,
|
||||
'FOP': FuturesOption,
|
||||
'FUND': MutualFund,
|
||||
'WAR': Warrant,
|
||||
'IOPT': Warrant,
|
||||
'BAG': Bag,
|
||||
'CRYPTO': Crypto,
|
||||
'NEWS': Contract,
|
||||
'EVENT': Contract,
|
||||
}.get(secType, Contract)
|
||||
if cls is not Contract:
|
||||
kwargs.pop('secType', '')
|
||||
return cls(**kwargs)
|
||||
|
||||
def isHashable(self) -> bool:
|
||||
"""
|
||||
See if this contract can be hashed by conId.
|
||||
|
||||
Note: Bag contracts always get conId=28812380, so they're not hashable.
|
||||
"""
|
||||
return bool(
|
||||
self.conId and self.conId != 28812380
|
||||
and self.secType != 'BAG')
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
isinstance(other, Contract)
|
||||
and (
|
||||
self.conId and self.conId == other.conId
|
||||
or util.dataclassAsDict(self) == util.dataclassAsDict(other)))
|
||||
|
||||
def __hash__(self):
|
||||
if not self.isHashable():
|
||||
raise ValueError(f'Contract {self} can\'t be hashed')
|
||||
if self.secType == 'CONTFUT':
|
||||
# CONTFUT gets the same conId as the front contract, invert it here
|
||||
h = -self.conId
|
||||
else:
|
||||
h = self.conId
|
||||
return h
|
||||
|
||||
def __repr__(self):
|
||||
attrs = util.dataclassNonDefaults(self)
|
||||
if self.__class__ is not Contract:
|
||||
attrs.pop('secType', '')
|
||||
clsName = self.__class__.__qualname__
|
||||
kwargs = ', '.join(f'{k}={v!r}' for k, v in attrs.items())
|
||||
return f'{clsName}({kwargs})'
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
class Stock(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', exchange: str = '', currency: str = '',
|
||||
**kwargs):
|
||||
"""
|
||||
Stock contract.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
exchange: Destination exchange.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, secType='STK', symbol=symbol,
|
||||
exchange=exchange, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class Option(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', lastTradeDateOrContractMonth: str = '',
|
||||
strike: float = 0.0, right: str = '', exchange: str = '',
|
||||
multiplier: str = '', currency: str = '', **kwargs):
|
||||
"""
|
||||
Option contract.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
lastTradeDateOrContractMonth: The option's last trading day
|
||||
or contract month.
|
||||
|
||||
* YYYYMM format: To specify last month
|
||||
* YYYYMMDD format: To specify last trading day
|
||||
strike: The option's strike price.
|
||||
right: Put or call option.
|
||||
Valid values are 'P', 'PUT', 'C' or 'CALL'.
|
||||
exchange: Destination exchange.
|
||||
multiplier: The contract multiplier.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, 'OPT', symbol=symbol,
|
||||
lastTradeDateOrContractMonth=lastTradeDateOrContractMonth,
|
||||
strike=strike, right=right, exchange=exchange,
|
||||
multiplier=multiplier, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class Future(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', lastTradeDateOrContractMonth: str = '',
|
||||
exchange: str = '', localSymbol: str = '', multiplier: str = '',
|
||||
currency: str = '', **kwargs):
|
||||
"""
|
||||
Future contract.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
lastTradeDateOrContractMonth: The option's last trading day
|
||||
or contract month.
|
||||
|
||||
* YYYYMM format: To specify last month
|
||||
* YYYYMMDD format: To specify last trading day
|
||||
exchange: Destination exchange.
|
||||
localSymbol: The contract's symbol within its primary exchange.
|
||||
multiplier: The contract multiplier.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, 'FUT', symbol=symbol,
|
||||
lastTradeDateOrContractMonth=lastTradeDateOrContractMonth,
|
||||
exchange=exchange, localSymbol=localSymbol,
|
||||
multiplier=multiplier, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class ContFuture(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', exchange: str = '', localSymbol: str = '',
|
||||
multiplier: str = '', currency: str = '', **kwargs):
|
||||
"""
|
||||
Continuous future contract.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
exchange: Destination exchange.
|
||||
localSymbol: The contract's symbol within its primary exchange.
|
||||
multiplier: The contract multiplier.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, 'CONTFUT', symbol=symbol,
|
||||
exchange=exchange, localSymbol=localSymbol,
|
||||
multiplier=multiplier, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class Forex(Contract):
|
||||
|
||||
def __init__(
|
||||
self, pair: str = '', exchange: str = 'IDEALPRO',
|
||||
symbol: str = '', currency: str = '', **kwargs):
|
||||
"""
|
||||
Foreign exchange currency pair.
|
||||
|
||||
Args:
|
||||
pair: Shortcut for specifying symbol and currency, like 'EURUSD'.
|
||||
exchange: Destination exchange.
|
||||
symbol: Base currency.
|
||||
currency: Quote currency.
|
||||
"""
|
||||
if pair:
|
||||
assert len(pair) == 6
|
||||
symbol = symbol or pair[:3]
|
||||
currency = currency or pair[3:]
|
||||
Contract.__init__(
|
||||
self, 'CASH', symbol=symbol,
|
||||
exchange=exchange, currency=currency, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
attrs = util.dataclassNonDefaults(self)
|
||||
attrs.pop('secType')
|
||||
s = 'Forex('
|
||||
if 'symbol' in attrs and 'currency' in attrs:
|
||||
pair = attrs.pop('symbol')
|
||||
pair += attrs.pop('currency')
|
||||
s += "'" + pair + "'" + (", " if attrs else "")
|
||||
s += ', '.join(f'{k}={v!r}' for k, v in attrs.items())
|
||||
s += ')'
|
||||
return s
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
def pair(self) -> str:
|
||||
"""Short name of pair."""
|
||||
return self.symbol + self.currency
|
||||
|
||||
|
||||
class Index(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', exchange: str = '', currency: str = '',
|
||||
**kwargs):
|
||||
"""
|
||||
Index.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
exchange: Destination exchange.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, 'IND', symbol=symbol,
|
||||
exchange=exchange, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class CFD(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', exchange: str = '', currency: str = '',
|
||||
**kwargs):
|
||||
"""
|
||||
Contract For Difference.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
exchange: Destination exchange.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, 'CFD', symbol=symbol,
|
||||
exchange=exchange, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class Commodity(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', exchange: str = '', currency: str = '',
|
||||
**kwargs):
|
||||
"""
|
||||
Commodity.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
exchange: Destination exchange.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, 'CMDTY', symbol=symbol,
|
||||
exchange=exchange, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class Bond(Contract):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Bond."""
|
||||
Contract.__init__(self, 'BOND', **kwargs)
|
||||
|
||||
|
||||
class FuturesOption(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', lastTradeDateOrContractMonth: str = '',
|
||||
strike: float = 0.0, right: str = '', exchange: str = '',
|
||||
multiplier: str = '', currency: str = '', **kwargs):
|
||||
"""
|
||||
Option on a futures contract.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
lastTradeDateOrContractMonth: The option's last trading day
|
||||
or contract month.
|
||||
|
||||
* YYYYMM format: To specify last month
|
||||
* YYYYMMDD format: To specify last trading day
|
||||
strike: The option's strike price.
|
||||
right: Put or call option.
|
||||
Valid values are 'P', 'PUT', 'C' or 'CALL'.
|
||||
exchange: Destination exchange.
|
||||
multiplier: The contract multiplier.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, 'FOP', symbol=symbol,
|
||||
lastTradeDateOrContractMonth=lastTradeDateOrContractMonth,
|
||||
strike=strike, right=right, exchange=exchange,
|
||||
multiplier=multiplier, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class MutualFund(Contract):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Mutual fund."""
|
||||
Contract.__init__(self, 'FUND', **kwargs)
|
||||
|
||||
|
||||
class Warrant(Contract):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Warrant option."""
|
||||
Contract.__init__(self, 'WAR', **kwargs)
|
||||
|
||||
|
||||
class Bag(Contract):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Bag contract."""
|
||||
Contract.__init__(self, 'BAG', **kwargs)
|
||||
|
||||
|
||||
class Crypto(Contract):
|
||||
|
||||
def __init__(
|
||||
self, symbol: str = '', exchange: str = '', currency: str = '',
|
||||
**kwargs):
|
||||
"""
|
||||
Crypto currency contract.
|
||||
|
||||
Args:
|
||||
symbol: Symbol name.
|
||||
exchange: Destination exchange.
|
||||
currency: Underlying currency.
|
||||
"""
|
||||
Contract.__init__(
|
||||
self, secType='CRYPTO', symbol=symbol,
|
||||
exchange=exchange, currency=currency, **kwargs)
|
||||
|
||||
|
||||
class TagValue(NamedTuple):
|
||||
tag: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComboLeg:
|
||||
conId: int = 0
|
||||
ratio: int = 0
|
||||
action: str = ''
|
||||
exchange: str = ''
|
||||
openClose: int = 0
|
||||
shortSaleSlot: int = 0
|
||||
designatedLocation: str = ''
|
||||
exemptCode: int = -1
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeltaNeutralContract:
|
||||
conId: int = 0
|
||||
delta: float = 0.0
|
||||
price: float = 0.0
|
||||
|
||||
|
||||
class TradingSession(NamedTuple):
|
||||
start: dt.datetime
|
||||
end: dt.datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContractDetails:
|
||||
contract: Optional[Contract] = None
|
||||
marketName: str = ''
|
||||
minTick: float = 0.0
|
||||
orderTypes: str = ''
|
||||
validExchanges: str = ''
|
||||
priceMagnifier: int = 0
|
||||
underConId: int = 0
|
||||
longName: str = ''
|
||||
contractMonth: str = ''
|
||||
industry: str = ''
|
||||
category: str = ''
|
||||
subcategory: str = ''
|
||||
timeZoneId: str = ''
|
||||
tradingHours: str = ''
|
||||
liquidHours: str = ''
|
||||
evRule: str = ''
|
||||
evMultiplier: int = 0
|
||||
mdSizeMultiplier: int = 1 # obsolete
|
||||
aggGroup: int = 0
|
||||
underSymbol: str = ''
|
||||
underSecType: str = ''
|
||||
marketRuleIds: str = ''
|
||||
secIdList: List[TagValue] = field(default_factory=list)
|
||||
realExpirationDate: str = ''
|
||||
lastTradeTime: str = ''
|
||||
stockType: str = ''
|
||||
minSize: float = 0.0
|
||||
sizeIncrement: float = 0.0
|
||||
suggestedSizeIncrement: float = 0.0
|
||||
# minCashQtySize: float = 0.0
|
||||
cusip: str = ''
|
||||
ratings: str = ''
|
||||
descAppend: str = ''
|
||||
bondType: str = ''
|
||||
couponType: str = ''
|
||||
callable: bool = False
|
||||
putable: bool = False
|
||||
coupon: float = 0
|
||||
convertible: bool = False
|
||||
maturity: str = ''
|
||||
issueDate: str = ''
|
||||
nextOptionDate: str = ''
|
||||
nextOptionType: str = ''
|
||||
nextOptionPartial: bool = False
|
||||
notes: str = ''
|
||||
|
||||
def tradingSessions(self) -> List[TradingSession]:
|
||||
return self._parseSessions(self.tradingHours)
|
||||
|
||||
def liquidSessions(self) -> List[TradingSession]:
|
||||
return self._parseSessions(self.liquidHours)
|
||||
|
||||
def _parseSessions(self, s: str) -> List[TradingSession]:
|
||||
tz = util.ZoneInfo(self.timeZoneId)
|
||||
sessions = []
|
||||
for sess in s.split(';'):
|
||||
if not sess or 'CLOSED' in sess:
|
||||
continue
|
||||
sessions.append(TradingSession(*[
|
||||
dt.datetime.strptime(t, '%Y%m%d:%H%M').replace(tzinfo=tz)
|
||||
for t in sess.split('-')]))
|
||||
return sessions
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContractDescription:
|
||||
contract: Optional[Contract] = None
|
||||
derivativeSecTypes: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanData:
|
||||
rank: int
|
||||
contractDetails: ContractDetails
|
||||
distance: str
|
||||
benchmark: str
|
||||
projection: str
|
||||
legsStr: str
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,127 @@
|
||||
"""Access to account statement webservice."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import xml.etree.ElementTree as et
|
||||
from contextlib import suppress
|
||||
from urllib.request import urlopen
|
||||
|
||||
from ib_insync import util
|
||||
from ib_insync.objects import DynamicObject
|
||||
|
||||
_logger = logging.getLogger('ib_insync.flexreport')
|
||||
|
||||
|
||||
class FlexError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FlexReport:
|
||||
"""
|
||||
To obtain a token:
|
||||
|
||||
* Login to web portal
|
||||
* Go to Settings
|
||||
* Click on "Configure Flex Web Service"
|
||||
* Generate token
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
root: et.Element
|
||||
|
||||
def __init__(self, token=None, queryId=None, path=None):
|
||||
"""
|
||||
Download a report by giving a valid ``token`` and ``queryId``,
|
||||
or load from file by giving a valid ``path``.
|
||||
"""
|
||||
if token and queryId:
|
||||
self.download(token, queryId)
|
||||
elif path:
|
||||
self.load(path)
|
||||
|
||||
def topics(self):
|
||||
"""Get the set of topics that can be extracted from this report."""
|
||||
return set(node.tag for node in self.root.iter() if node.attrib)
|
||||
|
||||
def extract(self, topic: str, parseNumbers=True) -> list:
|
||||
"""
|
||||
Extract items of given topic and return as list of objects.
|
||||
|
||||
The topic is a string like TradeConfirm, ChangeInDividendAccrual,
|
||||
Order, etc.
|
||||
"""
|
||||
cls = type(topic, (DynamicObject,), {})
|
||||
results = [cls(**node.attrib) for node in self.root.iter(topic)]
|
||||
if parseNumbers:
|
||||
for obj in results:
|
||||
d = obj.__dict__
|
||||
for k, v in d.items():
|
||||
with suppress(ValueError):
|
||||
d[k] = float(v)
|
||||
d[k] = int(v)
|
||||
return results
|
||||
|
||||
def df(self, topic: str, parseNumbers=True):
|
||||
"""Same as extract but return the result as a pandas DataFrame."""
|
||||
return util.df(self.extract(topic, parseNumbers))
|
||||
|
||||
def download(self, token, queryId):
|
||||
"""Download report for the given ``token`` and ``queryId``."""
|
||||
url = (
|
||||
'https://gdcdyn.interactivebrokers.com'
|
||||
f'/Universal/servlet/FlexStatementService.SendRequest?'
|
||||
f't={token}&q={queryId}&v=3')
|
||||
resp = urlopen(url)
|
||||
data = resp.read()
|
||||
|
||||
root = et.fromstring(data)
|
||||
elem = root.find('Status')
|
||||
if elem is not None and elem.text == 'Success':
|
||||
elem = root.find('ReferenceCode')
|
||||
assert elem is not None
|
||||
code = elem.text
|
||||
elem = root.find('Url')
|
||||
assert elem is not None
|
||||
baseUrl = elem.text
|
||||
_logger.info('Statement is being prepared...')
|
||||
else:
|
||||
elem = root.find('ErrorCode')
|
||||
errorCode = elem.text if elem is not None else ''
|
||||
elem = root.find('ErrorMessage')
|
||||
errorMsg = elem.text if elem is not None else ''
|
||||
raise FlexError(f'{errorCode}: {errorMsg}')
|
||||
|
||||
while True:
|
||||
time.sleep(1)
|
||||
url = f'{baseUrl}?q={code}&t={token}'
|
||||
resp = urlopen(url)
|
||||
self.data = resp.read()
|
||||
self.root = et.fromstring(self.data)
|
||||
if self.root[0].tag == 'code':
|
||||
msg = self.root[0].text
|
||||
if msg and msg.startswith('Statement generation in progress'):
|
||||
_logger.info('still working...')
|
||||
continue
|
||||
else:
|
||||
raise FlexError(msg)
|
||||
break
|
||||
_logger.info('Statement retrieved.')
|
||||
|
||||
def load(self, path):
|
||||
"""Load report from XML file."""
|
||||
with open(path, 'rb') as f:
|
||||
self.data = f.read()
|
||||
self.root = et.fromstring(self.data)
|
||||
|
||||
def save(self, path):
|
||||
"""Save report to XML file."""
|
||||
with open(path, 'wb') as f:
|
||||
f.write(self.data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
util.logToConsole()
|
||||
report = FlexReport('945692423458902392892687', '272555')
|
||||
print(report.topics())
|
||||
trades = report.extract('Trade')
|
||||
print(trades)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,358 @@
|
||||
"""Programmatic control over the TWS/gateway client software."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from eventkit import Event
|
||||
|
||||
import ib_insync.util as util
|
||||
from ib_insync.contract import Contract, Forex
|
||||
from ib_insync.ib import IB
|
||||
|
||||
|
||||
@dataclass
|
||||
class IBC:
|
||||
r"""
|
||||
Programmatic control over starting and stopping TWS/Gateway
|
||||
using IBC (https://github.com/IbcAlpha/IBC).
|
||||
|
||||
Args:
|
||||
twsVersion (int): (required) The major version number for
|
||||
TWS or gateway.
|
||||
gateway (bool):
|
||||
* True = gateway
|
||||
* False = TWS
|
||||
tradingMode (str): 'live' or 'paper'.
|
||||
userid (str): IB account username. It is recommended to set the real
|
||||
username/password in a secured IBC config file.
|
||||
password (str): IB account password.
|
||||
twsPath (str): Path to the TWS installation folder.
|
||||
Defaults:
|
||||
|
||||
* Linux: ~/Jts
|
||||
* OS X: ~/Applications
|
||||
* Windows: C:\\Jts
|
||||
twsSettingsPath (str): Path to the TWS settings folder.
|
||||
Defaults:
|
||||
|
||||
* Linux: ~/Jts
|
||||
* OS X: ~/Jts
|
||||
* Windows: Not available
|
||||
ibcPath (str): Path to the IBC installation folder.
|
||||
Defaults:
|
||||
|
||||
* Linux: /opt/ibc
|
||||
* OS X: /opt/ibc
|
||||
* Windows: C:\\IBC
|
||||
ibcIni (str): Path to the IBC configuration file.
|
||||
Defaults:
|
||||
|
||||
* Linux: ~/ibc/config.ini
|
||||
* OS X: ~/ibc/config.ini
|
||||
* Windows: %%HOMEPATH%%\\Documents\IBC\\config.ini
|
||||
javaPath (str): Path to Java executable.
|
||||
Default is to use the Java VM included with TWS/gateway.
|
||||
fixuserid (str): FIX account user id (gateway only).
|
||||
fixpassword (str): FIX account password (gateway only).
|
||||
on2fatimeout (str): What to do if 2-factor authentication times
|
||||
out; Can be 'restart' or 'exit'.
|
||||
|
||||
This is not intended to be run in a notebook.
|
||||
|
||||
To use IBC on Windows, the proactor (or quamash) event loop
|
||||
must have been set:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
asyncio.set_event_loop(asyncio.ProactorEventLoop())
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ibc = IBC(976, gateway=True, tradingMode='live',
|
||||
userid='edemo', password='demouser')
|
||||
ibc.start()
|
||||
IB.run()
|
||||
"""
|
||||
|
||||
IbcLogLevel: ClassVar = logging.DEBUG
|
||||
|
||||
twsVersion: int = 0
|
||||
gateway: bool = False
|
||||
tradingMode: str = ''
|
||||
twsPath: str = ''
|
||||
twsSettingsPath: str = ''
|
||||
ibcPath: str = ''
|
||||
ibcIni: str = ''
|
||||
javaPath: str = ''
|
||||
userid: str = ''
|
||||
password: str = ''
|
||||
fixuserid: str = ''
|
||||
fixpassword: str = ''
|
||||
on2fatimeout: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
self._isWindows = sys.platform == 'win32'
|
||||
if not self.ibcPath:
|
||||
self.ibcPath = '/opt/ibc' if not self._isWindows else 'C:\\IBC'
|
||||
self._proc = None
|
||||
self._monitor = None
|
||||
self._logger = logging.getLogger('ib_insync.IBC')
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *_exc):
|
||||
self.terminate()
|
||||
|
||||
def start(self):
|
||||
"""Launch TWS/IBG."""
|
||||
util.run(self.startAsync())
|
||||
|
||||
def terminate(self):
|
||||
"""Terminate TWS/IBG."""
|
||||
util.run(self.terminateAsync())
|
||||
|
||||
async def startAsync(self):
|
||||
if self._proc:
|
||||
return
|
||||
self._logger.info('Starting')
|
||||
|
||||
# map from field names to cmd arguments; key=(UnixArg, WindowsArg)
|
||||
args = dict(
|
||||
twsVersion=('', ''),
|
||||
gateway=('--gateway', '/Gateway'),
|
||||
tradingMode=('--mode=', '/Mode:'),
|
||||
twsPath=('--tws-path=', '/TwsPath:'),
|
||||
twsSettingsPath=('--tws-settings-path=', ''),
|
||||
ibcPath=('--ibc-path=', '/IbcPath:'),
|
||||
ibcIni=('--ibc-ini=', '/Config:'),
|
||||
javaPath=('--java-path=', '/JavaPath:'),
|
||||
userid=('--user=', '/User:'),
|
||||
password=('--pw=', '/PW:'),
|
||||
fixuserid=('--fix-user=', '/FIXUser:'),
|
||||
fixpassword=('--fix-pw=', '/FIXPW:'),
|
||||
on2fatimeout=('--on2fatimeout=', '/On2FATimeout:'),
|
||||
)
|
||||
|
||||
# create shell command
|
||||
cmd = [
|
||||
f'{self.ibcPath}\\scripts\\StartIBC.bat' if self._isWindows else
|
||||
f'{self.ibcPath}/scripts/ibcstart.sh']
|
||||
for k, v in util.dataclassAsDict(self).items():
|
||||
arg = args[k][self._isWindows]
|
||||
if v:
|
||||
if arg.endswith('=') or arg.endswith(':'):
|
||||
cmd.append(f'{arg}{v}')
|
||||
elif arg:
|
||||
cmd.append(arg)
|
||||
else:
|
||||
cmd.append(str(v))
|
||||
|
||||
# run shell command
|
||||
self._proc = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=asyncio.subprocess.PIPE)
|
||||
self._monitor = asyncio.ensure_future(self.monitorAsync())
|
||||
|
||||
async def terminateAsync(self):
|
||||
if not self._proc:
|
||||
return
|
||||
self._logger.info('Terminating')
|
||||
if self._monitor:
|
||||
self._monitor.cancel()
|
||||
self._monitor = None
|
||||
if self._isWindows:
|
||||
import subprocess
|
||||
subprocess.call(
|
||||
['taskkill', '/F', '/T', '/PID', str(self._proc.pid)])
|
||||
else:
|
||||
with suppress(ProcessLookupError):
|
||||
self._proc.terminate()
|
||||
await self._proc.wait()
|
||||
self._proc = None
|
||||
|
||||
async def monitorAsync(self):
|
||||
while self._proc:
|
||||
line = await self._proc.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
self._logger.log(IBC.IbcLogLevel, line.strip().decode())
|
||||
|
||||
|
||||
@dataclass
|
||||
class Watchdog:
|
||||
"""
|
||||
Start, connect and watch over the TWS or gateway app and try to keep it
|
||||
up and running. It is intended to be used in an event-driven
|
||||
application that properly initializes itself upon (re-)connect.
|
||||
|
||||
It is not intended to be used in a notebook or in imperative-style code.
|
||||
Do not expect Watchdog to magically shield you from reality. Do not use
|
||||
Watchdog unless you understand what it does and doesn't do.
|
||||
|
||||
Args:
|
||||
controller (IBC): (required) IBC instance.
|
||||
ib (IB): (required) IB instance to be used. Do not connect this
|
||||
instance as Watchdog takes care of that.
|
||||
host (str): Used for connecting IB instance.
|
||||
port (int): Used for connecting IB instance.
|
||||
clientId (int): Used for connecting IB instance.
|
||||
connectTimeout (float): Used for connecting IB instance.
|
||||
readonly (bool): Used for connecting IB instance.
|
||||
appStartupTime (float): Time (in seconds) that the app is given
|
||||
to start up. Make sure that it is given ample time.
|
||||
appTimeout (float): Timeout (in seconds) for network traffic idle time.
|
||||
retryDelay (float): Time (in seconds) to restart app after a
|
||||
previous failure.
|
||||
probeContract (Contract): Contract to use for historical data
|
||||
probe requests (default is EURUSD).
|
||||
probeTimeout (float); Timeout (in seconds) for the probe request.
|
||||
|
||||
The idea is to wait until there is no traffic coming from the app for
|
||||
a certain amount of time (the ``appTimeout`` parameter). This triggers
|
||||
a historical request to be placed just to see if the app is still alive
|
||||
and well. If yes, then continue, if no then restart the whole app
|
||||
and reconnect. Restarting will also occur directly on errors 1100 and 100.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def onConnected():
|
||||
print(ib.accountValues())
|
||||
|
||||
ibc = IBC(974, gateway=True, tradingMode='paper')
|
||||
ib = IB()
|
||||
ib.connectedEvent += onConnected
|
||||
watchdog = Watchdog(ibc, ib, port=4002)
|
||||
watchdog.start()
|
||||
ib.run()
|
||||
|
||||
Events:
|
||||
* ``startingEvent`` (watchdog: :class:`.Watchdog`)
|
||||
* ``startedEvent`` (watchdog: :class:`.Watchdog`)
|
||||
* ``stoppingEvent`` (watchdog: :class:`.Watchdog`)
|
||||
* ``stoppedEvent`` (watchdog: :class:`.Watchdog`)
|
||||
* ``softTimeoutEvent`` (watchdog: :class:`.Watchdog`)
|
||||
* ``hardTimeoutEvent`` (watchdog: :class:`.Watchdog`)
|
||||
"""
|
||||
|
||||
events = [
|
||||
'startingEvent', 'startedEvent', 'stoppingEvent', 'stoppedEvent',
|
||||
'softTimeoutEvent', 'hardTimeoutEvent']
|
||||
|
||||
controller: IBC
|
||||
ib: IB
|
||||
host: str = '127.0.0.1'
|
||||
port: int = 7497
|
||||
clientId: int = 1
|
||||
connectTimeout: float = 2
|
||||
appStartupTime: float = 30
|
||||
appTimeout: float = 20
|
||||
retryDelay: float = 2
|
||||
readonly: bool = False
|
||||
account: str = ''
|
||||
probeContract: Contract = Forex('EURUSD')
|
||||
probeTimeout: float = 4
|
||||
|
||||
def __post_init__(self):
|
||||
self.startingEvent = Event('startingEvent')
|
||||
self.startedEvent = Event('startedEvent')
|
||||
self.stoppingEvent = Event('stoppingEvent')
|
||||
self.stoppedEvent = Event('stoppedEvent')
|
||||
self.softTimeoutEvent = Event('softTimeoutEvent')
|
||||
self.hardTimeoutEvent = Event('hardTimeoutEvent')
|
||||
if not self.controller:
|
||||
raise ValueError('No controller supplied')
|
||||
if not self.ib:
|
||||
raise ValueError('No IB instance supplied')
|
||||
if self.ib.isConnected():
|
||||
raise ValueError('IB instance must not be connected')
|
||||
self._runner = None
|
||||
self._logger = logging.getLogger('ib_insync.Watchdog')
|
||||
|
||||
def start(self):
|
||||
self._logger.info('Starting')
|
||||
self.startingEvent.emit(self)
|
||||
self._runner = asyncio.ensure_future(self.runAsync())
|
||||
return self._runner
|
||||
|
||||
def stop(self):
|
||||
self._logger.info('Stopping')
|
||||
self.stoppingEvent.emit(self)
|
||||
self.ib.disconnect()
|
||||
self._runner = None
|
||||
|
||||
async def runAsync(self):
|
||||
|
||||
def onTimeout(idlePeriod):
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
def onError(reqId, errorCode, errorString, contract):
|
||||
if errorCode in {100, 1100} and not waiter.done():
|
||||
waiter.set_exception(Warning(f'Error {errorCode}'))
|
||||
|
||||
def onDisconnected():
|
||||
if not waiter.done():
|
||||
waiter.set_exception(Warning('Disconnected'))
|
||||
|
||||
while self._runner:
|
||||
try:
|
||||
await self.controller.startAsync()
|
||||
await asyncio.sleep(self.appStartupTime)
|
||||
await self.ib.connectAsync(
|
||||
self.host, self.port, self.clientId, self.connectTimeout,
|
||||
self.readonly, self.account)
|
||||
self.startedEvent.emit(self)
|
||||
self.ib.setTimeout(self.appTimeout)
|
||||
self.ib.timeoutEvent += onTimeout
|
||||
self.ib.errorEvent += onError
|
||||
self.ib.disconnectedEvent += onDisconnected
|
||||
|
||||
while self._runner:
|
||||
waiter: asyncio.Future = asyncio.Future()
|
||||
await waiter
|
||||
# soft timeout, probe the app with a historical request
|
||||
self._logger.debug('Soft timeout')
|
||||
self.softTimeoutEvent.emit(self)
|
||||
probe = self.ib.reqHistoricalDataAsync(
|
||||
self.probeContract, '', '30 S', '5 secs',
|
||||
'MIDPOINT', False)
|
||||
bars = None
|
||||
with suppress(asyncio.TimeoutError):
|
||||
bars = await asyncio.wait_for(probe, self.probeTimeout)
|
||||
if not bars:
|
||||
self.hardTimeoutEvent.emit(self)
|
||||
raise Warning('Hard timeout')
|
||||
self.ib.setTimeout(self.appTimeout)
|
||||
|
||||
except ConnectionRefusedError:
|
||||
pass
|
||||
except Warning as w:
|
||||
self._logger.warning(w)
|
||||
except Exception as e:
|
||||
self._logger.exception(e)
|
||||
finally:
|
||||
self.ib.timeoutEvent -= onTimeout
|
||||
self.ib.errorEvent -= onError
|
||||
self.ib.disconnectedEvent -= onDisconnected
|
||||
await self.controller.terminateAsync()
|
||||
self.stoppedEvent.emit(self)
|
||||
if self._runner:
|
||||
await asyncio.sleep(self.retryDelay)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ibc = IBC(1012, gateway=True, tradingMode='paper')
|
||||
ib = IB()
|
||||
app = Watchdog(ibc, ib, appStartupTime=15)
|
||||
app.start()
|
||||
IB.run()
|
||||
@@ -0,0 +1,499 @@
|
||||
"""Object hierarchy."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date as date_, datetime
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
|
||||
from eventkit import Event
|
||||
|
||||
from .contract import Contract, ScanData, TagValue
|
||||
from .util import EPOCH, UNSET_DOUBLE, UNSET_INTEGER
|
||||
|
||||
nan = float('nan')
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScannerSubscription:
|
||||
numberOfRows: int = -1
|
||||
instrument: str = ''
|
||||
locationCode: str = ''
|
||||
scanCode: str = ''
|
||||
abovePrice: float = UNSET_DOUBLE
|
||||
belowPrice: float = UNSET_DOUBLE
|
||||
aboveVolume: int = UNSET_INTEGER
|
||||
marketCapAbove: float = UNSET_DOUBLE
|
||||
marketCapBelow: float = UNSET_DOUBLE
|
||||
moodyRatingAbove: str = ''
|
||||
moodyRatingBelow: str = ''
|
||||
spRatingAbove: str = ''
|
||||
spRatingBelow: str = ''
|
||||
maturityDateAbove: str = ''
|
||||
maturityDateBelow: str = ''
|
||||
couponRateAbove: float = UNSET_DOUBLE
|
||||
couponRateBelow: float = UNSET_DOUBLE
|
||||
excludeConvertible: bool = False
|
||||
averageOptionVolumeAbove: int = UNSET_INTEGER
|
||||
scannerSettingPairs: str = ''
|
||||
stockTypeFilter: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class SoftDollarTier:
|
||||
name: str = ''
|
||||
val: str = ''
|
||||
displayName: str = ''
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.name or self.val or self.displayName)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Execution:
|
||||
execId: str = ''
|
||||
time: datetime = field(default=EPOCH)
|
||||
acctNumber: str = ''
|
||||
exchange: str = ''
|
||||
side: str = ''
|
||||
shares: float = 0.0
|
||||
price: float = 0.0
|
||||
permId: int = 0
|
||||
clientId: int = 0
|
||||
orderId: int = 0
|
||||
liquidation: int = 0
|
||||
cumQty: float = 0.0
|
||||
avgPrice: float = 0.0
|
||||
orderRef: str = ''
|
||||
evRule: str = ''
|
||||
evMultiplier: float = 0.0
|
||||
modelCode: str = ''
|
||||
lastLiquidity: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommissionReport:
|
||||
execId: str = ''
|
||||
commission: float = 0.0
|
||||
currency: str = ''
|
||||
realizedPNL: float = 0.0
|
||||
yield_: float = 0.0
|
||||
yieldRedemptionDate: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionFilter:
|
||||
clientId: int = 0
|
||||
acctCode: str = ''
|
||||
time: str = ''
|
||||
symbol: str = ''
|
||||
secType: str = ''
|
||||
exchange: str = ''
|
||||
side: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class BarData:
|
||||
date: Union[date_, datetime] = EPOCH
|
||||
open: float = 0.0
|
||||
high: float = 0.0
|
||||
low: float = 0.0
|
||||
close: float = 0.0
|
||||
volume: float = 0
|
||||
average: float = 0.0
|
||||
barCount: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealTimeBar:
|
||||
time: datetime = EPOCH
|
||||
endTime: int = -1
|
||||
open_: float = 0.0
|
||||
high: float = 0.0
|
||||
low: float = 0.0
|
||||
close: float = 0.0
|
||||
volume: float = 0.0
|
||||
wap: float = 0.0
|
||||
count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickAttrib:
|
||||
canAutoExecute: bool = False
|
||||
pastLimit: bool = False
|
||||
preOpen: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickAttribBidAsk:
|
||||
bidPastLow: bool = False
|
||||
askPastHigh: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TickAttribLast:
|
||||
pastLimit: bool = False
|
||||
unreported: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistogramData:
|
||||
price: float = 0.0
|
||||
count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewsProvider:
|
||||
code: str = ''
|
||||
name: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class DepthMktDataDescription:
|
||||
exchange: str = ''
|
||||
secType: str = ''
|
||||
listingExch: str = ''
|
||||
serviceDataType: str = ''
|
||||
aggGroup: int = UNSET_INTEGER
|
||||
|
||||
|
||||
@dataclass
|
||||
class PnL:
|
||||
account: str = ''
|
||||
modelCode: str = ''
|
||||
dailyPnL: float = nan
|
||||
unrealizedPnL: float = nan
|
||||
realizedPnL: float = nan
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradeLogEntry:
|
||||
time: datetime
|
||||
status: str = ''
|
||||
message: str = ''
|
||||
errorCode: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PnLSingle:
|
||||
account: str = ''
|
||||
modelCode: str = ''
|
||||
conId: int = 0
|
||||
dailyPnL: float = nan
|
||||
unrealizedPnL: float = nan
|
||||
realizedPnL: float = nan
|
||||
position: int = 0
|
||||
value: float = nan
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistoricalSession:
|
||||
startDateTime: str = ''
|
||||
endDateTime: str = ''
|
||||
refDate: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class HistoricalSchedule:
|
||||
startDateTime: str = ''
|
||||
endDateTime: str = ''
|
||||
timeZone: str = ''
|
||||
sessions: List[HistoricalSession] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WshEventData:
|
||||
conId: int = UNSET_INTEGER
|
||||
filter: str = ''
|
||||
fillWatchlist: bool = False
|
||||
fillPortfolio: bool = False
|
||||
fillCompetitors: bool = False
|
||||
startDate: str = ''
|
||||
endDate: str = ''
|
||||
totalLimit: int = UNSET_INTEGER
|
||||
|
||||
|
||||
class AccountValue(NamedTuple):
|
||||
account: str
|
||||
tag: str
|
||||
value: str
|
||||
currency: str
|
||||
modelCode: str
|
||||
|
||||
|
||||
class TickData(NamedTuple):
|
||||
time: datetime
|
||||
tickType: int
|
||||
price: float
|
||||
size: float
|
||||
|
||||
|
||||
class HistoricalTick(NamedTuple):
|
||||
time: datetime
|
||||
price: float
|
||||
size: float
|
||||
|
||||
|
||||
class HistoricalTickBidAsk(NamedTuple):
|
||||
time: datetime
|
||||
tickAttribBidAsk: TickAttribBidAsk
|
||||
priceBid: float
|
||||
priceAsk: float
|
||||
sizeBid: float
|
||||
sizeAsk: float
|
||||
|
||||
|
||||
class HistoricalTickLast(NamedTuple):
|
||||
time: datetime
|
||||
tickAttribLast: TickAttribLast
|
||||
price: float
|
||||
size: float
|
||||
exchange: str
|
||||
specialConditions: str
|
||||
|
||||
|
||||
class TickByTickAllLast(NamedTuple):
|
||||
tickType: int
|
||||
time: datetime
|
||||
price: float
|
||||
size: float
|
||||
tickAttribLast: TickAttribLast
|
||||
exchange: str
|
||||
specialConditions: str
|
||||
|
||||
|
||||
class TickByTickBidAsk(NamedTuple):
|
||||
time: datetime
|
||||
bidPrice: float
|
||||
askPrice: float
|
||||
bidSize: float
|
||||
askSize: float
|
||||
tickAttribBidAsk: TickAttribBidAsk
|
||||
|
||||
|
||||
class TickByTickMidPoint(NamedTuple):
|
||||
time: datetime
|
||||
midPoint: float
|
||||
|
||||
|
||||
class MktDepthData(NamedTuple):
|
||||
time: datetime
|
||||
position: int
|
||||
marketMaker: str
|
||||
operation: int
|
||||
side: int
|
||||
price: float
|
||||
size: float
|
||||
|
||||
|
||||
class DOMLevel(NamedTuple):
|
||||
price: float
|
||||
size: float
|
||||
marketMaker: str
|
||||
|
||||
|
||||
class PriceIncrement(NamedTuple):
|
||||
lowEdge: float
|
||||
increment: float
|
||||
|
||||
|
||||
class PortfolioItem(NamedTuple):
|
||||
contract: Contract
|
||||
position: float
|
||||
marketPrice: float
|
||||
marketValue: float
|
||||
averageCost: float
|
||||
unrealizedPNL: float
|
||||
realizedPNL: float
|
||||
account: str
|
||||
|
||||
|
||||
class Position(NamedTuple):
|
||||
account: str
|
||||
contract: Contract
|
||||
position: float
|
||||
avgCost: float
|
||||
|
||||
|
||||
class Fill(NamedTuple):
|
||||
contract: Contract
|
||||
execution: Execution
|
||||
commissionReport: CommissionReport
|
||||
time: datetime
|
||||
|
||||
|
||||
class OptionComputation(NamedTuple):
|
||||
tickAttrib: int
|
||||
impliedVol: Optional[float]
|
||||
delta: Optional[float]
|
||||
optPrice: Optional[float]
|
||||
pvDividend: Optional[float]
|
||||
gamma: Optional[float]
|
||||
vega: Optional[float]
|
||||
theta: Optional[float]
|
||||
undPrice: Optional[float]
|
||||
|
||||
|
||||
class OptionChain(NamedTuple):
|
||||
exchange: str
|
||||
underlyingConId: int
|
||||
tradingClass: str
|
||||
multiplier: str
|
||||
expirations: List[str]
|
||||
strikes: List[float]
|
||||
|
||||
|
||||
class Dividends(NamedTuple):
|
||||
past12Months: Optional[float]
|
||||
next12Months: Optional[float]
|
||||
nextDate: Optional[date_]
|
||||
nextAmount: Optional[float]
|
||||
|
||||
|
||||
class NewsArticle(NamedTuple):
|
||||
articleType: int
|
||||
articleText: str
|
||||
|
||||
|
||||
class HistoricalNews(NamedTuple):
|
||||
time: datetime
|
||||
providerCode: str
|
||||
articleId: str
|
||||
headline: str
|
||||
|
||||
|
||||
class NewsTick(NamedTuple):
|
||||
timeStamp: int
|
||||
providerCode: str
|
||||
articleId: str
|
||||
headline: str
|
||||
extraData: str
|
||||
|
||||
|
||||
class NewsBulletin(NamedTuple):
|
||||
msgId: int
|
||||
msgType: int
|
||||
message: str
|
||||
origExchange: str
|
||||
|
||||
|
||||
class FamilyCode(NamedTuple):
|
||||
accountID: str
|
||||
familyCodeStr: str
|
||||
|
||||
|
||||
class SmartComponent(NamedTuple):
|
||||
bitNumber: int
|
||||
exchange: str
|
||||
exchangeLetter: str
|
||||
|
||||
|
||||
class ConnectionStats(NamedTuple):
|
||||
startTime: float
|
||||
duration: float
|
||||
numBytesRecv: int
|
||||
numBytesSent: int
|
||||
numMsgRecv: int
|
||||
numMsgSent: int
|
||||
|
||||
|
||||
class BarDataList(List[BarData]):
|
||||
"""
|
||||
List of :class:`.BarData` that also stores all request parameters.
|
||||
|
||||
Events:
|
||||
|
||||
* ``updateEvent``
|
||||
(bars: :class:`.BarDataList`, hasNewBar: bool)
|
||||
"""
|
||||
|
||||
reqId: int
|
||||
contract: Contract
|
||||
endDateTime: Union[datetime, date_, str, None]
|
||||
durationStr: str
|
||||
barSizeSetting: str
|
||||
whatToShow: str
|
||||
useRTH: bool
|
||||
formatDate: int
|
||||
keepUpToDate: bool
|
||||
chartOptions: List[TagValue]
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
self.updateEvent = Event('updateEvent')
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class RealTimeBarList(List[RealTimeBar]):
|
||||
"""
|
||||
List of :class:`.RealTimeBar` that also stores all request parameters.
|
||||
|
||||
Events:
|
||||
|
||||
* ``updateEvent``
|
||||
(bars: :class:`.RealTimeBarList`, hasNewBar: bool)
|
||||
"""
|
||||
|
||||
reqId: int
|
||||
contract: Contract
|
||||
barSize: int
|
||||
whatToShow: str
|
||||
useRTH: bool
|
||||
realTimeBarsOptions: List[TagValue]
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
self.updateEvent = Event('updateEvent')
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class ScanDataList(List[ScanData]):
|
||||
"""
|
||||
List of :class:`.ScanData` that also stores all request parameters.
|
||||
|
||||
Events:
|
||||
* ``updateEvent`` (:class:`.ScanDataList`)
|
||||
"""
|
||||
|
||||
reqId: int
|
||||
subscription: ScannerSubscription
|
||||
scannerSubscriptionOptions: List[TagValue]
|
||||
scannerSubscriptionFilterOptions: List[TagValue]
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
self.updateEvent = Event('updateEvent')
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class DynamicObject:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
clsName = self.__class__.__name__
|
||||
kwargs = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items())
|
||||
return f'{clsName}({kwargs})'
|
||||
|
||||
|
||||
class FundamentalRatios(DynamicObject):
|
||||
"""
|
||||
See:
|
||||
https://web.archive.org/web/20200725010343/https://interactivebrokers.github.io/tws-api/fundamental_ratios_tags.html
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,412 @@
|
||||
"""Order types used by Interactive Brokers."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, List, NamedTuple, Set
|
||||
|
||||
from eventkit import Event
|
||||
|
||||
from .contract import Contract, TagValue
|
||||
from .objects import Fill, SoftDollarTier, TradeLogEntry
|
||||
from .util import UNSET_DOUBLE, UNSET_INTEGER, dataclassNonDefaults
|
||||
|
||||
|
||||
@dataclass
|
||||
class Order:
|
||||
"""
|
||||
Order for trading contracts.
|
||||
|
||||
https://interactivebrokers.github.io/tws-api/available_orders.html
|
||||
"""
|
||||
|
||||
orderId: int = 0
|
||||
clientId: int = 0
|
||||
permId: int = 0
|
||||
action: str = ''
|
||||
totalQuantity: float = 0.0
|
||||
orderType: str = ''
|
||||
lmtPrice: float = UNSET_DOUBLE
|
||||
auxPrice: float = UNSET_DOUBLE
|
||||
tif: str = ''
|
||||
activeStartTime: str = ''
|
||||
activeStopTime: str = ''
|
||||
ocaGroup: str = ''
|
||||
ocaType: int = 0
|
||||
orderRef: str = ''
|
||||
transmit: bool = True
|
||||
parentId: int = 0
|
||||
blockOrder: bool = False
|
||||
sweepToFill: bool = False
|
||||
displaySize: int = 0
|
||||
triggerMethod: int = 0
|
||||
outsideRth: bool = False
|
||||
hidden: bool = False
|
||||
goodAfterTime: str = ''
|
||||
goodTillDate: str = ''
|
||||
rule80A: str = ''
|
||||
allOrNone: bool = False
|
||||
minQty: int = UNSET_INTEGER
|
||||
percentOffset: float = UNSET_DOUBLE
|
||||
overridePercentageConstraints: bool = False
|
||||
trailStopPrice: float = UNSET_DOUBLE
|
||||
trailingPercent: float = UNSET_DOUBLE
|
||||
faGroup: str = ''
|
||||
faProfile: str = ''
|
||||
faMethod: str = ''
|
||||
faPercentage: str = ''
|
||||
designatedLocation: str = ''
|
||||
openClose: str = "O"
|
||||
origin: int = 0
|
||||
shortSaleSlot: int = 0
|
||||
exemptCode: int = -1
|
||||
discretionaryAmt: float = 0.0
|
||||
eTradeOnly: bool = False
|
||||
firmQuoteOnly: bool = False
|
||||
nbboPriceCap: float = UNSET_DOUBLE
|
||||
optOutSmartRouting: bool = False
|
||||
auctionStrategy: int = 0
|
||||
startingPrice: float = UNSET_DOUBLE
|
||||
stockRefPrice: float = UNSET_DOUBLE
|
||||
delta: float = UNSET_DOUBLE
|
||||
stockRangeLower: float = UNSET_DOUBLE
|
||||
stockRangeUpper: float = UNSET_DOUBLE
|
||||
randomizePrice: bool = False
|
||||
randomizeSize: bool = False
|
||||
volatility: float = UNSET_DOUBLE
|
||||
volatilityType: int = UNSET_INTEGER
|
||||
deltaNeutralOrderType: str = ''
|
||||
deltaNeutralAuxPrice: float = UNSET_DOUBLE
|
||||
deltaNeutralConId: int = 0
|
||||
deltaNeutralSettlingFirm: str = ''
|
||||
deltaNeutralClearingAccount: str = ''
|
||||
deltaNeutralClearingIntent: str = ''
|
||||
deltaNeutralOpenClose: str = ''
|
||||
deltaNeutralShortSale: bool = False
|
||||
deltaNeutralShortSaleSlot: int = 0
|
||||
deltaNeutralDesignatedLocation: str = ''
|
||||
continuousUpdate: bool = False
|
||||
referencePriceType: int = UNSET_INTEGER
|
||||
basisPoints: float = UNSET_DOUBLE
|
||||
basisPointsType: int = UNSET_INTEGER
|
||||
scaleInitLevelSize: int = UNSET_INTEGER
|
||||
scaleSubsLevelSize: int = UNSET_INTEGER
|
||||
scalePriceIncrement: float = UNSET_DOUBLE
|
||||
scalePriceAdjustValue: float = UNSET_DOUBLE
|
||||
scalePriceAdjustInterval: int = UNSET_INTEGER
|
||||
scaleProfitOffset: float = UNSET_DOUBLE
|
||||
scaleAutoReset: bool = False
|
||||
scaleInitPosition: int = UNSET_INTEGER
|
||||
scaleInitFillQty: int = UNSET_INTEGER
|
||||
scaleRandomPercent: bool = False
|
||||
scaleTable: str = ''
|
||||
hedgeType: str = ''
|
||||
hedgeParam: str = ''
|
||||
account: str = ''
|
||||
settlingFirm: str = ''
|
||||
clearingAccount: str = ''
|
||||
clearingIntent: str = ''
|
||||
algoStrategy: str = ''
|
||||
algoParams: List[TagValue] = field(default_factory=list)
|
||||
smartComboRoutingParams: List[TagValue] = field(default_factory=list)
|
||||
algoId: str = ''
|
||||
whatIf: bool = False
|
||||
notHeld: bool = False
|
||||
solicited: bool = False
|
||||
modelCode: str = ''
|
||||
orderComboLegs: List['OrderComboLeg'] = field(default_factory=list)
|
||||
orderMiscOptions: List[TagValue] = field(default_factory=list)
|
||||
referenceContractId: int = 0
|
||||
peggedChangeAmount: float = 0.0
|
||||
isPeggedChangeAmountDecrease: bool = False
|
||||
referenceChangeAmount: float = 0.0
|
||||
referenceExchangeId: str = ''
|
||||
adjustedOrderType: str = ''
|
||||
triggerPrice: float = UNSET_DOUBLE
|
||||
adjustedStopPrice: float = UNSET_DOUBLE
|
||||
adjustedStopLimitPrice: float = UNSET_DOUBLE
|
||||
adjustedTrailingAmount: float = UNSET_DOUBLE
|
||||
adjustableTrailingUnit: int = 0
|
||||
lmtPriceOffset: float = UNSET_DOUBLE
|
||||
conditions: List['OrderCondition'] = field(default_factory=list)
|
||||
conditionsCancelOrder: bool = False
|
||||
conditionsIgnoreRth: bool = False
|
||||
extOperator: str = ''
|
||||
softDollarTier: SoftDollarTier = field(default_factory=SoftDollarTier)
|
||||
cashQty: float = UNSET_DOUBLE
|
||||
mifid2DecisionMaker: str = ''
|
||||
mifid2DecisionAlgo: str = ''
|
||||
mifid2ExecutionTrader: str = ''
|
||||
mifid2ExecutionAlgo: str = ''
|
||||
dontUseAutoPriceForHedge: bool = False
|
||||
isOmsContainer: bool = False
|
||||
discretionaryUpToLimitPrice: bool = False
|
||||
autoCancelDate: str = ''
|
||||
filledQuantity: float = UNSET_DOUBLE
|
||||
refFuturesConId: int = 0
|
||||
autoCancelParent: bool = False
|
||||
shareholder: str = ''
|
||||
imbalanceOnly: bool = False
|
||||
routeMarketableToBbo: bool = False
|
||||
parentPermId: int = 0
|
||||
usePriceMgmtAlgo: bool = False
|
||||
duration: int = UNSET_INTEGER
|
||||
postToAts: int = UNSET_INTEGER
|
||||
advancedErrorOverride: str = ''
|
||||
manualOrderTime: str = ''
|
||||
minTradeQty: int = UNSET_INTEGER
|
||||
minCompeteSize: int = UNSET_INTEGER
|
||||
competeAgainstBestOffset: float = UNSET_DOUBLE
|
||||
midOffsetAtWhole: float = UNSET_DOUBLE
|
||||
midOffsetAtHalf: float = UNSET_DOUBLE
|
||||
|
||||
def __repr__(self):
|
||||
attrs = dataclassNonDefaults(self)
|
||||
if self.__class__ is not Order:
|
||||
attrs.pop('orderType', None)
|
||||
if not self.softDollarTier:
|
||||
attrs.pop('softDollarTier')
|
||||
clsName = self.__class__.__qualname__
|
||||
kwargs = ', '.join(
|
||||
f'{k}={v!r}' for k, v in attrs.items())
|
||||
return f'{clsName}({kwargs})'
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class LimitOrder(Order):
|
||||
|
||||
def __init__(self, action: str, totalQuantity: float, lmtPrice: float,
|
||||
**kwargs):
|
||||
Order.__init__(
|
||||
self, orderType='LMT', action=action,
|
||||
totalQuantity=totalQuantity, lmtPrice=lmtPrice, **kwargs)
|
||||
|
||||
|
||||
class MarketOrder(Order):
|
||||
|
||||
def __init__(self, action: str, totalQuantity: float, **kwargs):
|
||||
Order.__init__(
|
||||
self, orderType='MKT', action=action,
|
||||
totalQuantity=totalQuantity, **kwargs)
|
||||
|
||||
|
||||
class StopOrder(Order):
|
||||
|
||||
def __init__(self, action: str, totalQuantity: float, stopPrice: float,
|
||||
**kwargs):
|
||||
Order.__init__(
|
||||
self, orderType='STP', action=action,
|
||||
totalQuantity=totalQuantity, auxPrice=stopPrice, **kwargs)
|
||||
|
||||
|
||||
class StopLimitOrder(Order):
|
||||
|
||||
def __init__(self, action: str, totalQuantity: float, lmtPrice: float,
|
||||
stopPrice: float, **kwargs):
|
||||
Order.__init__(
|
||||
self, orderType='STP LMT', action=action,
|
||||
totalQuantity=totalQuantity, lmtPrice=lmtPrice,
|
||||
auxPrice=stopPrice, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderStatus:
|
||||
orderId: int = 0
|
||||
status: str = ''
|
||||
filled: float = 0.0
|
||||
remaining: float = 0.0
|
||||
avgFillPrice: float = 0.0
|
||||
permId: int = 0
|
||||
parentId: int = 0
|
||||
lastFillPrice: float = 0.0
|
||||
clientId: int = 0
|
||||
whyHeld: str = ''
|
||||
mktCapPrice: float = 0.0
|
||||
|
||||
PendingSubmit: ClassVar[str] = 'PendingSubmit'
|
||||
PendingCancel: ClassVar[str] = 'PendingCancel'
|
||||
PreSubmitted: ClassVar[str] = 'PreSubmitted'
|
||||
Submitted: ClassVar[str] = 'Submitted'
|
||||
ApiPending: ClassVar[str] = 'ApiPending'
|
||||
ApiCancelled: ClassVar[str] = 'ApiCancelled'
|
||||
Cancelled: ClassVar[str] = 'Cancelled'
|
||||
Filled: ClassVar[str] = 'Filled'
|
||||
Inactive: ClassVar[str] = 'Inactive'
|
||||
|
||||
DoneStates: ClassVar[Set[str]] = {'Filled', 'Cancelled', 'ApiCancelled'}
|
||||
ActiveStates: ClassVar[Set[str]] = {
|
||||
'PendingSubmit', 'ApiPending', 'PreSubmitted', 'Submitted'}
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderState:
|
||||
status: str = ''
|
||||
initMarginBefore: str = ''
|
||||
maintMarginBefore: str = ''
|
||||
equityWithLoanBefore: str = ''
|
||||
initMarginChange: str = ''
|
||||
maintMarginChange: str = ''
|
||||
equityWithLoanChange: str = ''
|
||||
initMarginAfter: str = ''
|
||||
maintMarginAfter: str = ''
|
||||
equityWithLoanAfter: str = ''
|
||||
commission: float = UNSET_DOUBLE
|
||||
minCommission: float = UNSET_DOUBLE
|
||||
maxCommission: float = UNSET_DOUBLE
|
||||
commissionCurrency: str = ''
|
||||
warningText: str = ''
|
||||
completedTime: str = ''
|
||||
completedStatus: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderComboLeg:
|
||||
price: float = UNSET_DOUBLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trade:
|
||||
"""
|
||||
Trade keeps track of an order, its status and all its fills.
|
||||
|
||||
Events:
|
||||
* ``statusEvent`` (trade: :class:`.Trade`)
|
||||
* ``modifyEvent`` (trade: :class:`.Trade`)
|
||||
* ``fillEvent`` (trade: :class:`.Trade`, fill: :class:`.Fill`)
|
||||
* ``commissionReportEvent`` (trade: :class:`.Trade`,
|
||||
fill: :class:`.Fill`, commissionReport: :class:`.CommissionReport`)
|
||||
* ``filledEvent`` (trade: :class:`.Trade`)
|
||||
* ``cancelEvent`` (trade: :class:`.Trade`)
|
||||
* ``cancelledEvent`` (trade: :class:`.Trade`)
|
||||
"""
|
||||
|
||||
events: ClassVar = (
|
||||
'statusEvent', 'modifyEvent', 'fillEvent',
|
||||
'commissionReportEvent', 'filledEvent',
|
||||
'cancelEvent', 'cancelledEvent')
|
||||
|
||||
contract: Contract = field(default_factory=Contract)
|
||||
order: Order = field(default_factory=Order)
|
||||
orderStatus: 'OrderStatus' = field(default_factory=OrderStatus)
|
||||
fills: List[Fill] = field(default_factory=list)
|
||||
log: List[TradeLogEntry] = field(default_factory=list)
|
||||
advancedError: str = ''
|
||||
|
||||
def __post_init__(self):
|
||||
self.statusEvent = Event('statusEvent')
|
||||
self.modifyEvent = Event('modifyEvent')
|
||||
self.fillEvent = Event('fillEvent')
|
||||
self.commissionReportEvent = Event('commissionReportEvent')
|
||||
self.filledEvent = Event('filledEvent')
|
||||
self.cancelEvent = Event('cancelEvent')
|
||||
self.cancelledEvent = Event('cancelledEvent')
|
||||
|
||||
def isActive(self):
|
||||
"""True if eligible for execution, false otherwise."""
|
||||
return self.orderStatus.status in OrderStatus.ActiveStates
|
||||
|
||||
def isDone(self):
|
||||
"""True if completely filled or cancelled, false otherwise."""
|
||||
return self.orderStatus.status in OrderStatus.DoneStates
|
||||
|
||||
def filled(self):
|
||||
"""Number of shares filled."""
|
||||
fills = self.fills
|
||||
if self.contract.secType == 'BAG':
|
||||
# don't count fills for the leg contracts
|
||||
fills = [f for f in fills if f.contract.secType == 'BAG']
|
||||
return sum(f.execution.shares for f in fills)
|
||||
|
||||
def remaining(self):
|
||||
"""Number of shares remaining to be filled."""
|
||||
return self.order.totalQuantity - self.filled()
|
||||
|
||||
|
||||
class BracketOrder(NamedTuple):
|
||||
parent: Order
|
||||
takeProfit: Order
|
||||
stopLoss: Order
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderCondition:
|
||||
|
||||
@staticmethod
|
||||
def createClass(condType):
|
||||
d = {
|
||||
1: PriceCondition,
|
||||
3: TimeCondition,
|
||||
4: MarginCondition,
|
||||
5: ExecutionCondition,
|
||||
6: VolumeCondition,
|
||||
7: PercentChangeCondition}
|
||||
return d[condType]
|
||||
|
||||
def And(self):
|
||||
self.conjunction = 'a'
|
||||
return self
|
||||
|
||||
def Or(self):
|
||||
self.conjunction = 'o'
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriceCondition(OrderCondition):
|
||||
condType: int = 1
|
||||
conjunction: str = 'a'
|
||||
isMore: bool = True
|
||||
price: float = 0.0
|
||||
conId: int = 0
|
||||
exch: str = ''
|
||||
triggerMethod: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeCondition(OrderCondition):
|
||||
condType: int = 3
|
||||
conjunction: str = 'a'
|
||||
isMore: bool = True
|
||||
time: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarginCondition(OrderCondition):
|
||||
condType: int = 4
|
||||
conjunction: str = 'a'
|
||||
isMore: bool = True
|
||||
percent: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionCondition(OrderCondition):
|
||||
condType: int = 5
|
||||
conjunction: str = 'a'
|
||||
secType: str = ''
|
||||
exch: str = ''
|
||||
symbol: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class VolumeCondition(OrderCondition):
|
||||
condType: int = 6
|
||||
conjunction: str = 'a'
|
||||
isMore: bool = True
|
||||
volume: int = 0
|
||||
conId: int = 0
|
||||
exch: str = ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class PercentChangeCondition(OrderCondition):
|
||||
condType: int = 7
|
||||
conjunction: str = 'a'
|
||||
isMore: bool = True
|
||||
changePercent: float = 0.0
|
||||
conId: int = 0
|
||||
exch: str = ''
|
||||
@@ -0,0 +1,362 @@
|
||||
"""Access to realtime market information."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import ClassVar, List, Optional, Union
|
||||
|
||||
from eventkit import Event, Op
|
||||
|
||||
from ib_insync.contract import Contract
|
||||
from ib_insync.objects import (
|
||||
DOMLevel, Dividends, FundamentalRatios, MktDepthData,
|
||||
OptionComputation, TickByTickAllLast, TickByTickBidAsk, TickByTickMidPoint,
|
||||
TickData)
|
||||
from ib_insync.util import dataclassRepr, isNan
|
||||
|
||||
nan = float('nan')
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ticker:
|
||||
"""
|
||||
Current market data such as bid, ask, last price, etc. for a contract.
|
||||
|
||||
Streaming level-1 ticks of type :class:`.TickData` are stored in
|
||||
the ``ticks`` list.
|
||||
|
||||
Streaming level-2 ticks of type :class:`.MktDepthData` are stored in the
|
||||
``domTicks`` list. The order book (DOM) is available as lists of
|
||||
:class:`.DOMLevel` in ``domBids`` and ``domAsks``.
|
||||
|
||||
Streaming tick-by-tick ticks are stored in ``tickByTicks``.
|
||||
|
||||
For options the :class:`.OptionComputation` values for the bid, ask, resp.
|
||||
last price are stored in the ``bidGreeks``, ``askGreeks`` resp.
|
||||
``lastGreeks`` attributes. There is also ``modelGreeks`` that conveys
|
||||
the greeks as calculated by Interactive Brokers' option model.
|
||||
|
||||
Events:
|
||||
* ``updateEvent`` (ticker: :class:`.Ticker`)
|
||||
"""
|
||||
|
||||
events: ClassVar = ('updateEvent',)
|
||||
|
||||
contract: Optional[Contract] = None
|
||||
time: Optional[datetime] = None
|
||||
marketDataType: int = 1
|
||||
minTick: float = nan
|
||||
bid: float = nan
|
||||
bidSize: float = nan
|
||||
bidExchange: str = ''
|
||||
ask: float = nan
|
||||
askSize: float = nan
|
||||
askExchange: str = ''
|
||||
last: float = nan
|
||||
lastSize: float = nan
|
||||
lastExchange: str = ''
|
||||
prevBid: float = nan
|
||||
prevBidSize: float = nan
|
||||
prevAsk: float = nan
|
||||
prevAskSize: float = nan
|
||||
prevLast: float = nan
|
||||
prevLastSize: float = nan
|
||||
volume: float = nan
|
||||
open: float = nan
|
||||
high: float = nan
|
||||
low: float = nan
|
||||
close: float = nan
|
||||
vwap: float = nan
|
||||
low13week: float = nan
|
||||
high13week: float = nan
|
||||
low26week: float = nan
|
||||
high26week: float = nan
|
||||
low52week: float = nan
|
||||
high52week: float = nan
|
||||
bidYield: float = nan
|
||||
askYield: float = nan
|
||||
lastYield: float = nan
|
||||
markPrice: float = nan
|
||||
halted: float = nan
|
||||
rtHistVolatility: float = nan
|
||||
rtVolume: float = nan
|
||||
rtTradeVolume: float = nan
|
||||
rtTime: Optional[datetime] = None
|
||||
avVolume: float = nan
|
||||
tradeCount: float = nan
|
||||
tradeRate: float = nan
|
||||
volumeRate: float = nan
|
||||
shortableShares: float = nan
|
||||
indexFuturePremium: float = nan
|
||||
futuresOpenInterest: float = nan
|
||||
putOpenInterest: float = nan
|
||||
callOpenInterest: float = nan
|
||||
putVolume: float = nan
|
||||
callVolume: float = nan
|
||||
avOptionVolume: float = nan
|
||||
histVolatility: float = nan
|
||||
impliedVolatility: float = nan
|
||||
dividends: Optional[Dividends] = None
|
||||
fundamentalRatios: Optional[FundamentalRatios] = None
|
||||
ticks: List[TickData] = field(default_factory=list)
|
||||
tickByTicks: List[Union[
|
||||
TickByTickAllLast, TickByTickBidAsk, TickByTickMidPoint]] = \
|
||||
field(default_factory=list)
|
||||
domBids: List[DOMLevel] = field(default_factory=list)
|
||||
domAsks: List[DOMLevel] = field(default_factory=list)
|
||||
domTicks: List[MktDepthData] = field(default_factory=list)
|
||||
bidGreeks: Optional[OptionComputation] = None
|
||||
askGreeks: Optional[OptionComputation] = None
|
||||
lastGreeks: Optional[OptionComputation] = None
|
||||
modelGreeks: Optional[OptionComputation] = None
|
||||
auctionVolume: float = nan
|
||||
auctionPrice: float = nan
|
||||
auctionImbalance: float = nan
|
||||
regulatoryImbalance: float = nan
|
||||
bboExchange: str = ''
|
||||
snapshotPermissions: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
self.updateEvent = TickerUpdateEvent('updateEvent')
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
__repr__ = dataclassRepr
|
||||
__str__ = dataclassRepr
|
||||
|
||||
def hasBidAsk(self) -> bool:
|
||||
"""See if this ticker has a valid bid and ask."""
|
||||
return (
|
||||
self.bid != -1 and not isNan(self.bid) and self.bidSize > 0
|
||||
and self.ask != -1 and not isNan(self.ask) and self.askSize > 0)
|
||||
|
||||
def midpoint(self) -> float:
|
||||
"""
|
||||
Return average of bid and ask, or NaN if no valid bid and ask
|
||||
are available.
|
||||
"""
|
||||
return (self.bid + self.ask) * 0.5 if self.hasBidAsk() else nan
|
||||
|
||||
def marketPrice(self) -> float:
|
||||
"""
|
||||
Return the first available one of
|
||||
|
||||
* last price if within current bid/ask or no bid/ask available;
|
||||
* average of bid and ask (midpoint).
|
||||
"""
|
||||
if self.hasBidAsk():
|
||||
if self.bid <= self.last <= self.ask:
|
||||
price = self.last
|
||||
else:
|
||||
price = self.midpoint()
|
||||
else:
|
||||
price = self.last
|
||||
return price
|
||||
|
||||
|
||||
class TickerUpdateEvent(Event):
|
||||
__slots__ = ()
|
||||
|
||||
def trades(self) -> "Tickfilter":
|
||||
"""Emit trade ticks."""
|
||||
return Tickfilter((4, 5, 48, 68, 71), self)
|
||||
|
||||
def bids(self) -> "Tickfilter":
|
||||
"""Emit bid ticks."""
|
||||
return Tickfilter((0, 1, 66, 69), self)
|
||||
|
||||
def asks(self) -> "Tickfilter":
|
||||
"""Emit ask ticks."""
|
||||
return Tickfilter((2, 3, 67, 70), self)
|
||||
|
||||
def bidasks(self) -> "Tickfilter":
|
||||
"""Emit bid and ask ticks."""
|
||||
return Tickfilter((0, 1, 66, 69, 2, 3, 67, 70), self)
|
||||
|
||||
def midpoints(self) -> "Tickfilter":
|
||||
"""Emit midpoint ticks."""
|
||||
return Midpoints((), self)
|
||||
|
||||
|
||||
class Tickfilter(Op):
|
||||
"""Tick filtering event operators that ``emit(time, price, size)``."""
|
||||
|
||||
__slots__ = ('_tickTypes',)
|
||||
|
||||
def __init__(self, tickTypes, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._tickTypes = set(tickTypes)
|
||||
|
||||
def on_source(self, ticker):
|
||||
for t in ticker.ticks:
|
||||
if t.tickType in self._tickTypes:
|
||||
self.emit(t.time, t.price, t.size)
|
||||
|
||||
def timebars(self, timer: Event) -> "TimeBars":
|
||||
"""
|
||||
Aggregate ticks into time bars, where the timing of new bars
|
||||
is derived from a timer event.
|
||||
Emits a completed :class:`Bar`.
|
||||
|
||||
This event stores a :class:`BarList` of all created bars in the
|
||||
``bars`` property.
|
||||
|
||||
Args:
|
||||
timer: Event for timing when a new bar starts.
|
||||
"""
|
||||
return TimeBars(timer, self)
|
||||
|
||||
def tickbars(self, count: int) -> "TickBars":
|
||||
"""
|
||||
Aggregate ticks into bars that have the same number of ticks.
|
||||
Emits a completed :class:`Bar`.
|
||||
|
||||
This event stores a :class:`BarList` of all created bars in the
|
||||
``bars`` property.
|
||||
|
||||
Args:
|
||||
count: Number of ticks to use to form one bar.
|
||||
"""
|
||||
return TickBars(count, self)
|
||||
|
||||
def volumebars(self, volume: int) -> "VolumeBars":
|
||||
"""
|
||||
Aggregate ticks into bars that have the same volume.
|
||||
Emits a completed :class:`Bar`.
|
||||
|
||||
This event stores a :class:`BarList` of all created bars in the
|
||||
``bars`` property.
|
||||
|
||||
Args:
|
||||
count: Number of ticks to use to form one bar.
|
||||
"""
|
||||
return VolumeBars(volume, self)
|
||||
|
||||
|
||||
class Midpoints(Tickfilter):
|
||||
__slots__ = ()
|
||||
|
||||
def on_source(self, ticker):
|
||||
if ticker.ticks:
|
||||
self.emit(ticker.time, ticker.midpoint(), 0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Bar:
|
||||
time: Optional[datetime]
|
||||
open: float = nan
|
||||
high: float = nan
|
||||
low: float = nan
|
||||
close: float = nan
|
||||
volume: int = 0
|
||||
count: int = 0
|
||||
|
||||
|
||||
class BarList(List[Bar]):
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args)
|
||||
self.updateEvent = Event('updateEvent')
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
|
||||
class TimeBars(Op):
|
||||
__slots__ = ('_timer', 'bars',)
|
||||
__doc__ = Tickfilter.timebars.__doc__
|
||||
|
||||
bars: BarList
|
||||
|
||||
def __init__(self, timer, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._timer = timer
|
||||
self._timer.connect(self._on_timer, None, self._on_timer_done)
|
||||
self.bars = BarList()
|
||||
|
||||
def on_source(self, time, price, size):
|
||||
if not self.bars:
|
||||
return
|
||||
bar = self.bars[-1]
|
||||
if isNan(bar.open):
|
||||
bar.open = bar.high = bar.low = price
|
||||
bar.high = max(bar.high, price)
|
||||
bar.low = min(bar.low, price)
|
||||
bar.close = price
|
||||
bar.volume += size
|
||||
bar.count += 1
|
||||
self.bars.updateEvent.emit(self.bars, False)
|
||||
|
||||
def _on_timer(self, time):
|
||||
if self.bars:
|
||||
bar = self.bars[-1]
|
||||
if isNan(bar.close) and len(self.bars) > 1:
|
||||
bar.open = bar.high = bar.low = bar.close = \
|
||||
self.bars[-2].close
|
||||
self.bars.updateEvent.emit(self.bars, True)
|
||||
self.emit(bar)
|
||||
self.bars.append(Bar(time))
|
||||
|
||||
def _on_timer_done(self, timer):
|
||||
self._timer = None
|
||||
self.set_done()
|
||||
|
||||
|
||||
class TickBars(Op):
|
||||
__slots__ = ('_count', 'bars')
|
||||
__doc__ = Tickfilter.tickbars.__doc__
|
||||
|
||||
bars: BarList
|
||||
|
||||
def __init__(self, count, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._count = count
|
||||
self.bars = BarList()
|
||||
|
||||
def on_source(self, time, price, size):
|
||||
if not self.bars or self.bars[-1].count == self._count:
|
||||
bar = Bar(time, price, price, price, price, size, 1)
|
||||
self.bars.append(bar)
|
||||
else:
|
||||
bar = self.bars[-1]
|
||||
bar.high = max(bar.high, price)
|
||||
bar.low = min(bar.low, price)
|
||||
bar.close = price
|
||||
bar.volume += size
|
||||
bar.count += 1
|
||||
if bar.count == self._count:
|
||||
self.bars.updateEvent.emit(self.bars, True)
|
||||
self.emit(self.bars)
|
||||
|
||||
|
||||
class VolumeBars(Op):
|
||||
__slots__ = ('_volume', 'bars')
|
||||
__doc__ = Tickfilter.volumebars.__doc__
|
||||
|
||||
bars: BarList
|
||||
|
||||
def __init__(self, volume, source=None):
|
||||
Op.__init__(self, source)
|
||||
self._volume = volume
|
||||
self.bars = BarList()
|
||||
|
||||
def on_source(self, time, price, size):
|
||||
if not self.bars or self.bars[-1].volume >= self._volume:
|
||||
bar = Bar(time, price, price, price, price, size, 1)
|
||||
self.bars.append(bar)
|
||||
else:
|
||||
bar = self.bars[-1]
|
||||
bar.high = max(bar.high, price)
|
||||
bar.low = min(bar.low, price)
|
||||
bar.close = price
|
||||
bar.volume += size
|
||||
bar.count += 1
|
||||
if bar.volume >= self._volume:
|
||||
self.bars.updateEvent.emit(self.bars, True)
|
||||
self.emit(self.bars)
|
||||
@@ -0,0 +1,550 @@
|
||||
"""Utilities."""
|
||||
|
||||
import asyncio
|
||||
import datetime as dt
|
||||
import logging
|
||||
import math
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import (
|
||||
AsyncIterator, Awaitable, Callable, Iterator, List, Optional, Union)
|
||||
|
||||
import eventkit as ev
|
||||
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
except ImportError:
|
||||
from backports.zoneinfo import ZoneInfo # type: ignore
|
||||
|
||||
|
||||
globalErrorEvent = ev.Event()
|
||||
"""
|
||||
Event to emit global exceptions.
|
||||
"""
|
||||
|
||||
EPOCH = dt.datetime(1970, 1, 1, tzinfo=dt.timezone.utc)
|
||||
UNSET_INTEGER = 2 ** 31 - 1
|
||||
UNSET_DOUBLE = sys.float_info.max
|
||||
|
||||
Time_t = Union[dt.time, dt.datetime]
|
||||
|
||||
|
||||
def df(objs, labels: Optional[List[str]] = None):
|
||||
"""
|
||||
Create pandas DataFrame from the sequence of same-type objects.
|
||||
|
||||
Args:
|
||||
labels: If supplied, retain only the given labels and drop the rest.
|
||||
"""
|
||||
import pandas as pd
|
||||
from .objects import DynamicObject
|
||||
if objs:
|
||||
objs = list(objs)
|
||||
obj = objs[0]
|
||||
if is_dataclass(obj):
|
||||
df = pd.DataFrame.from_records(dataclassAsTuple(o) for o in objs)
|
||||
df.columns = [field.name for field in fields(obj)]
|
||||
elif isinstance(obj, DynamicObject):
|
||||
df = pd.DataFrame.from_records(o.__dict__ for o in objs)
|
||||
else:
|
||||
df = pd.DataFrame.from_records(objs)
|
||||
if isinstance(obj, tuple):
|
||||
_fields = getattr(obj, '_fields', None)
|
||||
if _fields:
|
||||
# assume it's a namedtuple
|
||||
df.columns = _fields
|
||||
else:
|
||||
df = None
|
||||
if labels:
|
||||
exclude = [label for label in df if label not in labels]
|
||||
df = df.drop(exclude, axis=1)
|
||||
return df
|
||||
|
||||
|
||||
def dataclassAsDict(obj) -> dict:
|
||||
"""
|
||||
Return dataclass values as ``dict``.
|
||||
This is a non-recursive variant of ``dataclasses.asdict``.
|
||||
"""
|
||||
if not is_dataclass(obj):
|
||||
raise TypeError(f'Object {obj} is not a dataclass')
|
||||
return {field.name: getattr(obj, field.name) for field in fields(obj)}
|
||||
|
||||
|
||||
def dataclassAsTuple(obj) -> tuple:
|
||||
"""
|
||||
Return dataclass values as ``tuple``.
|
||||
This is a non-recursive variant of ``dataclasses.astuple``.
|
||||
"""
|
||||
if not is_dataclass(obj):
|
||||
raise TypeError(f'Object {obj} is not a dataclass')
|
||||
return tuple(getattr(obj, field.name) for field in fields(obj))
|
||||
|
||||
|
||||
def dataclassNonDefaults(obj) -> dict:
|
||||
"""
|
||||
For a ``dataclass`` instance get the fields that are different from the
|
||||
default values and return as ``dict``.
|
||||
"""
|
||||
if not is_dataclass(obj):
|
||||
raise TypeError(f'Object {obj} is not a dataclass')
|
||||
values = [getattr(obj, field.name) for field in fields(obj)]
|
||||
return {
|
||||
field.name: value for field, value in zip(fields(obj), values)
|
||||
if value != field.default
|
||||
and value == value
|
||||
and not (isinstance(value, list) and value == [])}
|
||||
|
||||
|
||||
def dataclassUpdate(obj, *srcObjs, **kwargs) -> object:
|
||||
"""
|
||||
Update fields of the given ``dataclass`` object from zero or more
|
||||
``dataclass`` source objects and/or from keyword arguments.
|
||||
"""
|
||||
if not is_dataclass(obj):
|
||||
raise TypeError(f'Object {obj} is not a dataclass')
|
||||
for srcObj in srcObjs:
|
||||
obj.__dict__.update(dataclassAsDict(srcObj))
|
||||
obj.__dict__.update(**kwargs)
|
||||
return obj
|
||||
|
||||
|
||||
def dataclassRepr(obj) -> str:
|
||||
"""
|
||||
Provide a culled representation of the given ``dataclass`` instance,
|
||||
showing only the fields with a non-default value.
|
||||
"""
|
||||
attrs = dataclassNonDefaults(obj)
|
||||
clsName = obj.__class__.__qualname__
|
||||
kwargs = ', '.join(f'{k}={v!r}' for k, v in attrs.items())
|
||||
return f'{clsName}({kwargs})'
|
||||
|
||||
|
||||
def isnamedtupleinstance(x):
|
||||
"""From https://stackoverflow.com/a/2166841/6067848"""
|
||||
t = type(x)
|
||||
b = t.__bases__
|
||||
if len(b) != 1 or b[0] != tuple:
|
||||
return False
|
||||
f = getattr(t, '_fields', None)
|
||||
if not isinstance(f, tuple):
|
||||
return False
|
||||
return all(type(n) == str for n in f)
|
||||
|
||||
|
||||
def tree(obj):
|
||||
"""
|
||||
Convert object to a tree of lists, dicts and simple values.
|
||||
The result can be serialized to JSON.
|
||||
"""
|
||||
if isinstance(obj, (bool, int, float, str, bytes)):
|
||||
return obj
|
||||
elif isinstance(obj, (dt.date, dt.time)):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, dict):
|
||||
return {k: tree(v) for k, v in obj.items()}
|
||||
elif isnamedtupleinstance(obj):
|
||||
return {f: tree(getattr(obj, f)) for f in obj._fields}
|
||||
elif isinstance(obj, (list, tuple, set)):
|
||||
return [tree(i) for i in obj]
|
||||
elif is_dataclass(obj):
|
||||
return {obj.__class__.__qualname__: tree(dataclassNonDefaults(obj))}
|
||||
else:
|
||||
return str(obj)
|
||||
|
||||
|
||||
def barplot(bars, title='', upColor='blue', downColor='red'):
|
||||
"""
|
||||
Create candlestick plot for the given bars. The bars can be given as
|
||||
a DataFrame or as a list of bar objects.
|
||||
"""
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.lines import Line2D
|
||||
from matplotlib.patches import Rectangle
|
||||
|
||||
if isinstance(bars, pd.DataFrame):
|
||||
ohlcTups = [
|
||||
tuple(v) for v in bars[['open', 'high', 'low', 'close']].values]
|
||||
elif bars and hasattr(bars[0], 'open_'):
|
||||
ohlcTups = [(b.open_, b.high, b.low, b.close) for b in bars]
|
||||
else:
|
||||
ohlcTups = [(b.open, b.high, b.low, b.close) for b in bars]
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.set_title(title)
|
||||
ax.grid(True)
|
||||
fig.set_size_inches(10, 6)
|
||||
for n, (open_, high, low, close) in enumerate(ohlcTups):
|
||||
if close >= open_:
|
||||
color = upColor
|
||||
bodyHi, bodyLo = close, open_
|
||||
else:
|
||||
color = downColor
|
||||
bodyHi, bodyLo = open_, close
|
||||
line = Line2D(
|
||||
xdata=(n, n),
|
||||
ydata=(low, bodyLo),
|
||||
color=color,
|
||||
linewidth=1)
|
||||
ax.add_line(line)
|
||||
line = Line2D(
|
||||
xdata=(n, n),
|
||||
ydata=(high, bodyHi),
|
||||
color=color,
|
||||
linewidth=1)
|
||||
ax.add_line(line)
|
||||
rect = Rectangle(
|
||||
xy=(n - 0.3, bodyLo),
|
||||
width=0.6,
|
||||
height=bodyHi - bodyLo,
|
||||
edgecolor=color,
|
||||
facecolor=color,
|
||||
alpha=0.4,
|
||||
antialiased=True
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
ax.autoscale_view()
|
||||
return fig
|
||||
|
||||
|
||||
def allowCtrlC():
|
||||
"""Allow Control-C to end program."""
|
||||
signal.signal(signal.SIGINT, signal.SIG_DFL)
|
||||
|
||||
|
||||
def logToFile(path, level=logging.INFO):
|
||||
"""Create a log handler that logs to the given file."""
|
||||
logger = logging.getLogger()
|
||||
if logger.handlers:
|
||||
logging.getLogger('ib_insync').setLevel(level)
|
||||
else:
|
||||
logger.setLevel(level)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s %(name)s %(levelname)s %(message)s')
|
||||
handler = logging.FileHandler(path)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def logToConsole(level=logging.INFO):
|
||||
"""Create a log handler that logs to the console."""
|
||||
logger = logging.getLogger()
|
||||
stdHandlers = [
|
||||
h for h in logger.handlers
|
||||
if type(h) is logging.StreamHandler and h.stream is sys.stderr]
|
||||
if stdHandlers:
|
||||
# if a standard stream handler already exists, use it and
|
||||
# set the log level for the ib_insync namespace only
|
||||
logging.getLogger('ib_insync').setLevel(level)
|
||||
else:
|
||||
# else create a new handler
|
||||
logger.setLevel(level)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s %(name)s %(levelname)s %(message)s')
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def isNan(x: float) -> bool:
|
||||
"""Not a number test."""
|
||||
return x != x
|
||||
|
||||
|
||||
def formatSI(n: float) -> str:
|
||||
"""Format the integer or float n to 3 significant digits + SI prefix."""
|
||||
s = ''
|
||||
if n < 0:
|
||||
n = -n
|
||||
s += '-'
|
||||
if type(n) is int and n < 1000:
|
||||
s = str(n) + ' '
|
||||
elif n < 1e-22:
|
||||
s = '0.00 '
|
||||
else:
|
||||
assert n < 9.99e26
|
||||
log = int(math.floor(math.log10(n)))
|
||||
i, j = divmod(log, 3)
|
||||
for _try in range(2):
|
||||
templ = '%.{}f'.format(2 - j)
|
||||
val = templ % (n * 10 ** (-3 * i))
|
||||
if val != '1000':
|
||||
break
|
||||
i += 1
|
||||
j = 0
|
||||
s += val + ' '
|
||||
if i != 0:
|
||||
s += 'yzafpnum kMGTPEZY'[i + 8]
|
||||
return s
|
||||
|
||||
|
||||
class timeit:
|
||||
"""Context manager for timing."""
|
||||
|
||||
def __init__(self, title='Run'):
|
||||
self.title = title
|
||||
|
||||
def __enter__(self):
|
||||
self.t0 = time.time()
|
||||
|
||||
def __exit__(self, *_args):
|
||||
print(self.title + ' took ' + formatSI(time.time() - self.t0) + 's')
|
||||
|
||||
|
||||
def run(*awaitables: Awaitable, timeout: Optional[float] = None):
|
||||
"""
|
||||
By default run the event loop forever.
|
||||
|
||||
When awaitables (like Tasks, Futures or coroutines) are given then
|
||||
run the event loop until each has completed and return their results.
|
||||
|
||||
An optional timeout (in seconds) can be given that will raise
|
||||
asyncio.TimeoutError if the awaitables are not ready within the
|
||||
timeout period.
|
||||
"""
|
||||
loop = getLoop()
|
||||
if not awaitables:
|
||||
if loop.is_running():
|
||||
return
|
||||
loop.run_forever()
|
||||
result = None
|
||||
if sys.version_info >= (3, 7):
|
||||
all_tasks = asyncio.all_tasks(loop) # type: ignore
|
||||
else:
|
||||
all_tasks = asyncio.Task.all_tasks() # type: ignore
|
||||
if all_tasks:
|
||||
# cancel pending tasks
|
||||
f = asyncio.gather(*all_tasks)
|
||||
f.cancel()
|
||||
try:
|
||||
loop.run_until_complete(f)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
else:
|
||||
if len(awaitables) == 1:
|
||||
future = awaitables[0]
|
||||
else:
|
||||
future = asyncio.gather(*awaitables)
|
||||
if timeout:
|
||||
future = asyncio.wait_for(future, timeout)
|
||||
task = asyncio.ensure_future(future)
|
||||
|
||||
def onError(_):
|
||||
task.cancel()
|
||||
|
||||
globalErrorEvent.connect(onError)
|
||||
try:
|
||||
result = loop.run_until_complete(task)
|
||||
except asyncio.CancelledError as e:
|
||||
raise globalErrorEvent.value() or e
|
||||
finally:
|
||||
globalErrorEvent.disconnect(onError)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _fillDate(time: Time_t) -> dt.datetime:
|
||||
# use today if date is absent
|
||||
if isinstance(time, dt.time):
|
||||
t = dt.datetime.combine(dt.date.today(), time)
|
||||
else:
|
||||
t = time
|
||||
return t
|
||||
|
||||
|
||||
def schedule(time: Time_t, callback: Callable, *args):
|
||||
"""
|
||||
Schedule the callback to be run at the given time with
|
||||
the given arguments.
|
||||
This will return the Event Handle.
|
||||
|
||||
Args:
|
||||
time: Time to run callback. If given as :py:class:`datetime.time`
|
||||
then use today as date.
|
||||
callback: Callable scheduled to run.
|
||||
args: Arguments for to call callback with.
|
||||
"""
|
||||
t = _fillDate(time)
|
||||
now = dt.datetime.now(t.tzinfo)
|
||||
delay = (t - now).total_seconds()
|
||||
loop = getLoop()
|
||||
return loop.call_later(delay, callback, *args)
|
||||
|
||||
|
||||
def sleep(secs: float = 0.02) -> bool:
|
||||
"""
|
||||
Wait for the given amount of seconds while everything still keeps
|
||||
processing in the background. Never use time.sleep().
|
||||
|
||||
Args:
|
||||
secs (float): Time in seconds to wait.
|
||||
"""
|
||||
run(asyncio.sleep(secs))
|
||||
return True
|
||||
|
||||
|
||||
def timeRange(start: Time_t, end: Time_t, step: float) \
|
||||
-> Iterator[dt.datetime]:
|
||||
"""
|
||||
Iterator that waits periodically until certain time points are
|
||||
reached while yielding those time points.
|
||||
|
||||
Args:
|
||||
start: Start time, can be specified as datetime.datetime,
|
||||
or as datetime.time in which case today is used as the date
|
||||
end: End time, can be specified as datetime.datetime,
|
||||
or as datetime.time in which case today is used as the date
|
||||
step (float): The number of seconds of each period
|
||||
"""
|
||||
assert step > 0
|
||||
delta = dt.timedelta(seconds=step)
|
||||
t = _fillDate(start)
|
||||
tz = dt.timezone.utc if t.tzinfo else None
|
||||
now = dt.datetime.now(tz)
|
||||
while t < now:
|
||||
t += delta
|
||||
while t <= _fillDate(end):
|
||||
waitUntil(t)
|
||||
yield t
|
||||
t += delta
|
||||
|
||||
|
||||
def waitUntil(t: Time_t) -> bool:
|
||||
"""
|
||||
Wait until the given time t is reached.
|
||||
|
||||
Args:
|
||||
t: The time t can be specified as datetime.datetime,
|
||||
or as datetime.time in which case today is used as the date.
|
||||
"""
|
||||
now = dt.datetime.now(t.tzinfo)
|
||||
secs = (_fillDate(t) - now).total_seconds()
|
||||
run(asyncio.sleep(secs))
|
||||
return True
|
||||
|
||||
|
||||
async def timeRangeAsync(
|
||||
start: Time_t,
|
||||
end: Time_t,
|
||||
step: float) -> AsyncIterator[dt.datetime]:
|
||||
"""Async version of :meth:`timeRange`."""
|
||||
assert step > 0
|
||||
delta = dt.timedelta(seconds=step)
|
||||
t = _fillDate(start)
|
||||
tz = dt.timezone.utc if t.tzinfo else None
|
||||
now = dt.datetime.now(tz)
|
||||
while t < now:
|
||||
t += delta
|
||||
while t <= _fillDate(end):
|
||||
await waitUntilAsync(t)
|
||||
yield t
|
||||
t += delta
|
||||
|
||||
|
||||
async def waitUntilAsync(t: Time_t) -> bool:
|
||||
"""Async version of :meth:`waitUntil`."""
|
||||
now = dt.datetime.now(t.tzinfo)
|
||||
secs = (_fillDate(t) - now).total_seconds()
|
||||
await asyncio.sleep(secs)
|
||||
return True
|
||||
|
||||
|
||||
def patchAsyncio():
|
||||
"""Patch asyncio to allow nested event loops."""
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
|
||||
|
||||
def getLoop():
|
||||
"""Get the asyncio event loop for the current thread."""
|
||||
return asyncio.get_event_loop_policy().get_event_loop()
|
||||
|
||||
|
||||
def startLoop():
|
||||
"""Use nested asyncio event loop for Jupyter notebooks."""
|
||||
patchAsyncio()
|
||||
|
||||
|
||||
def useQt(qtLib: str = 'PyQt5', period: float = 0.01):
|
||||
"""
|
||||
Run combined Qt5/asyncio event loop.
|
||||
|
||||
Args:
|
||||
qtLib: Name of Qt library to use:
|
||||
|
||||
* PyQt5
|
||||
* PyQt6
|
||||
* PySide2
|
||||
* PySide6
|
||||
period: Period in seconds to poll Qt.
|
||||
"""
|
||||
def qt_step():
|
||||
loop.call_later(period, qt_step)
|
||||
if not stack:
|
||||
qloop = qc.QEventLoop()
|
||||
timer = qc.QTimer()
|
||||
timer.timeout.connect(qloop.quit)
|
||||
stack.append((qloop, timer))
|
||||
qloop, timer = stack.pop()
|
||||
timer.start(0)
|
||||
qloop.exec() if qtLib == 'PyQt6' else qloop.exec_()
|
||||
timer.stop()
|
||||
stack.append((qloop, timer))
|
||||
qApp.processEvents() # type: ignore
|
||||
|
||||
if qtLib not in ('PyQt5', 'PyQt6', 'PySide2', 'PySide6'):
|
||||
raise RuntimeError(f'Unknown Qt library: {qtLib}')
|
||||
from importlib import import_module
|
||||
qc = import_module(qtLib + '.QtCore')
|
||||
qw = import_module(qtLib + '.QtWidgets')
|
||||
global qApp
|
||||
qApp = (qw.QApplication.instance() # type: ignore
|
||||
or qw.QApplication(sys.argv)) # type: ignore
|
||||
loop = getLoop()
|
||||
stack: list = []
|
||||
qt_step()
|
||||
|
||||
|
||||
def formatIBDatetime(t: Union[dt.date, dt.datetime, str, None]) -> str:
|
||||
"""Format date or datetime to string that IB uses."""
|
||||
if not t:
|
||||
s = ''
|
||||
elif isinstance(t, dt.datetime):
|
||||
# convert to UTC timezone
|
||||
t = t.astimezone(tz=dt.timezone.utc)
|
||||
s = t.strftime('%Y%m%d %H:%M:%S UTC')
|
||||
elif isinstance(t, dt.date):
|
||||
t = dt.datetime(
|
||||
t.year, t.month, t.day, 23, 59, 59).astimezone(tz=dt.timezone.utc)
|
||||
s = t.strftime('%Y%m%d %H:%M:%S UTC')
|
||||
else:
|
||||
s = t
|
||||
return s
|
||||
|
||||
|
||||
def parseIBDatetime(s: str) -> Union[dt.date, dt.datetime]:
|
||||
"""Parse string in IB date or datetime format to datetime."""
|
||||
if len(s) == 8:
|
||||
# YYYYmmdd
|
||||
y = int(s[0:4])
|
||||
m = int(s[4:6])
|
||||
d = int(s[6:8])
|
||||
t = dt.date(y, m, d)
|
||||
elif s.isdigit():
|
||||
t = dt.datetime.fromtimestamp(int(s), dt.timezone.utc)
|
||||
elif s.count(' ') >= 2 and ' ' not in s:
|
||||
# 20221125 10:00:00 Europe/Amsterdam
|
||||
s0, s1, s2 = s.split(' ', 2)
|
||||
t = dt.datetime.strptime(s0 + s1, '%Y%m%d%H:%M:%S')
|
||||
t = t.replace(tzinfo=ZoneInfo(s2))
|
||||
else:
|
||||
# YYYYmmdd HH:MM:SS
|
||||
# or
|
||||
# YYYY-mm-dd HH:MM:SS.0
|
||||
ss = s.replace(' ', '').replace('-', '')[:16]
|
||||
t = dt.datetime.strptime(ss, '%Y%m%d%H:%M:%S')
|
||||
return t
|
||||
@@ -0,0 +1,4 @@
|
||||
"""Version info."""
|
||||
|
||||
__version_info__ = (0, 9, 86)
|
||||
__version__ = '.'.join(str(v) for v in __version_info__)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user