| | import os |
| | import numpy as np |
| | from sentence_transformers import SentenceTransformer |
| | import tensorflow as tf |
| | from typing import List, Tuple, Dict, Optional, Union, Any |
| | import math |
| | from dataclasses import dataclass |
| | import json |
| | from pathlib import Path |
| | import datetime |
| | import faiss |
| | import gc |
| | import re |
| | from response_quality_checker import ResponseQualityChecker |
| | from cross_encoder_reranker import CrossEncoderReranker |
| | from conversation_summarizer import DeviceAwareModel, Summarizer |
| | from chatbot_config import ChatbotConfig |
| | from tf_data_pipeline import TFDataPipeline |
| | import absl.logging |
| | from logger_config import config_logger |
| | from tqdm.auto import tqdm |
| |
|
| | absl.logging.set_verbosity(absl.logging.WARNING) |
| | logger = config_logger(__name__) |
| | logger.setLevel("WARNING") |
| | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
| | tqdm(disable=True) |
| | |
| | class RetrievalChatbot(DeviceAwareModel): |
| | """ |
| | Retrieval-based learning chatbot model. |
| | Uses trained embeddings and FAISS for similarity search. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | config: ChatbotConfig, |
| | device: str = None, |
| | strategy=None, |
| | reranker: Optional[CrossEncoderReranker] = None, |
| | summarizer: Optional[Summarizer] = None, |
| | mode: str = 'training' |
| | ): |
| | |
| | super().__init__() |
| | self.config = config |
| | self.strategy = strategy |
| | self.device = device or self._setup_default_device() |
| | self.mode = mode.lower() |
| | |
| | |
| | self.encoder = self._initialize_encoder() |
| | self.tokenizer = self.encoder.tokenizer |
| | self.reranker = reranker or self._initialize_reranker() |
| | self.summarizer = summarizer or self._initialize_summarizer() |
| | |
| | |
| | logger.info("Initializing TFDataPipeline.") |
| | |
| | self.data_pipeline = TFDataPipeline( |
| | config=self.config, |
| | tokenizer=self.tokenizer, |
| | encoder=self.encoder, |
| | response_pool=[], |
| | query_embeddings_cache={}, |
| | ) |
| | |
| | |
| | if self.mode == 'inference': |
| | logger.info("Mode set to 'inference'. Loading FAISS index and response pool.") |
| | self._load_faiss_index_and_responses() |
| | elif self.mode != 'training': |
| | logger.error(f"Unsupported mode in RetrievalChatbot init: {self.mode}") |
| | raise ValueError(f"Unsupported mode in RetrievalChatbot init: {self.mode}") |
| | |
| | |
| | self.history = { |
| | "train_loss": [], |
| | "val_loss": [], |
| | "train_metrics": {}, |
| | "val_metrics": {} |
| | } |
| | |
| | def _setup_default_device(self) -> str: |
| | """Set up default device if none is provided.""" |
| | if tf.config.list_physical_devices('GPU'): |
| | return 'GPU' |
| | else: |
| | return 'CPU' |
| |
|
| | def _initialize_reranker(self) -> CrossEncoderReranker: |
| | """Initialize the CrossEncoderReranker.""" |
| | logger.info("Initializing default CrossEncoderReranker...") |
| | return CrossEncoderReranker(model_name=self.config.cross_encoder_model) |
| |
|
| | def _initialize_summarizer(self) -> Summarizer: |
| | """Initialize the Summarizer.""" |
| | return Summarizer( |
| | tokenizer=self.tokenizer, |
| | model_name=self.config.summarizer_model, |
| | max_summary_length=self.config.max_context_length // 4, |
| | device=self.device, |
| | max_summary_rounds=2 |
| | ) |
| | |
| | def _initialize_encoder(self) -> SentenceTransformer: |
| | """Initialize the Sentence Transformer model.""" |
| | logger.info("Initializing SentenceTransformer encoder model...") |
| | encoder = SentenceTransformer(self.config.pretrained_model) |
| | return encoder |
| | |
| | def _load_faiss_index_and_responses(self) -> None: |
| | """Load FAISS index and response pool for inference.""" |
| | try: |
| | logger.info(f"Loading FAISS index from {self.data_pipeline.faiss_index_file_path}...") |
| | self.data_pipeline.load_faiss_index(self.data_pipeline.faiss_index_file_path) |
| | logger.info("FAISS index loaded successfully.") |
| | |
| | |
| | response_pool_path = self.data_pipeline.faiss_index_file_path.replace('.index', '_responses.json') |
| | if os.path.exists(response_pool_path): |
| | with open(response_pool_path, 'r', encoding='utf-8') as f: |
| | self.data_pipeline.response_pool = json.load(f) |
| | logger.info(f"Loaded {len(self.data_pipeline.response_pool)} responses from {response_pool_path}.") |
| | else: |
| | logger.error(f"Response pool file not found at {response_pool_path}.") |
| | raise FileNotFoundError(f"Response pool file not found at {response_pool_path}.") |
| | |
| | |
| | self.data_pipeline.validate_faiss_index() |
| | logger.info("FAISS index and response pool validated successfully.") |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to load FAISS index and response pool: {e}") |
| | raise |
| | |
| | @classmethod |
| | def load_model(cls, load_dir: Union[str, Path], mode: str = 'training') -> 'RetrievalChatbot': |
| | """Load chatbot model and configuration.""" |
| | load_dir = Path(load_dir) |
| | |
| | |
| | config_path = load_dir / "config.json" |
| | if config_path.exists(): |
| | with open(config_path, "r") as f: |
| | config = ChatbotConfig.from_dict(json.load(f)) |
| | logger.info("Loaded ChatbotConfig from config.json.") |
| | else: |
| | raise FileNotFoundError(f"Config file not found at {config_path}. Please ensure it exists.") |
| | |
| | |
| | chatbot = cls(config, mode=mode) |
| | |
| | |
| | model_path = load_dir / "sentence_transformer" |
| | if model_path.exists(): |
| | |
| | chatbot.encoder = SentenceTransformer(str(model_path)) |
| | logger.info("Loaded SentenceTransformer model from local path successfully.") |
| | else: |
| | |
| | chatbot.encoder = SentenceTransformer(config.pretrained_model) |
| | logger.info(f"Loaded SentenceTransformer model '{config.pretrained_model}' from the hub successfully.") |
| | |
| | return chatbot |
| | |
| | @classmethod |
| | def _prepare_model_for_inference(cls, chatbot: 'RetrievalChatbot', load_dir: Path) -> None: |
| | """Load inference components.""" |
| | try: |
| | |
| | faiss_path = load_dir / 'faiss_indices/faiss_index_production.index' |
| | if faiss_path.exists(): |
| | chatbot.index = faiss.read_index(str(faiss_path)) |
| | logger.info("FAISS index loaded successfully") |
| | else: |
| | raise FileNotFoundError(f"FAISS index not found at {faiss_path}") |
| | |
| | |
| | response_pool_path = load_dir / 'faiss_indices/faiss_index_production_responses.json' |
| | if response_pool_path.exists(): |
| | with open(response_pool_path, 'r') as f: |
| | chatbot.response_pool = json.load(f) |
| | logger.info(f"Loaded {len(chatbot.response_pool)} responses") |
| | else: |
| | raise FileNotFoundError(f"Response pool not found at {response_pool_path}") |
| | |
| | |
| | if chatbot.index.d != chatbot.config.embedding_dim: |
| | raise ValueError( |
| | f"FAISS index dimension {chatbot.index.d} doesn't match " |
| | f"model dimension {chatbot.config.embedding_dim}" |
| | ) |
| | |
| | except Exception as e: |
| | logger.error(f"Error loading inference components: {e}") |
| | raise |
| | |
| | def save_models(self, save_dir: Union[str, Path]): |
| | """Save SentenceTransformer model and config.""" |
| | save_dir = Path(save_dir) |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | with open(save_dir / "config.json", "w") as f: |
| | json.dump(self.config.to_dict(), f, indent=2) |
| | |
| | |
| | self.encoder.save(save_dir / "sentence_transformer") |
| | logger.info(f"Model and config saved to {save_dir}.") |
| | |
| | def retrieve_responses( |
| | self, |
| | query: str, |
| | top_k: int = 10, |
| | reranker: Optional[CrossEncoderReranker] = None, |
| | summarizer: Optional[Summarizer] = None, |
| | summarize_threshold: int = 512, |
| | boost_factor: float = 1.15 |
| | ) -> List[Tuple[str, float]]: |
| | """ |
| | Retrieve top-k responses using FAISS and cross-encoder re-ranking. |
| | Args: |
| | query: The user's input text. |
| | top_k: Number of responses to return. |
| | reranker: Optional reranker for refined scoring. |
| | summarizer: Optional summarizer for long queries. |
| | summarize_threshold: Threshold to summarize long queries. |
| | boost_factor: Factor to boost scores for keyword matches. |
| | Returns: |
| | List of (response_text, final_score). |
| | """ |
| | def sigmoid(x: float) -> float: |
| | return 1 / (1 + np.exp(-x)) |
| | |
| | |
| | if summarizer and len(query.split()) > summarize_threshold: |
| | logger.info(f"Query is long ({len(query.split())} words). Summarizing...") |
| | query = summarizer.summarize_text(query) |
| | logger.info(f"Summarized query: {query}") |
| | |
| | |
| | detected_domain = self.detect_domain_from_query(query) |
| | |
| | |
| | |
| | |
| | faiss_candidates = self.data_pipeline.retrieve_responses(query, top_k=top_k * 10) |
| | |
| | if not faiss_candidates: |
| | logger.warning("No candidates retrieved from FAISS.") |
| | return [] |
| | |
| | |
| | if detected_domain != 'other': |
| | in_domain_candidates = [c for c in faiss_candidates if c[0]["domain"] == detected_domain] |
| | if in_domain_candidates: |
| | faiss_candidates = in_domain_candidates |
| | else: |
| | logger.info(f"No in-domain responses found for '{query}'. Using all candidates.") |
| | |
| | |
| | |
| | texts = [item[0]["text"] for item in faiss_candidates] |
| | faiss_scores = [item[1] for item in faiss_candidates] |
| | |
| | if reranker is None: |
| | reranker = CrossEncoderReranker(model_name=self.config.cross_encoder_model) |
| | |
| | ce_logits = reranker.rerank(query, texts, max_length=256) |
| | |
| | |
| | final_candidates = [] |
| | for resp_text, faiss_score, logit in zip(texts, faiss_scores, ce_logits): |
| | ce_prob = sigmoid(logit) |
| | faiss_norm = (faiss_score + 1) / 2 |
| | combined_score = 0.75 * ce_prob + 0.25 * faiss_norm |
| | |
| | |
| | query_keywords = self.extract_keywords(query) |
| | if query_keywords and any(kw in resp_text.lower() for kw in query_keywords): |
| | combined_score *= boost_factor |
| | |
| | |
| | length_adjusted_score = self.length_adjust_score(resp_text, combined_score) |
| | |
| | final_candidates.append((resp_text, length_adjusted_score)) |
| | |
| | |
| | final_candidates.sort(key=lambda x: x[1], reverse=True) |
| | |
| | |
| | return final_candidates[:top_k] |
| | |
| | def extract_keywords(self, query: str) -> List[str]: |
| | """ |
| | Return any domain keywords present in the query (lowercased). |
| | """ |
| | domain_keywords = { |
| | 'restaurant': ['restaurant', 'dining', 'food', 'dine', 'reservation', 'table', 'menu', 'cuisine', 'eat', 'place to eat', 'hungry', 'chef', 'dish', 'meal', 'brunch', 'bistro', 'buffet', 'catering', 'gourmet', 'fast food', 'fine dining', 'takeaway', 'delivery', 'restaurant booking'], |
| | 'movie': ['movie', 'cinema', 'film', 'ticket', 'showtime', 'showing', 'theater', 'flick', 'screening', 'film ticket', 'film show', 'blockbuster', 'premiere', 'trailer', 'director', 'actor', 'actress', 'plot', 'genre', 'screen', 'sequel', 'animation', 'documentary'], |
| | 'ride_share': ['ride', 'taxi', 'uber', 'lyft', 'car service', 'pickup', 'dropoff', 'driver', 'cab', 'hailing', 'rideshare', 'ride hailing', 'carpool', 'chauffeur', 'transit', 'transportation', 'hail ride'], |
| | 'coffee': ['coffee', 'café', 'cafe', 'starbucks', 'espresso', 'latte', 'mocha', 'americano', 'barista', 'brew', 'cappuccino', 'macchiato', 'iced coffee', 'cold brew', 'espresso machine', 'coffee shop', 'tea', 'chai', 'java', 'bean', 'roast', 'decaf'], |
| | 'pizza': ['pizza', 'delivery', 'order food', 'pepperoni', 'topping', 'pizzeria', 'slice', 'pie', 'margherita', 'deep dish', 'thin crust', 'cheese', 'oven', 'tossed', 'sauce', 'garlic bread', 'calzone'], |
| | 'auto': ['car', 'vehicle', 'repair', 'maintenance', 'mechanic', 'oil change', 'garage', 'auto shop', 'tire', 'check engine', 'battery', 'transmission', 'brake', 'engine diagnostics', 'carwash', 'detail', 'alignment', 'exhaust', 'spark plug', 'dashboard'], |
| | } |
| | |
| | query_lower = query.lower() |
| | found = set() |
| | for domain, kw_list in domain_keywords.items(): |
| | for kw in kw_list: |
| | if kw in query_lower: |
| | found.add(kw) |
| | return list(found) |
| | |
| | def length_adjust_score(self, text: str, base_score: float) -> float: |
| | """ |
| | Penalize very short lines, reward longer lines. |
| | """ |
| | words = text.split() |
| | wcount = len(words) |
| | |
| | |
| | if wcount < 4: |
| | return base_score * 0.8 |
| | |
| | |
| | if wcount > 15: |
| | bonus = min(0.03, 0.001 * (wcount - 15)) |
| | base_score += bonus |
| | |
| | return base_score |
| | |
| | def detect_domain_from_query(self, query: str) -> str: |
| | """ |
| | Detect the domain of the query based on keywords. Used for filtering FAISS search. |
| | """ |
| | domain_patterns = { |
| | 'restaurant': r'\b(restaurant|restaurants?|dining|food|foods?|dine|reservation|reservations?|table|tables?|menu|menus?|cuisine|cuisines?|eat|eats?|place\s?to\s?eat|places\s?to\s?eat|hungry|chef|chefs?|dish|dishes?|meal|meals?|fork|forks?|knife|knives?|spoon|spoons?|brunch|bistro|buffet|buffets?|catering|caterings?|gourmet|fast\s?food|fine\s?dining|takeaway|takeaways?|delivery|deliveries|restaurant\s?booking)\b', |
| | 'movie': r'\b(movie|movies?|cinema|cinemas?|film|films?|ticket|tickets?|showtime|showtimes?|showing|showings?|theater|theaters?|flick|flicks?|screening|screenings?|film\s?ticket|film\s?tickets?|film\s?show|film\s?shows?|blockbuster|blockbusters?|premiere|premieres?|trailer|trailers?|director|directors?|actor|actors?|actress|actresses?|plot|plots?|genre|genres?|screen|screens?|sequel|sequels?|animation|animations?|documentary|documentaries)\b', |
| | 'ride_share': r'\b(ride|rides?|taxi|taxis?|uber|lyft|car\s?service|car\s?services?|pickup|pickups?|dropoff|dropoffs?|driver|drivers?|cab|cabs?|hailing|hailings?|rideshare|rideshares?|ride\s?hailing|ride\s?hailings?|carpool|carpools?|chauffeur|chauffeurs?|transit|transits?|transportation|transportations?|hail\s?ride|hail\s?rides?)\b', |
| | 'coffee': r'\b(coffee|coffees?|café|cafés?|cafe|cafes?|starbucks|espresso|espressos?|latte|lattes?|mocha|mochas?|americano|americanos?|barista|baristas?|brew|brews?|cappuccino|cappuccinos?|macchiato|macchiatos?|iced\s?coffee|iced\s?coffees?|cold\s?brew|cold\s?brews?|espresso\s?machine|espresso\s?machines?|coffee\s?shop|coffee\s?shops?|tea|teas?|chai|chais?|java|javas?|bean|beans?|roast|roasts?|decaf)\b', |
| | 'pizza': r'\b(pizza|pizzas?|delivery|deliveries|order\s?food|order\s?foods?|pepperoni|pepperonis?|topping|toppings?|pizzeria|pizzerias?|slice|slices?|pie|pies?|margherita|margheritas?|deep\s?dish|deep\s?dishes?|thin\s?crust|thin\s?crusts?|cheese|cheeses?|oven|ovens?|tossed|tosses?|sauce|sauces?|garlic\s?bread|garlic\s?breads?|calzone|calzones?)\b', |
| | 'auto': r'\b(car|cars?|vehicle|vehicles?|repair|repairs?|maintenance|maintenances?|mechanic|mechanics?|oil\s?change|oil\s?changes?|garage|garages?|auto\s?shop|auto\s?shops?|tire|tires?|check\s?engine|check\s?engines?|battery|batteries?|transmission|transmissions?|brake|brakes?|engine\s?diagnostics|engine\s?diagnostic|carwash|carwashes?|detail|details?|alignment|alignments?|exhaust|exhausts?|spark\s?plug|spark\s?plugs?|dashboard|dashboards?)\b', |
| | } |
| |
|
| | |
| | for domain, pattern in domain_patterns.items(): |
| | if re.search(pattern, query.lower()): |
| | return domain |
| |
|
| | return 'other' |
| | |
| | def is_numeric_response(self, text: str) -> bool: |
| | """ |
| | Return True if `text` is purely digits and/or spaces. |
| | """ |
| | pattern = r'^[\s]*[\d]+([\s.,\d]+)*[\s]*$' |
| | return bool(re.match(pattern, text.strip())) |
| | |
| | def introduction_message(self) -> None: |
| | """Print an introduction message to introduce the chatbot.""" |
| | print( |
| | "\nAssistant: Hello! I'm a simple chatbot assistant. I've been trained to answer " |
| | "basic questions about topics including restaurants, movies, ride sharing, coffee, and pizza. " |
| | "Please ask me a question and I'll do my best to assist you." |
| | ) |
| | |
| | def run_interactive_chat(self, quality_checker, show_alternatives=False): |
| | """Separate function for interactive chat loop.""" |
| | |
| | |
| | self.introduction_message() |
| | |
| | |
| | while True: |
| | try: |
| | user_input = input("\nYou: ") |
| | except (KeyboardInterrupt, EOFError): |
| | print("\nAssistant: Goodbye!") |
| | break |
| | |
| | if user_input.lower() in ["quit", "exit", "bye"]: |
| | print("\nAssistant: Goodbye!") |
| | break |
| | |
| | response, candidates, metrics, top_response_score = self.chat( |
| | query=user_input, |
| | conversation_history=None, |
| | quality_checker=quality_checker, |
| | top_k=10 |
| | ) |
| | |
| | print(f"\nAssistant: {response}") |
| | |
| | if show_alternatives and candidates and metrics.get("is_confident", False): |
| | print("\n Alternative responses:") |
| | for resp, score in candidates[1:4]: |
| | print(f" Score: {score:.4f} - {resp}") |
| | elif top_response_score < 0.7: |
| | print("\n[Low Confidence]: Consider rephrasing your query for better assistance.") |
| | |
| | def chat( |
| | self, |
| | query: str, |
| | conversation_history: Optional[List[Tuple[str, str]]] = None, |
| | quality_checker: Optional['ResponseQualityChecker'] = None, |
| | top_k: int = 10, |
| | ) -> Tuple[str, List[Tuple[str, float]], Dict[str, Any]]: |
| | """ |
| | Live chat with the chatbot. Uses same processing flow as validation, except for context handling and quality checking. |
| | """ |
| | @self.run_on_device |
| | def get_response(self_arg, query_arg): |
| | |
| | conversation_str = self_arg._build_conversation_context(query_arg, conversation_history) |
| | |
| | |
| | responses = self_arg.retrieve_responses( |
| | query=conversation_str, |
| | top_k=top_k, |
| | reranker=self_arg.reranker, |
| | summarizer=self_arg.summarizer, |
| | summarize_threshold=512 |
| | ) |
| | |
| | |
| | if not responses: |
| | return ("I'm sorry, but I couldn't find a relevant response.", [], {}) |
| | |
| | |
| | metrics = quality_checker.check_response_quality(query_arg, responses) |
| | is_confident = metrics.get('is_confident', False) |
| | top_response_score = responses[0][1] |
| | |
| | |
| | if not is_confident or top_response_score < 0.5: |
| | return ("I need more information to provide a good answer. Could you please clarify?", responses, metrics, top_response_score) |
| | |
| | |
| | return responses[0][0], responses, metrics, top_response_score |
| | |
| | return get_response(self, query) |
| |
|
| | def _build_conversation_context( |
| | self, |
| | query: str, |
| | conversation_history: Optional[List[Tuple[str, str]]] |
| | ) -> str: |
| | """ |
| | Build conversation context string from conversation history, |
| | using literal <USER> and <ASSISTANT> tokens (no tokenizer special index). |
| | """ |
| | USER_TOKEN = "<USER>" |
| | ASSISTANT_TOKEN = "<ASSISTANT>" |
| |
|
| | if not conversation_history: |
| | return f"{USER_TOKEN} {query}" |
| | |
| | conversation_parts = [] |
| | for user_txt, assistant_txt in conversation_history: |
| | |
| | conversation_parts.append(f"{USER_TOKEN} {user_txt}") |
| | conversation_parts.append(f"{ASSISTANT_TOKEN} {assistant_txt}") |
| | |
| | conversation_parts.append(f"{USER_TOKEN} {query}") |
| | return "\n".join(conversation_parts) |
| | |
| | def train_model( |
| | self, |
| | tfrecord_file_path: str, |
| | epochs: int = 20, |
| | batch_size: int = 16, |
| | validation_split: float = 0.2, |
| | checkpoint_dir: str = "checkpoints/", |
| | use_lr_schedule: bool = True, |
| | peak_lr: float = 1e-5, |
| | warmup_steps_ratio: float = 0.1, |
| | early_stopping_patience: int = 3, |
| | min_delta: float = 1e-4, |
| | test_mode: bool = False, |
| | initial_epoch: int = 0 |
| | ) -> None: |
| | """ |
| | Train the retrieval model using a pre-prepared TFRecord dataset. |
| | - Checkpoint loading/restoring |
| | - LR scheduling |
| | - Epoch/iteration tracking |
| | - Training-history logging |
| | - Early stopping |
| | - Custom loss function (Contrastive loss with hard negative sampling)) |
| | """ |
| | logger.info("Starting training with pre-prepared TFRecord dataset...") |
| |
|
| | def parse_tfrecord_fn(example_proto, max_length, neg_samples): |
| | """ |
| | Parses a single TFRecord example. |
| | """ |
| | feature_description = { |
| | 'query_ids': tf.io.FixedLenFeature([max_length], tf.int64), |
| | 'positive_ids': tf.io.FixedLenFeature([max_length], tf.int64), |
| | 'negative_ids': tf.io.FixedLenFeature([neg_samples * max_length], tf.int64), |
| | } |
| | parsed_features = tf.io.parse_single_example(example_proto, feature_description) |
| |
|
| | query_ids = tf.cast(parsed_features['query_ids'], tf.int32) |
| | positive_ids = tf.cast(parsed_features['positive_ids'], tf.int32) |
| | negative_ids = tf.cast(parsed_features['negative_ids'], tf.int32) |
| | negative_ids = tf.reshape(negative_ids, [neg_samples, max_length]) |
| |
|
| | return query_ids, positive_ids, negative_ids |
| |
|
| | |
| | raw_dataset = tf.data.TFRecordDataset(tfrecord_file_path) |
| | total_pairs = sum(1 for _ in raw_dataset) |
| | logger.info(f"Total pairs in TFRecord: {total_pairs}") |
| |
|
| | train_size = int(total_pairs * (1 - validation_split)) |
| | val_size = total_pairs - train_size |
| | steps_per_epoch = math.ceil(train_size / batch_size) |
| | val_steps = math.ceil(val_size / batch_size) |
| | total_steps = steps_per_epoch * epochs |
| | buffer_size = max(1, total_pairs // 2) |
| |
|
| | logger.info(f"Training pairs: {train_size}") |
| | logger.info(f"Validation pairs: {val_size}") |
| | logger.info(f"Steps per epoch: {steps_per_epoch}") |
| | logger.info(f"Validation steps: {val_steps}") |
| | logger.info(f"Total steps: {total_steps}") |
| |
|
| | |
| | if use_lr_schedule: |
| | warmup_steps = int(total_steps * warmup_steps_ratio) |
| | lr_schedule = self._get_lr_schedule( |
| | total_steps=total_steps, |
| | peak_lr=tf.cast(peak_lr, tf.float32), |
| | warmup_steps=warmup_steps |
| | ) |
| | self.optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) |
| | logger.info("Using custom learning rate schedule.") |
| | else: |
| | self.optimizer = tf.keras.optimizers.Adam(learning_rate=tf.cast(peak_lr, tf.float32)) |
| | logger.info("Using fixed learning rate.") |
| |
|
| | |
| | dummy_input = tf.zeros((1, self.config.max_context_length), dtype=tf.int32) |
| | with tf.GradientTape() as tape: |
| | dummy_output = self.encoder(dummy_input) |
| | dummy_loss = tf.cast(tf.reduce_mean(dummy_output), tf.float32) |
| | dummy_grads = tape.gradient(dummy_loss, self.encoder.trainable_variables) |
| | self.optimizer.apply_gradients(zip(dummy_grads, self.encoder.trainable_variables)) |
| |
|
| | |
| | checkpoint = tf.train.Checkpoint( |
| | epoch=tf.Variable(0, dtype=tf.int32), |
| | optimizer=self.optimizer, |
| | model=self.encoder |
| | ) |
| |
|
| | |
| | manager = tf.train.CheckpointManager( |
| | checkpoint, |
| | directory=checkpoint_dir, |
| | max_to_keep=3, |
| | checkpoint_name='ckpt' |
| | ) |
| |
|
| | |
| | latest_checkpoint = manager.latest_checkpoint |
| | history_path = Path(checkpoint_dir) / 'training_history.json' |
| |
|
| | |
| | if not hasattr(self, 'history'): |
| | self.history = {'train_loss': [], 'val_loss': [], 'learning_rate': []} |
| |
|
| | if latest_checkpoint and not test_mode: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | status = checkpoint.restore(latest_checkpoint) |
| | status.assert_consumed() |
| | logger.info(f"Restored from checkpoint: {latest_checkpoint}") |
| | logger.info(f"Optimizer iterations after restore: {self.optimizer.iterations.numpy()}") |
| | |
| | |
| | if use_lr_schedule: |
| | current_lr = float(lr_schedule(self.optimizer.iterations)) |
| | else: |
| | current_lr = float(self.optimizer.learning_rate.numpy()) |
| | logger.info(f"Current learning rate after restore: {current_lr:.2e}") |
| |
|
| | |
| | ckpt_number = int(latest_checkpoint.split('ckpt-')[-1]) |
| | if initial_epoch == 0: |
| | initial_epoch = ckpt_number |
| |
|
| | |
| | checkpoint.epoch.assign(tf.cast(initial_epoch, tf.int32)) |
| | logger.info(f"Resuming from epoch {initial_epoch}") |
| | |
| | |
| | if history_path.exists(): |
| | try: |
| | with open(history_path, 'r') as f: |
| | self.history = json.load(f) |
| | logger.info(f"Loaded previous training history from {history_path}") |
| | except Exception as e: |
| | logger.warning(f"Could not load history, starting fresh: {e}") |
| | |
| | |
| | |
| | |
| | |
| | self.save_models(Path(checkpoint_dir) / "pretrained_full_model") |
| | logger.info(f"Manually saved custom weights after restore.") |
| | else: |
| | logger.info("Starting training from scratch") |
| | checkpoint.epoch.assign(tf.cast(0, tf.int32)) |
| | initial_epoch = 0 |
| |
|
| | |
| | log_dir = Path(checkpoint_dir) / "tensorboard_logs" |
| | log_dir.mkdir(parents=True, exist_ok=True) |
| | current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") |
| | train_log_dir = str(log_dir / f"train_{current_time}") |
| | val_log_dir = str(log_dir / f"val_{current_time}") |
| | train_summary_writer = tf.summary.create_file_writer(train_log_dir) |
| | val_summary_writer = tf.summary.create_file_writer(val_log_dir) |
| | logger.info(f"TensorBoard logs will be saved in {log_dir}") |
| | |
| | |
| | dataset = tf.data.TFRecordDataset(tfrecord_file_path) |
| | |
| | |
| | if test_mode: |
| | subset_size = 200 |
| | dataset = dataset.take(subset_size) |
| | logger.info(f"TEST MODE: Using only {subset_size} examples") |
| | |
| | total_pairs = subset_size |
| | train_size = int(total_pairs * (1 - validation_split)) |
| | val_size = total_pairs - train_size |
| | batch_size = min(batch_size, val_size) |
| | steps_per_epoch = math.ceil(train_size / batch_size) |
| | val_steps = math.ceil(val_size / batch_size) |
| | total_steps = steps_per_epoch * epochs |
| | buffer_size = max(1, total_pairs // 10) |
| | epochs = min(epochs, 5) |
| | early_stopping_patience = 2 |
| | logger.info(f"New training pairs: {train_size}") |
| | logger.info(f"New validation pairs: {val_size}") |
| | |
| | dataset = dataset.map( |
| | lambda x: parse_tfrecord_fn(x, self.config.max_context_length, self.data_pipeline.neg_samples), |
| | num_parallel_calls=tf.data.AUTOTUNE |
| | ) |
| | |
| | |
| | train_dataset = dataset.take(train_size) |
| | val_dataset = dataset.skip(train_size).take(val_size) |
| | |
| | |
| | train_dataset = train_dataset.shuffle(buffer_size=buffer_size) |
| | train_dataset = train_dataset.batch(batch_size, drop_remainder=True) |
| | train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE) |
| | |
| | val_dataset = val_dataset.batch(batch_size, drop_remainder=False) |
| | val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE) |
| | val_dataset = val_dataset.cache() |
| | |
| | |
| | best_val_loss = float("inf") |
| | epochs_no_improve = 0 |
| | |
| | for epoch in range(int(checkpoint.epoch.numpy()) + 1, epochs + 1): |
| | checkpoint.epoch.assign(epoch) |
| | logger.info(f"Starting Epoch {epoch}...") |
| | |
| | epoch_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32) |
| | batches_processed = 0 |
| | |
| | try: |
| | train_pbar = tqdm( |
| | total=steps_per_epoch, |
| | desc=f"Training Epoch {epoch}", |
| | unit="batch" |
| | ) |
| | is_tqdm_train = True |
| | except ImportError: |
| | train_pbar = None |
| | is_tqdm_train = False |
| | |
| | |
| | for q_batch, p_batch, n_batch in train_dataset: |
| | loss, grad_norm, post_clip_norm = self.train_step(q_batch, p_batch, n_batch) |
| | epoch_loss_avg(loss) |
| | batches_processed += 1 |
| |
|
| | |
| | with train_summary_writer.as_default(): |
| | step = (epoch - 1) * steps_per_epoch + batches_processed |
| | tf.summary.scalar("loss", tf.cast(loss, tf.float32), step=step) |
| | tf.summary.scalar("gradient_norm_pre_clip", tf.cast(grad_norm, tf.float32), step=step) |
| | tf.summary.scalar("gradient_norm_post_clip", tf.cast(post_clip_norm, tf.float32), step=step) |
| |
|
| | |
| | if use_lr_schedule: |
| | current_lr = float(lr_schedule(self.optimizer.iterations)) |
| | else: |
| | current_lr = float(self.optimizer.learning_rate.numpy()) |
| |
|
| | if is_tqdm_train: |
| | train_pbar.update(1) |
| | train_pbar.set_postfix({ |
| | "loss": f"{loss.numpy():.4f}", |
| | "pre_clip": f"{grad_norm.numpy():.2e}", |
| | "post_clip": f"{post_clip_norm.numpy():.2e}", |
| | "lr": f"{current_lr:.2e}", |
| | "batches": f"{batches_processed}/{steps_per_epoch}" |
| | }) |
| | |
| | gc.collect() |
| | |
| | |
| | if batches_processed >= steps_per_epoch: |
| | break |
| | |
| | if is_tqdm_train and train_pbar: |
| | train_pbar.close() |
| | |
| | |
| | val_loss_avg = tf.keras.metrics.Mean(dtype=tf.float32) |
| | val_batches_processed = 0 |
| | |
| | try: |
| | val_pbar = tqdm(total=val_steps, desc="Validation", unit="batch") |
| | is_tqdm_val = True |
| | except ImportError: |
| | val_pbar = None |
| | is_tqdm_val = False |
| | |
| | last_valid_val_loss = None |
| | valid_batches = False |
| | |
| | for q_batch, p_batch, n_batch in val_dataset: |
| | |
| | if tf.shape(q_batch)[0] < 2: |
| | logger.warning(f"Skipping validation batch of size {tf.shape(q_batch)[0]}") |
| | continue |
| | |
| | valid_batches = True |
| | val_loss = self.validation_step(q_batch, p_batch, n_batch) |
| | val_loss_avg(val_loss) |
| | last_valid_val_loss = val_loss |
| | val_batches_processed += 1 |
| | |
| | if is_tqdm_val: |
| | val_pbar.update(1) |
| | val_pbar.set_postfix({ |
| | "val_loss": f"{val_loss.numpy():.4f}", |
| | "batches": f"{val_batches_processed}/{val_steps}" |
| | }) |
| | |
| | gc.collect() |
| | |
| | if val_batches_processed >= val_steps: |
| | break |
| | |
| | if not valid_batches: |
| | |
| | logger.warning("No valid validation batches in this epoch") |
| | if last_valid_val_loss is not None: |
| | val_loss = last_valid_val_loss |
| | val_loss_avg(val_loss) |
| | else: |
| | val_loss = epoch_loss_avg.result() |
| | val_loss_avg(val_loss) |
| | |
| | if is_tqdm_val and val_pbar: |
| | val_pbar.close() |
| | |
| | |
| | train_loss = epoch_loss_avg.result().numpy() |
| | val_loss = val_loss_avg.result().numpy() |
| | logger.info(f"Epoch {epoch} Complete: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}") |
| | |
| | |
| | with train_summary_writer.as_default(): |
| | tf.summary.scalar("epoch_loss", train_loss, step=epoch) |
| | with val_summary_writer.as_default(): |
| | tf.summary.scalar("val_loss", val_loss, step=epoch) |
| | |
| | |
| | manager.save() |
| | |
| | |
| | model_save_path = Path(checkpoint_dir) / f"model_epoch_{epoch}" |
| | self.save_models(model_save_path) |
| | logger.info(f"Saved model for epoch {epoch} at {model_save_path}") |
| | |
| | |
| | self.history['train_loss'].append(train_loss) |
| | self.history['val_loss'].append(val_loss) |
| | self.history.setdefault('learning_rate', []).append(current_lr) |
| | |
| | def convert_to_py_floats(obj): |
| | if isinstance(obj, dict): |
| | return {k: convert_to_py_floats(v) for k, v in obj.items()} |
| | elif isinstance(obj, list): |
| | return [convert_to_py_floats(x) for x in obj] |
| | elif isinstance(obj, (np.float32, np.float64)): |
| | return float(obj) |
| | elif tf.is_tensor(obj): |
| | return float(obj.numpy()) |
| | else: |
| | return obj |
| | |
| | json_history = convert_to_py_floats(self.history) |
| | |
| | |
| | with open(history_path, 'w') as f: |
| | json.dump(json_history, f) |
| | logger.info(f"Saved training history to {history_path}") |
| | |
| | |
| | if val_loss < best_val_loss - min_delta: |
| | best_val_loss = val_loss |
| | epochs_no_improve = 0 |
| | logger.info(f"Validation loss improved to {val_loss:.4f}. Reset patience.") |
| | else: |
| | epochs_no_improve += 1 |
| | logger.info(f"No improvement this epoch. Patience: {epochs_no_improve}/{early_stopping_patience}") |
| | if epochs_no_improve >= early_stopping_patience: |
| | logger.info("Early stopping triggered.") |
| | break |
| | |
| | logger.info("Training completed!") |
| | |
| | @tf.function |
| | def train_step( |
| | self, |
| | q_batch: tf.Tensor, |
| | p_batch: tf.Tensor, |
| | n_batch: tf.Tensor |
| | ) -> tf.Tensor: |
| | """ |
| | Single training step using queries, positives, and hard negatives. |
| | """ |
| | with tf.GradientTape() as tape: |
| | |
| | q_enc = self.encoder(q_batch, training=True) |
| | p_enc = self.encoder(p_batch, training=True) |
| | shape = tf.shape(n_batch) |
| | bs = shape[0] |
| | neg_samples = shape[1] |
| |
|
| | |
| | n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]]) |
| | n_enc_flat = self.encoder(n_batch_flat, training=True) |
| |
|
| | |
| | n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1]) |
| |
|
| | |
| | |
| | combined_p_n = tf.concat([tf.expand_dims(p_enc, axis=1), n_enc], axis=1) |
| |
|
| | |
| | dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32) |
| | labels = tf.zeros([bs], dtype=tf.int32) |
| | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
| | labels=labels, |
| | logits=dot_products |
| | ) |
| | loss = tf.cast(tf.reduce_mean(loss), tf.float32) |
| |
|
| | |
| | gradients = tape.gradient(loss, self.encoder.trainable_variables) |
| | gradients_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32) |
| | max_grad_norm = tf.constant(1.5, dtype=tf.float32) |
| | gradients, _ = tf.clip_by_global_norm(gradients, max_grad_norm, gradients_norm) |
| | post_clip_norm = tf.cast(tf.linalg.global_norm(gradients), tf.float32) |
| | |
| | self.optimizer.apply_gradients(zip(gradients, self.encoder.trainable_variables)) |
| | |
| | return loss, gradients_norm, post_clip_norm |
| |
|
| | @tf.function |
| | def validation_step( |
| | self, |
| | q_batch: tf.Tensor, |
| | p_batch: tf.Tensor, |
| | n_batch: tf.Tensor |
| | ) -> tf.Tensor: |
| | """ |
| | Single validation step using queries, positives, and hard negatives. |
| | Same idea as train_step, but without gradient updates. |
| | """ |
| | q_enc = self.encoder(q_batch, training=False) |
| | p_enc = self.encoder(p_batch, training=False) |
| |
|
| | shape = tf.shape(n_batch) |
| | bs = shape[0] |
| | neg_samples = shape[1] |
| |
|
| | n_batch_flat = tf.reshape(n_batch, [bs * neg_samples, shape[2]]) |
| | n_enc_flat = self.encoder(n_batch_flat, training=False) |
| | n_enc = tf.reshape(n_enc_flat, [bs, neg_samples, -1]) |
| |
|
| | combined_p_n = tf.concat( |
| | [tf.expand_dims(p_enc, axis=1), n_enc], |
| | axis=1 |
| | ) |
| |
|
| | dot_products = tf.cast(tf.einsum('bd,bkd->bk', q_enc, combined_p_n), tf.float32) |
| | labels = tf.zeros([bs], dtype=tf.int32) |
| | |
| | loss = tf.nn.sparse_softmax_cross_entropy_with_logits( |
| | labels=labels, |
| | logits=dot_products |
| | ) |
| | loss = tf.cast(tf.reduce_mean(loss), tf.float32) |
| |
|
| | return loss |
| | |
| | def _get_lr_schedule( |
| | self, |
| | total_steps: int, |
| | peak_lr: float, |
| | warmup_steps: int |
| | ) -> tf.keras.optimizers.schedules.LearningRateSchedule: |
| | """ |
| | Custom learning rate schedule with warmup and cosine decay. |
| | """ |
| | class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): |
| | def __init__( |
| | self, |
| | total_steps: int, |
| | peak_lr: float, |
| | warmup_steps: int |
| | ): |
| | super().__init__() |
| | self.total_steps = tf.cast(total_steps, tf.float32) |
| | self.peak_lr = tf.cast(peak_lr, tf.float32) |
| | |
| | |
| | adjusted_warmup_steps = min(warmup_steps, max(1, total_steps // 10)) |
| | self.warmup_steps = tf.cast(adjusted_warmup_steps, tf.float32) |
| | |
| | |
| | self.initial_lr = tf.cast(self.peak_lr * 0.1, tf.float32) |
| | self.min_lr = tf.cast(self.peak_lr * 0.01, tf.float32) |
| | |
| | logger.info(f"Learning rate schedule initialized:") |
| | logger.info(f" Initial LR: {float(self.initial_lr):.6f}") |
| | logger.info(f" Peak LR: {float(self.peak_lr):.6f}") |
| | logger.info(f" Min LR: {float(self.min_lr):.6f}") |
| | logger.info(f" Warmup steps: {int(self.warmup_steps)}") |
| | logger.info(f" Total steps: {int(self.total_steps)}") |
| | |
| | def __call__(self, step): |
| | step = tf.cast(step, tf.float32) |
| | |
| | |
| | warmup_factor = tf.cast(tf.minimum(1.0, step / self.warmup_steps), tf.float32) |
| | warmup_lr = self.initial_lr + (self.peak_lr - self.initial_lr) * warmup_factor |
| | |
| | |
| | decay_steps = tf.cast(tf.maximum(1.0, self.total_steps - self.warmup_steps), tf.float32) |
| | decay_factor = tf.cast((step - self.warmup_steps) / decay_steps, tf.float32) |
| | decay_factor = tf.cast(tf.minimum(tf.maximum(0.0, decay_factor), 1.0), tf.float32) |
| | cosine_decay = tf.cast(0.5 * (1.0 + tf.cos(tf.constant(math.pi, dtype=tf.float32) * decay_factor)), tf.float32) |
| | decay_lr = self.min_lr + (self.peak_lr - self.min_lr) * cosine_decay |
| | |
| | final_lr = tf.where(step < self.warmup_steps, warmup_lr, decay_lr) |
| | |
| | |
| | final_lr = tf.maximum(self.min_lr, final_lr) |
| | final_lr = tf.where(tf.math.is_finite(final_lr), final_lr, self.min_lr) |
| | |
| | return final_lr |
| | |
| | def get_config(self): |
| | return { |
| | "total_steps": self.total_steps, |
| | "peak_lr": self.peak_lr, |
| | "warmup_steps": self.warmup_steps, |
| | } |
| | |
| | return CustomSchedule(total_steps, peak_lr, warmup_steps) |
| |
|