import base64 import math import json from typing import List, Tuple, Dict, Any, Iterator import torch from torch.utils.data import Dataset, IterableDataset def vread(buf: bytes, i: int): shift = val = 0 while True: b = buf[i] i += 1 val |= (b & 0x7F) << shift if b < 0x80: return val, i shift += 7 def vwrite(v: int, out: bytearray): while True: byte = v & 0x7F v >>= 7 out.append(byte | 0x80 if v else byte) if not v: break def compress_windows_starts_lens(starts, lens): buf = bytearray() cursor = 0 for s, L in zip(starts, lens): gap = s - cursor vwrite(gap, buf) vwrite(L, buf) cursor = s + L return base64.b64encode(buf).decode("ascii") def decompress_windows_starts_lens(b64_stream): buf = base64.b64decode(b64_stream) i = 0 cursor= 0 starts, lens = [], [] while i < len(buf): gap, i = vread(buf, i) size, i = vread(buf, i) start = cursor + gap length = size starts.append(start) lens.append(length) cursor = start + length return starts, lens def unpack_windows( input_bytes: bytes, b64_stream: str, ) -> List[Tuple[bytes, int]]: """ Returns - byte_windows: list of (bytes, int) tuples, where the int is 0 if the bytes is raw and 1 if the bytes is compressed """ buf = base64.b64decode(b64_stream) i = 0 cursor = 0 byte_windows = [] while i < len(buf): gap, i = vread(buf, i) size, i = vread(buf, i) start = cursor + gap if gap > 0: hole = input_bytes[cursor:start] byte_windows.append((hole, 0)) length = size end = start + length win = input_bytes[start:end] byte_windows.append((win, 1)) cursor = end if cursor < len(input_bytes): hole = input_bytes[cursor:] byte_windows.append((hole, 0)) return byte_windows def pseudo_to_packed_bytes(lst: list[int]) -> bytes: out = bytearray() acc = bits = 0 for v in lst: acc |= (v & 0x1FF) << bits bits += 9 while bits >= 8: out.append(acc & 0xFF) acc >>= 8 bits -= 8 if bits: # flush tail out.append(acc) return bytes(out) def packed_bytes_to_pseudo(b: bytes) -> list[int]: out, acc, bits = [], 0, 0 for byte in b: acc |= byte << bits bits += 8 while bits >= 9: out.append(acc & 0x1FF) acc >>= 9 bits -= 9 return out def pad_batch(batch: List[bytes]): batch_tensors = [torch.tensor(data, dtype=torch.int64) for data in batch] lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64) padded_batch = torch.nn.utils.rnn.pad_sequence( batch_tensors, batch_first=True, padding_value=0, padding_side="right" ) return padded_batch, lengths class JsonlShardedDataset(Dataset): def __init__( self, file_path: str, current_proc_rank: int = 0, total_procs: int = 1, ) -> None: assert 0 <= current_proc_rank < total_procs, "rank must be in [0, world_size)" self.current_proc_rank = current_proc_rank self.total_procs = total_procs # -- load the whole file once (fast for < few-GB files) ------------- with open(file_path, "r", encoding="utf-8") as f: full_data: List[Dict[str, Any]] = [json.loads(line) for line in f] # -- pick the slice that belongs to *this* process ------------------ total = len(full_data) per_proc = math.ceil(total / total_procs) start = current_proc_rank * per_proc end = min(start + per_proc, total) self.data = full_data[start:end] def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict[str, Any]: return self.data[idx] class InterleavedJsonlDataset(IterableDataset): """ An iterable-style dataset for reading a large JSONL file using an interleaving/striding pattern, without yielding state information. This is designed for multi-process data loading. Each process reads the entire file but only processes lines that match its rank (offset). For `N` total processes (world_size), process `r` (rank) will read lines r, r+N, r+2N, ... (0-indexed). This method ensures an even distribution of lines across processes. Args: file_path (str): Path to the JSONL file. rank (int): The rank of the current process, used as the offset. world_size (int): The total number of processes, used as the block_size/stride. """ def __init__( self, file_path: str, rank: int, world_size: int, ) -> None: super().__init__() if not (0 <= rank < world_size): raise ValueError(f"Rank must be in [0, {world_size-1}], but got {rank}") self.file_path = file_path self.offset = rank self.block_size = world_size def __iter__(self) -> Iterator[Dict[str, Any]]: """ The iterator method that yields the parsed JSON data for the assigned lines. """ try: with open(self.file_path, "r", encoding="utf-8") as f: # We use a simple line counter to determine which lines to process. # The line_number is 0-indexed. for line_number, line in enumerate(f): # Check if the current line number belongs to this process if (line_number % self.block_size) == self.offset: try: # Yield the parsed JSON object yield json.loads(line) except json.JSONDecodeError: # This line is malformed. We can either raise an error # or, more robustly, just print a warning and skip it. print(f"Warning: Rank {self.offset} could not decode JSON on line ~{line_number+1}. Skipping.") continue except Exception as e: print(f"Error in worker {self.offset}: {e}") raise def batched_m1_compress_predict_fn(model): def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: if input_tensor.dim() == 1: input_tensor = input_tensor.unsqueeze(0) with torch.no_grad(): # get logits logits = model(input_tensor, **kwargs) logits = logits[..., :256] logits = logits.float() assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values." probs = torch.softmax(logits, dim=-1) return probs return predict_fn def find_next_batch_range(all_windows, start_idx, max_m1_batch_size, get_batch_size_for_length_fn): M = len(all_windows) if start_idx >= M: return start_idx, start_idx first_window_len = len(all_windows[start_idx]) base_batch_size = get_batch_size_for_length_fn(first_window_len, max_m1_batch_size) low = start_idx high = min(start_idx + base_batch_size, M) high_batch_size = get_batch_size_for_length_fn(len(all_windows[high - 1]), max_m1_batch_size) if high_batch_size == base_batch_size: return start_idx, high search_low = low search_high = high while search_low < search_high: mid = search_low + (search_high - search_low) // 2 mid_window_len = len(all_windows[mid]) if get_batch_size_for_length_fn(mid_window_len, max_m1_batch_size) == base_batch_size: # This window is valid. The partition point must be to the right of it. # So, we continue searching in the range [mid + 1, high). search_low = mid + 1 else: # This window is NOT valid. It might be the partition point itself, # or the point is to its left. # So, we continue searching in the range [low, mid). search_high = mid end_idx = search_low if end_idx == start_idx: return start_idx, start_idx + 1 else: return start_idx, end_idx