Counterfactuals, an active area of research in machine learning explainability, are explanations that produce actionable steps to move a data point from one side of a decision boundary to another. These explanations have a clear use-case for several applications ranging from loan decisions (example shown below) to healthcare diagnosis.
- Problem: A binary classifier, being used by a financial services company, predicts if an individual should be approved (1) or denied for a loan (0). Individuals want to know how to get approved for a loan if the model predicts that they should be rejected.
- Counterfactual Explanation: We can now provide a series of steps (as changes in the input features) that can help an individual get approved for a loan, e.g. add $10,000 to your salary and gain 2 more years of education.
In practice, we might have many more requirements of CFEs beyond just finding the other side of the decision boundary. For example, we might need the features to be constrained to only certain actionable sets, or we might need the resultant counterfactual to be realistic and similar to the training data. Recent work has summarized these various desiderata and the common research themes in the field. Additionally, we need to be able to compute counterfactuals in way that is computationally efficient for high-data-volume uses cases. In this article, we present a distributed and scalable reinforcement learning framework that can produce real-time counterfactual explanations for any binary classifier. We provide an overview of the algorithm used to find counterfactual explanations in real-time, and implementation details of how we've used Open AI Gym and Ray (RLLib) to put this work into practice.
We define a counterfactual explanation (CFE) as follows:
One of the key shortcomings of most counterfactual approaches is that they are computed on each individual instance in a data set. This means that for any set of inferences, we must solve an optimization problem to find a counterfactual explanation that has the following properties:
Solving a new optimization problem for each data point can be expensive, preventing us from creating real-time counterfactual explanations. Our approach creates a model that allows us to pay for this expensive training upfront for the entire dataset and produce fast inferences for any new data point that needs a counterfactual explanation.
Reinforcement Learning Framework for CFEs
Based on our goal of achieving near real-time explanations, we considered a reinforcement learning framework [Verma, Hines, Dickerson] that allows us to do a one-time training of the model and produce explanations in real-time. In this section, we will present a brief overview of reinforcement learning and how we applied it to counterfactual explanations.
Reinforcement Learning Overview
Reinforcement Learning is a machine learning framework that allows an agent to interactively "learn" a best set of actions for a desired state in a given environment. Reinforcement Learning has had the most success in robotics and game-playing scenarios (we have seen RL models beat the best players in the world at Poker and Go).
A few key terms will be used through this article as we model our counterfactuals problem:
For the purposes of this article, we do not provide a more comprehensive overview of a reinforcement learning framework. I would highly recommend reading [this blog], which provides such an overview.
Reinforcement Learning for CFEs
Figure 1: Decision Boundary for Arbitrary Dataset. Note the binary classifier is not perfect and mislabels a few points, namely the blue points on the right side of the graph.
To frame this problem as a reinforcement learning problem, we define the following:
Figure 2 shows a possible path that a point can take to find a CFE.
Figure 2: Sample path that an agent can take in our environment.
Open AI Gym
Open AI Gym provides a framework that allows us to create an environment for an agent to interact with. This is the default standard for defining RL environments (and already comes with a set of pre-defined environments for different tasks), however you have the ability to define your own environment. In order for us to learn a model that achieves our RL task, we must fit it into the Open AI Gym framework.
1. Observation Space: You must define what every state will look like for your RL environment. In Python, all observation spaces must be defined by one of the following types: Tuple, Discrete, Box, Dict, Multi-Binary, Multi-Discrete (view here).
2. Action Space: You must also define what an action will look for any given agent. It is similarly defined by the same types as observation spaces (view here).
3. Step Function: We define a function that when given an action (from the action space) and a state, it will be able to return the reward produced by taking this action. It will also return the new state the agent moved because of this action (this may or may not be deterministic).
4. Reset Function: Once the agent has reach a desired state (or we want to start over), we need to be able to reset to some starting state. We must define the policy for determining the starting state in the reset function.
We can now formally define our FastCFE (fast counterfactual explanations) class, which has all the above components defined in one python class.
Training the Model using Ray + Rllib
Ray and Rllib
Ray is a new distributed framework in python designed to distribute training tasks across any cluster. Rllib is a specific package within Ray that is designed to train different RL agents for different environments. Rllib has a variety of different optimization algorithms, and provides a configurable dictionary that allows us to distribute training and evaluation to different cores and machines very easily. Furthermore, it provides an easy API that allows use to modify the internal of the Deep RL Optimizer. Rllib is maintained by Anyscale (founded out of Berkeley RiseLab) and it one of the state-of-the art frameworks for distributed computing/machine learning.
We opted to use the Proximal Policy Optimizer (a Deep RL Algorithm) because of its favorable balance of faster training times and simplicity. We needed an algorithm that would train relatively fast and that could distribute fairly simply, both of which PPO provides out of the box. We provide our pseudo-code that provides a FastCFE specific wrapper around a native Rllib PPO optimizer.
For our implementation, we opted to use the Proximal Policy Optimizer (a Deep RL Algorithm), which outperforms other online policy gradient methods, and overall strikes a favorable balance between sample complexity, simplicity, and wall-time [Schulman et. al]. Rllib provides an out of the box way to use PPO and distribute it across a cluster and on your laptop. Below, we showcase some pseudo for how we wrapped our optimizer around Rllib to provide an easy package to use Rllib and our FastCFE model:
We want to showcase two major benchmarks for this algorithm and implementation:
- Performance Metrics: We want to understand how our FastCFE approach using RL compares against other methods for performance metrics (described later).
- Training Time: We want to see how much savings we achieved when using a Distributed RL Framework like Ray against a single-threaded research implementation.
These benchmarks for implemented for a variety of combinations of the following credit risk datasets. All of these datasets contained some from of credit risk data and the models are binary classifiers that predict a single applicant (one row) should be accepted or rejected for a loan. We want to find counterfactual explanations for all the rejected applicants. The sizes of the datasets are shown below (number of rows by number of columns):
- German Credit: 1,000 data-points x 21 features
- Adult Credit: 48,842 data-points x 14 features
- Credit Default: 30,000 data-points x 25 features
- Small Encoded Credit Dataset: 332,000 data-points x 44 features
- Large Encoded Credit Dataset: 332,000 data-points x 652 features
The first three datasets (German Credit, Adult Credit, Credit Default) are all open source datasets with the links provided above. The last two datasets are proprietary datasets with obfuscated column names. These datasets were larger and tested the scaleable of our implementation.
We want to see how our FastCFE model compares to other well-known methods. Specifically, we are focusing on the following two metrics:
The results shown below are FastCFE against a number of state of the art counterfactual explanation methods:
As we see here, we perform nearly as close as the best method (Dice-Random) across these three different datasets. Furthermore, we have inference times of up to 20x faster than Dice-Random.
The first implementation of this project was done using a package called StablesBaseline3, and was naively computed trained on a single machine. This section wants to showcase the change in training time of our scaleable and distributed Rllib implementation against a naive implementation. The results are shown below:
We achieve a nearly 6x savings in train time and can handle much larger datasets than we could through our naive implementation. This shows the promise and power of using a scaleable and distributed reinforcement learning framework — we can significantly reduce training times which is a major bottleneck for several reinforcement learning applications.
We hope this article provided an overview into the following ideas/concepts:
- Counterfactual Explanations: What they are and how they are useful for industrial and explainability applications
- Reinforcement Learning Implementation: How we implement a production level reinforcement learning model.
- Power of Distribution: How we can achieve tremendous savings by using scalable and distributed reinforcement learning frameworks such as Rllib.
We hope that we provided some interesting ideas and some starter code to help you make your own Reinforcement Learning model. If you would like to learn more about this article, please reach out!