Starter Code
This commit is contained in:
49
MidasHMM/hmm/midas/analysis.py
Normal file
49
MidasHMM/hmm/midas/analysis.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# midas/analysis.py
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from matplotlib.lines import Line2D
|
||||
import numpy as np
|
||||
|
||||
class MarketRegimeAnalysis:
|
||||
def __init__(self, model, features):
|
||||
self.model = model
|
||||
self.features = features
|
||||
self.states = model.predict(features)
|
||||
|
||||
def plot_regimes(self, prices: pd.Series):
|
||||
plt.figure(figsize=(15, 8))
|
||||
palette = sns.color_palette("husl", n_colors=self.model.n_components)
|
||||
|
||||
for state in range(self.model.n_components):
|
||||
mask = self.states == state
|
||||
plt.scatter(prices.index[mask], prices[mask],
|
||||
color=palette[state], s=10, label=f'Regime {state}')
|
||||
|
||||
plt.title("Market Regime Visualization")
|
||||
plt.xlabel("Date")
|
||||
plt.ylabel("Price")
|
||||
plt.legend()
|
||||
return plt
|
||||
|
||||
def plot_transition_matrix(self):
|
||||
transmat = self.model.transmat_
|
||||
plt.figure(figsize=(10, 8))
|
||||
sns.heatmap(transmat, annot=True, fmt=".2f", cmap="Blues",
|
||||
xticklabels=range(transmat.shape[0]),
|
||||
yticklabels=range(transmat.shape[1]))
|
||||
plt.title("State Transition Probabilities")
|
||||
plt.xlabel("Next State")
|
||||
plt.ylabel("Current State")
|
||||
return plt
|
||||
|
||||
def plot_state_durations(self):
|
||||
state_changes = np.diff(self.states, prepend=self.states[0])
|
||||
change_points = np.where(state_changes != 0)[0]
|
||||
durations = np.diff(np.append(change_points, len(self.states)))
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
sns.histplot(durations, bins=30, kde=True)
|
||||
plt.title("Regime Duration Distribution")
|
||||
plt.xlabel("Duration (Bars)")
|
||||
plt.ylabel("Frequency")
|
||||
return plt
|
||||
Reference in New Issue
Block a user