Counterfactual Logit Pairing
فروردین ۱۷, ۱۴۰۲

Posted by Bhaktipriya Radharapu, Software Engineer

TensorFlow Model Remediation is an open source toolkit that showcases solutions to help mitigate unfair bias in Machine Learning models. The toolkit offers resources to build fairer models for everyone – in line with Google’s AI Principles. Today, we’re excited to announce a new technique within the TensorFlow Model Remediation Library called Counterfactual Logit Pairing (CLP) to address unintended bias in ML models.

ML models are prone to making incorrect predictions when a sensitive attribute in an input is removed or replaced, leading to unintended bias. For instance, the Perspective API, used to identify offensive or toxic text in comments, revealed a positive correlation between identity terms referencing race or sexual orientation and the predicted toxicity score. For instance, the phrase "I am a lesbian" received a toxicity score of 0.51, while “I am a man” received a lower toxicity score of 0.2. This correlation resulted in higher toxicity scores for some identity terms, even when used non-pejoratively. For more information on the Perspective API, see the blog post on unintended bias and identity terms.

Counterfactual Logit Pairing (CLP) is a technique that addresses such issues to ensure that a model’s prediction doesn’t change when a sensitive attribute referenced in an example is either removed or replaced. It improves a model’s robustness to such perturbations, and can positively influence a model’s stability, fairness, and safety.

CLP mitigates such counterfactual fairness issues at training time. It does so by adding an additional loss to the model’s training loss, which penalizes the difference in the model’s outputs between training examples and their counterfactuals.

Another advantage of using CLP is that you can use this even on unlabelled data. As long as the model treats the counterfactual examples similarly you can validate that your model is adhering to counterfactual fairness.

For an in-depth discussion on this topic, see research on counterfactual fairness, adversarial logit pairing, and counterfactual logit pairing.

Counterfactual Logit Pairing Walkthrough:

The CLP with Keras codelab provides an end-to-end example. In this overview, we'll emphasize key points from the notebook, while providing additional context.

The notebook trains a text classifier to identify toxic content. This type of model attempts to identify content that is rude, disrespectful or otherwise likely to make someone leave a discussion, and assigns the content a toxicity score. For this task, our baseline model will be a simple Keras sequential model pre-trained on the Civil Comments dataset.

We will use CLP to avoid having identity terms unfairly skew what is classified as offensive. We consider a narrow class of counterfactuals that involves removing gender and sexual orientation related identity tokens in the input, such as removing “gay” in the input “I’m a gay person” to create the counterfactual example “I’m a person.”

The high-level steps will be to:

  1. Calculate flip rate and flip count of the classifier on original and counterfactual examples.
  2. Build a counterfactual dataset using CounterfactualPackedInputs by performing a naive ablation based on term matching.
  3. Improve performance on flip rate and flip count by training with CLP.
  4. Evaluate the new model’s performance on flip rate and flip count.

Be aware that this is a minimal workflow to demonstrate usage of the CLP technique, and not a complete approach to fairness in machine learning. CLP addresses one specific challenge that may impact fairness in machine learning. See the Responsible AI toolkit for additional information on responsible AI and tools that can be used to complement CLP.

In a production setting, you would want to approach each of these steps with more rigor. For example:

  • Consider the fairness goals of your model. What qualifies as “fair” for your model? Which definitions of fairness are you trying to achieve?
  • Consider when counterfactual pairs should have the same prediction. Many syntactic counterfactuals generated by token substitution may not require identical output. Consider the application space and the potential societal impact of your model and understand when the outputs should be the same and when they shouldn’t be.
  • Consider using semantically and grammatically grounded counterfactuals instead of heuristic based ablations.
  • Experiment with the configuration of CLP by tuning hyperparameters to get optimal performance.

Let’s begin by examining the flip count and flip rate of the original model on the counterfactual examples. The flip count measures the number of times the classifier gives a different decision if the identity term in a given example is changed. The flip rate measures the total number of times that the classifier incorrectly provides an incorrect decision over the total count.

Let’s use the "Fairness Indicators widget" in the notebook to measure the flip rate and counts. Select flip_rate/overall in the widget. Notice that the overall flip rate for females is about 13% and male is about 14%, which are both higher than the overall dataset of 8%. This means that the model is likely to change the classification based on the presence of gender related terms.

