Multivariate Prediction

Running MVPA style analyses using multivariate regression is even easier and faster than univariate methods. All you need to do is specify the algorithm and cross-validation parameters. Currently, we have several different linear algorithms implemented from scikit-learn.

Load Data

First, let’s load the pain data for this example. We need to specify the training levels. We will grab the pain intensity variable from the data.X field.

from nltools.datasets import fetch_pain

data = fetch_pain()
data.Y = data.X['PainLevel']

Prediction with Cross-Validation

We can now predict the output variable is a dictionary of the most useful output from the prediction analyses. The predict function runs the prediction multiple times. One of the iterations uses all of the data to calculate the ‘weight_map’. The other iterations are to estimate the cross-validated predictive accuracy.

stats = data.predict(algorithm='ridge',
                    cv_dict={'type': 'kfolds','n_folds': 5,'stratified':data.Y})
  • Prediction
  • plot multivariate prediction

Out:

/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass shuffle=False, random_state=None as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error
  warnings.warn(f"Pass {args_msg} as keyword args. From version "
overall Root Mean Squared Error: 0.00
overall Correlation: 1.00
overall CV Root Mean Squared Error: 0.56
overall CV Correlation: 0.74
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(

Display the available data in the output dictionary

stats.keys()

Out:

dict_keys(['Y', 'yfit_all', 'intercept', 'weight_map', 'yfit_xval', 'intercept_xval', 'weight_map_xval', 'cv_idx', 'rmse_all', 'r_all', 'rmse_xval', 'r_xval'])

Plot the multivariate weight map

stats['weight_map'].plot()
plot multivariate prediction

Return the cross-validated predicted data

stats['yfit_xval']

Out:

array([2.429143  , 1.6314082 , 1.8126671 , 3.5586872 , 0.9020816 ,
       1.2204387 , 2.7179723 , 1.3614657 , 1.6188056 , 3.3171613 ,
       0.8049244 , 1.515244  , 2.9259906 , 1.4362226 , 1.7737616 ,
       1.9213707 , 1.8017511 , 1.1947873 , 2.4830327 , 0.893635  ,
       1.8442726 , 2.4820771 , 1.4400073 , 1.4230402 , 4.0127153 ,
       1.4100523 , 2.3947031 , 1.8863165 , 1.3747299 , 2.1128185 ,
       2.119888  , 1.1560311 , 2.0790844 , 3.9388707 , 1.2056164 ,
       2.1786842 , 3.2859573 , 1.6076977 , 1.677222  , 2.524438  ,
       1.4692608 , 2.2496471 , 2.2506433 , 1.5619733 , 1.9757571 ,
       3.2622285 , 1.2873312 , 2.0369186 , 1.7596395 , 0.9663484 ,
       2.233838  , 1.7051153 , 1.8357867 , 0.75496656, 2.624507  ,
       0.991569  , 2.3120227 , 2.2533536 , 1.9480464 , 2.0109935 ,
       2.0345416 , 1.0215064 , 1.3073304 , 2.4055963 , 1.38359   ,
       1.560549  , 2.3911996 , 1.6636779 , 2.0051494 , 2.4855537 ,
       1.2744507 , 2.5009933 , 2.8917952 , 1.9525683 , 1.7255962 ,
       3.1939049 , 1.4861531 , 1.9714726 , 2.1241016 , 1.5545266 ,
       1.5466385 , 2.7263155 , 1.1705334 , 2.1081424 ], dtype=float32)

Algorithms

There are several types of linear algorithms implemented including: Support Vector Machines (svr), Principal Components Analysis (pcr), and penalized methods such as ridge and lasso. These examples use 5-fold cross-validation holding out the same subject in each fold.

subject_id = data.X['SubjectID']
svr_stats = data.predict(algorithm='svr',
                        cv_dict={'type': 'kfolds','n_folds': 5,
                        'subject_id':subject_id}, **{'kernel':"linear"})
  • Prediction
  • plot multivariate prediction

Out:

overall Root Mean Squared Error: 0.10
overall Correlation: 0.99
overall CV Root Mean Squared Error: 0.88
overall CV Correlation: 0.57
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(

Lasso Regression

lasso_stats = data.predict(algorithm='lasso',
                        cv_dict={'type': 'kfolds','n_folds': 5,
                        'subject_id':subject_id}, **{'alpha':.1})
  • Prediction
  • plot multivariate prediction

Out:

overall Root Mean Squared Error: 0.69
overall Correlation: 0.58
overall CV Root Mean Squared Error: 0.74
overall CV Correlation: 0.43
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(

Principal Components Regression

pcr_stats = data.predict(algorithm='pcr',
                        cv_dict={'type': 'kfolds','n_folds': 5,
                        'subject_id':subject_id})
  • Prediction
  • plot multivariate prediction

Out:

overall Root Mean Squared Error: 0.00
overall Correlation: 1.00
overall CV Root Mean Squared Error: 0.91
overall CV Correlation: 0.58
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(

Principal Components Regression with Lasso

pcr_stats = data.predict(algorithm='lassopcr',
                        cv_dict={'type': 'kfolds','n_folds': 5,
                        'subject_id':subject_id})
  • Prediction
  • plot multivariate prediction

Out:

overall Root Mean Squared Error: 0.48
overall Correlation: 0.84
overall CV Root Mean Squared Error: 0.73
overall CV Correlation: 0.54
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
  warnings.warn(

Cross-Validation Schemes

There are several different ways to perform cross-validation. The standard approach is to use k-folds, where the data is equally divided into k subsets and each fold serves as both training and test. Often we want to hold out the same subjects in each fold. This can be done by passing in a vector of unique subject IDs that correspond to the images in the data frame.

subject_id = data.X['SubjectID']
ridge_stats = data.predict(algorithm='ridge',
                        cv_dict={'type': 'kfolds','n_folds': 5,'subject_id':subject_id},
                        plot=False, **{'alpha':.1})

Out:

/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.7471e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.60058e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.06616e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.74861e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.13513e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.35695e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
overall Root Mean Squared Error: 0.00
overall Correlation: 1.00
overall CV Root Mean Squared Error: 0.91
overall CV Correlation: 0.58

Sometimes we want to ensure that the training labels are balanced across folds. This can be done using the stratified k-folds method.

ridge_stats = data.predict(algorithm='ridge',
                        cv_dict={'type': 'kfolds','n_folds': 5, 'stratified':data.Y},
                        plot=False, **{'alpha':.1})

Out:

/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.7471e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass shuffle=False, random_state=None as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error
  warnings.warn(f"Pass {args_msg} as keyword args. From version "
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.65136e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=4.04699e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.79715e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.88579e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.84816e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
overall Root Mean Squared Error: 0.00
overall Correlation: 1.00
overall CV Root Mean Squared Error: 0.56
overall CV Correlation: 0.74

Leave One Subject Out Cross-Validaiton (LOSO) is when k=n subjects. This can be performed by passing in a vector indicating subject id’s of each image and using the loso flag.

ridge_stats = data.predict(algorithm='ridge',
                        cv_dict={'type': 'loso','subject_id': subject_id},
                        plot=False, **{'alpha':.1})

Out:

/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.7471e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.9641e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.33201e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.53165e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.36392e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.15943e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.19182e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.00879e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.49376e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.63967e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.04156e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.89777e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.75343e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.52513e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.39306e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.97886e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.06796e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.31996e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.78119e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.95224e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.01953e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.57332e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.56552e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.26117e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.90453e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=1.818e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=3.0476e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.61631e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/linear_model/_ridge.py:187: LinAlgWarning: Ill-conditioned matrix (rcond=2.72577e-08): result may not be accurate.
  dual_coef = linalg.solve(K, y, sym_pos=True,
overall Root Mean Squared Error: 0.00
overall Correlation: 1.00
overall CV Root Mean Squared Error: 0.91
overall CV Correlation: 0.59

There are also methods to estimate the shrinkage parameter for the penalized methods using nested crossvalidation with the ridgeCV and lassoCV algorithms.

import numpy as np

ridgecv_stats = data.predict(algorithm='ridgeCV',
                        cv_dict={'type': 'kfolds','n_folds': 5, 'stratified':data.Y},
                        plot=False, **{'alphas':np.linspace(.1, 10, 5)})

Out:

/opt/hostedtoolcache/Python/3.8.9/x64/lib/python3.8/site-packages/sklearn/utils/validation.py:70: FutureWarning: Pass shuffle=False, random_state=None as keyword args. From version 1.0 (renaming of 0.25) passing these as positional arguments will result in an error
  warnings.warn(f"Pass {args_msg} as keyword args. From version "
overall Root Mean Squared Error: 0.00
overall Correlation: 1.00
overall CV Root Mean Squared Error: 0.56
overall CV Correlation: 0.74

Total running time of the script: ( 1 minutes 49.844 seconds)

Gallery generated by Sphinx-Gallery