In the rapidly evolving world of AI, staying ahead with cutting-edge tools and techniques is vital. Recently, I completed a successful model training session with Unsloth, a library designed for faster, more efficient model fine-tuning. My goal was to enhance the understanding and generation of PHP class explanations within the Chameleon CMS framework using the Gemma-2-9b model. Here’s a step-by-step recount of how I leveraged Unsloth to achieve fast, accurate results while keeping memory usage optimal
Setting Up Unsloth for Model Fine-Tuning
I began by installing the necessary packages for Unsloth and Flash Attention 2, a library crucial for softcapping support that makes attention-based models more efficient. Here’s the setup I used:
!pip install unsloth !pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git" !pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"
Loading the Model
I opted for the Gemma-2-9b model, specifically designed for efficiency with 4-bit quantization to reduce memory usage by up to 4x. Quantization allows models to maintain performance while drastically cutting down on computational and memory resources.
model, tokenizer = FastLanguageModel.from_pretrained( model_name="unsloth/gemma-2-9b", max_seq_length=2048, load_in_4bit=True )
Fine-Tuning with LoRA Adapters
To ensure only a fraction of parameters needed updating during fine-tuning, I used LoRA (Low-Rank Adaptation) adapters. This method allowed me to efficiently fine-tune the model while only adjusting 1-10% of the parameters.
model = FastLanguageModel.get_peft_model( model, r=16, lora_alpha=16, lora_dropout=0, use_gradient_checkpointing="unsloth" )
Preparing the Dataset
For training, I used a dataset that consisted of explanations for various PHP classes within Chameleon CMS. Each entry in the dataset was formatted into question-answer pairs, where the input was a request to explain a class and the output was the corresponding explanation.
formatted_new_data = { 'question': [f"Explain the class {entry['class_name']}" for entry in new_data], 'answer': [entry['explanation'] for entry in new_data] }
The dataset was then transformed into a format suitable for training, and I ensured the question-answer pairs were mapped correctly.
formatted_dataset = new_dataset.map(formatting_prompts_func, batched=True)
Training the Model
With the dataset ready, I used Huggingface’s SFTTrainer
to fine-tune the model. The process was streamlined by the fast, lightweight nature of Unsloth, allowing me to complete the training in under 10 minutes with limited GPU memory usage.
trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=formatted_dataset, max_seq_length=2048, args=TrainingArguments( per_device_train_batch_size=2, gradient_accumulation_steps=4, max_steps=60, learning_rate=2e-4, fp16=True, optim="adamw_8bit", output_dir="outputs" ) )
Results
The training was completed in just over 9 minutes, with a peak memory usage of 9.21 GB, leaving more than 30% of the GPU’s total memory available for other processes. The model achieved rapid convergence with a final training loss of 0.0438 after 60 steps.
Here are some key performance statistics:
- Total Time: 9.01 minutes
- Peak Memory Usage: 9.213 GB
- Training Loss: 0.0438
Demo: Generating PHP Class Explanations
Once trained, I tested the model by asking it to explain specific PHP classes. For instance, here’s how the model explained the MTFeedbackErrors
class, which handles error management for form fields:
instruction = "Explain the MTFeedbackErrors class. I don't want to see code, I want only the explanation." inputs = tokenizer([instruction], return_tensors="pt").to("cuda") outputs = model.generate(**inputs, max_new_tokens=256) generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True) print(generated_text[0])
The result was an insightful, detailed explanation:
The `MTFeedbackErrors` class is designed to manage error handling for form fields. It provides methods to add, retrieve, and check for errors associated with specific field names. Key functions include `AddError`, `FieldHasErrors`, and `GetFieldErrors`, which streamline error management and improve user experience by providing clear feedback mechanisms for form validation.
Using Unsloth, I was able to fine-tune a large model on a relatively small dataset with incredible speed and efficiency. The 4-bit quantization, combined with LoRA adapters, allowed me to save memory while still achieving high-quality results. If you’re working with models on limited resources, I highly recommend giving Unsloth a try for your next project!
Views: 19