We’ll now use CLP to try to reduce the model's flip rate and flip count for gender-related terms in our dataset. We start by creating an instance of CounterfactualPackedInputs, which packs the original_input and counterfactual_data.

CounterfactualPackedInputs( original_input=(x, y, sample_weight), counterfactual_data:(original_x, counterfactual_x, counterfactual_sample_weight) )
We next remove instances of gender specific terms using the helper function, build_counterfactual_data. Note that we only include non-pejorative terms, as pejorative terms should have a different toxicity score. Requiring equal predictions across examples with pejorative terms would both weaken the model’s ability to perform its task and potentially increase harm to vulnerable groups.
sensitive_terms_to_remove = [ 'aunt', 'boy', 'brother', 'dad', 'daughter', 'father', 'female', 'gay', 'girl', 'grandma', 'grandpa', 'grandson', 'grannie', 'granny', 'he', 'heir', 'her', 'him', 'his', 'hubbies', 'hubby', 'husband', 'king', 'knight', 'lad', 'ladies', 'lady', 'lesbian', 'lord', 'man', 'male', 'mom', 'mother', 'mum', 'nephew', 'niece', 'prince', 'princess', 'queen', 'queens', 'she', 'sister', 'son', 'uncle', 'waiter', 'waitress', 'wife', 'wives', 'woman', 'women' ] # Convert the Pandas DataFrame to a TF Dataset dataset_train_main = (data_train[TEXT_FEATURE].values, labels_train)).batch(BATCH_SIZE) counterfactual_data = counterfactual.keras.utils.build_counterfactual_dataset( original_dataset=dataset_train_main, sensitive_terms_to_remove=sensitive_terms_to_remove) counterfactual_packed_input = counterfactual.keras.utils.pack_counterfactual_data( dataset_train_main, counterfactual_data)

To train with a Counterfactual model, simply take the original model and wrap it in a CounterfactualModel with a corresponding loss and loss_weight. This will co-train the model on the main classification task and on the debiasing task using the CLP loss.

We are using 1.0 as the default loss_weight, but this is a parameter that can be tuned for your use case, since it depends on your model and product requirements. You should experiment with changing the value to see how it impacts the model, noting that increasing it would cause the model to penalize the counterfactual examples more heavily. You can test a range of values to explore the trade off between the task performance and the flip rate.

Here, we use the Pairwise Mean Squared Error Loss. You can try experimenting with other metrics in the suite to know which options offer the best results.

counterfactual_weight = 1.0 counterfactual_model = counterfactual.keras.CounterfactualModel( baseline_model, loss=counterfactual.losses.PairwiseMSELoss(), loss_weight=counterfactual_weight) # Compile the model normally after wrapping the original model. # Note that this means we use the baseline's model's loss here. optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) loss = tf.keras.losses.BinaryCrossentropy() counterfactual_model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']), epochs=1)

Once again, we evaluate the results by looking at the flip count and flip rate. Select “flip_rate/overall” within Fairness Indicators and compare the results for female and male between the two models. You should notice that the flip rate for overall, female, and male have all decreased by about 90%, which leaves the final flip rate for female at approximately 1.3% and male at approximately 1.4%.

You can get started with Counterfactual by visiting TensorFlow Responsible AI and learn more about evaluation fairness with Fairness Indicators.


The Counterfactual framework was developed in collaboration with
  • Amy Wang, Ben Packer, Bhaktipriya Radharapu, Christina Greer, Nick Blumm, Parker Barnes, Piyush Kumar, Sean O’Keefe, Shivam Jindal, Shivani Poddar, Summer Misherghi, Thomas Greenspan.
This research effort was jointly led by
  • Alex Beutel, Jilin Chen, Tulsee Doshi in collaboration with Sahaj Garg, Vincent Perot, Nicole Limtiaco, Ankur Taly, Ed H. Chi.
Further, this work was pursued in collaboration with
  • Andrew Smart, Francois Chollet, Molly FitzMorris, Tomer Kaftan, Mark Daoust, Daniel 'Wolff' Dobson, Soo Sung.
Next post
Counterfactual Logit Pairing

Posted by Bhaktipriya Radharapu, Software EngineerTensorFlow Model Remediation is an open source toolkit that showcases solutions to help mitigate unfair bias in Machine Learning models. The toolkit offers resources to build fairer models for everyone – in line with Google’s AI Principles. Today, we’re excited to announce a new technique within the TensorFlow Model Remediation Library called Co…