updated it to accept a --output_dir option and defaulot it to output.

This commit is contained in:
2025-03-03 03:37:49 +00:00
parent 9ef6f1ba64
commit f1a63d5642
2 changed files with 14 additions and 8 deletions

View File

@@ -1 +1 @@
3.11.4 3.9

View File

@@ -235,6 +235,8 @@ def parse_arguments():
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).') help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
parser.add_argument('--monitor_resources', action='store_true', parser.add_argument('--monitor_resources', action='store_true',
help='Enable real-time resource monitoring.') help='Enable real-time resource monitoring.')
parser.add_argument('--output_dir', type=str, default='output',
help='Directory where all output files will be saved.')
return parser.parse_args() return parser.parse_args()
# ============================ # ============================
@@ -517,6 +519,8 @@ def train_and_evaluate_dqn(hyperparams, env_params, total_timesteps, eval_episod
def main(): def main():
args = parse_arguments() args = parse_arguments()
csv_path = args.csv_path csv_path = args.csv_path
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
lstm_window_size = args.lstm_window_size lstm_window_size = args.lstm_window_size
dqn_total_timesteps = args.dqn_total_timesteps dqn_total_timesteps = args.dqn_total_timesteps
dqn_eval_episodes = args.dqn_eval_episodes dqn_eval_episodes = args.dqn_eval_episodes
@@ -529,8 +533,10 @@ def main():
# ----------------------------- # -----------------------------
logging.basicConfig(level=logging.INFO, logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("LSTMDQN.log"), logging.StreamHandler(sys.stdout)]) handlers=[
logging.FileHandler(os.path.join(output_dir, "LSTMDQN.log")),
logging.StreamHandler(sys.stdout)
])
# ----------------------------- # -----------------------------
# Resource Detection & Logging # Resource Detection & Logging
# ----------------------------- # -----------------------------
@@ -703,7 +709,7 @@ def main():
plt.title('LSTM: Actual vs Predicted Closing Prices') plt.title('LSTM: Actual vs Predicted Closing Prices')
plt.legend() plt.legend()
plt.grid(True) plt.grid(True)
plt.savefig('lstm_actual_vs_pred.png') plt.savefig(os.path.join(output_dir, 'lstm_actual_vs_pred.png')
plt.close() plt.close()
table = [] table = []
@@ -718,9 +724,9 @@ def main():
_r2, _diracc = evaluate_final_lstm(final_lstm, X_test, y_test) _r2, _diracc = evaluate_final_lstm(final_lstm, X_test, y_test)
# 9) Save final LSTM model and scalers # 9) Save final LSTM model and scalers
final_lstm.save('best_lstm_model.h5') final_lstm.save(os.path.join(output_dir, 'best_lstm_model.h5')
joblib.dump(scaler_features, 'scaler_features.pkl') joblib.dump(scaler_features, os.path.join(output_dir, 'scaler_features.pkl')
joblib.dump(scaler_target, 'scaler_target.pkl') joblib.dump(scaler_target, os.path.join(output_dir, 'scaler_target.pkl')
logging.info("Saved best LSTM model and scaler objects (best_lstm_model.h5, scaler_features.pkl, scaler_target.pkl).") logging.info("Saved best LSTM model and scaler objects (best_lstm_model.h5, scaler_features.pkl, scaler_target.pkl).")
############################################################ ############################################################
@@ -766,7 +772,7 @@ def main():
if net_worth >= PERFORMANCE_THRESHOLD: if net_worth >= PERFORMANCE_THRESHOLD:
logging.info("Agent meets performance criteria!") logging.info("Agent meets performance criteria!")
best_agent = agent best_agent = agent
best_agent.save("best_dqn_model_lstm.zip") best_agent.save(os.path.join(output_dir, "best_dqn_model_lstm.zip")
break break
else: else:
logging.info("Performance below threshold. Adjusting hyperparameters and retrying...") logging.info("Performance below threshold. Adjusting hyperparameters and retrying...")