Files
gwitt1Repo/MidasHMM/hmm/midas/analysis.py
2025-01-29 23:39:42 -05:00

50 lines
1.7 KiB
Python

# midas/analysis.py
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D
import numpy as np
class MarketRegimeAnalysis:
def __init__(self, model, features):
self.model = model
self.features = features
self.states = model.predict(features)
def plot_regimes(self, prices: pd.Series):
plt.figure(figsize=(15, 8))
palette = sns.color_palette("husl", n_colors=self.model.n_components)
for state in range(self.model.n_components):
mask = self.states == state
plt.scatter(prices.index[mask], prices[mask],
color=palette[state], s=10, label=f'Regime {state}')
plt.title("Market Regime Visualization")
plt.xlabel("Date")
plt.ylabel("Price")
plt.legend()
return plt
def plot_transition_matrix(self):
transmat = self.model.transmat_
plt.figure(figsize=(10, 8))
sns.heatmap(transmat, annot=True, fmt=".2f", cmap="Blues",
xticklabels=range(transmat.shape[0]),
yticklabels=range(transmat.shape[1]))
plt.title("State Transition Probabilities")
plt.xlabel("Next State")
plt.ylabel("Current State")
return plt
def plot_state_durations(self):
state_changes = np.diff(self.states, prepend=self.states[0])
change_points = np.where(state_changes != 0)[0]
durations = np.diff(np.append(change_points, len(self.states)))
plt.figure(figsize=(10, 6))
sns.histplot(durations, bins=30, kde=True)
plt.title("Regime Duration Distribution")
plt.xlabel("Duration (Bars)")
plt.ylabel("Frequency")
return plt