"""
Common utility for Plotly figures.
"""
__author__ = "York <york.jong@gmail.com>"
__date__ = "2023/02/09 (initial version) ~ 2024/10/03 (last revision)"
__all__ = [
'get_candlestick_colors',
'get_volume_colors',
'hide_nontrading_periods',
'add_crosshair_cursor',
'add_hovermode_menu',
]
import pandas as pd
from ..utils import MarketColorStyle
[docs]
def get_candlestick_colors(market_color_style=MarketColorStyle.WESTERN):
colors = {
'increasing_line_color': '#32a455',
'increasing_fillcolor': 'rgba(50, 164, 85, 0.4)',
'decreasing_line_color': '#d71917',
'decreasing_fillcolor': 'rgba(215, 25, 23, 0.4)'
}
if market_color_style == MarketColorStyle.WESTERN:
return colors
else:
return {
'increasing_line_color': colors['decreasing_line_color'],
'increasing_fillcolor': colors['decreasing_fillcolor'],
'decreasing_line_color': colors['increasing_line_color'],
'decreasing_fillcolor': colors['increasing_fillcolor'],
}
[docs]
def get_volume_colors(market_color_style=MarketColorStyle.WESTERN):
if market_color_style == MarketColorStyle.WESTERN:
return {
'up': 'green',
'down': 'red'
}
else:
return {
'up': 'red',
'down': 'green'
}
#------------------------------------------------------------------------------
[docs]
def hide_nontrading_periods(fig, df, interval):
"""Hide non-tranding time-periods.
This function can hide certain time-periods to avoid the gaps at
non-trading time-periods.
Parameters
----------
fig: plotly.graph_objects.figure
the figure
df: pandas.DataFrame
the stock table
interval: str
the interval of an OHLC item.
Valid values are 1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo,
3mo. Intraday data cannot extend last 60 days:
- 1m - max 7 days within last 30 days
- up to 90m - max 60 days
- 60m, 1h - max 730 days (yes 1h is technically < 90m but this what
Yahoo does)
"""
# If the index is not datetime, convert it back
if not isinstance(df.index, pd.DatetimeIndex):
df.index = pd.to_datetime(df.index)
# Convert aliases from `interval` to `freq`
# These aliases represent 'month', 'minute', 'hour', 'day', and 'week'.
freq = interval
interval_aliases = ('mo', 'm', 'h', 'd', 'wk')
freq_aliases = ('M', 'min', 'H', 'D', 'W')
for i, f in zip(interval_aliases, freq_aliases):
freq = freq.replace(i, f)
# Calculate nontrading time-periods
dt_all = pd.date_range(start=df.index[0], end=df.index[-1], freq=freq)
dt_breaks = dt_all.difference(df.index)
#print("All dates (dt_all):", dt_all)
#print("Trading dates (df.index):", df.index)
#print("Breaks (dt_breaks):", dt_breaks)
# Calculate dvalue in milliseconds
dvalue = 24*60*60 * 1000 # 1 day in milliseconds
if interval.endswith('m'): # minute
dvalue = 60 * int(interval.replace('m', '')) * 1000
elif interval.endswith('h'): # hour
dvalue = 60*60 * int(interval.replace('h', '')) * 1000
# Update xaxes to hide non-trading time-periods
fig.update_xaxes(rangebreaks=[dict(values=dt_breaks, dvalue=dvalue)])
#------------------------------------------------------------------------------
[docs]
def add_crosshair_cursor(fig):
"""Add crosshair cursor to a given figure.
Parameters
----------
fig: plotly.graph_objects.Figure
the figure
"""
fig.update_yaxes(
spikemode='across', spikesnap='cursor',
spikethickness=1, spikedash='solid', spikecolor='grey')
fig.update_xaxes(
spikemode='across', spikesnap='cursor',
spikethickness=1, spikedash='solid', spikecolor='grey')
fig.update_layout(hovermode='x') # 'x', 'y', 'closest', False,
# 'x unified', 'y unified'
#------------------------------------------------------------------------------