Explore how AI reward hacking has evolved into alignment faking, a more dangerous behavior where AI models pretend to be safe while hiding misaligned goals. Understand the risks, research findings, and what researchers are doing about it.

Learn what attention head pruning is, why transformer models have built-in redundancy, and how removing certain attention heads can make models smaller and faster without sacrificing performance.

You trained a massive transformer model. It runs slow, uses too much memory, and your GPU bill is climbing. The obvious answer is: use a smaller model. But what if you didn't have to start from scratch?
It turns out that many transformer models are carrying dead weight. Not every attention head is doing useful work. Some heads repeat what others already learned. Some barely activate at all. This "slack" in the architecture is not a flaw. It is a feature that helped the model train well. But once training is done, you can cut it.
That is exactly what attention head pruning does. It finds and removes the redundant heads, leaving you with a leaner model that runs faster, uses less memory, and often performs just as well on the tasks that matter.
A transformer model processes text using a mechanism called multi-head attention. Instead of having one attention layer look at the input, it splits into multiple "heads," each learning to focus on different relationships in the data.
For example, in a sentence like "The cat sat on the mat," one head might learn subject-verb relationships, another might track pronoun references, and another might focus on positional patterns.
Models like BERT-base have 12 layers with 12 heads each, giving 144 attention heads total. GPT-2 medium has 16 layers with 16 heads, totaling 256 heads.
| Model | Layers | Heads per Layer | Total Heads |
|---|---|---|---|
| BERT-base | 12 | 12 | 144 |
| BERT-large | 24 | 16 | 384 |
| GPT-2 small | 12 | 12 | 144 |
| GPT-2 medium | 24 | 16 | 384 |
| GPT-2 large | 36 | 20 | 720 |
During training, the model benefits from having many heads because they act as an ensemble. Different initializations explore different solutions, and the best ones survive through gradient descent. The redundancy helps training stability and generalization. But after training, many heads overlap or become nearly inactive.
Attention head pruning is the process of identifying and permanently removing attention heads that contribute little to model performance.
The core idea: not all 144 heads in BERT-base matter equally. Research has shown you can remove a significant portion of them with almost no drop in accuracy on downstream tasks.
There are three common pruning strategies:
Magnitude-based pruning removes heads with the smallest weight norms. The assumption is that heads with small weights are not doing much work.
Importance-score pruning computes a score for each head based on how much the model's output changes if that head is masked out. Heads with low importance scores get removed.
Structured pruning with learned gates adds a binary or soft gate to each head during fine-tuning. The model learns which gates to close, effectively learning to prune itself.
Before you prune, you need to know which heads to cut. The most practical method is the "head importance score" from Michel et al. (2019).
The idea is simple: mask out a head (set its output to zero) and measure the effect on the loss. If the loss barely changes, the head is not important.
import torch
import torch.nn as nn
def compute_head_importance(model, dataloader, num_labels):
head_importance = torch.zeros(
model.config.num_hidden_layers,
model.config.num_attention_heads
)
model.eval()
for batch in dataloader:
outputs = model(**batch, output_attentions=True)
loss = outputs.loss
loss.backward()
for layer_idx in range(model.config.num_hidden_layers):
attn = model.bert.encoder.layer[layer_idx].attention.self
# Gradient * weight gives sensitivity score
grad = attn.query.weight.grad
head_size = model.config.hidden_size // model.config.num_attention_heads
for head_idx in range(model.config.num_attention_heads):
start = head_idx * head_size
end = start + head_size
head_importance[layer_idx][head_idx] += grad[start:end].abs().mean()
return head_importanceOnce you have the importance scores, you can sort them and remove the bottom N percent.
Here is a simplified example of masking out specific heads in a Hugging Face BERT model:
from transformers import BertForSequenceClassification
import torch
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
# Define which heads to prune: {layer_index: [head_indices_to_prune]}
heads_to_prune = {
0: [3, 7, 11], # prune heads 3, 7, 11 in layer 0
3: [0, 5], # prune heads 0, 5 in layer 3
8: [2, 9, 10], # prune heads 2, 9, 10 in layer 8
}
model.prune_heads(heads_to_prune)
# Check the new config
for i, layer in enumerate(model.bert.encoder.layer):
remaining = layer.attention.self.num_attention_heads
print(f"Layer {i}: {remaining} heads remaining")Hugging Face makes this easy with the built-in prune_heads() method. It physically removes the pruned heads from the weight matrices, so you get real memory and compute savings, not just masking.
This is the key question. The good news: pruning a significant fraction of heads often costs very little accuracy.
A landmark paper by Michel et al. (2019) found that on the MNLI task, you can prune 20 percent of BERT's attention heads and lose less than 1 percent accuracy. Some heads are actually slightly harmful, and removing them improves performance.
| Pruning Level | Accuracy Drop (MNLI) | Speed Improvement |
|---|---|---|
| 0% pruned (baseline) | 0% | 1x |
| 20% pruned | less than 1% | 1.1x |
| 40% pruned | 1-2% | 1.25x |
| 60% pruned | 3-5% | 1.5x |
| 80% pruned | 8-15% | 1.8x |
The sweet spot is usually around 20-40% pruning, where you get meaningful speedup with minimal accuracy loss. Beyond 50%, the accuracy drop becomes more noticeable and task-dependent.
There are two main approaches to when you prune:
One-shot pruning removes all targeted heads at once after training. It is fast and simple but can cause a larger accuracy drop.
Iterative pruning removes a small fraction of heads, then fine-tunes the model, then prunes again. This cycle repeats until you hit your target. It takes longer but preserves accuracy much better.
# Iterative pruning workflow
Initial model
|
+--> Compute importance scores
|
+--> Prune bottom 10% of heads
|
+--> Fine-tune for 1-2 epochs
|
+--> Repeat until target sparsity reached
|
v
Final pruned modelFor production models where accuracy matters, iterative pruning is almost always worth the extra time.
Pruning is not the only way to make models smaller. Here is how it compares to other popular approaches:
| Method | What It Removes | Accuracy Impact | Implementation Complexity |
|---|---|---|---|
| Head pruning | Redundant attention heads | Low if done carefully | Medium |
| Weight pruning | Individual weights (sparse) | Low | Medium |
| Knowledge distillation | Trains a smaller student model | Low to medium | High |
| Quantization | Reduces precision (float32 to int8) | Very low | Low |
| Layer pruning | Entire transformer layers | Medium to high | Low |
Pruning pairs very well with quantization. You can prune heads first, then quantize the remaining weights, stacking both optimizations for maximum compression.
Here is a complete example of a pruning pipeline using Hugging Face Transformers:
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
# 1. Load model and tokenizer
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# 2. Load dataset
dataset = load_dataset("glue", "sst2")
def tokenize(batch):
return tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=128)
tokenized = dataset.map(tokenize, batched=True)
# 3. Define pruning targets (manually or from importance scores)
heads_to_prune = {
1: [2, 5],
4: [0, 7, 11],
9: [3, 8],
}
# 4. Prune the model
model.prune_heads(heads_to_prune)
# 5. Fine-tune after pruning
training_args = TrainingArguments(
output_dir="./pruned_bert",
num_train_epochs=3,
per_device_train_batch_size=32,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
)
trainer.train()
# 6. Save the pruned model
model.save_pretrained("./pruned_bert_final")
tokenizer.save_pretrained("./pruned_bert_final")After saving, you can load and deploy this model just like any other Hugging Face model. The pruned heads are gone from the weight matrices permanently.
1. Does pruning change the model architecture permanently?
Yes. When you use prune_heads() in Hugging Face, it physically removes rows and columns from the weight matrices. The pruned heads are gone, and the model is smaller on disk and in memory.
2. Can I prune any transformer model, not just BERT?
Yes. The concept applies to any multi-head attention model, including GPT-style models, T5, and others. The implementation details differ per architecture, but the core idea is the same.
3. What is the difference between attention head pruning and layer pruning?
Head pruning removes individual heads within a layer, offering finer control. Layer pruning removes entire transformer blocks, which is more aggressive and typically causes a bigger accuracy drop.
4. Do I always need to fine-tune after pruning?
For small pruning rates (under 10-15%), you might get away without fine-tuning. But for anything above that, fine-tuning after pruning is strongly recommended to recover accuracy.
5. How do I know how many heads to prune?
Start conservative. Prune 10-20%, evaluate on your target task, and check if accuracy is acceptable. Increase gradually. The right number depends on your accuracy-vs-speed tradeoff.
6. Does attention head pruning work on generative models like GPT?
Yes, though the evaluation is trickier because generative models are assessed on perplexity and generation quality rather than classification accuracy. The same importance scoring methods apply.
7. Can pruning actually improve accuracy?
Sometimes, yes. If certain heads are learning noise or conflicting patterns, removing them can reduce interference and slightly improve performance. This is rare but documented in research.
8. Is attention head pruning the same as sparse attention?
No. Sparse attention limits which tokens each head attends to (a runtime operation). Head pruning removes entire heads from the model permanently (a structural change).
9. Can I combine pruning with quantization?
Absolutely, and it is a common production strategy. Prune first to reduce the number of heads, then quantize the remaining weights to int8 or int4. You get multiplicative compression benefits.
10. What tools support attention head pruning out of the box?
Hugging Face Transformers has built-in prune_heads() support. The nn_pruning library by Hugging Face also supports structured pruning with learned masks. PyTorch's torch.nn.utils.prune module supports magnitude-based pruning at the weight level.
Michel, P., Levy, O., and Neubig, G. (2019). Are Sixteen Heads Really Better than One? - https://arxiv.org/abs/1905.10650
Voita, E., Talbot, D., Moiseev, F., Sennrich, R., and Titov, I. (2019). Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned - https://arxiv.org/abs/1905.09418
Sanh, V., Wolf, T., and Rush, A. (2020). Movement Pruning: Adaptive Sparsity by Fine-Tuning - https://arxiv.org/abs/2005.07683
Lagunas, F., Charlaix, E., Sanh, V., and Rush, A. (2021). Block Pruning For Faster Transformers - https://arxiv.org/abs/2109.04838
Tags
Explore how AI reward hacking has evolved into alignment faking, a more dangerous behavior where AI models pretend to be safe while hiding misaligned goals. Understand the risks, research findings, and what researchers are doing about it.

Learn how brands in 2026 are using AI to deliver personalized content at scale, one-to-one user experiences, and privacy-first strategies that build trust instead of breaking it.

Explore whether AI can be genuinely creative or just a pattern-matching machine. This post breaks down what creativity really means, how AI generates "new" ideas, and where the line between human and machine imagination actually lies.

A practical guide to GraphRAG and classical vector search. Learn how Entity-Relation Fusion works, when to use each approach, and how to decide which retrieval strategy fits your AI application.

A clear, beginner-friendly guide to mixed-criticality systems in physical AI and robotics: what they are, why they matter, the real engineering challenges, and how the industry is solving them today.
