CLEFT: Language-Image Contrastive Learning with Efficient Large Language Model and Prompt Fine-Tuning
๐ Abstract
The paper introduces a novel language-image Contrastive Learning method called CLEFT that leverages pre-trained large language models (LLMs) and visual models to address the challenges of limited medical data and constrained GPU resources. Key aspects include:
- Efficiently incorporating an LLM as the text encoder in the CLIP framework, using parameter-efficient fine-tuning (PEFT) techniques to reduce the number of trainable parameters.
- Treating the CLIP training as a knowledge distillation process, where the LLM serves as a teacher model to guide the visual encoder.
- Learning context-based prompts to mitigate the bias introduced by manually crafted prompts.
The proposed CLEFT method demonstrates state-of-the-art performance on multiple chest X-ray and mammography datasets compared to various baselines, while being parameter-efficient.
๐ Q&A
[01] Boosting CLIP with an LLM
1. What are the key components of the CLEFT framework for boosting CLIP with an LLM?
- The CLEFT framework incorporates a pre-trained medical LLM (GPT-2 based) as the text encoder in the CLIP framework.
- It uses parameter-efficient fine-tuning (PEFT) techniques like LoRA, IA3, and prefix fine-tuning to efficiently adapt the LLM to the CLIP task, reducing the number of trainable parameters.
- It treats the CLIP training as a knowledge distillation process, where the pre-trained LLM serves as a teacher model to guide the training of the visual encoder.
2. How does CLEFT address the challenges of limited medical data and constrained GPU resources?
- By leveraging the strengths of pre-trained LLMs and visual models, CLEFT can counterbalance the scarcity of medical data.
- The parameter-efficient fine-tuning of the LLM conserves GPU resources without compromising the knowledge acquired from natural language.
3. What are the benefits of using a causal LLM (GPT-2) versus a BERT-based LLM as the text encoder? The paper states that the causal LLM has shown better capability as it scales up to over a billion parameters, allowing the model to embed the input into a more robust feature space with less training.
[02] Learning the Context-Based Prompt
1. What is the motivation behind learning context-based prompts in the second stage of CLEFT? The paper argues that the common approach of manually crafting textual prompts for CLIP can lead to a lack of diversity in the text training prompts, which can result in the catastrophic forgetting phenomenon in the text encoder and limit the model's performance. Learning context-based prompts aims to mitigate this issue.
2. How does the prompt context learning stage work in CLEFT? After the initial pre-training stage, CLEFT freezes both the text and visual encoders. It then replaces the original hand-crafted prompt with a series of trainable context tokens that are fed into the language model. These context tokens are optimized using a zero-shot classification cross-entropy loss, allowing them to adapt to different classes evenly and avoid potential shortcut issues.
3. How are the context-based prompt tokens initialized in CLEFT? If the prompt length is longer than the original hand-crafted caption, the first few tokens are initialized according to the original caption, and the remaining tokens are initialized with random uniform distribution. This initialization strategy aims to leverage the information from the original prompt while allowing the model to learn new context-based prompts.
[03] Experimental Results
1. How does CLEFT perform compared to the baseline methods in the zero-shot, linear probing, and full fine-tuning settings?
- In zero-shot classification, CLEFT with LoRA outperforms other baselines by 7% on the CheXpert-5x200 dataset.
- In linear probing, CLEFT with LoRA achieves the best performance on both the CheXpert-5x200 and RSNA datasets, with a 5% gap on CheXpert-5x200 compared to the baselines.
- In full fine-tuning, CLEFT outperforms all other baselines on both the CheXpert and RSNA datasets.
2. How does CLEFT perform on the mammography dataset compared to the baselines? CLEFT with LoRA clearly surpasses the compared baselines on the EMBED mammography dataset, suggesting the proposed model has the potential to be applied to other medical domains beyond chest X-rays.
3. What are the key findings from the ablation experiments conducted on CLEFT?
- Removing the prompt context learning stage leads to a 3% drop in accuracy on both datasets.
- Using a fully frozen language model harms performance on in-domain data but improves performance on out-of-domain data.
- Fully fine-tuning the model improves performance but comes at the cost of higher GPU memory usage and longer training time.
- The length of the prompt context does not always positively correlate with performance improvement.