Explainable AI for Image Classification using Twin System and Grad-CAM
Introduction
Explainable Artificial Intelligence (XAI) is essential for making ML models more interpretable and trustworthy, particularly in opaque or high-stakes domains. This project applied two complementary post-hoc explanation methods to a binary classification task , distinguishing between real and AI-generated (fake) cat images.
The goal:
- Train a high-performing ResNet-18-based image classifier.
- Explain its predictions using visual and example-based techniques.
XAI Techniques Used
Grad-CAM (Gradient-weighted Class Activation Mapping)
Grad-CAM generates heatmaps showing which regions of an image influence the model’s prediction. It works by backpropagating gradients to the final convolutional layer.
Purpose: Visual explanation of what the model is attending to.
Twin System (Embedding Similarity via Case-Based Reasoning)
This method explains predictions by retrieving visually similar images from the training set. Embeddings are extracted from the penultimate layer of ResNet-18, and cosine similarity is used to find the top matches.
Purpose: Intuitive justification by referencing known cases.
Inspired by: This Looks Like That (2018)
Dataset
- Total Images: 300
- 150 Real cat images from public datasets
- 150 Fake cat images from
google/ddpm-cat-256
- Preprocessing: Resized to 224x224, normalized (mean=0.5, std=0.5)
- Split:
- Train: 100 real + 100 fake
- Validation: 50 real + 50 fake
Model Architecture
- Base Model: Pretrained ResNet-18
- Final Layer: Modified for 2-class output
- Training Setup:
- Optimizer: Adam (lr=1e-4)
- Loss: CrossEntropyLoss
- Epochs: 10
- Batch Size: 32
Final Validation Accuracy: 91%
Evaluation Metrics
Metric | Value |
---|---|
Accuracy | 91% |
Precision (Real) | 0.94 |
Recall (Real) | 0.88 |
Precision (Fake) | 0.89 |
Recall (Fake) | 0.94 |
F1 Score (Overall) | 0.91 |
Grad-CAM Results
Sample saliency visualizations show which parts of the input image the model focused on:
Key insights: model focuses on fur texture, eyes, and facial shape for classification.
Twin System Results
Shows most similar training samples (same predicted class) based on ResNet embeddings:
Misclassification Analysis
Error Type | Sample IDs |
---|---|
Real → Fake (FN) | 13, 18, 22, 34, 40, 44 |
Fake → Real (FP) | 57, 77, 80 |
Grad-CAM and Twin visualizations revealed blur and atypical poses as key contributors to misclassification.
Conclusion
This project combined two explainability approaches to enhance understanding of model behavior.
Method | Explanation Type | Contribution |
---|---|---|
Grad-CAM | Visual (pixel) | Shows where the model looks |
Twin System | Example-based | Shows why via similar cases |
Multi-view XAI builds trust and insight into deep learning models.
Future Work
- Add counterfactual examples (nearest from opposite class)
- Use CLIP embeddings for better semantic similarity
- Improve Twin system via ProtoPNet architecture
ProtoPNet Attempt
- Backbone: ResNet-18
- Added 10 learnable prototypes per class
- Goal: Learn and match local image regions
Validation Accuracy: 50%
Problem: Overfit to “real” class due to prototype imbalance
Learned Prototypes:
Although accuracy was low, the model successfully:
- Learned and projected prototypes
- Visualized most activating examples
- Demonstrated potential for local-region interpretability