Causal Explanations for Deep Neural Networks

Deepankar
4 min readJun 28, 2021

--

Today we are going to learn about model interpretations using the Explainable AI (XAI) method CXplain (causal explanations). To understand the topic, one must have a knowledge of machine/deep learning models, i.e. how a particular set of input features influence the output.

The aim of the CXplain model is to estimate the importance of the features in order to predict the correct output. CXplain is capable of providing explanations for the decisions of machine-learning models as a causal learning task. To interpret the model one has to train a CXplain model first, once it is trained, it can be used to explain the target model in little time, and enables the quantification of the uncertainty associated with its feature importance estimates via bootstrap ensembling.

In this era, everyone is interested in deep learning methods and using them very frequently without knowing what is going on inside the model, i.e. User/Data scientist treat deep learning models as a complete black box. However, complex models, such as ensemble models and deep neural networks, need to be explained thoroughly so that one can rely on the model’s output.

CXplain does the above task with the power of estimating the feature importance as well as tells about uncertainty in those estimates using confidence intervals. Let’s have a glimpse of what is confidence interval:

Confidence intervals are a way of quantifying the uncertainty of an estimate.

For example, a confidence interval could be used in presenting the skill of a classification model, which could be stated as:

Given the sample, there is an 80% likelihood that the range x to y covers the true model accuracy.

  • Smaller Confidence Interval: A more precise estimate.
  • Larger Confidence Interval: A less precise estimate.

Nonparametric Confidence Interval

When to apply: The assumptions that underlie parametric confidence intervals are often violated. The predicted variable sometimes isn’t normally distributed, and even when it is, the variance of the normal distribution might not be equal at all levels of the predictor variable.

In these cases, the bootstrap resampling method can be used as a nonparametric method for calculating confidence intervals, nominally called bootstrap confidence intervals. This is used by the CXplain model for uncertainty.

Let’s Code CXplain for a deep learning model

We will be using the famous MNIST dataset. This dataset contains the hand-written numbers from 0 to 9, i.e. the dataset has 10 classes and the model will predict the number in an image.

Below are the 10–10 training images of each class of the MNIST dataset.

Note: Google Colab is recommended. The code for this article can be found here. GPU runtime is required for smooth & fast execution of the code. We can reduce the training and testing data for faster execution.

Steps to train the CXplain model

1. Install the required libraries

2. Import the libraries

3. Define the CXplain model architecture

4. Load the CNN model which we have to explain using CXplain

5. Train the CXplain model

6. Visualize Attributions & Confidence

The above figure represents that CXplain’s uncertainty is low that means CXplain able to explain the CNN model very well. Attributions represent a feature that the CNN model takes into account to classify the image.

References

  • CXPlain: Causal Explanations for Model Interpretation under Uncertainty
  • CXplain codes
  • My code from scratch for the MNIST dataset
  • The MNIST dataset

--

--