updated it to accept a --output_dir option and defaulot it to output.
This commit is contained in:
@@ -1 +1 @@
|
|||||||
3.11.4
|
3.9
|
||||||
|
|||||||
@@ -235,6 +235,8 @@ def parse_arguments():
|
|||||||
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
|
help='Number of worker processes for data preprocessing. Defaults to (logical cores - 2).')
|
||||||
parser.add_argument('--monitor_resources', action='store_true',
|
parser.add_argument('--monitor_resources', action='store_true',
|
||||||
help='Enable real-time resource monitoring.')
|
help='Enable real-time resource monitoring.')
|
||||||
|
parser.add_argument('--output_dir', type=str, default='output',
|
||||||
|
help='Directory where all output files will be saved.')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
# ============================
|
# ============================
|
||||||
@@ -517,6 +519,8 @@ def train_and_evaluate_dqn(hyperparams, env_params, total_timesteps, eval_episod
|
|||||||
def main():
|
def main():
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
csv_path = args.csv_path
|
csv_path = args.csv_path
|
||||||
|
output_dir = args.output_dir
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
lstm_window_size = args.lstm_window_size
|
lstm_window_size = args.lstm_window_size
|
||||||
dqn_total_timesteps = args.dqn_total_timesteps
|
dqn_total_timesteps = args.dqn_total_timesteps
|
||||||
dqn_eval_episodes = args.dqn_eval_episodes
|
dqn_eval_episodes = args.dqn_eval_episodes
|
||||||
@@ -529,8 +533,10 @@ def main():
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
logging.basicConfig(level=logging.INFO,
|
logging.basicConfig(level=logging.INFO,
|
||||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
handlers=[logging.FileHandler("LSTMDQN.log"), logging.StreamHandler(sys.stdout)])
|
handlers=[
|
||||||
|
logging.FileHandler(os.path.join(output_dir, "LSTMDQN.log")),
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
])
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Resource Detection & Logging
|
# Resource Detection & Logging
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@@ -703,7 +709,7 @@ def main():
|
|||||||
plt.title('LSTM: Actual vs Predicted Closing Prices')
|
plt.title('LSTM: Actual vs Predicted Closing Prices')
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
plt.savefig('lstm_actual_vs_pred.png')
|
plt.savefig(os.path.join(output_dir, 'lstm_actual_vs_pred.png')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
table = []
|
table = []
|
||||||
@@ -718,9 +724,9 @@ def main():
|
|||||||
_r2, _diracc = evaluate_final_lstm(final_lstm, X_test, y_test)
|
_r2, _diracc = evaluate_final_lstm(final_lstm, X_test, y_test)
|
||||||
|
|
||||||
# 9) Save final LSTM model and scalers
|
# 9) Save final LSTM model and scalers
|
||||||
final_lstm.save('best_lstm_model.h5')
|
final_lstm.save(os.path.join(output_dir, 'best_lstm_model.h5')
|
||||||
joblib.dump(scaler_features, 'scaler_features.pkl')
|
joblib.dump(scaler_features, os.path.join(output_dir, 'scaler_features.pkl')
|
||||||
joblib.dump(scaler_target, 'scaler_target.pkl')
|
joblib.dump(scaler_target, os.path.join(output_dir, 'scaler_target.pkl')
|
||||||
logging.info("Saved best LSTM model and scaler objects (best_lstm_model.h5, scaler_features.pkl, scaler_target.pkl).")
|
logging.info("Saved best LSTM model and scaler objects (best_lstm_model.h5, scaler_features.pkl, scaler_target.pkl).")
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
@@ -766,7 +772,7 @@ def main():
|
|||||||
if net_worth >= PERFORMANCE_THRESHOLD:
|
if net_worth >= PERFORMANCE_THRESHOLD:
|
||||||
logging.info("Agent meets performance criteria!")
|
logging.info("Agent meets performance criteria!")
|
||||||
best_agent = agent
|
best_agent = agent
|
||||||
best_agent.save("best_dqn_model_lstm.zip")
|
best_agent.save(os.path.join(output_dir, "best_dqn_model_lstm.zip")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logging.info("Performance below threshold. Adjusting hyperparameters and retrying...")
|
logging.info("Performance below threshold. Adjusting hyperparameters and retrying...")
|
||||||
|
|||||||
Reference in New Issue
Block a user