File size: 8,404 Bytes
d77d3bf
 
 
 
 
 
 
331e226
 
 
d77d3bf
 
 
331e226
 
 
d77d3bf
 
331e226
d77d3bf
 
 
331e226
 
 
 
d77d3bf
 
 
 
331e226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b9a558
331e226
 
4b9a558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331e226
 
 
4b9a558
331e226
 
4b9a558
331e226
 
 
4b9a558
331e226
 
4b9a558
 
 
 
 
 
 
 
 
 
 
 
 
331e226
4b9a558
331e226
4b9a558
331e226
 
4b9a558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331e226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
---
base_model: unsloth/gemma-2-2b-it-bnb-4bit
tags:
- text-generation-inference
- transformers
- unsloth
- gemma2
- text-to-sql
- qlora
- sql-generation
license: apache-2.0
language:
- en
datasets:
- gretelai/synthetic_text_to_sql
pipeline_tag: text-generation
---

# Gemma-2-2B Text-to-SQL QLoRA Fine-tuned Model

- **Developed by:** rajaykumar12959
- **License:** apache-2.0
- **Finetuned from model:** unsloth/gemma-2-2b-it-bnb-4bit
- **Dataset:** gretelai/synthetic_text_to_sql
- **Task:** Text-to-SQL Generation
- **Fine-tuning Method:** QLoRA (Quantized Low-Rank Adaptation)

This gemma2 model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.

[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)

## Model Description

This model is specifically fine-tuned to generate SQL queries from natural language questions and database schemas. It excels at handling complex multi-table queries requiring JOINs, aggregations, filtering, and advanced SQL operations.

### Key Features

- βœ… **Multi-table JOINs** (INNER, LEFT, RIGHT)
- βœ… **Aggregation functions** (SUM, COUNT, AVG, MIN, MAX)
- βœ… **GROUP BY and HAVING clauses**
- βœ… **Complex WHERE conditions**
- βœ… **Subqueries and CTEs**
- βœ… **Date/time operations**
- βœ… **String functions and pattern matching**

## Training Configuration

The model was fine-tuned using QLoRA with the following configuration:

```python
# LoRA Configuration
r = 16  # Rank: 16 is a good balance for 2B models
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lora_alpha = 16
lora_dropout = 0
bias = "none"
use_gradient_checkpointing = "unsloth"

# Training Parameters
max_seq_length = 2048
per_device_train_batch_size = 2
gradient_accumulation_steps = 4  # Effective batch size = 8
warmup_steps = 5
max_steps = 100  # Demo configuration - increase to 300+ for production
learning_rate = 2e-4
optim = "adamw_8bit"  # 8-bit optimizer for memory efficiency
weight_decay = 0.01
lr_scheduler_type = "linear"
```

## Installation

```bash
pip install unsloth transformers torch trl datasets
```

## Usage

### Loading the Model

```python
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "rajaykumar12959/gemma-2-2b-text-to-sql-qlora",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

FastLanguageModel.for_inference(model)  # Enable faster inference
```

### Inference Function

```python
def inference_text_to_sql(model, tokenizer, schema, question, max_new_tokens=300):
    """
    Perform inference to generate SQL from natural language question and database schema.
    
    Args:
        model: Fine-tuned Gemma model
        tokenizer: Model tokenizer
        schema: Database schema as string
        question: Natural language question
        max_new_tokens: Maximum tokens to generate
    
    Returns:
        Generated SQL query as string
    """
    # Format the input prompt
    input_prompt = f"""<start_of_turn>user
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

### Schema:
{schema}

### Question:
{question}<end_of_turn>
<start_of_turn>model
"""
    
    # Tokenize input
    inputs = tokenizer([input_prompt], return_tensors="pt").to("cuda")
    
    # Generate output
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            use_cache=True,
            do_sample=True,
            temperature=0.1,  # Low temperature for more deterministic output
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode and clean the result
    result = tokenizer.batch_decode(outputs)[0]
    sql_query = result.split("<start_of_turn>model")[-1].replace("<end_of_turn>", "").strip()
    
    return sql_query
```

### Example Usage

#### Example 1: Simple Single-Table Query

