FedPFT: Federated Proxy Fine-Tuning of Foundation Models
๐ Abstract
The article introduces Federated Proxy Fine-Tuning (FedPFT), a novel method for enhancing the adaptation of Foundation Models (FMs) for downstream tasks through Federated Learning (FL). FedPFT addresses the challenges of insufficient FM fine-tuning and gradient error accumulation by:
- Constructing sub-FMs through layer-wise compression based on neuron saliency in the Feed-Forward Network (FFN) of transformer layers.
- Aligning sub-FMs with FMs via a two-step distillation process - layer-level before FL fine-tuning and neuron-level during FL fine-tuning.
The proposed approach aims to achieve optimal downstream task adaptation of FMs without directly sharing the server FM or client data.
๐ Q&A
[01] Sub-FM Construction Module
1. How does the sub-FM construction module work? The sub-FM construction module performs layer-wise compression on the FM by measuring the saliency of neurons in the Feed-Forward Network (FFN) of each transformer layer. It systematically removes neurons with low saliency at a fixed ratio to construct the sub-FM, which serves as a proxy for the full FM.
2. What is the rationale behind compressing the FFN rather than the Multi-Head Attention (MHA) layer? The authors choose to compress the FFN rather than the MHA because the majority of the parameters in a transformer layer are contained in the FFN. By compressing the FFN, the authors can minimize the parameters in the sub-FM while ensuring a consistent set of trainable parameters (i.e., MHA) between the FM and its sub-FM at each layer.
[02] Sub-FM Alignment Module
1. Why is it important to align the sub-FM with the full FM? The authors provide theoretical analysis (Theorem 1) showing that a significant disparity between the gradients of the sub-FM and the full FM can impede the convergence of fine-tuning methods using a proxy sub-model. Therefore, aligning the sub-FM with the full FM is crucial to ensure the convergence of the fine-tuned FM.
2. How does the two-step distillation process work? The sub-FM alignment module conducts a two-step distillation process:
- Layer-level distillation before FL fine-tuning: This distillation leverages the outputs from all layers to compute the distillation loss and introduces a regularization term to quantify the disparity between the weights of the FFN and sub-FFN in each layer.
- Neuron-level distillation during FL fine-tuning: This distillation selectively updates a subset of neurons with low saliency in the local fine-tuning to prevent the risk of sub-FM forgetting knowledge of local data.
3. What is the theoretical guarantee provided for the sub-FM alignment? The authors provide Theorem 2, which shows that shrinking the error of gradients can be achieved by narrowing the difference in output and weights between the sub-FM and the full FM.
[03] Experimental Results
1. How do the results compare FedPFT with the baseline methods? The experimental results show that FedPFT consistently outperforms the baseline methods (FedOT and FedPETuning) across various datasets and models (BERT, RoBERTa, and ViT). FedPFT achieves performance closer to FedPETuning, which fine-tunes the full model, while FedOT exhibits a substantial performance gap.
2. How does FedPFT perform under data heterogeneity scenarios? The authors evaluate the performance of FedPFT, FedOT, and FedPETuning under different non-IID data distribution scenarios (Dir-1.0, Dir-5.0, Dir-10.0). The results show that the performance of all methods declines as the degree of non-IID increases, but FedPFT still outperforms FedOT and achieves competitive performance closer to FedPETuning.
3. What are the key findings from the ablation study? The ablation study demonstrates the importance of both the sub-FM construction module and the sub-FM alignment module in FedPFT. Variants of FedPFT that lack either of these components (FedPFT_N and FedPFT_D) exhibit notably poorer performance compared to the full FedPFT method.
[04] Computational and Communication Cost Analysis
1. How does FedPFT reduce the computational cost compared to the full model? The authors provide a theoretical analysis of the computational cost, showing that FedPFT can reduce almost half the computational cost of all clients compared to the full model, as the sub-FM has fewer parameters to process.
2. How does FedPFT reduce the communication cost compared to the full model? The authors analyze the communication cost, showing that if PEFT methods are not used, FedPFT can shrink nearly half of the communication cost compared to the full model. If PEFT methods are used (e.g., Lora), the communication cost is further reduced, as only a small subset of parameters needs to be transmitted during the alignment process.
</output_format>