50 lines
1.7 KiB
Python
50 lines
1.7 KiB
Python
# midas/analysis.py
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
from matplotlib.lines import Line2D
|
|
import numpy as np
|
|
|
|
class MarketRegimeAnalysis:
|
|
def __init__(self, model, features):
|
|
self.model = model
|
|
self.features = features
|
|
self.states = model.predict(features)
|
|
|
|
def plot_regimes(self, prices: pd.Series):
|
|
plt.figure(figsize=(15, 8))
|
|
palette = sns.color_palette("husl", n_colors=self.model.n_components)
|
|
|
|
for state in range(self.model.n_components):
|
|
mask = self.states == state
|
|
plt.scatter(prices.index[mask], prices[mask],
|
|
color=palette[state], s=10, label=f'Regime {state}')
|
|
|
|
plt.title("Market Regime Visualization")
|
|
plt.xlabel("Date")
|
|
plt.ylabel("Price")
|
|
plt.legend()
|
|
return plt
|
|
|
|
def plot_transition_matrix(self):
|
|
transmat = self.model.transmat_
|
|
plt.figure(figsize=(10, 8))
|
|
sns.heatmap(transmat, annot=True, fmt=".2f", cmap="Blues",
|
|
xticklabels=range(transmat.shape[0]),
|
|
yticklabels=range(transmat.shape[1]))
|
|
plt.title("State Transition Probabilities")
|
|
plt.xlabel("Next State")
|
|
plt.ylabel("Current State")
|
|
return plt
|
|
|
|
def plot_state_durations(self):
|
|
state_changes = np.diff(self.states, prepend=self.states[0])
|
|
change_points = np.where(state_changes != 0)[0]
|
|
durations = np.diff(np.append(change_points, len(self.states)))
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
sns.histplot(durations, bins=30, kde=True)
|
|
plt.title("Regime Duration Distribution")
|
|
plt.xlabel("Duration (Bars)")
|
|
plt.ylabel("Frequency")
|
|
return plt
|