import torch import torch.nn.functional as F from torch.utils.data import IterableDataset, DataLoader import json import numpy as np from pathlib import Path from typing import Iterator, List, Dict, Any, Tuple, Optional, Union, Callable import logging import argparse import base64 import gc from collections import defaultdict, Counter, deque from m1_compression.batched_arithmetic_coder import ( _pdf_to_cdf, ) from m1_compression.hybrid_arithmetic_coder import CPUArithmeticEncoder from m1_compression import utils from m1_compression.compressor import ( load_m1_model_and_tokenizer, load_m1_model_cpu, ALPHABET_SIZE, ARITHMETIC_CODER_BASE, ARITHMETIC_CODER_PRECISION, ) import torch.multiprocessing as mp from offline_utils import ( unpack_windows, pseudo_to_packed_bytes, pad_batch, find_next_batch_range, packed_bytes_to_pseudo, pseudo_to_packed_bytes, pad_batch, InterleavedJsonlDataset, batched_m1_compress_predict_fn, ) MINIMUM_SEGMENT_SIZE = 3 COMPRESSION_OFFSET = 256 GC_FREQ = 10 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger() class SegmentCache: """Cache for segments""" def __init__(self, cache_size: int = 819200, cache_desc: str = "Prediction"): self.cache_size = cache_size self.cache_desc = cache_desc self.cache: Dict[bytes, Union[torch.Tensor, List[int]]] = {} # segment -> CDF tensor or compressed pseudo bytes logger.info(f"Created Cache with size: {cache_size}, type: {cache_desc}") def get_batch(self, segments: List[bytes]) -> Tuple[List[bytes], List[torch.Tensor], List[int]]: """ Returns: - cache_misses: unique segments not in cache - cache_results: CDF tensors for segments in cache (in order of hit_indices) - hit_indices: indices of segments that were cache hits (in input order) """ segment_to_indices = defaultdict(list) for idx, seg in enumerate(segments): segment_to_indices[seg].append(idx) unique_segments = list(segment_to_indices.keys()) cache_results = {} cache_misses = [] for seg in unique_segments: if seg in self.cache: cache_results[seg] = self.cache[seg] else: cache_misses.append((seg, segment_to_indices[seg])) hit_indices = [] for seg, indices in segment_to_indices.items(): if seg in cache_results: for idx in indices: hit_indices.append(idx) logger.info(f"{self.cache_desc} cache: {len(unique_segments)} unique segments, {len(cache_results)} hits, {len(cache_misses)} misses, {len(segments)} total segments") return cache_misses, cache_results, hit_indices def put_batch(self, segments: List[bytes], values: List[Union[torch.Tensor, List[int]]]): """Store segment -> value mappings""" if self.cache_size <= 0: return for segment, value in zip(segments, values): if segment not in self.cache: if len(self.cache) < self.cache_size: if isinstance(value, tuple): assert len(value) == 2 or len(value) == 5, "value must be a tuple of length 2 or 5" cloned_value = tuple(v.clone() if isinstance(v, torch.Tensor) else v for v in value) self.cache[segment] = cloned_value elif isinstance(value, torch.Tensor): self.cache[segment] = value.clone() else: self.cache[segment] = value def get_batch_size_for_length(window_len, max_batch_size): """Determines the batch size for a given window length.""" BATCH_SIZE_TIERS = { 128: max_batch_size, 512: max(max_batch_size // 64, 1), 1024: max(max_batch_size // 128, 1), 2048: max(max_batch_size // 256, 1), } for max_len, batch_size in BATCH_SIZE_TIERS.items(): if window_len <= max_len: return batch_size return 1 def segment_prediction_fn( batch: List[Dict[str, Any]], max_m1_batch_size, batched_predict_fn, first_byte_prob, debug, prediction_cache: Optional[SegmentCache] = None ): """ Consumer: reads from task_queue, compresses, puts result in result_queue. """ all_segments = [] compressed_or_raw_segments = [] sample_idx_to_list_segment_idx = defaultdict(list) segment_idx = 0 for sample_idx, item in enumerate(batch): assert "windows_starts_lens_b64" in item, "windows_starts_lens_b64 must be in item" sample_bytes = item["text"].encode('utf-8') byte_windows = unpack_windows(sample_bytes, item["windows_starts_lens_b64"]) for byte_window_indicator in byte_windows: all_segments.append(byte_window_indicator[0]) compressed_or_raw_segments.append(byte_window_indicator[1]) sample_idx_to_list_segment_idx[sample_idx].append(segment_idx) segment_idx += 1 effective_segments = [] ineffective_segments = [] for orig_idx, (segment, indicator) in enumerate(zip(all_segments, compressed_or_raw_segments)): if len(segment) > MINIMUM_SEGMENT_SIZE and indicator == 1: effective_segments.append((orig_idx, segment)) else: ineffective_segments.append((orig_idx, segment)) sorted_effective_segments = sorted(effective_segments, key=lambda x: len(x[1])) sorted_idx, sorted_segments = zip(*sorted_effective_segments) sorted_segments = list(sorted_segments) # Convert tuple to list effective_segments_idx_map = { orig_idx: new_idx for new_idx, orig_idx in enumerate(sorted_idx) } raw_idx, raw_segments = zip(*ineffective_segments) raw_segments = list(raw_segments) ineffective_segments_idx_map = { orig_idx: new_idx for new_idx, orig_idx in enumerate(raw_idx) } if debug: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() torch.cuda.synchronize()## make sure all previous events are completed print("[Debug CUDA] time start", flush=True) assert first_byte_prob.shape == (1, 1, ALPHABET_SIZE), "first_byte_prob must be of shape (1, 1, ALPHABET_SIZE)" batched_windows_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in sorted_segments] M = len(batched_windows_np) batched_cdf_ends = [None] * M # Pre-allocate to maintain order if debug: batched_pdfs = [None] * M else: batched_pdfs = None if prediction_cache is not None: cache_misses_tup, cache_results, hit_indices = prediction_cache.get_batch(sorted_segments) # Fill in cache hits for hit_idx in hit_indices: segment = sorted_segments[hit_idx] value = cache_results[segment] if debug: batched_cdf_ends[hit_idx] = value[0] batched_pdfs[hit_idx] = value[1] else: batched_cdf_ends[hit_idx] = value # Update miss_indices to only include segments not in cache cache_misses, cache_miss_indices = zip(*cache_misses_tup) else: cache_misses = sorted_segments cache_miss_indices = [[i] for i in range(M)] # Process cache misses if cache_misses: cache_miss_cdf_ends = [] cache_miss_pdfs = [] if debug else None start_idx = 0 batched_windows_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in cache_misses] miss_count = len(batched_windows_np) while start_idx < miss_count: # Use the new helper function to find the exact range for the next safe batch start_idx, end_idx = find_next_batch_range(batched_windows_np, start_idx, max_m1_batch_size, get_batch_size_for_length) windows_np_chunked = batched_windows_np[start_idx:end_idx] padded_batched_windows, lengths = pad_batch(windows_np_chunked) padded_batched_windows, lengths = padded_batched_windows.cuda(), lengths.cuda() with torch.no_grad(): prompt_probs = batched_predict_fn(padded_batched_windows) prompt_probs = torch.cat( [ first_byte_prob.expand(prompt_probs.shape[0], -1, -1), prompt_probs[:, :-1, ...] ], dim=1 ) prompt_probs = utils.batched_normalize_pdf_for_arithmetic_coding(prompt_probs) cdfs_gpu = _pdf_to_cdf(prompt_probs) cdf_low = cdfs_gpu.gather(2, padded_batched_windows.unsqueeze(-1)).squeeze(-1) cdf_high = cdfs_gpu.gather(2, (padded_batched_windows + 1).unsqueeze(-1)).squeeze(-1) cdf_ends = torch.stack([cdf_low, cdf_high], dim=-1) start_idx = end_idx if debug: cache_miss_pdfs.extend(prompt_probs.cpu()) cache_miss_cdf_ends.extend(cdf_ends.cpu()) # each miss idx maps a list of original indices for idx, miss_indices in enumerate(cache_miss_indices): for orig_idx in miss_indices: batched_cdf_ends[orig_idx] = cache_miss_cdf_ends[idx] if debug: batched_pdfs[orig_idx] = cache_miss_pdfs[idx] # Store new results in cache if prediction_cache is not None: if debug: prediction_cache.put_batch(cache_misses, zip(cache_miss_cdf_ends, cache_miss_pdfs)) else: prediction_cache.put_batch(cache_misses, cache_miss_cdf_ends) return ( batch, sorted_segments, raw_segments, effective_segments_idx_map, ineffective_segments_idx_map, sample_idx_to_list_segment_idx, batched_cdf_ends, batched_pdfs, ) def segment_compression_fn( batch: List[Dict[str, Any]], sorted_segments: List[List[int]], raw_segments: List[List[int]], effective_segments_idx_map: Dict[int, int], ineffective_segments_idx_map: Dict[int, int], sample_idx_to_list_segment_idx: Dict[int, List[int]], batched_cdf_ends: List[torch.Tensor], batched_pdfs: List[torch.Tensor], output_window_size: int, escape_first_byte: bool, iterative_compress: bool, force_padding_to_threshold: bool, predict_fn: Callable, first_byte_prob: torch.Tensor, debug: bool = False, compression_cache: Optional[SegmentCache] = None ): ENCODING_BATCH_SIZE = 512 # 384 if iterative_compress: assert not escape_first_byte, "iterative_compress does not support escape_first_byte" M = len(batched_cdf_ends) processed_batched_compressed_bytes = [None] * M if debug: batched_stop_steps = [None] * M batched_num_padded_bits = [None] * M batched_prompt_probs = [None] * M batched_lengths = [None] * M # Check cache for all segments if compression_cache is not None: cache_misses_tup, cache_results, hit_indices = compression_cache.get_batch(sorted_segments) # Fill in cache hits for hit_idx in hit_indices: segment = sorted_segments[hit_idx] if debug: assert len(cache_results[segment]) == 5, "cache_results must be a tuple of length 5" if isinstance(cache_results[segment][0], tuple): processed_batched_compressed_bytes[hit_idx] = cache_results[segment][0][0] batched_stop_steps[hit_idx] = None batched_num_padded_bits[hit_idx] = None batched_prompt_probs[hit_idx] = None batched_lengths[hit_idx] = None else: processed_batched_compressed_bytes[hit_idx] = cache_results[segment][0] batched_stop_steps[hit_idx] = cache_results[segment][1] batched_num_padded_bits[hit_idx] = cache_results[segment][2] batched_prompt_probs[hit_idx] = cache_results[segment][3] batched_lengths[hit_idx] = cache_results[segment][4] else: processed_batched_compressed_bytes[hit_idx] = cache_results[segment] # Update miss_indices to only include segments not in cache cache_misses, cache_miss_indices = zip(*cache_misses_tup) else: cache_misses = sorted_segments cache_miss_indices = [[i] for i in range(M)] # Process cache misses if cache_misses: cache_miss_compressed_bytes = [] cache_miss_stop_steps = [] cache_miss_num_padded_bits = [] cache_miss_prompt_probs = [] cache_miss_lengths = [] ######## Use BatchArithmeticEncoder to replace address one by one ########### encoder = CPUArithmeticEncoder( base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION ) # Get CDF ends and segments for cache misses only miss_cdf_ends = [batched_cdf_ends[miss_indices[0]] for miss_indices in cache_miss_indices] if debug: miss_pdfs = [batched_pdfs[miss_indices[0]] for miss_indices in cache_miss_indices] else: miss_pdfs = None miss_count = len(cache_misses) cache_miss_compressed_results = [] for chunk_idx in range(0, miss_count, ENCODING_BATCH_SIZE): chunk_start = chunk_idx chunk_end = min(chunk_idx + ENCODING_BATCH_SIZE, miss_count) chunk_size = chunk_end - chunk_start chunk_segments = cache_misses[chunk_start:chunk_end] chunk_cdf_ends = miss_cdf_ends[chunk_start:chunk_end] lengths = torch.tensor([len(segment) for segment in chunk_segments], dtype=torch.int64) padded_chunk_cdf_ends = torch.zeros( (chunk_size, lengths.max().item(), 2), device="cpu" ) for idx, (cdf_end, length) in enumerate(zip(chunk_cdf_ends, lengths)): padded_chunk_cdf_ends[idx, :length, :] = cdf_end[:length, :] if escape_first_byte: chunked_compressed_bytes, chunked_stop_steps, chunked_num_padded_bits = encoder.incremental_batched_encode( padded_chunk_cdf_ends[:, 1:, ...], ALPHABET_SIZE, lengths - 1, bit_threshold=output_window_size, force_padding_to_threshold=force_padding_to_threshold, return_num_padded_bits=True ) # if we escape the first byte, we need to add offset 1 to the stop step chunked_stop_steps = [step + 1 for step in chunked_stop_steps] else: chunked_compressed_bytes, chunked_stop_steps, chunked_num_padded_bits = encoder.incremental_batched_encode( padded_chunk_cdf_ends, ALPHABET_SIZE, lengths, bit_threshold=output_window_size, force_padding_to_threshold=force_padding_to_threshold, return_num_padded_bits=True ) cache_miss_compressed_bytes.extend(chunked_compressed_bytes) cache_miss_stop_steps.extend(chunked_stop_steps) if debug: chunk_pdfs = miss_pdfs[chunk_start:chunk_end] padded_chunk_pdfs = torch.zeros( (chunk_size, lengths.max().item(), ALPHABET_SIZE), device="cpu" ) for idx, (pdf, length) in enumerate(zip(chunk_pdfs, lengths)): padded_chunk_pdfs[idx, :length, :] = pdf[:length, :] if escape_first_byte: cache_miss_num_padded_bits.extend(chunked_num_padded_bits) cache_miss_prompt_probs.extend(padded_chunk_pdfs[:, 1:, ...]) cache_miss_lengths.extend(lengths - 1) else: cache_miss_num_padded_bits.extend(chunked_num_padded_bits) cache_miss_prompt_probs.extend(padded_chunk_pdfs) cache_miss_lengths.extend(lengths) for i in range(chunk_start, chunk_end): window_bytes = cache_misses[i] stop_step = cache_miss_stop_steps[i] _compressed_bytes = list(cache_miss_compressed_bytes[i]) compressed_bytes = [COMPRESSION_OFFSET + b for b in _compressed_bytes] if escape_first_byte: compressed_bytes = list(window_bytes[0:1]) + compressed_bytes if stop_step == -1 or stop_step >= len(window_bytes): cache_miss_compressed_results.append(compressed_bytes) else: remaining_raw_bytes = list(window_bytes[stop_step:]) if iterative_compress and len(remaining_raw_bytes) > MINIMUM_SEGMENT_SIZE: cache_miss_compressed_results.append((remaining_raw_bytes, compressed_bytes)) else: compressed_bytes = compressed_bytes + remaining_raw_bytes cache_miss_compressed_results.append(compressed_bytes) if iterative_compress: incomplete_window_ids = [] incomplete_window_remaining_bytes = [] incomplete_window_compressed_bytes = [] for i, compressed_bytes in enumerate(cache_miss_compressed_results): if isinstance(compressed_bytes, tuple): incomplete_window_ids.append(i) incomplete_window_remaining_bytes.append(compressed_bytes[0]) incomplete_window_compressed_bytes.append(compressed_bytes[1]) remaining_compressed_bytes = iterative_compress_ac( incomplete_window_remaining_bytes, predict_fn, first_byte_prob, output_window_size, force_padding_to_threshold, ENCODING_BATCH_SIZE, debug ) for i, remaining_compressed_b in enumerate(remaining_compressed_bytes): id_in_cache = incomplete_window_ids[i] final_compressed_bytes = incomplete_window_compressed_bytes[i] + remaining_compressed_b if debug: cache_miss_compressed_results[id_in_cache] = (final_compressed_bytes, "skip_debug") else: cache_miss_compressed_results[id_in_cache] = final_compressed_bytes logger.info(f"[DEBUG] total remaining windows: {len(incomplete_window_ids)}") # Fill in cache misses in the correct positions for idx, miss_indices in enumerate(cache_miss_indices): for orig_idx in miss_indices: if debug: if isinstance(cache_miss_compressed_results[idx], tuple): assert cache_miss_compressed_results[idx][1] == "skip_debug" processed_batched_compressed_bytes[orig_idx] = cache_miss_compressed_results[idx][0] batched_stop_steps[orig_idx] = None batched_num_padded_bits[orig_idx] = None batched_prompt_probs[orig_idx] = None batched_lengths[orig_idx] = None else: processed_batched_compressed_bytes[orig_idx] = cache_miss_compressed_results[idx] batched_stop_steps[orig_idx] = cache_miss_stop_steps[idx] batched_num_padded_bits[orig_idx] = cache_miss_num_padded_bits[idx] batched_prompt_probs[orig_idx] = cache_miss_prompt_probs[idx] batched_lengths[orig_idx] = cache_miss_lengths[idx] else: processed_batched_compressed_bytes[orig_idx] = cache_miss_compressed_results[idx] # Store new results in cache if compression_cache is not None: if debug: compression_cache.put_batch( cache_misses, zip( cache_miss_compressed_results, cache_miss_stop_steps, cache_miss_num_padded_bits, cache_miss_prompt_probs, cache_miss_lengths ) ) else: compression_cache.put_batch(cache_misses, cache_miss_compressed_results) # 4.recompose all segmentations B = len(batch) #### fix: add pseudo length to split the compressed and raw bytes pseudo_lens_per_segment = [[] for _ in range(B)] #### fix end compressed_bytes = [[] for _ in range(B)] original_bytes = [[] for _ in range(B)] for sample_idx, list_segment_idx in sample_idx_to_list_segment_idx.items(): for segment_idx in list_segment_idx: if segment_idx in effective_segments_idx_map: compressed_idx = effective_segments_idx_map[segment_idx] compressed_byte = processed_batched_compressed_bytes[compressed_idx] else: raw_idx = ineffective_segments_idx_map[segment_idx] compressed_byte = raw_segments[raw_idx] #### fix: whatever the compressed or raw bytes windows,restore the pseudo bytes pseudo_lens_per_segment[sample_idx].append(len(compressed_byte)) #### fix end compressed_bytes[sample_idx].extend(list(compressed_byte)) if debug: if segment_idx in effective_segments_idx_map: compressed_idx = effective_segments_idx_map[segment_idx] original_byte = sorted_segments[compressed_idx] _debug_prompt_probs = batched_prompt_probs[compressed_idx] _debug_padded_bits = batched_num_padded_bits[compressed_idx] _debug_lengths = batched_lengths[compressed_idx] _debug_stop_step = batched_stop_steps[compressed_idx] if _debug_prompt_probs is None: original_bytes[sample_idx].append(original_byte) continue processed_compressed_byte = processed_batched_compressed_bytes[compressed_idx] # de-postprocess the compressed byte if escape_first_byte: _debug_escaped_compressed_byte = processed_compressed_byte[1:] else: _debug_escaped_compressed_byte = processed_compressed_byte if _debug_stop_step == -1 or _debug_stop_step >= len(original_byte): _debug_compressed_byte = _debug_escaped_compressed_byte _debug_raw_remaining_bytes = None raw_bytes_len = None else: raw_bytes_len = len(original_byte[_debug_stop_step:]) _debug_compressed_byte = _debug_escaped_compressed_byte[:-raw_bytes_len] _debug_raw_remaining_bytes = _debug_escaped_compressed_byte[-raw_bytes_len:] _debug_compressed_byte = [b - COMPRESSION_OFFSET for b in _debug_compressed_byte] print(f"##### _debug_pdfs is {_debug_prompt_probs.shape}") print(f"##### _debug_padded is {_debug_padded_bits}") print(f"##### _debug_compressed is {_debug_compressed_byte}") print(f"##### _debug_lengths is {_debug_lengths}") print(f"##### _debug_stop_step is {_debug_stop_step}") print(f"##### _debug_raw_remaining_bytes is {_debug_raw_remaining_bytes}") print(f"##### raw_bytes_len is {raw_bytes_len}") print(f"##### original_byte len is {len(original_byte)}") decoded = encoder.batched_decode( _debug_prompt_probs.unsqueeze(0), [_debug_compressed_byte], [_debug_padded_bits], _debug_lengths.unsqueeze(0) )[0, :_debug_lengths.item()].cpu().tolist() print(f"##### AC decoded is {decoded}") if escape_first_byte: decoded = processed_compressed_byte[0:1] + decoded if _debug_stop_step < (_debug_lengths.item() + 1): decoded = decoded[:_debug_stop_step] else: if _debug_stop_step < _debug_lengths.item(): decoded = decoded[:_debug_stop_step] print(f"##### escape_first_byte decoded is {decoded}") if _debug_raw_remaining_bytes: decoded = decoded + _debug_raw_remaining_bytes print(f"##### decoded is {decoded}") print(f"##### original_byte is {list(original_byte)}") assert bytes(decoded) == original_byte, "roundtrip encoding/decoding failed \n{} and \n{}".format(bytes(decoded), original_byte) else: raw_idx = ineffective_segments_idx_map[segment_idx] original_byte = raw_segments[raw_idx] original_bytes[sample_idx].append(original_byte) # --- 关键:内部自验证测试 (仅在 debug 模式下运行) --- if debug: logger.info("Running internal self-verification test...") for i in range(B): item = batch[i] # 重新获取原始分段信息 original_segments = unpack_windows(item["text"].encode('utf-8'), item["windows_starts_lens_b64"]) generated_lens = pseudo_lens_per_segment[i] generated_pseudo_list = compressed_bytes[i] # 测试 1: 元数据列表的长度必须和原始分段数量一致 assert len(original_segments) == len(generated_lens), \ f"Metadata length mismatch for sample {i}: segments={len(original_segments)}, lens={len(generated_lens)}" # 测试 2: 使用元数据“走查”一遍生成的伪字节流 test_ptr = 0 for j in range(len(original_segments)): raw_chunk, indicator = original_segments[j] segment_len = generated_lens[j] pseudo_slice = generated_pseudo_list[test_ptr : test_ptr + segment_len] # 测试 2a: 对于“洞”,内容必须完全一致 if indicator == 0: assert list(raw_chunk) == pseudo_slice, \ f"Hole content mismatch for sample {i}, segment {j}" # 移动指针 test_ptr += segment_len # 测试 3: 所有分段长度加起来必须等于总伪字节流长度 assert test_ptr == len(generated_pseudo_list), \ f"Total length mismatch for sample {i}: ptr_sum={test_ptr}, total_len={len(generated_pseudo_list)}" logger.info("✓ Internal self-verification test passed for all samples in the batch!") # --- 自验证测试结束 --- if debug: assert len(compressed_bytes) == len(batch) for sample_idx in range(len(batch)): assert b"".join(original_bytes[sample_idx]) == batch[sample_idx]["text"].encode('utf-8'), ( "Assembled original bytes does not match the original batch: \n{} and \n{}".format( b"".join(original_bytes[sample_idx]), batch[sample_idx]["text"].encode('utf-8') ) ) # window_size_stats = collect_window_size_statistics(original_bytes) # logger.info(f"Window size stats: {window_size_stats}") # logger.info(f"original_bytes: {original_bytes}") # logger.info(f"Finish compressing, Avg compress ratio is ..: {np.mean(compression_ratios):.4f}") logger.info(f"Example compressed bytes: {compressed_bytes[0]}") write_results = [] ac_key = f"m1_ac_ow{output_window_size}_escapefb-{escape_first_byte}_iterative-{iterative_compress}_forcepadding-{force_padding_to_threshold}" for item, compressed_bytes_item in zip(batch, compressed_bytes): item = batch[i] compressed = pseudo_to_packed_bytes(compressed_bytes_item) result = { **item, ac_key: base64.b64encode(compressed).decode("ascii"), "pseudo_lens_per_segment": pseudo_lens_per_segment[i] } if debug: unpacked = packed_bytes_to_pseudo(compressed) assert unpacked == compressed_bytes_item, "Unpacked does not match compressed bytes item: \n{} and \n{}".format(unpacked, compressed_bytes_item) logger.info("✓ pseudo-bytes-enc-dec round-trip passes") write_results.append(result) orig_total_bytes = sum([len(data["text"].encode('utf-8')) for data in batch]) compressed_total_bytes = sum([len(data) for data in compressed_bytes]) compression_ratio = orig_total_bytes / compressed_total_bytes if compressed_total_bytes > 0 else 0 logger.info(f"[DEBUG] original total bytes: {orig_total_bytes}, compressed total bytes: {compressed_total_bytes}, compression rate : {compression_ratio:.3f}") return write_results def iterative_compress_ac( batch_windows: List[List[int]], predict_fn: Callable, first_byte_prob: torch.Tensor, output_window_size: int, force_padding_to_threshold: bool, max_m1_batch_size: int = 4096, debug: bool = False, ) -> List[bytes]: """ Buffer-based compression pipeline that reads max_window_size from each file, performs batched compression, advances positions based on stop_steps, and repeats. """ if debug: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() torch.cuda.synchronize() print("[Debug CUDA] time start", flush=True) original_total_bytes = sum([len(data) for data in batch_windows]) print(f"[Debug] BufferBased-> Original total bytes: {original_total_bytes}", flush=True) print(f"[Debug] BufferBased-> Batch size: {len(batch_windows)}", flush=True) B = len(batch_windows) # Initialize buffers and positions for each file window_positions = [0] * B output_compressed_bytes = [[] for _ in range(B)] windows_done = [False] * B if debug: output_padded_bits = [[] for _ in range(B)] output_prompt_probs = [[] for _ in range(B)] output_lengths = [[] for _ in range(B)] iter_step = 0 while not all(windows_done): iter_step += 1 # Step 1: Read max_window_size bytes from each file to buffer current_windows = [] active_file_indices = [] for i in range(B): if windows_done[i]: continue # Read up to max_window_size bytes from current position start_pos = window_positions[i] end_pos = len(batch_windows[i]) if start_pos >= len(batch_windows[i]) - MINIMUM_SEGMENT_SIZE: windows_done[i] = True continue window_bytes = batch_windows[i][start_pos:end_pos] current_windows.append(window_bytes) active_file_indices.append(i) if not current_windows: break start_idx = 0 batched_windows_np = [np.array(data, dtype=np.uint8) for data in current_windows] current_windows_count = len(batched_windows_np) encoder = CPUArithmeticEncoder( base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION ) batched_compressed_bytes = [] batched_stop_steps = [] if debug: batched_num_padded_bits = [] batched_pdfs = [] _temp_cdf_ends = [] _temp_lengths = [] while start_idx < current_windows_count: # Use the new helper function to find the exact range for the next safe batch start_idx, end_idx = find_next_batch_range(batched_windows_np, start_idx, max_m1_batch_size, get_batch_size_for_length) windows_np_chunked = batched_windows_np[start_idx:end_idx] padded_batched_windows, lengths = pad_batch(windows_np_chunked) # NOTE: switch to GPU padded_batched_windows = padded_batched_windows.cuda() with torch.no_grad(): prompt_probs = predict_fn(padded_batched_windows) prompt_probs = torch.cat( [ first_byte_prob.expand(prompt_probs.shape[0], -1, -1), prompt_probs[:, :-1, ...] ], dim=1 ) prompt_probs = utils.batched_normalize_pdf_for_arithmetic_coding(prompt_probs) cdfs_gpu = _pdf_to_cdf(prompt_probs) cdf_low = cdfs_gpu.gather(2, padded_batched_windows.unsqueeze(-1)).squeeze(-1) cdf_high = cdfs_gpu.gather(2, (padded_batched_windows + 1).unsqueeze(-1)).squeeze(-1) cdf_ends = torch.stack([cdf_low, cdf_high], dim=-1) start_idx = end_idx _temp_cdf_ends.append(cdf_ends.cpu()) _temp_lengths.append(lengths) if debug: batched_pdfs.extend(prompt_probs.cpu()) for cdf_ends, lengths in zip(_temp_cdf_ends, _temp_lengths): chunked_compressed_bytes, chunked_stop_steps, chunked_num_padded_bits = encoder.incremental_batched_encode( # NOTE: switch to CPU cdf_ends, ALPHABET_SIZE, lengths, bit_threshold=output_window_size, force_padding_to_threshold=force_padding_to_threshold, return_num_padded_bits=True ) batched_compressed_bytes.extend(chunked_compressed_bytes) batched_stop_steps.extend(chunked_stop_steps) if debug: batched_num_padded_bits.extend(chunked_num_padded_bits) # NOTE: debug this function # Step 3: Process results and advance positions for window_idx, file_idx in enumerate(active_file_indices): compressed_bytes = batched_compressed_bytes[window_idx] stop_step = batched_stop_steps[window_idx] # Add compressed bytes to output output_compressed_bytes[file_idx].append(compressed_bytes) if debug: output_padded_bits[file_idx].append(batched_num_padded_bits[window_idx]) output_prompt_probs[file_idx].append(batched_pdfs[window_idx]) length = torch.tensor([stop_step], dtype=torch.long, device=batched_pdfs[window_idx].device) output_lengths[file_idx].append(length) window_positions[file_idx] += stop_step if window_positions[file_idx] >= len(batch_windows[file_idx]) - MINIMUM_SEGMENT_SIZE: windows_done[file_idx] = True # Concatenate all compressed bytes for each file final_compressed = [] for i in range(B): _original_byte_window = batch_windows[i] _stopped_position = window_positions[i] _byte_array = b''.join(output_compressed_bytes[i]) offset_compressed_bytes = [b + COMPRESSION_OFFSET for b in list(_byte_array)] if _stopped_position < len(_original_byte_window): raw_leftover_bytes = _original_byte_window[_stopped_position:] offset_compressed_bytes = offset_compressed_bytes + list(raw_leftover_bytes) final_compressed.append(offset_compressed_bytes) if debug: end_event.record() torch.cuda.synchronize() elapsed_time = start_event.elapsed_time(end_event) print(f"[Debug CUDA] Elapsed time: {elapsed_time:.3f}ms", flush=True) encoder = CPUArithmeticEncoder( base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION ) for ( output_compressed_bytes_item, output_padded_bits_item, output_prompt_probs_item, output_lengths_item, batch_windows_item, stopped_position ) in zip(output_compressed_bytes, output_padded_bits, output_prompt_probs, output_lengths, batch_windows, window_positions): original_bytes = batch_windows_item[:stopped_position] decoded_bytes = [] for ( _debug_compressed, _debug_padded, _debug_pdfs, _debug_lengths ) in zip( output_compressed_bytes_item, output_padded_bits_item, output_prompt_probs_item, output_lengths_item ): print(f"##### _debug_pdfs is {_debug_pdfs.shape}") print(f"##### _debug_padded is {_debug_padded}") print(f"##### _debug_compressed is {_debug_compressed}") print(f"##### _debug_lengths is {_debug_lengths}") print(f"##### original_bytes is {original_bytes}") decoded = encoder.batched_decode(_debug_pdfs.unsqueeze(0), [_debug_compressed], [_debug_padded], _debug_lengths) decoded_bytes += decoded[0, :_debug_lengths.item()].cpu().tolist() print(f"##### decoded is {bytes(decoded[0, :_debug_lengths.item()].cpu().tolist())}") assert decoded_bytes == original_bytes, "roundtrip encoding/decoding failed \n{} and \n{}".format(decoded_bytes, original_bytes) return final_compressed def writer_consumer( write_queue, output_file, buffer_size=100, debug=False, output_window_size=16, escape_first_byte=False, compression_cache_size=819200, iterative_compress=False, force_padding_to_threshold=False, entropy_model_path=None, firstbyte_prob_path=None, num_workers=None, ): """ Writer consumer: reads compressed results from write_queue and writes to file. Maintains its own buffer and writes when buffer is full or receives sentinel. """ if num_workers is not None: num_threads = torch.get_num_threads() # new_num_threads = max(1, int(num_threads // 2 // num_workers)) # TODO: HACK new_num_threads = 1 # max(1, int(num_threads // (num_workers + 1))) torch.set_num_threads(new_num_threads) logger.info(f"[Debug] Set num threads to {new_num_threads} for writer process {mp.current_process().name}") write_buf = [] # Initialize compression cache for this worker compression_cache = SegmentCache(cache_size=compression_cache_size, cache_desc="Compression") if compression_cache_size > 0 else None if iterative_compress: model, _, _ = load_m1_model_and_tokenizer(entropy_model_path) predict_fn = batched_m1_compress_predict_fn(model) if firstbyte_prob_path is not None: with open(firstbyte_prob_path, 'r', encoding='utf-8') as f: first_byte_prob = json.load(f) print(first_byte_prob) first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) else: first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE # NOTE: use CPU # model = load_m1_model_cpu(entropy_model_path) # predict_fn = batched_m1_compress_predict_fn(model) # if firstbyte_prob_path is not None: # with open(firstbyte_prob_path, 'r', encoding='utf-8') as f: # first_byte_prob = json.load(f) # print(first_byte_prob) # first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cpu").unsqueeze(0).unsqueeze(0) # else: # first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cpu") / ALPHABET_SIZE else: predict_fn = None first_byte_prob = None try: with open(output_file, 'w', encoding='utf-8') as f: while True: args = write_queue.get() if args is None: break ( batch, sorted_segments, raw_segments, effective_segments_idx_map, ineffective_segments_idx_map, sample_idx_to_list_segment_idx, batched_cdf_ends, batched_pdfs, ) = args write_results = segment_compression_fn( batch, sorted_segments, raw_segments, effective_segments_idx_map, ineffective_segments_idx_map, sample_idx_to_list_segment_idx, batched_cdf_ends, batched_pdfs, output_window_size, escape_first_byte, iterative_compress, force_padding_to_threshold, predict_fn, first_byte_prob, debug=debug, compression_cache=compression_cache ) write_buf.extend(write_results) # Write buffer when it's full if len(write_buf) >= buffer_size: logger.info(f"Writer: Dumping buffer of {len(write_buf)} items to {output_file}") for buffered_item in write_buf: f.write(json.dumps(buffered_item) + '\n') f.flush() write_buf = [] # Clean up GPU memory gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Write remaining items in buffer if write_buf: logger.info(f"Writer: Dumping remaining {len(write_buf)} items to {output_file}") for buffered_item in write_buf: f.write(json.dumps(buffered_item) + '\n') f.flush() except Exception as e: logger.error(f"Writer process error: {e}") raise def merge_output_files(output_file, writer_output_files): """Merge all writer output files into a single file""" logger.info(f"Merging {len(writer_output_files)} writer files into {output_file}") with open(output_file, 'w', encoding='utf-8') as outf: for writer_output_file in writer_output_files: if writer_output_file.exists(): with open(writer_output_file, 'r', encoding='utf-8') as inf: for line in inf: outf.write(line) # Optionally remove the individual writer files writer_output_file.unlink() logger.info(f"Merged and removed writer file: {writer_output_file}") logger.info(f"Merged output written to: {output_file}") return output_file def shutdown_writers(write_queue, writer_processes): """Send shutdown signals to shared queue and wait for all writers to complete""" # Send one sentinel per writer to ensure all writers get the shutdown signal for i in range(len(writer_processes)): write_queue.put(None) logger.info(f"Sent shutdown signal {i+1}/{len(writer_processes)}") # Wait for all writers to complete for i, writer_process in enumerate(writer_processes): writer_process.join() if writer_process.exitcode != 0: logger.error(f"Writer process {i} failed with exit code: {writer_process.exitcode}") else: logger.info(f"Writer process {i} completed successfully") def main(): # Set up argument parser parser = argparse.ArgumentParser(description='Process JSONL files using M1 arithmetic compression with buffer-based approach') parser.add_argument('--input_file', type=str, required=True, help='Directory containing input JSONL files') parser.add_argument('--output_dir', type=str, required=True, help='Directory to write compressed results') parser.add_argument('--entropy_model_path', type=str, required=True, help='Path to the M1 model checkpoint') parser.add_argument('--compression_model_path', type=str, required=True, help='Path to the M1 model checkpoint') parser.add_argument('--data_batch_size', type=int, default=512, help='Size of batches for processing (default: 512)') parser.add_argument('--output_window_size', type=int, default=16, help='Size of window for compression (default: 16)') parser.add_argument('--escape_first_byte', action='store_true', default=False, help='Escape the first byte of each window (default: False)') parser.add_argument('--max_window_size', type=int, default=1024, help='Maximum window size for reading from each file (default: 1024)') parser.add_argument('--max_entropy_batch_size', type=int, default=4096, help='Size of max batch for compression (default: 4096)') parser.add_argument('--max_compression_batch_size', type=int, default=4096, help='Size of max batch for compression (default: 4096)') parser.add_argument('--chunk_size', type=int, default=512, help='Size of chunk for compression (default: 512)') parser.add_argument('--base_global_quantile', type=float, default=0.9, help='Base global quantile for compression (default: 0.9)') parser.add_argument('--base_monotonic_quantile', type=float, default=0.9, help='Base monotonic quantile for compression (default: 0.9)') parser.add_argument('--debug', action='store_true', default=False, help='Debug mode (default: False)') parser.add_argument('--firstbyte_prob_path', type=str, default=None, help='Probability path for the first word of each window (default : None)') parser.add_argument('--num_workers', type=int, default=1, help='Number of workers for CPU jobs (default: 1)') parser.add_argument('--process_id', type=int, default=0, help='Process ID for distributed processing (default: 0)') parser.add_argument('--num_processes', type=int, default=1, help='Number of processes for distributed processing (default: 1)') parser.add_argument('--merge_output', action='store_true', default=False, help='Merge all writer output files into a single file (default: False)') parser.add_argument('--prediction_cache_size', type=int, default=81920, help='Size of prediction cache per process (default: 819200)') parser.add_argument('--compression_cache_size', type=int, default=81920, help='Size of compression cache per worker (default: 819200)') parser.add_argument('--disable_caching', action='store_true', default=False, help='Disable both prediction and compression caching (default: False)') parser.add_argument('--iterative_compress', action='store_true', default=False, help='Iterative compression (default: False)') parser.add_argument('--force_padding_to_threshold', action='store_true', default=False, help='Force padding to threshold (default: False)') args = parser.parse_args() num_threads = torch.get_num_threads() # new_num_threads = max(1, int(num_threads // 2)) # TODO: HACK new_num_threads = 2 # max(1, int(num_threads // (args.num_workers + 1))) torch.set_num_threads(new_num_threads) logger.info(f"[Debug] Set num threads to {new_num_threads} for main process") mp.set_start_method('spawn', force=True) dump_freq = 100 # Create output directory if it doesn't exist output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Load model and tokenizer model, _, _ = load_m1_model_and_tokenizer(args.entropy_model_path) batched_predict_fn = batched_m1_compress_predict_fn(model) if args.firstbyte_prob_path is not None: with open(args.firstbyte_prob_path, 'r', encoding='utf-8') as f: first_byte_prob = json.load(f) print(first_byte_prob) first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) else: first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE # Create dataset and dataloader dataset = InterleavedJsonlDataset( file_path=args.input_file, rank=args.process_id, world_size=args.num_processes, ) dataloader = DataLoader( dataset, batch_size=args.data_batch_size, shuffle=False, collate_fn=lambda x: x ) input_file = Path(args.input_file) logger.info(f"Processing file: {input_file}") output_file = output_dir / f"{input_file.stem}_out_{args.process_id}.jsonl" logger.info("Data loaded. Start processing...") # Initialize prediction cache for this process prediction_cache = None if not args.disable_caching and args.prediction_cache_size > 0: prediction_cache = SegmentCache(cache_size=args.prediction_cache_size, cache_desc="Prediction") logger.info(f"Prediction cache enabled with size: {args.prediction_cache_size}") else: logger.info("Prediction cache disabled") compression_cache_size = 0 if args.disable_caching else args.compression_cache_size if compression_cache_size > 0: logger.info(f"Compression cache enabled with size: {compression_cache_size} per worker") else: logger.info("Compression cache disabled") write_queue = mp.Queue(maxsize=200) writer_processes = [] writer_output_files = [] for i in range(args.num_workers): # Create unique output file for each writer output_path = Path(output_file) writer_output_file = output_path.parent / f"{output_path.stem}_writer_{i}.jsonl" writer_output_files.append(writer_output_file) writer_process = mp.Process( target=writer_consumer, args=( write_queue, writer_output_file, dump_freq, args.debug, args.output_window_size, args.escape_first_byte, compression_cache_size, args.iterative_compress, args.force_padding_to_threshold, args.entropy_model_path, args.firstbyte_prob_path, args.num_workers, ) ) writer_processes.append(writer_process) writer_process.start() logger.info(f"Started writer process {i} for output file: {writer_output_file}") try: # Process each batch for batch_idx, batch in enumerate(dataloader): pred_results = segment_prediction_fn( batch, max_m1_batch_size=args.max_compression_batch_size, batched_predict_fn=batched_predict_fn, first_byte_prob=first_byte_prob, debug=args.debug, prediction_cache=prediction_cache ) logger.info(f"Processed batch {batch_idx}") write_queue.put(pred_results) if batch_idx % GC_FREQ == 0: # Clean up GPU memory gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Signal completion to all writer processes shutdown_writers(write_queue, writer_processes) except Exception as e: logger.error(f"Error during processing: {e}") # Try to terminate writer processes cleanly try: shutdown_writers(write_queue, writer_processes) except: pass raise if args.merge_output: final_output_file = merge_output_files(output_file, writer_output_files) logger.info(f"Completed processing successfully, merged output written to {final_output_file}") else: logger.info(f"Completed processing successfully, outputs written to {args.num_workers} separate files") if __name__ == "__main__": main()