```python
# Simple employee database
simple_schema = """
CREATE TABLE employees (
    employee_id INT PRIMARY KEY,
    name TEXT,
    department TEXT,
    salary DECIMAL,
    hire_date DATE
);
"""

simple_question = "Find all employees in the 'Engineering' department with salary greater than 75000"

sql_result = inference_text_to_sql(model, tokenizer, simple_schema, simple_question)
print(f"Generated SQL:\n{sql_result}")
```

**Expected Output:**
```sql
SELECT * FROM employees 
WHERE department = 'Engineering' 
AND salary > 75000;
```

## Training Details

### Dataset
- **Source:** gretelai/synthetic_text_to_sql
- **Size:** 100,000 synthetic text-to-SQL examples
- **Columns used:** 
  - `sql_context`: Database schema
  - `sql_prompt`: Natural language question
  - `sql`: Target SQL query

### Training Process
The model uses a custom formatting function to structure the training data:

```python
def formatting_prompts_func(examples):
    schemas   = examples["sql_context"]
    questions = examples["sql_prompt"] 
    outputs   = examples["sql"]

    texts = []
    for schema, question, output in zip(schemas, questions, outputs):
        text = gemma_prompt.format(schema, question, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }
```

### Hardware Requirements
- **GPU:** Single GPU with 8GB+ VRAM
- **Training Time:** ~30 minutes for 100 steps
- **Memory Optimization:** 4-bit quantization + 8-bit optimizer

## Performance Characteristics

### Strengths
- Excellent performance on multi-table JOINs
- Accurate aggregation and GROUP BY operations
- Proper handling of foreign key relationships
- Good understanding of filtering logic (WHERE/HAVING)

### Model Capabilities Test
The model was tested on a complex 4-table JOIN query requiring:
1. **Multi-table JOINs** (users β†’ orders β†’ order_items β†’ products)
2. **Category filtering** (WHERE p.category = 'Electronics')
3. **User grouping** (GROUP BY user fields)
4. **Aggregation** (SUM of price Γ— quantity)
5. **Aggregate filtering** (HAVING total > 500)

## Limitations

- **Training Scale:** Trained with only 100 steps for demonstration. For production use, increase `max_steps` to 300+
- **Context Length:** Limited to 2048 tokens maximum sequence length
- **SQL Dialects:** Primarily trained on standard SQL syntax
- **Complex Subqueries:** May require additional fine-tuning for highly complex nested queries

## Reproduction

To reproduce this training:

1. **Clone the notebook:** Use the provided `Fine_tune_qlora.ipynb`
2. **Install dependencies:** 
   ```bash
   pip install unsloth transformers torch trl datasets
   ```
3. **Configure training:** Adjust `max_steps` in TrainingArguments for longer training
4. **Run training:** Execute all cells in the notebook

### Production Training Recommendations
```python
# For production use, update these parameters:
max_steps = 300,  # Increase from 100
warmup_steps = 10,  # Increase warmup
per_device_train_batch_size = 4,  # If you have more GPU memory
```

## Model Card

| Parameter | Value |
|-----------|--------|
| Base Model | Gemma-2-2B (4-bit quantized) |
| Fine-tuning Method | QLoRA |
| LoRA Rank | 16 |
| Training Steps | 100 (demo) |
| Learning Rate | 2e-4 |
| Batch Size | 8 (effective) |
| Max Sequence Length | 2048 |
| Dataset Size | 100k examples |

## Citation

```bibtex
@misc{gemma-2-2b-text-to-sql-qlora,
  author = {rajaykumar12959},
  title = {Gemma-2-2B Text-to-SQL QLoRA Fine-tuned Model},
  year = {2024},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/rajaykumar12959/gemma-2-2b-text-to-sql-qlora}},
}
```

## Acknowledgments

- **Base Model:** Google's Gemma-2-2B via Unsloth optimization
- **Dataset:** Gretel AI's synthetic text-to-SQL dataset
- **Framework:** Unsloth for efficient fine-tuning and TRL for training
- **Method:** QLoRA for parameter-efficient training

## License

This model is licensed under Apache 2.0. See the LICENSE file for details.

---

*This model is intended for research and educational purposes. Please ensure compliance with your organization's data and AI usage policies when using in production environments.*