""" new pipeline: 1.prepare data: from original batch to unpack,sort,rerank windows and reconstruct indexes 2.compress_segment_xxx(): use different compression algorithm 3.reconstruct_result: reconstruct the window from compressed results and idx 4.main()--- use prepare_segments and compress_seg_xx to produce and pass them to consumers 5.write_consumer: get from compressed data, reconstruct and write the result """ import torch import torch.nn.functional as F from torch.utils.data import IterableDataset, Dataset, DataLoader import json import numpy as np from pathlib import Path from typing import Iterator, List, Dict, Any, Callable, Tuple, Optional import logging import argparse import base64 import time import math import gc from collections import defaultdict, Counter,deque from m1_compression.utils import * from m1_compression.compressor import ( load_m1_model_and_tokenizer, ALPHABET_SIZE, ) import multiprocessing as mp from m1_compression.enumerative_coder_simple import SimpleAdaptiveRankCodec from m1_compression.batched_arithmetic_coder import BatchedArithmeticEncoder from m1_compression.hybrid_arithmetic_coder import HybridArithmeticEncoder from m1_compression.compressor import ( load_m1_model_and_tokenizer, ALPHABET_SIZE, ARITHMETIC_CODER_BASE, ARITHMETIC_CODER_PRECISION, ) from offline_entropy_window_split import unpack_windows logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger() def pseudo_to_packed_bytes(lst: list[int]) -> bytes: out = bytearray() acc = bits = 0 for v in lst: acc |= (v & 0x1FF) << bits bits += 9 while bits >= 8: out.append(acc & 0xFF) acc >>= 8 bits -= 8 if bits: # flush tail out.append(acc) return bytes(out) def packed_bytes_to_pseudo(b: bytes) -> list[int]: out, acc, bits = [], 0, 0 for byte in b: acc |= byte << bits bits += 8 while bits >= 9: out.append(acc & 0x1FF) acc >>= 9 bits -= 9 return out def calculate_compression_ratio(original_bytes: List[bytes], compressed_segments: List[bytes]) -> float: if not compressed_segments or len(original_bytes) == 0: return 1.0 total_compressed_length = sum(len(compressed_seg) for compressed_seg in compressed_segments) ratio = total_compressed_length / sum(len(orig_seg) for orig_seg in original_bytes) if ratio > 2.0: logger.warning(f"Unusual compression ratio: {ratio:.4f} (compressed larger than original)") return ratio def collect_window_size_statistics(segmented_results: List[List[bytes]]) -> Dict[int, int]: window_size_counts = Counter() for segments in segmented_results: for segment in segments: window_size = len(segment) window_size_counts[window_size] += 1 return dict(window_size_counts) # def pad_batch(batch: List[bytes]): # batch_tensors = [torch.tensor(data, dtype=torch.int64) for data in batch] # lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64) # padded_batch = torch.nn.utils.rnn.pad_sequence( # batch_tensors, # batch_first=True, # padding_value=0, # padding_side="right" # ) # return padded_batch, lengths def pad_batch(batch: List[bytes]): # fix 1: transfer bytes to (list(data)) batch_tensors = [torch.tensor(list(data), dtype=torch.int64) for data in batch] lengths = torch.tensor([len(data) for data in batch], dtype=torch.int64) # fix 2: remove torch.nn.utils.rnn.pad_sequence 不支持的 'padding_side' 参数。 # right padding(对于 batch_first=True) padded_batch = torch.nn.utils.rnn.pad_sequence( batch_tensors, batch_first=True, padding_value=0 ) return padded_batch, lengths # control long seq with smaller batch def get_batch_size_for_length(window_len, max_batch_size): """ Determines the batch size for a given window length. VERY AGGRESSIVE reduction for long sequences to prevent OOM. """ # max_batch_size only for short len if window_len <= 128: return max_batch_size if window_len <= 256: return max(max_batch_size // 4, 1) if window_len <= 512: return max(max_batch_size // 16, 1) if window_len <= 1024: return max(max_batch_size // 64, 1) if window_len <= 2048: return 2 # 对于 1k-2k 的序列,最多处理 2 个 # 对于超过 2048 的超长序列,一次只处理 1 个 return 1 # def get_batch_size_for_length(window_len, max_batch_size): # """Determines the batch size for a given window length.""" # BATCH_SIZE_TIERS = { # 128: max_batch_size, # 512: max(max_batch_size // 64, 1), # 1024: max(max_batch_size // 128, 1), # 2048: max(max_batch_size // 256, 1), # } # for max_len, batch_size in BATCH_SIZE_TIERS.items(): # if window_len <= max_len: # return batch_size # return 1 def find_next_batch_range(all_windows, start_idx, max_m1_batch_size): M = len(all_windows) if start_idx >= M: return start_idx, start_idx first_window_len = len(all_windows[start_idx]) base_batch_size = get_batch_size_for_length(first_window_len, max_m1_batch_size) low = start_idx high = min(start_idx + base_batch_size, M) high_batch_size = get_batch_size_for_length(len(all_windows[high - 1]), max_m1_batch_size) if high_batch_size == base_batch_size: return start_idx, high search_low = low search_high = high while search_low < search_high: mid = search_low + (search_high - search_low) // 2 mid_window_len = len(all_windows[mid]) if get_batch_size_for_length(mid_window_len, max_m1_batch_size) == base_batch_size: # This window is valid. The partition point must be to the right of it. # So, we continue searching in the range [mid + 1, high). search_low = mid + 1 else: # This window is NOT valid. It might be the partition point itself, # or the point is to its left. # So, we continue searching in the range [low, mid). search_high = mid end_idx = search_low if end_idx == start_idx: return start_idx, start_idx + 1 else: return start_idx, end_idx class JsonlShardedDataset(Dataset): def __init__( self, file_path: str, current_proc_rank: int = 0, total_procs: int = 1, ) -> None: assert 0 <= current_proc_rank < total_procs, "rank must be in [0, world_size)" self.current_proc_rank = current_proc_rank self.total_procs = total_procs # -- load the whole file once (fast for < few-GB files) ------------- with open(file_path, "r", encoding="utf-8") as f: full_data: List[Dict[str, Any]] = [json.loads(line) for line in f] # -- pick the slice that belongs to *this* process ------------------ total = len(full_data) per_proc = math.ceil(total / total_procs) start = current_proc_rank * per_proc end = min(start + per_proc, total) self.data = full_data[start:end] def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict[str, Any]: return self.data[idx] class InterleavedJsonlDataset(IterableDataset): """ An iterable-style dataset for reading a large JSONL file using an interleaving/striding pattern, without yielding state information. This is designed for multi-process data loading. Each process reads the entire file but only processes lines that match its rank (offset). For `N` total processes (world_size), process `r` (rank) will read lines r, r+N, r+2N, ... (0-indexed). This method ensures an even distribution of lines across processes. Args: file_path (str): Path to the JSONL file. rank (int): The rank of the current process, used as the offset. world_size (int): The total number of processes, used as the block_size/stride. """ def __init__( self, file_path: str, rank: int, world_size: int, ) -> None: super().__init__() if not (0 <= rank < world_size): raise ValueError(f"Rank must be in [0, {world_size-1}], but got {rank}") self.file_path = file_path self.offset = rank self.block_size = world_size def __iter__(self) -> Iterator[Dict[str, Any]]: """ The iterator method that yields the parsed JSON data for the assigned lines. """ try: with open(self.file_path, "r", encoding="utf-8") as f: # We use a simple line counter to determine which lines to process. # The line_number is 0-indexed. for line_number, line in enumerate(f): # Check if the current line number belongs to this process if (line_number % self.block_size) == self.offset: try: # Yield the parsed JSON object yield json.loads(line) except json.JSONDecodeError: # This line is malformed. We can either raise an error # or, more robustly, just print a warning and skip it. print(f"Warning: Rank {self.offset} could not decode JSON on line ~{line_number+1}. Skipping.") continue except Exception as e: print(f"Error in worker {self.offset}: {e}") raise def batched_m1_compress_predict_fn(model): def predict_fn(input_tensor: torch.Tensor, **kwargs) -> torch.Tensor: if input_tensor.dim() == 1: input_tensor = input_tensor.unsqueeze(0) with torch.no_grad(): logits = model(input_tensor, **kwargs) logits = logits[..., :256] logits = logits.float() assert torch.isfinite(logits).all(), "Logits contain NaN or Inf values." probs = torch.softmax(logits, dim=-1) return probs return predict_fn class CachingCompressorWrapper: def __init__( self, base_compression_fn: Callable, # add cache on base compressor, make module seperate cache_size: int = 819200, # default a big cache size cache_policy: str = 'fifo' # default fifo ): if cache_policy not in ['fifo']: raise ValueError(f"no caching policy: {cache_policy}.") self.base_compression_fn = base_compression_fn self.cache_size = cache_size self.cache_policy = cache_policy # self.cache 存储: raw_bytes -> compressed_pseudo_bytes (List[int]) self.cache: Dict[bytes, List[int]] = {} self.fifo_queue: deque[bytes] = deque() logger.info(f"Create CachingCompressorWrapper '{self.base_compression_fn.__name__}'," f"Cache size: {self.cache_size}, policy: {self.cache_policy}") def compress( self, sorted_segments: List[bytes], *args, **kwargs ) -> List[List[int]]: """ compressors with cache """ if not sorted_segments: return [] M = len(sorted_segments) # 1. unique data and indxes segment_to_indices = defaultdict(list) for i, seg in enumerate(sorted_segments): segment_to_indices[seg].append(i) unique_segments = list(segment_to_indices.keys()) # 2. check in cache or not misses_data = [] results_for_uniques: Dict[bytes, List[int]] = {} ## for each unique segment, check in cache or not for segment in unique_segments: if segment in self.cache: results_for_uniques[segment] = self.cache[segment] else: misses_data.append(segment) hit_count = len(unique_segments) - len(misses_data) logger.info(f"Cache checking: {len(unique_segments)} segments, " f"Get {hit_count}, No caching {len(misses_data)} ") # 3. compress non-caching segments if misses_data: ## keep use original one newly_compressed = self.base_compression_fn( misses_data, *args, **kwargs ) # 4.refresh cache and fill in result for i in range(len(misses_data)): raw_segment = misses_data[i] compressed_result = newly_compressed[i] results_for_uniques[raw_segment] = compressed_result # refresh cache if self.cache_size > 0 and raw_segment not in self.cache: if len(self.cache) >= self.cache_size: if self.cache_policy == 'fifo': oldest_key = self.fifo_queue.popleft() del self.cache[oldest_key] self.cache[raw_segment] = compressed_result self.fifo_queue.append(raw_segment) # 5. rebuild all results all_compressed_results = [None] * M for seg, indices in segment_to_indices.items(): result = results_for_uniques[seg] for original_index in indices: all_compressed_results[original_index] = result return all_compressed_results def __call__(self, *args, **kwargs): return self.compress(*args, **kwargs) def compress_segments_hybrid_arithmetic( sorted_segments: List[bytes], batched_predict_fn: Callable, first_byte_prob: torch.Tensor, max_m1_batch_size: int=4096, debug: bool = True ) -> List[List[int]]: """ 这个函数现在只处理它收到的数据,不需要关心缓存或去重。 这些逻辑已经被外层的 CachingCompressorWrapper 处理了。 """ M = len(sorted_segments) if M == 0: return [] # 注意:这里的 sorted_segments 已经是去重后、未命中缓存的数据了。 logger.info(f"Hybrid AC 核心: 正在处理 {M} 个不重复、未命中缓存的段。") segment_to_compressed = {} ENCODING_BATCH_SIZE = 128 encoder = HybridArithmeticEncoder( batched_predict_fn=batched_predict_fn, first_byte_prob=first_byte_prob ) all_compressed_results = [] for i in range(0, M, ENCODING_BATCH_SIZE): batch_start = i batch_end = min(i + ENCODING_BATCH_SIZE, M) batch_segments = sorted_segments[batch_start:batch_end] try: codes = encoder.batched_encode(batch_segments, return_num_padded_bits=False) # 对比压缩效果 for seg, code in zip(batch_segments, codes): if len(code) < len(seg): all_compressed_results.append(list(code)) else: all_compressed_results.append(list(seg)) # 压缩效果不好,用原始数据 except Exception as e: logger.warning(f"Hybrid AC 核心: 批次 {batch_start}-{batch_end} 编码失败: {e}. 该批次使用原始字节。") for seg in batch_segments: all_compressed_results.append(list(seg)) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return all_compressed_results def prepare_segments(batch: List[Dict[str, Any]])->Dict[str,Any]: """ remove the unpack,sort and rerank methods from segment_pre.. address unpack and judge the compressiable simultaneously """ all_segments = [] is_compressible_indicator = [] sample_idx_to_list_segment_idx = defaultdict(list) segment_idx = 0 for sample_idx, item in enumerate(batch): assert "windows_starts_lens_b64" in item, "windows_starts_lens_b64 must be in item" sample_bytes = item["text"].encode('utf-8') byte_windows = unpack_windows(sample_bytes, item["windows_starts_lens_b64"]) for segment,indicator in byte_windows: all_segments.append(segment) is_compressible = (indicator == 1 and len(segment) > 3) is_compressible_indicator.append(is_compressible) # record mapping for sample--segment sample_idx_to_list_segment_idx[sample_idx].append(segment_idx) segment_idx += 1 effective_segments = {} # {index: window} raw_segments_map = {} # {index: window} for i, (segment, is_comp) in enumerate(zip(all_segments, is_compressible_indicator)): if is_comp: effective_segments[i] = segment else: raw_segments_map[i] = segment # rerank by length, reduce padding sorted_indices_to_compress = sorted( effective_segments.keys(), key=lambda idx: len(effective_segments[idx]) ) sorted_segments_to_compress = [effective_segments[idx] for idx in sorted_indices_to_compress] # create reconstruct information -- # 1. mapping sample and segment idx -- in one time unpack # 2. mapping old and new idx sorted_to_original_idx_map = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices_to_compress)} reconstruction_info = { "sample_idx_to_list_segment_idx": sample_idx_to_list_segment_idx, "sorted_to_original_idx_map": sorted_to_original_idx_map, "raw_segments_map": raw_segments_map, "total_segments": len(all_segments), "batch_meta": batch, # meta data "effective_segments_map": effective_segments } return { "sorted_segments_to_compress": sorted_segments_to_compress, "reconstruction_info": reconstruction_info, } # wrap it as the first processing function def simple_rle_topk_compression( batch: List[bytes], predict_fn: Callable, first_byte_prob: torch.Tensor, max_m1_batch_size: int = 4096, debug: bool = True, ): """use language model to compress, return compressed bytes and padded bits Args: sliding_windows: List of byte sequences to compress predict_fn: Function that predicts next token probabilities return_num_padded_bits: Whether to return number of padded bits profile: Whether to print timing information for each major step """ if debug: start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() torch.cuda.synchronize()## make sure all previous events are completed print("[Debug CUDA] time start", flush=True) assert first_byte_prob.shape == (1, 1, ALPHABET_SIZE), "first_byte_prob must be of shape (1, 1, ALPHABET_SIZE)" # refactored batch output window AC: #### 1. pad the current batch batched_windows_np = [np.frombuffer(bytes(data), dtype=np.uint8) for data in batch] M = len(batched_windows_np) batched_repeat_probs = [] batched_ranks = [] batched_lengths = [] if debug: batched_sorted_indices = [] start_idx = 0 while start_idx < M: # Use the new helper function to find the exact range for the next safe batch start_idx, end_idx = find_next_batch_range(batched_windows_np, start_idx, max_m1_batch_size) windows_np_chunked = batched_windows_np[start_idx:end_idx] padded_batched_windows, lengths = pad_batch(windows_np_chunked) padded_batched_windows, lengths = padded_batched_windows.cuda(), lengths.cuda() prompt_probs = predict_fn(padded_batched_windows) prompt_probs = torch.cat( [ first_byte_prob.expand(prompt_probs.shape[0], -1, -1), prompt_probs[:, :-1, ...] ], dim=1 ) prompt_probs = utils.batched_normalize_pdf_for_arithmetic_coding(prompt_probs) ######## Use BatchArithmeticEncoder to replace address one by one ########### # we calculate two quantiles from prompt_probs # 1. the probability of the next byte # 2. the byte ids of the topk next bytes next_token_probs = torch.gather( prompt_probs, dim=-1, index=padded_batched_windows.unsqueeze(-1) ).squeeze(-1) # [B, L] sorted_indices = torch.argsort(prompt_probs, dim=-1, descending=True) rank_bitvector = padded_batched_windows.unsqueeze(-1) == sorted_indices ranks = torch.argmax(rank_bitvector.float(), dim=-1) # [B, L] start_idx = end_idx batched_repeat_probs.extend(next_token_probs.cpu().numpy().tolist()) batched_ranks.extend(ranks.cpu().numpy().tolist()) batched_lengths.extend(lengths.cpu().numpy().tolist()) if debug: batched_sorted_indices.extend(sorted_indices.cpu().numpy().tolist()) if debug: return batched_repeat_probs, batched_ranks, batched_lengths, batched_sorted_indices else: return batched_repeat_probs, batched_ranks, batched_lengths def compress_segments_rank_based( sorted_segments: List[bytes], batched_predict_fn: Callable, first_byte_prob: torch.Tensor, max_m1_batch_size: int=4096, debug: bool = True ) -> List[List[int]]: """ (SimpleAdaptiveRankCodec)。 decompress GPU probs and CPU compression。 """ # --- GPU Stage : acquire probs and rank --- # use original simple_rle_topk_compression try: gpu_result = simple_rle_topk_compression( sorted_segments, batched_predict_fn, first_byte_prob, max_m1_batch_size=max_m1_batch_size, debug=debug, ) if debug: batched_repeat_probs, batched_ranks, batched_lengths, batched_sorted_indices = gpu_result else: batched_repeat_probs, batched_ranks, batched_lengths = gpu_result batched_sorted_indices = None # --- CPU Stage: encoding one by one --- if len(batched_lengths) != len(sorted_segments): logger.error(f"FATAL: Length mismatch after GPU stage. Expected {len(sorted_segments)}, got {len(batched_lengths)}. Falling back to raw data.") # 如果长度不匹配,说明上游出错了,直接返回原始数据 return [list(seg) for seg in sorted_segments] M = len(batched_lengths) batched_compressed_bytes = [] for i in range(M): lengths = batched_lengths[i] window_bytes = sorted_segments[i] repeat_probs = batched_repeat_probs[i][:lengths] ranks = batched_ranks[i][:lengths] codec = SimpleAdaptiveRankCodec(top_k=4) encoding = codec.encode_window(list(window_bytes), repeat_probs, ranks) compressed_bytes = codec.encoding_to_pseudo_bytes(encoding) #Add: compare compress result if len(compressed_bytes) >= len(window_bytes): # use raw bytes replace batched_compressed_bytes.append(list(window_bytes)) else: # compress successfully batched_compressed_bytes.append(compressed_bytes) if debug: # make sure batched_compressed_bytes[i] is a list if batched_sorted_indices is None or batched_sorted_indices[i] is None: logger.warning(f"Debug mode is on but sorted_indices for segment {i} is None. Skipping decode check.") continue #Add: valid encode and decode rather than pseudo_bytes -> encoding sorted_indices = batched_sorted_indices[i][:lengths] decoded = codec.decode_window(encoding, lengths, sorted_indices) assert bytes(decoded) == window_bytes, "decoded does not match window_bytes: \n{} and \n{}".format(decoded, window_bytes) if i < 10: logger.info(f"Example input window bytes: {window_bytes}") logger.info(f"Example encoding : {encoding}") logger.info(f"Example compressed bytes : {compressed_bytes}") # batched_compressed_bytes.append(compressed_bytes) 这里重复会导致keyerror return batched_compressed_bytes except Exception as e: logger.error(f"Unhandled exception in compress_segments_rank_based: {e}. Falling back to raw data for the entire batch.", exc_info=True) # if any error back to original bytes return [list(seg) for seg in sorted_segments] def compress_segments_arithmetic( sorted_segments: List[bytes], batched_predict_fn: Callable, first_byte_prob: torch.Tensor, max_m1_batch_size: int = 4096, debug: bool = True ) -> List[List[int]]: """ Final robust version for arithmetic compression. This version is inspired by successful production code and is designed to be stable. It compresses unique segments in small, manageable batches and handles failures gracefully. """ device = first_byte_prob.device M = len(sorted_segments) if M == 0: return [] # --- 1. unique --- logger.info(f"Step 1: Identifying unique segments to compress.") # 创建从原始段到其所有出现位置的映射--这里是为unique准备 # mapping original <-> position segment_to_indices = defaultdict(list) for i, seg in enumerate(sorted_segments): segment_to_indices[seg].append(i) # only valid segment unique_segments = [seg for seg in segment_to_indices.keys() if len(seg) > 2] logger.info(f"Found {len(unique_segments)} unique segments (len>2) out of {M} total segments.") # store only compressed result segment_to_compressed = {} # --- 2. safe encoding --- ENCODING_BATCH_SIZE = 128 encoder = BatchedArithmeticEncoder(base=ARITHMETIC_CODER_BASE, precision=ARITHMETIC_CODER_PRECISION) logger.info(f"Step 2: Encoding unique segments in batches of size {ENCODING_BATCH_SIZE}.") for i in range(0, len(unique_segments), ENCODING_BATCH_SIZE): batch_start = i batch_end = min(i + ENCODING_BATCH_SIZE, len(unique_segments)) batch_unique_segments = unique_segments[batch_start:batch_end] # for each small batch try: # prepare for bytes--padding batch_padded_segments, batch_lengths = pad_batch(batch_unique_segments) batch_padded_segments = batch_padded_segments.to(device) batch_lengths = batch_lengths.to(device) # batch_predict_fn with torch.no_grad(): # 净化输入,防止模型侧的错误 safe_padded_segments = batch_padded_segments.clamp(0, ALPHABET_SIZE - 1) probs = batched_predict_fn(safe_padded_segments) # # NOTE : normalize probs to avoid NaN/Inf # if not torch.isfinite(probs).all(): # logger.warning(f"NaN/Inf detected in model probabilities for batch {i//ENCODING_BATCH_SIZE}. Clamping.") # probs = torch.nan_to_num(probs, nan=1e-9, posinf=1.0, neginf=1e-9) # probs = probs / probs.sum(dim=-1, keepdim=True) final_probs = torch.cat([first_byte_prob.expand(probs.shape[0], -1, -1), probs[:, :-1, ...]], dim=1) normalized_probs = utils.batched_normalize_pdf_for_arithmetic_coding(final_probs) if not torch.isfinite(normalized_probs).all(): raise ValueError("NaN or Inf in normalized probabilities after normalization.") # batch_encode codes, _ = encoder.batched_encode( normalized_probs, batch_padded_segments, lengths=batch_lengths, return_num_padded_bits=True ) # store result for seg, code in zip(batch_unique_segments, codes): segment_to_compressed[seg] = list(code) except Exception as e: # fail,fallback to original data logger.warning(f"Batch encoding failed for unique segments {batch_start}-{batch_end}: {e}. Using raw bytes for this batch.") for seg in batch_unique_segments: segment_to_compressed[seg] = list(seg) # clean step gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # --- 3. reconstruct all result --- logger.info("Step 3: Reconstructing final list from unique compressed segments.") all_compressed_results = [None] * M for seg, indices in segment_to_indices.items(): if len(seg) <= 2: # short segments result = list(seg) else: # from mapping to get results compressed_data = segment_to_compressed.get(seg, list(seg)) # not find,use original data if len(compressed_data) >= len(seg): result = list(seg) # back to origin else: result = compressed_data # fill back original data for original_index in indices: all_compressed_results[original_index] = result return all_compressed_results def compress_segments_hybrid_arithmetic( sorted_segments: List[bytes], batched_predict_fn: Callable, first_byte_prob: torch.Tensor, max_m1_batch_size: int=4096, debug: bool = True ) -> List[List[int]]: """ GPU and CPU hybrid version for arithmetic compression. """ M = len(sorted_segments) if M == 0: return [] logger.info("Step 1: Identifying unique segments to compress.") device = first_byte_prob.device segment_to_indices = defaultdict(list) for i, seg in enumerate(sorted_segments): segment_to_indices[seg].append(i) unique_segments = [seg for seg in segment_to_indices.keys() if len(seg) > 2] logger.info(f"Found {len(unique_segments)} unique segments (len>2) out of {M} total segments.") # store only compressed result segment_to_compressed = {} # --- 2. safe encoding --- ENCODING_BATCH_SIZE = 128 encoder = HybridArithmeticEncoder( batched_predict_fn=batched_predict_fn, first_byte_prob=first_byte_prob ) logger.info(f"Step 2: Encoding unique segments in batches of size {ENCODING_BATCH_SIZE}.") for i in range(0, len(unique_segments), ENCODING_BATCH_SIZE): batch_start = i batch_end = min(i + ENCODING_BATCH_SIZE, len(unique_segments)) batch_unique_segments = unique_segments[batch_start:batch_end] try: if debug: # in debug pattetn, get padded_bits for validation codes, padded_bits = encoder.batched_encode(batch_unique_segments, return_num_padded_bits=True) decoded_tensor = encoder.batched_decode(codes, padded_bits, batch_unique_segments) for j, original_seg_bytes in enumerate(batch_unique_segments): original_len = len(original_seg_bytes) decoded_bytes = bytes(decoded_tensor[j, :original_len].cpu().tolist()) assert decoded_bytes == original_seg_bytes, f"Hybrid decode mismatch for segment!" else: codes = encoder.batched_encode(batch_unique_segments, return_num_padded_bits=False) # store results for seg, code in zip(batch_unique_segments, codes): segment_to_compressed[seg] = list(code) except Exception as e: logger.warning(f"Batch encoding failed for unique segments {batch_start}-{batch_end}: {e}. Using raw bytes for this batch.") for seg in batch_unique_segments: segment_to_compressed[seg] = list(seg) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Step 3: Reconstructing final list from unique compressed segments.") all_compressed_results = [None] * M for seg, indices in segment_to_indices.items(): if len(seg) <= 2: result = list(seg) else: compressed_data = segment_to_compressed.get(seg, list(seg)) if len(compressed_data) >= len(seg): result = list(seg) else: result = compressed_data for original_index in indices: all_compressed_results[original_index] = result return all_compressed_results def reconstruct_results( compressed_map: Dict[int, List[int]], reconstruction_info: Dict[str, Any], debug: bool = True ) -> List[Dict[str, Any]]: """ Reconstruct the original results from the compressed results. Need and compressed ratio and assert the reconstruction is correct. """ sample_idx_to_list_segment_idx = reconstruction_info["sample_idx_to_list_segment_idx"] raw_segments_map = reconstruction_info["raw_segments_map"] batch_meta = reconstruction_info["batch_meta"] #Add: valid reconstruction sorted_to_original_idx_map = reconstruction_info["sorted_to_original_idx_map"] # mapping original_idx -> compressed_data original_idx_to_compressed_data = { v: compressed_map[k] for k, v in sorted_to_original_idx_map.items() if k in compressed_map } write_results = [] ac_key = "m1_enumerative" # compute compress ratio total_original_bytes = 0 total_compressed_pseudo_bytes = 0 for sample_idx,item in enumerate(batch_meta): final_pseudo_bytes = [] if debug: reconstructed_original_segments = [] segment_indices_for_sample = sample_idx_to_list_segment_idx.get(sample_idx, []) for original_idx in segment_indices_for_sample: if original_idx in original_idx_to_compressed_data: # compressed bytes compressed_data = original_idx_to_compressed_data[original_idx] final_pseudo_bytes.extend(compressed_data) if debug: #Add: assert encode and decode total_compressed_pseudo_bytes += len(compressed_data) original_segment_bytes = reconstruction_info["effective_segments_map"][original_idx] reconstructed_original_segments.append(original_segment_bytes) total_original_bytes += len(original_segment_bytes) elif original_idx in raw_segments_map: # raw bytes raw_data = raw_segments_map[original_idx] final_pseudo_bytes.extend(list(raw_data)) if debug: total_compressed_pseudo_bytes += len(raw_data) reconstructed_original_segments.append(raw_data) total_original_bytes += len(raw_data) else: # Case 3: wrong both not exist logger.error(f"FATAL LOGIC ERROR: Segment with original_idx {original_idx} does not exist in effective_segments_map or raw_segments_map!") # try from effective_segments_map original_segment_bytes = reconstruction_info["effective_segments_map"].get(original_idx) if original_segment_bytes: final_pseudo_bytes.extend(list(original_segment_bytes)) packed_bytes = pseudo_to_packed_bytes(final_pseudo_bytes) result = { **item, "m1_compressed_data": base64.b64encode(packed_bytes).decode("ascii") } write_results.append(result) #Add: assert encode and decode if debug and reconstructed_original_segments: original_sample_bytes = item["text"].encode('utf-8') reconstructed_sample_bytes = b"".join(reconstructed_original_segments) assert reconstructed_sample_bytes == original_sample_bytes, \ f"Sample {sample_idx} reconstruction failed!" # check pack and unpack unpacked_pseudo_bytes = packed_bytes_to_pseudo(packed_bytes) assert unpacked_pseudo_bytes == final_pseudo_bytes, \ f"Pseudo-bytes packing/unpacking round-trip failed for sample {sample_idx}" #Add: copmute compress ratio if debug and total_original_bytes > 0: compression_ratio = total_compressed_pseudo_bytes / total_original_bytes logger.info(f"Batch compression stats: " f"Original bytes: {total_original_bytes}, " f"Compressed pseudo-bytes: {total_compressed_pseudo_bytes}, " f"Ratio: {compression_ratio:.4f}") # this tab wrong return write_results def writer_consumer(write_queue, output_file, buffer_size=100,debug=True): write_buf = [] try: with open(output_file, 'w', encoding='utf-8') as f: while True: payload = write_queue.get() if payload is None: break # get result from reconstruct results write_results = reconstruct_results( payload["compressed_map"], payload["reconstruction_info"], debug=debug ) write_buf.extend(write_results) # clean the complex expression of before segmentation # Write buffer when it's full if len(write_buf) >= buffer_size: logger.info(f"Writer: Dumping buffer of {len(write_buf)} items to {output_file}") for buffered_item in write_buf: f.write(json.dumps(buffered_item) + '\n') f.flush() write_buf = [] # Write remaining items in buffer if write_buf: logger.info(f"Writer: Dumping remaining {len(write_buf)} items to {output_file}") for buffered_item in write_buf: f.write(json.dumps(buffered_item) + '\n') f.flush() except Exception as e: logger.error(f"Writer process error: {e}") raise def merge_output_files(output_file, writer_output_files): """Merge all writer output files into a single file""" logger.info(f"Merging {len(writer_output_files)} writer files into {output_file}") with open(output_file, 'w', encoding='utf-8') as outf: for writer_output_file in writer_output_files: if writer_output_file.exists(): with open(writer_output_file, 'r', encoding='utf-8') as inf: for line in inf: outf.write(line) # Optionally remove the individual writer files writer_output_file.unlink() logger.info(f"Merged and removed writer file: {writer_output_file}") logger.info(f"Merged output written to: {output_file}") return output_file def shutdown_writers(write_queue, writer_processes): """Send shutdown signals to shared queue and wait for all writers to complete""" # Send one sentinel per writer to ensure all writers get the shutdown signal for i in range(len(writer_processes)): write_queue.put(None) logger.info(f"Sent shutdown signal {i+1}/{len(writer_processes)}") # Wait for all writers to complete for i, writer_process in enumerate(writer_processes): writer_process.join() if writer_process.exitcode != 0: logger.error(f"Writer process {i} failed with exit code: {writer_process.exitcode}") else: logger.info(f"Writer process {i} completed successfully") def main_processor_fn( batch: List[Dict[str, Any]], compression_fn: Callable, # <-- 传入一个压缩函数作为参数! predict_fn: Callable, first_byte_prob: torch.Tensor, max_m1_batch_size: int, debug: bool = True ): # 1. preparing data prep_data = prepare_segments(batch) sorted_segments = prep_data["sorted_segments_to_compress"] reconstruction_info = prep_data["reconstruction_info"] # 2. compress data if sorted_segments: #Add: Time consume start_time = time.time() compressed_pseudo_bytes = compression_fn( sorted_segments, predict_fn, first_byte_prob, max_m1_batch_size, debug ) end_time = time.time() duration = end_time - start_time logger.info( f"Compressed {len(sorted_segments)} segments " f"in {duration:.4f} seconds ({len(sorted_segments)/duration if duration > 0 else float('inf'):.2f} segments/sec)." ) # create a mapping from origal idx to compressed results #sorted_to_original_idx_map = reconstruction_info["sorted_to_original_idx_map"] compressed_map = { i: data for i, data in enumerate(compressed_pseudo_bytes) } # compressed_map: key--sorted i; value--compressed bytes # but below sorted_to_original_idx_map[i] is original idx.. # compressed_map = { # sorted_to_original_idx_map[i]: data # for i, data in enumerate(compressed_pseudo_bytes) # } else: compressed_map = {} # 3. pack result to consumer payload = { "compressed_map": compressed_map, "reconstruction_info": reconstruction_info } return payload def main(): # Set up argument parser parser = argparse.ArgumentParser(description='Process JSONL files using M1 arithmetic compression with buffer-based approach') parser.add_argument('--input_file', type=str, required=True, help='Directory containing input JSONL files') parser.add_argument('--output_dir', type=str, required=True, help='Directory to write compressed results') parser.add_argument('--entropy_model_path', type=str, required=True, help='Path to the M1 model checkpoint') parser.add_argument('--compression_model_path', type=str, required=True, help='Path to the M1 model checkpoint') parser.add_argument('--compressor', type=str, default='rank_based', choices=['rank_based', 'arithmetic','hybrid_arithmetic'], help='Choose the compression algorithm.') parser.add_argument('--data_batch_size', type=int, default=512, help='Size of batches for processing (default: 512)') parser.add_argument('--output_window_size', type=int, default=16, help='Size of window for compression (default: 16)') parser.add_argument('--max_window_size', type=int, default=1024, help='Maximum window size for reading from each file (default: 1024)') parser.add_argument('--max_entropy_batch_size', type=int, default=4096, help='Size of max batch for compression (default: 4096)') parser.add_argument('--max_compression_batch_size', type=int, default=4096, help='Size of max batch for compression (default: 4096)') parser.add_argument('--chunk_size', type=int, default=512, help='Size of chunk for compression (default: 512)') parser.add_argument('--base_global_quantile', type=float, default=0.9, help='Base global quantile for compression (default: 0.9)') parser.add_argument('--base_monotonic_quantile', type=float, default=0.9, help='Base monotonic quantile for compression (default: 0.9)') parser.add_argument('--debug', action='store_true', default=True, help='Debug mode (default: False)') parser.add_argument('--firstbyte_prob_path', type=str, default=None, help='Probability path for the first word of each window (default : None)') parser.add_argument('--num_workers', type=int, default=1, help='Number of workers for CPU jobs (default: 1)') parser.add_argument('--process_id', type=int, default=0, help='Process ID for distributed processing (default: 0)') parser.add_argument('--num_processes', type=int, default=1, help='Number of processes for distributed processing (default: 1)') parser.add_argument('--merge_output', action='store_true', default=False, help='Merge all writer output files into a single file (default: False)') # adding cache parameters parser.add_argument('--use_global_cache', action='store_true', default=True, help='Enable the global compression cache.') parser.add_argument('--cache_size', type=int, default=819200, help='Size of the global compression cache.') args = parser.parse_args() # choose compression algorithm if args.compressor == 'rank_based': compression_algorithm = compress_segments_rank_based elif args.compressor == 'arithmetic': compression_algorithm = compress_segments_arithmetic elif args.compressor == 'hybrid_arithmetic': compression_algorithm = compress_segments_hybrid_arithmetic else: raise ValueError(f"Unknown compressor: {args.compressor}") logger.info(f"Using compression algorithm: {compression_algorithm.__name__}") # use wrapper to make cache for each algorithm if args.use_global_cache: caching_wrapper = CachingCompressorWrapper( base_compression_fn=compression_algorithm, cache_size=args.cache_size ) # use cache compression_algorithm_to_use = caching_wrapper logger.info("Global cache start....") else: # no cache compression_algorithm_to_use = compression_algorithm logger.info("No Global cache ...") mp.set_start_method('spawn', force=True) gc_freq = 100 dump_freq = 25 # Create output directory if it doesn't exist output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Load model and tokenizer model, _, _ = load_m1_model_and_tokenizer(args.entropy_model_path) batched_predict_fn = batched_m1_compress_predict_fn(model) if args.firstbyte_prob_path is not None: with open(args.firstbyte_prob_path, 'r', encoding='utf-8') as f: first_byte_prob = json.load(f) print(first_byte_prob) first_byte_prob = torch.tensor(first_byte_prob, dtype=torch.float32, device="cuda").unsqueeze(0).unsqueeze(0) else: first_byte_prob = torch.ones((1, 1, ALPHABET_SIZE), dtype=torch.float32, device="cuda") / ALPHABET_SIZE # Create dataset and dataloader dataset = InterleavedJsonlDataset( file_path=args.input_file, rank=args.process_id, world_size=args.num_processes, ) dataloader = DataLoader( dataset, batch_size=args.data_batch_size, shuffle=False, collate_fn=lambda x: x ) input_file = Path(args.input_file) logger.info(f"Processing file: {input_file}") output_file = output_dir / f"{input_file.stem}_out_{args.process_id}.jsonl" logger.info("Data loaded. Start processing...") write_queue = mp.Queue(maxsize=200) writer_processes = [] writer_output_files = [] for i in range(args.num_workers): # Create unique output file for each writer output_path = Path(output_file) writer_output_file = output_path.parent / f"{output_path.stem}_writer_{i}.jsonl" writer_output_files.append(writer_output_file) writer_process = mp.Process( target=writer_consumer, args=(write_queue, writer_output_file, dump_freq,args.debug) ) writer_processes.append(writer_process) writer_process.start() logger.info(f"Started writer process {i} for output file: {writer_output_file}") try: # Process each batch for batch_idx, batch in enumerate(dataloader): payload_for_writer = main_processor_fn( batch, compression_algorithm_to_use, # compressor with cache # compression_algorithm, # 把选择的算法传进去 batched_predict_fn, first_byte_prob, args.max_compression_batch_size, args.debug, ) logger.info(f"Processed batch {batch_idx}") write_queue.put(payload_for_writer) if batch_idx % gc_freq == 0: # Clean up GPU memory gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Signal completion to all writer processes shutdown_writers(write_queue, writer_processes) except Exception as e: logger.error(f"Error during processing: {e}") # Try to terminate writer processes cleanly try: shutdown_writers(write_queue, writer_processes) except: pass raise if args.merge_output: final_output_file = merge_output_files(output_file, writer_output_files) logger.info(f"Completed processing successfully, merged output written to {final_output_file}") else: logger.info(f"Completed processing successfully, outputs written to {args.num_workers} separate files") if __name__ == "__main__": main()