import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms
[docs]
def highlight_text(explanation,
input_tokens=None,
show_plot=True,
output_filename=None,
colormap="bwr",
alpha=1.0,
heatmap_range=(-1, 1)):
"""Highlights a given text based on values in a given explanation object.
Args:
explanation: list of tuples of (word, index of word in original data, importance)
input_tokens: list of all tokens (including those without importance)
show_plot: Shows plot if true (for testing or writing plots to disk instead)
output_filename: Name of the file to save the plot to (optional).
colormap: color map for the heatmap plot (see mpl.Axes.imshow documentation for options).
alpha: alpha value for the heatmap plot.
heatmap_range: a tuple (vmin, vmax) to set the range of the heatmap.
Returns:
None
"""
tokens, _, importances = zip(*explanation)
if input_tokens:
# Make a list of tuples (token, i, importance) for each token in the
# input_tokens. if a token isnot in the explanantion, the importance is
# None
explanation = [
(
token, i, importances[tokens.index(token)] if token in tokens else None
)
for i, token in enumerate(input_tokens)
]
vmin, vmax = heatmap_range
x, y = (0, 0) # the initial position of the text
space_token = ' '
fig, ax = plt.subplots(figsize=(10, 1))
ax.axis('off')
for token, _, importance in explanation:
color = _get_text_color(importance, vmin, vmax, colormap, alpha)
text = ax.text(x, y, token, fontsize=12, backgroundcolor=color)
# Get the bounding box of the text in display space and convert it to data space
bbox = text.get_window_extent()
bbox_data = mtransforms.Bbox(ax.transData.inverted().transform(bbox))
x = bbox_data.x1
# Add a space after each token
text = ax.text(x, y, space_token, fontsize=12)
bbox = text.get_window_extent()
bbox_data = mtransforms.Bbox(ax.transData.inverted().transform(bbox))
# The next x is the right side of the bbox plus the fixed space
x = bbox_data.x1
# Wrap the text if token is a dot
if token == '.':
x = 0
y -= 0.5 # space between lines in inches
# adjust the height of the figure
y_hight = 1
if abs(y) > 1:
y_hight = y
ax.set_ylim(y_hight, 0)
fig.set_figheight(abs(y_hight))
# add colorbar
sm = plt.cm.ScalarMappable(
cmap=plt.get_cmap(colormap),
norm=plt.Normalize(vmin, vmax)
)
sm.set_array([])
fig.colorbar(sm, ax=ax, orientation='horizontal', aspect=20, use_gridspec=True)
# TODO add alpha to the colorbar
if not show_plot:
plt.close()
if output_filename:
plt.savefig(output_filename)
return fig, ax
[docs]
def _get_text_color(importance, vmin, vmax, colormap, alpha):
"""Assign a color to a text based on its importance.
Args:
importance (float): The importance of the text (between vmin and vmax)
vmin (float): The minimum value of the importance range
vmax (float): The maximum value of the importance range
colormap (str): color map for the heatmap plot (see mpl.Axes.imshow
documentation for options).
alpha (float): alpha value for the color.
Returns:
tuple: (r, g, b, alpha) values of the color.
"""
if importance is None:
return "none"
cmap = plt.get_cmap(colormap)
norm = plt.Normalize(vmin, vmax)
r, g, b, _ = cmap(norm(importance))
return (r, g, b, alpha)