Accelerating Model Training with Unsloth: My Chameleon CMS AI Journey

In the rapidly evolving world of AI, staying ahead with cutting-edge tools and techniques is vital. Recently, I completed a successful model training session with Unsloth, a library designed for faster, more efficient model fine-tuning. My goal was to enhance the understanding and generation of PHP class explanations within the Chameleon CMS framework using the Gemma-2-9b model. Here’s a step-by-step recount of how I leveraged Unsloth to achieve fast, accurate results while keeping memory usage optimal

Setting Up Unsloth for Model Fine-Tuning

I began by installing the necessary packages for Unsloth and Flash Attention 2, a library crucial for softcapping support that makes attention-based models more efficient. Here’s the setup I used:

!pip install unsloth
!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps packaging ninja einops "flash-attn>=2.6.3"

Loading the Model

I opted for the Gemma-2-9b model, specifically designed for efficiency with 4-bit quantization to reduce memory usage by up to 4x. Quantization allows models to maintain performance while drastically cutting down on computational and memory resources.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/gemma-2-9b",
    max_seq_length=2048,
    load_in_4bit=True
)

Fine-Tuning with LoRA Adapters

To ensure only a fraction of parameters needed updating during fine-tuning, I used LoRA (Low-Rank Adaptation) adapters. This method allowed me to efficiently fine-tune the model while only adjusting 1-10% of the parameters.

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0,
    use_gradient_checkpointing="unsloth"
)

Preparing the Dataset

For training, I used a dataset that consisted of explanations for various PHP classes within Chameleon CMS. Each entry in the dataset was formatted into question-answer pairs, where the input was a request to explain a class and the output was the corresponding explanation.

formatted_new_data = {
    'question': [f"Explain the class {entry['class_name']}" for entry in new_data],
    'answer': [entry['explanation'] for entry in new_data]
}

The dataset was then transformed into a format suitable for training, and I ensured the question-answer pairs were mapped correctly.

formatted_dataset = new_dataset.map(formatting_prompts_func, batched=True)

Training the Model

With the dataset ready, I used Huggingface’s SFTTrainer to fine-tune the model. The process was streamlined by the fast, lightweight nature of Unsloth, allowing me to complete the training in under 10 minutes with limited GPU memory usage.

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=formatted_dataset,
    max_seq_length=2048,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        max_steps=60,
        learning_rate=2e-4,
        fp16=True,
        optim="adamw_8bit",
        output_dir="outputs"
    )
)

Results

The training was completed in just over 9 minutes, with a peak memory usage of 9.21 GB, leaving more than 30% of the GPU’s total memory available for other processes. The model achieved rapid convergence with a final training loss of 0.0438 after 60 steps.

Here are some key performance statistics:

  • Total Time: 9.01 minutes
  • Peak Memory Usage: 9.213 GB
  • Training Loss: 0.0438

Demo: Generating PHP Class Explanations

Once trained, I tested the model by asking it to explain specific PHP classes. For instance, here’s how the model explained the MTFeedbackErrors class, which handles error management for form fields:

instruction = "Explain the MTFeedbackErrors class. I don't want to see code, I want only the explanation."
inputs = tokenizer([instruction], return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=256)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(generated_text[0])

The result was an insightful, detailed explanation:

The `MTFeedbackErrors` class is designed to manage error handling for form fields. It provides methods to add, retrieve, and check for errors associated with specific field names. Key functions include `AddError`, `FieldHasErrors`, and `GetFieldErrors`, which streamline error management and improve user experience by providing clear feedback mechanisms for form validation.

Using Unsloth, I was able to fine-tune a large model on a relatively small dataset with incredible speed and efficiency. The 4-bit quantization, combined with LoRA adapters, allowed me to save memory while still achieving high-quality results. If you’re working with models on limited resources, I highly recommend giving Unsloth a try for your next project!

Views: 19