updated it to accept a --output_dir option and defaulot it to output.
This commit is contained in:
@@ -1 +1 @@
|
||||
3.11.4
|
||||
3.9
|
||||
|
||||
@@ -235,6 +235,8 @@ def parse_arguments():
|
||||
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
|
||||
parser.add_argument('--monitor_resources', action='store_true',
|
||||
help='Enable real-time resource monitoring.')
|
||||
parser.add_argument('--output_dir', type=str, default='output',
|
||||
help='Directory where all output files will be saved.')
|
||||
return parser.parse_args()
|
||||
|
||||
# ============================
|
||||
@@ -517,6 +519,8 @@ def train_and_evaluate_dqn(hyperparams, env_params, total_timesteps, eval_episod
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
csv_path = args.csv_path
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
lstm_window_size = args.lstm_window_size
|
||||
dqn_total_timesteps = args.dqn_total_timesteps
|
||||
dqn_eval_episodes = args.dqn_eval_episodes
|
||||
@@ -529,8 +533,10 @@ def main():
|
||||
# -----------------------------
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.FileHandler("LSTMDQN.log"), logging.StreamHandler(sys.stdout)])
|
||||
|
||||
handlers=[
|
||||
logging.FileHandler(os.path.join(output_dir, "LSTMDQN.log")),
|
||||
logging.StreamHandler(sys.stdout)
|
||||
])
|
||||
# -----------------------------
|
||||
# Resource Detection & Logging
|
||||
# -----------------------------
|
||||
@@ -703,7 +709,7 @@ def main():
|
||||
plt.title('LSTM: Actual vs Predicted Closing Prices')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig('lstm_actual_vs_pred.png')
|
||||
plt.savefig(os.path.join(output_dir, 'lstm_actual_vs_pred.png')
|
||||
plt.close()
|
||||
|
||||
table = []
|
||||
@@ -718,9 +724,9 @@ def main():
|
||||
_r2, _diracc = evaluate_final_lstm(final_lstm, X_test, y_test)
|
||||
|
||||
# 9) Save final LSTM model and scalers
|
||||
final_lstm.save('best_lstm_model.h5')
|
||||
joblib.dump(scaler_features, 'scaler_features.pkl')
|
||||
joblib.dump(scaler_target, 'scaler_target.pkl')
|
||||
final_lstm.save(os.path.join(output_dir, 'best_lstm_model.h5')
|
||||
joblib.dump(scaler_features, os.path.join(output_dir, 'scaler_features.pkl')
|
||||
joblib.dump(scaler_target, os.path.join(output_dir, 'scaler_target.pkl')
|
||||
logging.info("Saved best LSTM model and scaler objects (best_lstm_model.h5, scaler_features.pkl, scaler_target.pkl).")
|
||||
|
||||
############################################################
|
||||
@@ -766,7 +772,7 @@ def main():
|
||||
if net_worth >= PERFORMANCE_THRESHOLD:
|
||||
logging.info("Agent meets performance criteria!")
|
||||
best_agent = agent
|
||||
best_agent.save("best_dqn_model_lstm.zip")
|
||||
best_agent.save(os.path.join(output_dir, "best_dqn_model_lstm.zip")
|
||||
break
|
||||
else:
|
||||
logging.info("Performance below threshold. Adjusting hyperparameters and retrying...")
|
||||
|
||||
Reference in New Issue
Block a user