updated it to accept a --output_dir option and defaulot it to output.

This commit is contained in:
2025-03-03 03:37:49 +00:00
parent 9ef6f1ba64
commit f1a63d5642
2 changed files with 14 additions and 8 deletions

View File

@@ -1 +1 @@
3.11.4
3.9

View File

@@ -235,6 +235,8 @@ def parse_arguments():
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
parser.add_argument('--monitor_resources', action='store_true',
help='Enable real-time resource monitoring.')
parser.add_argument('--output_dir', type=str, default='output',
help='Directory where all output files will be saved.')
return parser.parse_args()
# ============================
@@ -517,6 +519,8 @@ def train_and_evaluate_dqn(hyperparams, env_params, total_timesteps, eval_episod
def main():
args = parse_arguments()
csv_path = args.csv_path
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
lstm_window_size = args.lstm_window_size
dqn_total_timesteps = args.dqn_total_timesteps
dqn_eval_episodes = args.dqn_eval_episodes
@@ -529,8 +533,10 @@ def main():
# -----------------------------
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("LSTMDQN.log"), logging.StreamHandler(sys.stdout)])
handlers=[
logging.FileHandler(os.path.join(output_dir, "LSTMDQN.log")),
logging.StreamHandler(sys.stdout)
])
# -----------------------------
# Resource Detection & Logging
# -----------------------------
@@ -703,7 +709,7 @@ def main():
plt.title('LSTM: Actual vs Predicted Closing Prices')
plt.legend()
plt.grid(True)
plt.savefig('lstm_actual_vs_pred.png')
plt.savefig(os.path.join(output_dir, 'lstm_actual_vs_pred.png')
plt.close()
table = []
@@ -718,9 +724,9 @@ def main():
_r2, _diracc = evaluate_final_lstm(final_lstm, X_test, y_test)
# 9) Save final LSTM model and scalers
final_lstm.save('best_lstm_model.h5')
joblib.dump(scaler_features, 'scaler_features.pkl')
joblib.dump(scaler_target, 'scaler_target.pkl')
final_lstm.save(os.path.join(output_dir, 'best_lstm_model.h5')
joblib.dump(scaler_features, os.path.join(output_dir, 'scaler_features.pkl')
joblib.dump(scaler_target, os.path.join(output_dir, 'scaler_target.pkl')
logging.info("Saved best LSTM model and scaler objects (best_lstm_model.h5, scaler_features.pkl, scaler_target.pkl).")
############################################################
@@ -766,7 +772,7 @@ def main():
if net_worth >= PERFORMANCE_THRESHOLD:
logging.info("Agent meets performance criteria!")
best_agent = agent
best_agent.save("best_dqn_model_lstm.zip")
best_agent.save(os.path.join(output_dir, "best_dqn_model_lstm.zip")
break
else:
logging.info("Performance below threshold. Adjusting hyperparameters and retrying...")