"""
NeuroLearn Plotting Tools
=========================
Numerous functions to plot data
"""
__all__ = [
"dist_from_hyperplane_plot",
"scatterplot",
"probability_plot",
"roc_plot",
"plot_stacked_adjacency",
"plot_mean_label_distance",
"plot_between_label_distance",
"plot_silhouette",
"plot_t_brain",
"plot_brain",
"plot_interactive_brain",
]
__author__ = ["Luke Chang"]
__license__ = "MIT"
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from numpy.fft import fft, fftfreq
from nltools.stats import two_sample_permutation, one_sample_permutation
from nilearn.plotting import plot_glass_brain, plot_stat_map, view_img, view_img_on_surf
from nltools.prefs import MNI_Template, resolve_mni_path
from nltools.utils import attempt_to_import
import warnings
import sklearn
import os
# Optional dependencies
ipywidgets = attempt_to_import(
"ipywidgets",
name="ipywidgets",
fromlist=["interact", "fixed", "widgets", "BoundedFloatText", "BoundedIntText"],
)
[docs]def plot_interactive_brain(
brain,
threshold=1e-6,
surface=False,
percentile_threshold=False,
anatomical=None,
**kwargs,
):
"""
This function leverages nilearn's new javascript based brain viewer functions to create interactive plotting functionality.
Args:
brain (nltools.Brain_Data): a Brain_Data instance of 1d or 2d shape (i.e. 3d or 4d volume)
threshold (float/str): threshold to initialize the visualization, maybe be a percentile string; default 0
surface (bool): whether to create a surface-based plot; default False
percentile_threshold (bool): whether to interpret threshold values as percentiles
kwargs: optional arguments to nilearn.view_img or nilearn.view_img_on_surf
Returns:
interactive brain viewer widget
"""
if ipywidgets is None:
raise ImportError(
"ipywidgets>=5.2.2 is required for interactive plotting. Please install this package manually or install nltools with optional arguments: pip install 'nltools[interactive_plots]'"
)
if isinstance(threshold, str):
if threshold[-1] != "%":
raise ValueError("Starting threshold provided as string must end in '%'")
percentile_threshold = True
warnings.warn(
"Percentile thresholding ignores brain mask. Results are likely more liberal than you expect (e.g. with non-interactive plotting)!"
)
threshold = int(threshold[:-1])
if len(brain.shape()) == 2:
time_slider = True
max_idx = brain.shape()[0] - 1
elif len(brain.shape()) == 1:
time_slider = False
else:
raise ValueError("Brain_Data object is not 1d or 2d")
thresh_box = ipywidgets.widgets.FloatText(value=threshold, description="Threshold")
if time_slider:
idx = ipywidgets.widgets.IntSlider(
min=0,
max=max_idx,
step=1,
value=0,
orientation="horizontal",
continuous_update=False,
description="Volume",
readout_format="d",
)
else:
idx = ipywidgets.widgets.HTML(
value="Image is 3D", description="Volume", placeholder=""
)
ipywidgets.interact(
_viewer,
brain=ipywidgets.fixed(brain),
thresh=thresh_box,
idx=idx,
percentile_threshold=percentile_threshold,
surface=surface,
anatomical=ipywidgets.fixed(anatomical),
**kwargs,
)
def _viewer(brain, thresh, idx, percentile_threshold, surface, anatomical, **kwargs):
if thresh == 0:
thresh = 1e-6
else:
if percentile_threshold:
thresh = str(thresh) + "%"
if isinstance(idx, int):
b = brain[idx].to_nifti()
else:
b = brain.to_nifti()
if anatomical:
bg_img = anatomical
else:
bg_img = "MNI152"
cut_coords = kwargs.get("cut_coords", [0, 0, 0])
if surface:
return view_img_on_surf(b, threshold=thresh, **kwargs)
else:
return view_img(
b, bg_img=bg_img, threshold=thresh, cut_coords=cut_coords, **kwargs
)
[docs]def plot_t_brain(
objIn, how="full", thr="unc", alpha=None, nperm=None, cut_coords=[], **kwargs
):
"""
Takes a brain data object and computes a 1 sample t-test across it's first axis. If a list is provided will compute difference between brain data objects in list (i.e. paired samples t-test).
Args:
objIn (list/Brain_Data): if list will compute difference map first
how (list): whether to plot a glass brain 'glass', 3 view-multi-slice mni 'mni', or both 'full'
thr (str): what method to use for multiple comparisons correction unc, fdr, or tfce
alpha (float): p-value threshold
nperm (int): number of permutations for tcfe; default 1000
cut_coords (list): x,y,z coords to plot brain slice
kwargs: optionals args to nilearn plot functions (e.g. vmax)
"""
if thr not in ["unc", "fdr", "tfce"]:
raise ValueError("Acceptable threshold methods are 'unc','fdr','tfce'")
views = ["x", "y", "z"]
if len(cut_coords) == 0:
cut_coords = [
range(-40, 50, 10),
[-88, -72, -58, -38, -26, 8, 20, 34, 46],
[-34, -22, -10, 0, 16, 34, 46, 56, 66],
]
else:
if len(cut_coords) != 3:
raise ValueError(
"cut_coords must be a list of coordinates like [[xs],[ys],[zs]]"
)
cmap = "RdBu_r"
if isinstance(objIn, list):
if len(objIn) == 2:
obj = objIn[0] - objIn[1]
else:
raise ValueError("Contrasts should contain only 2 list items!")
thrDict = {}
if thr == "tfce":
thrDict["permutation"] = thr
if nperm is None:
nperm = 1000
thrDict["n_permutations"] = nperm
print("1-sample t-test corrected using: TFCE w/ %s permutations" % nperm)
else:
if thr == "unc":
if alpha is None:
alpha = 0.001
thrDict[thr] = alpha
print("1-sample t-test uncorrected at p < %.3f " % alpha)
elif thr == "fdr":
if alpha is None:
alpha = 0.05
thrDict[thr] = alpha
print("1-sample t-test corrected at q < %.3f " % alpha)
else:
thrDict = None
print("1-sample test unthresholded")
out = objIn.ttest(threshold_dict=thrDict)
if thrDict is not None:
obj = out["thr_t"]
else:
obj = out["t"]
if how == "full":
plot_glass_brain(
obj.to_nifti(),
display_mode="lzry",
colorbar=True,
cmap=cmap,
plot_abs=False,
**kwargs,
)
for v, c in zip(views, cut_coords):
plot_stat_map(
obj.to_nifti(),
cut_coords=c,
display_mode=v,
cmap=cmap,
bg_img=resolve_mni_path(MNI_Template)["brain"],
**kwargs,
)
elif how == "glass":
plot_glass_brain(
obj.to_nifti(),
display_mode="lzry",
colorbar=True,
cmap=cmap,
plot_abs=False,
**kwargs,
)
elif how == "mni":
for v, c in zip(views, cut_coords):
plot_stat_map(
obj.to_nifti(),
cut_coords=c,
display_mode=v,
cmap=cmap,
bg_img=resolve_mni_path(MNI_Template)["brain"],
**kwargs,
)
del obj
del out
return
[docs]def plot_brain(objIn, how="full", thr_upper=None, thr_lower=None, save=False, **kwargs):
"""
More complete brain plotting of a Brain_Data instance
Args:
obj (Brain_Data): object to plot
how (str): whether to plot a glass brain 'glass', 3 view-multi-slice mni 'mni', or both 'full'
thr_upper (str/float): thresholding of image. Can be string for percentage, or float for data units (see Brain_Data.threshold()
thr_lower (str/float): thresholding of image. Can be string for percentage, or float for data units (see Brain_Data.threshold()
save (str): if a string file name or path is provided plots will be saved into this directory appended with the orientation they belong to
kwargs: optionals args to nilearn plot functions (e.g. vmax)
"""
if thr_upper or thr_lower:
obj = objIn.threshold(upper=thr_upper, lower=thr_lower)
else:
obj = objIn.copy()
views = ["x", "y", "z"]
coords = [
range(-50, 51, 8),
range(-80, 50, 10),
range(-40, 71, 9),
] # [-88,-72,-58,-38,-26,8,20,34,46]
cmap = "RdBu_r"
if thr_upper is None and thr_lower is None:
print("Plotting unthresholded image")
else:
if isinstance(thr_upper, str):
print("Plotting top %s of voxels" % thr_upper)
elif isinstance(thr_upper, (float, int)):
print("Plotting voxels with stat value >= %s" % thr_upper)
if isinstance(thr_lower, str):
print("Plotting lower %s of voxels" % thr_lower)
elif isinstance(thr_lower, (float, int)):
print("Plotting voxels with stat value <= %s" % thr_lower)
if save:
path, filename = os.path.split(save)
filename, extension = filename.split(".")
glass_save = os.path.join(path, filename + "_glass." + extension)
x_save = os.path.join(path, filename + "_x." + extension)
y_save = os.path.join(path, filename + "_y." + extension)
z_save = os.path.join(path, filename + "_z." + extension)
else:
glass_save, x_save, y_save, z_save = None, None, None, None
saves = [x_save, y_save, z_save]
if how == "full":
plot_glass_brain(
obj.to_nifti(),
display_mode="lzry",
colorbar=True,
cmap=cmap,
plot_abs=False,
**kwargs,
)
if save:
plt.savefig(glass_save, bbox_inches="tight")
for v, c, savefile in zip(views, coords, saves):
plot_stat_map(
obj.to_nifti(),
cut_coords=c,
display_mode=v,
cmap=cmap,
bg_img=resolve_mni_path(MNI_Template)["brain"],
**kwargs,
)
if save:
plt.savefig(savefile, bbox_inches="tight")
elif how == "glass":
plot_glass_brain(
obj.to_nifti(),
display_mode="lzry",
colorbar=True,
cmap=cmap,
plot_abs=False,
**kwargs,
)
if save:
plt.savefig(glass_save, bbox_inches="tight")
elif how == "mni":
for v, c, savefile in zip(views, coords, saves):
plot_stat_map(
obj.to_nifti(),
cut_coords=c,
display_mode=v,
cmap=cmap,
bg_img=resolve_mni_path(MNI_Template)["brain"],
**kwargs,
)
if save:
plt.savefig(savefile, bbox_inches="tight")
del obj # save memory
return
[docs]def dist_from_hyperplane_plot(stats_output):
"""Plot SVM Classification Distance from Hyperplane
Args:
stats_output: a pandas file with prediction output
Returns:
fig: Will return a seaborn plot of distance from hyperplane
"""
if "dist_from_hyperplane_xval" in stats_output.columns:
sns.factorplot(
"subject_id",
"dist_from_hyperplane_xval",
hue="Y",
data=stats_output,
kind="point",
)
else:
sns.factorplot(
"subject_id",
"dist_from_hyperplane_all",
hue="Y",
data=stats_output,
kind="point",
)
plt.xlabel("Subject", fontsize=16)
plt.ylabel("Distance from Hyperplane", fontsize=16)
plt.title("Classification", fontsize=18)
return
[docs]def scatterplot(stats_output):
"""Plot Prediction Scatterplot
Args:
stats_output: a pandas file with prediction output
Returns:
fig: Will return a seaborn scatterplot
"""
if "yfit_xval" in stats_output.columns:
sns.lmplot(x="Y", y="yfit_xval", data=stats_output)
else:
sns.lmplot(x="Y", y="yfit_all", data=stats_output)
plt.xlabel("Y", fontsize=16)
plt.ylabel("Predicted Value", fontsize=16)
plt.title("Prediction", fontsize=18)
return
[docs]def probability_plot(stats_output):
"""Plot Classification Probability
Args:
stats_output: a pandas file with prediction output
Returns:
fig: Will return a seaborn scatterplot
"""
if "Probability_xval" in stats_output.columns:
sns.lmplot("Y", "Probability_xval", data=stats_output, logistic=True)
else:
sns.lmplot("Y", "Probability_all", data=stats_output, logistic=True)
plt.xlabel("Y", fontsize=16)
plt.ylabel("Predicted Probability", fontsize=16)
plt.title("Prediction", fontsize=18)
return
# # and plot the result
# plt.figure(1, figsize=(4, 3))
# plt.clf()
# plt.scatter(X.ravel(), y, color='black', zorder=20)
# X_test = np.linspace(-5, 10, 300)
# def model(x):
# return 1 / (1 + np.exp(-x))
# loss = model(X_test * clf.coef_ + clf.intercept_).ravel()
# plt.plot(X_test, loss, color='blue', linewidth=3)
[docs]def roc_plot(fpr, tpr):
"""Plot 1-Specificity by Sensitivity
Args:
fpr: false positive rate from Roc.calculate
tpr: true positive rate from Roc.calculate
Returns:
fig: Will return a matplotlib ROC plot
"""
plt.figure()
plt.plot(fpr, tpr, color="red", linewidth=3)
# fig = sns.tsplot(tpr,fpr,color='red',linewidth=3)
plt.xlabel("(1 - Specificity)", fontsize=16)
plt.ylabel("Sensitivity", fontsize=16)
plt.title("ROC Plot", fontsize=18)
return
[docs]def plot_stacked_adjacency(adjacency1, adjacency2, normalize=True, **kwargs):
"""Create stacked adjacency to illustrate similarity.
Args:
matrix1: Adjacency instance 1
matrix2: Adjacency instance 2
normalize: (boolean) Normalize matrices.
Returns:
matplotlib figure
"""
from nltools.data import Adjacency
if not isinstance(adjacency1, Adjacency) or not isinstance(adjacency2, Adjacency):
raise ValueError("This function requires Adjacency() instances as input.")
upper = np.triu(adjacency2.squareform(), k=1)
lower = np.tril(adjacency1.squareform(), k=-1)
if normalize:
upper = np.triu((adjacency1 - adjacency1.mean()).squareform(), k=1)
lower = np.tril((adjacency2 - adjacency2.mean()).squareform(), k=-1)
upper = upper / np.max(upper)
lower = lower / np.max(lower)
dist = upper + lower
return sns.heatmap(
dist, xticklabels=False, yticklabels=False, square=True, **kwargs
)
[docs]def plot_mean_label_distance(
distance,
labels,
ax=None,
permutation_test=False,
n_permute=5000,
fontsize=18,
**kwargs,
):
"""Create a violin plot indicating within and between label distance.
Args:
distance: pandas dataframe of distance
labels: labels indicating columns and rows to group
ax: matplotlib axis to plot on
permutation_test: (bool) indicates whether to run permuatation test or not
n_permute: (int) number of permutations to run
fontsize: (int) fontsize for plot labels
Returns:
f: heatmap
stats: (optional if permutation_test=True) permutation results
"""
if not isinstance(distance, pd.DataFrame):
raise ValueError("distance must be a pandas dataframe")
if distance.shape[0] != distance.shape[1]:
raise ValueError("distance must be square.")
if len(labels) != distance.shape[0]:
raise ValueError("Labels must be same length as distance matrix")
out = pd.DataFrame(columns=["Distance", "Group", "Type"], index=None)
for i in labels.unique():
tmp_w = pd.DataFrame(columns=out.columns, index=None)
tmp_w["Distance"] = distance.loc[labels == i, labels == i].values[
np.triu_indices(sum(labels == i), k=1)
]
tmp_w["Type"] = "Within"
tmp_w["Group"] = i
tmp_b = pd.DataFrame(columns=out.columns, index=None)
tmp_b["Distance"] = distance.loc[labels == i, labels != i].values.flatten()
tmp_b["Type"] = "Between"
tmp_b["Group"] = i
out = out.append(tmp_w).append(tmp_b)
f = sns.violinplot(
x="Group",
y="Distance",
hue="Type",
data=out,
split=True,
inner="quartile",
palette={"Within": "lightskyblue", "Between": "red"},
ax=ax,
**kwargs,
)
f.set_ylabel("Average Distance", fontsize=fontsize)
f.set_title("Average Group Distance", fontsize=fontsize)
if permutation_test:
stats = dict()
for i in labels.unique():
# Between group test
tmp1 = out.loc[(out["Group"] == i) & (out["Type"] == "Within"), "Distance"]
tmp2 = out.loc[(out["Group"] == i) & (out["Type"] == "Between"), "Distance"]
stats[str(i)] = two_sample_permutation(tmp1, tmp2, n_permute=n_permute)
return (f, stats)
else:
return f
[docs]def plot_between_label_distance(
distance,
labels,
ax=None,
permutation_test=True,
n_permute=5000,
fontsize=18,
**kwargs,
):
"""Create a heatmap indicating average between label distance
Args:
distance: (pandas dataframe) brain_distance matrix
labels: (pandas dataframe) group labels
ax: axis to plot (default=None)
permutation_test: (boolean)
n_permute: (int) number of samples for permuation test
fontsize: (int) size of font for plot
Returns:
f: heatmap
out: pandas dataframe of pairwise distance between conditions
within_dist_out: average pairwise distance matrix
mn_dist_out: (optional if permutation_test=True) average difference in distance between conditions
p_dist_out: (optional if permutation_test=True) p-value for difference in distance between conditions
"""
labels = np.unique(np.array(labels))
out = pd.DataFrame(columns=["Distance", "Group", "Comparison"], index=None)
for i in labels:
for j in labels:
tmp_b = pd.DataFrame(columns=out.columns, index=None)
if (
distance.loc[labels == i, labels == j].shape[0]
== distance.loc[labels == i, labels == j].shape[1]
):
tmp_b["Distance"] = distance.loc[labels == i, labels == i].values[
np.triu_indices(sum(labels == i), k=1)
]
else:
tmp_b["Distance"] = distance.loc[
labels == i, labels == j
].values.flatten()
tmp_b["Comparison"] = j
tmp_b["Group"] = i
out = out.append(tmp_b)
within_dist_out = pd.DataFrame(
np.zeros((len(out["Group"].unique()), len(out["Group"].unique()))),
columns=out["Group"].unique(),
index=out["Group"].unique(),
)
for i in out["Group"].unique():
for j in out["Comparison"].unique():
within_dist_out.loc[i, j] = out.loc[
(out["Group"] == i) & (out["Comparison"] == j)
]["Distance"].mean()
if ax is None:
_, ax = plt.subplots(1)
else:
plt.figure()
if permutation_test:
mn_dist_out = pd.DataFrame(
np.zeros((len(out["Group"].unique()), len(out["Group"].unique()))),
columns=out["Group"].unique(),
index=out["Group"].unique(),
)
p_dist_out = pd.DataFrame(
np.zeros((len(out["Group"].unique()), len(out["Group"].unique()))),
columns=out["Group"].unique(),
index=out["Group"].unique(),
)
for i in out["Group"].unique():
for j in out["Comparison"].unique():
tmp1 = out.loc[
(out["Group"] == i) & (out["Comparison"] == i), "Distance"
]
tmp2 = out.loc[
(out["Group"] == i) & (out["Comparison"] == j), "Distance"
]
s = two_sample_permutation(tmp1, tmp2, n_permute=n_permute)
mn_dist_out.loc[i, j] = s["mean"]
p_dist_out.loc[i, j] = s["p"]
sns.heatmap(mn_dist_out, ax=ax, square=True, **kwargs)
sns.heatmap(
mn_dist_out,
mask=p_dist_out > 0.05,
square=True,
linewidth=2,
annot=True,
ax=ax,
cbar=False,
)
return (out, within_dist_out, mn_dist_out, p_dist_out)
else:
sns.heatmap(within_dist_out, ax=ax, square=True, **kwargs)
return (out, within_dist_out)
[docs]def plot_silhouette(
distance, labels, ax=None, permutation_test=True, n_permute=5000, **kwargs
):
"""Create a silhouette plot indicating between relative to within label distance
Args:
distance: (pandas dataframe) brain_distance matrix
labels: (pandas dataframe) group labels
ax: axis to plot (default=None)
permutation_test: (boolean)
n_permute: (int) number of samples for permuation test
Optional keyword args:
figsize: (list) dimensions of silhouette plot
colors: (list) color triplets for silhouettes. Length must equal number of unique labels
Returns:
# f: heatmap
# out: pandas dataframe of pairwise distance between conditions
# within_dist_out: average pairwise distance matrix
# mn_dist_out: (optional if permutation_test=True) average difference in distance between conditions
# p_dist_out: (optional if permutation_test=True) p-value for difference in distance between conditions
"""
# Define label set
labelSet = np.unique(np.array(labels))
n_clusters = len(labelSet)
# Set defaults for plot design
if "colors" not in kwargs.keys():
colors = sns.color_palette("hls", n_clusters)
if "figsize" not in kwargs.keys():
figsize = (6, 4)
# Compute silhouette scores
out = pd.DataFrame(columns=("Label", "MeanWit", "MeanBet", "Sil"))
for index in range(len(labels)):
label = labels.iloc[index]
sameIndices = [
i for i, labelcur in enumerate(labels) if (labelcur == label) & (i != index)
]
within = distance.iloc[index, sameIndices].values.flatten()
otherIndices = [i for i, labelcur in enumerate(labels) if (labelcur != label)]
between = distance.iloc[index, otherIndices].values.flatten()
silhouetteScore = (np.mean(between) - np.mean(within)) / max(
np.mean(between), np.mean(within)
)
out_tmp = pd.DataFrame(columns=out.columns)
out_tmp.at[index] = index
out_tmp["Label"] = label
out_tmp["MeanWit"] = np.mean(within)
out_tmp["MeanBet"] = np.mean(between)
out_tmp["Sil"] = silhouetteScore
out = out.append(out_tmp)
sample_silhouette_values = out["Sil"]
# Plot
with sns.axes_style("white"):
if ax is None:
_, ax = plt.subplots(1, figsize=figsize)
else:
plt.plot(figsize=figsize)
x_lower = 10
labelX = []
for labelInd in range(n_clusters):
label = labelSet[labelInd]
ith_cluster_silhouette_values = sample_silhouette_values[labels == label]
ith_cluster_silhouette_values.sort_values(inplace=True)
size_cluster_i = ith_cluster_silhouette_values.shape[0]
x_upper = x_lower + size_cluster_i
color = colors[labelInd]
with sns.axes_style("white"):
plt.fill_between(
np.arange(x_lower, x_upper),
0,
ith_cluster_silhouette_values,
facecolor=color,
edgecolor=color,
)
labelX = np.hstack((labelX, np.mean([x_lower, x_upper])))
x_lower = x_upper + 3
# Format plot
ax.set_xticks(labelX)
ax.set_xticklabels(labelSet)
ax.set_title("Silhouettes", fontsize=18)
ax.set_xlim([5, 10 + len(labels) + n_clusters * 3])
# Permutation test on mean silhouette score per label
if permutation_test:
outAll = pd.DataFrame(columns=["label", "mean", "p"])
for labelInd in range(n_clusters):
temp = pd.DataFrame(columns=outAll.columns)
label = labelSet[labelInd]
data = sample_silhouette_values[labels == label]
temp.loc[labelInd, "label"] = label
temp.loc[labelInd, "mean"] = np.mean(data)
if np.mean(data) > 0: # Only test positive mean silhouette scores
statsout = one_sample_permutation(data, n_permute=n_permute)
temp["p"] = statsout["p"]
else:
temp["p"] = 999
outAll = outAll.append(temp)
return outAll
else:
return
def component_viewer(output, tr=2.0):
"""This a function to interactively view the results of a decomposition analysis
Args:
output: (dict) output dictionary from running Brain_data.decompose()
tr: (float) repetition time of data
"""
if ipywidgets is None:
raise ImportError(
"ipywidgets is required for interactive plotting. Please install this package manually or install nltools with optional arguments: pip install 'nltools[interactive_plots]'"
)
def component_inspector(component, threshold):
"""This a function to be used with ipywidgets to interactively view a decomposition analysis
Make sure you have tr and output assigned to variables.
Example:
from ipywidgets import BoundedFloatText, BoundedIntText
from ipywidgets import interact
tr = 2.4
output = data_filtered_smoothed.decompose(algorithm='ica', n_components=30, axis='images', whiten=True)
interact(component_inspector, component=BoundedIntText(description='Component', value=0, min=0, max=len(output['components'])-1),
threshold=BoundedFloatText(description='Threshold', value=2.0, min=0, max=4, step=.1))
"""
_, ax = plt.subplots(nrows=3, figsize=(12, 8))
thresholded = (
output["components"][component] - output["components"][component].mean()
) * (1 / output["components"][component].std())
thresholded.data[np.abs(thresholded.data) <= threshold] = 0
plot_stat_map(
thresholded.to_nifti(),
cut_coords=range(-40, 70, 10),
display_mode="z",
black_bg=True,
colorbar=True,
annotate=False,
draw_cross=False,
axes=ax[0],
)
if isinstance(output["decomposition_object"], (sklearn.decomposition.PCA)):
var_exp = output["decomposition_object"].explained_variance_ratio_[
component
]
ax[0].set_title(
f"Component: {component}/{len(output['components'])}, Variance Explained: {var_exp:2.2}",
fontsize=18,
)
else:
ax[0].set_title(
f"Component: {component}/{len(output['components'])}", fontsize=18
)
ax[1].plot(output["weights"][:, component], linewidth=2, color="red")
ax[1].set_ylabel("Intensity (AU)", fontsize=18)
ax[1].set_title(f"Timecourse (TR={tr})", fontsize=16)
y = fft(output["weights"][:, component])
f = fftfreq(len(y), d=tr)
ax[2].plot(f[f > 0], np.abs(y)[f > 0] ** 2)
ax[2].set_ylabel("Power", fontsize=18)
ax[2].set_xlabel("Frequency (Hz)", fontsize=16)
ipywidgets.interact(
component_inspector,
component=ipywidgets.BoundedIntText(
description="Component", value=0, min=0, max=len(output["components"]) - 1
),
threshold=ipywidgets.BoundedFloatText(
description="Threshold", value=2.0, min=0, max=4, step=0.1
),
)