Evaluation of explanation quality: feature importance vectors#

In this notebook, we are going to explore how we can use teex to evaluate feature importance explanations

[1]:
from teex.featureImportance.data import SenecaFI, lime_to_feature_importance, scale_fi_bounds
from teex.featureImportance.eval import feature_importance_scores

from lime.lime_tabular import LimeTabularExplainer

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import fbeta_score

import numpy as np

We are going to

  1. Generate synthetic data with feature importance explanations

  2. Train a black box model on the data

  3. Generate LIME explanations of the data

  4. Evaluate the LIME explanations agains the ground truth

1. Generating the data#

[8]:
dataGen = SenecaFI(nSamples=300, nFeatures=4, randomState=0)

X, y, exps = dataGen[:]
[9]:
exps[0]
[9]:
array([0.9735, 0.2541, 0.8937, 0.4596], dtype=float32)

2. Training a black box model on the data#

[10]:
# split data
Xtr, Xte = X[:200], X[200:]
ytr, yte = y[:200], y[200:]
etr, ete = exps[:200], exps[200:]
[11]:
model = RandomForestClassifier()
model.fit(Xtr, ytr)
[11]:
RandomForestClassifier()
[12]:
# F1 Score
print('Train F1: ', fbeta_score(model.predict(Xtr), ytr, beta=1))
print('Test F1: ', fbeta_score(model.predict(Xte), yte, beta=1))
Train F1:  1.0
Test F1:  0.8378378378378379

The classifier fits the data quite nicely.

3. Generating LIME explanations#

We first instance the explainer

[13]:
explainer = LimeTabularExplainer(Xtr, feature_names=dataGen.featureNames, mode='classification')

# sample explanation
explainer.explain_instance(Xte[0], model.predict_proba).show_in_notebook()

and generate the explanations for the test set. Unfortunately, with LIME we have to generate explanations 1 by 1. We use a function for transforming LIME explanations into feature importance vectors. This function does not transform the data.

[14]:
limeExps = []

# takes a few moments to run
for testObs in Xte:
    exp = explainer.explain_instance(testObs, model.predict_proba)
    limeExps.append(lime_to_feature_importance(exp, nFeatures=4))

limeExps = np.array(limeExps)
[15]:
limeExps[:5]
[15]:
array([[-0.07230919, -0.0801119 ,  0.09253925,  0.10824497],
       [-0.18176845, -0.03213522, -0.01890396, -0.13922554],
       [-0.07199723,  0.03484774, -0.06955719, -0.11750789],
       [ 0.12277604,  0.02374334, -0.06775021,  0.08342091],
       [ 0.10764252, -0.07811511, -0.01083719,  0.09821577]])

4. Evaluating LIME explanations#

Now that the explanations are computed, we can evaluate them agains the ground truth explanations. See how LIME explanations are not bounded in the (-1, 1) range, so comparison would not valid. For this reason, teex’s FI evaluation method scales each feature to the (-1, 1) or (0, 1) range if they are not already in the range.

[16]:
metrics = ['fscore', 'cs', 'auc', 'prec', 'rec']
avgMetrics = feature_importance_scores(ete, limeExps, metrics)
/Users/master/Google Drive/U/4t/TFG/teex/venv/lib/python3.8/site-packages/teex/featureImportance/eval.py:80: UserWarning: A binary prediction contains uniform values, so one entry has been randomly flipped for the metrics to be defined.
  warnings.warn('A binary prediction contains uniform values, so one entry has been randomly flipped '
[17]:
avgMetrics
[17]:
array([0.34333327, 0.50356585, 0.51      , 0.58      , 0.24666668],
      dtype=float32)

We can also not average the metrics across the observations:

[19]:
allMetrics = feature_importance_scores(ete, limeExps, metrics, average=False)
/Users/master/Google Drive/U/4t/TFG/teex/venv/lib/python3.8/site-packages/teex/featureImportance/eval.py:80: UserWarning: A binary prediction contains uniform values, so one entry has been randomly flipped for the metrics to be defined.
  warnings.warn('A binary prediction contains uniform values, so one entry has been randomly flipped '
[20]:
allMetrics[:5]
[20]:
array([[0.6666667 , 0.34652364, 0.5       , 1.        , 0.5       ],
       [0.        , 0.7850784 , 0.5       , 0.        , 0.        ],
       [0.        , 0.85806304, 1.        , 0.        , 0.        ],
       [0.6666667 , 0.44639328, 0.5       , 1.        , 0.5       ],
       [0.        , 0.37971714, 0.33333334, 0.        , 0.        ]],
      dtype=float32)