import json import base64 import argparse import os from collections import defaultdict from itertools import combinations import re from tqdm import tqdm from typing import List, Dict, Any, Tuple import argparse # before running: pip install python-Levenshtein # pip install matplotlib # pip install seaborn # numpy # version_1 : for dir, complete matched try: import Levenshtein except ImportError: print("❌ errord: 'python-Levenshtein' have not be installed") print("run:pip install python-Levenshtein") exit(1) import matplotlib.pyplot as plt import seaborn as sns import numpy as np def vread(buf: bytes, i: int): shift = val = 0 while True: b = buf[i] i += 1 val |= (b & 0x7F) << shift if b < 0x80: return val, i shift += 7 def unpack_windows(input_bytes: bytes, b64_stream: str) -> List[Tuple[bytes, int]]: try: buf, i, cursor, byte_windows = base64.b64decode(b64_stream), 0, 0, [] while i < len(buf): gap, i = vread(buf, i) size, i = vread(buf, i) start = cursor + gap if gap > 0: byte_windows.append((input_bytes[cursor:start], 0)) end = start + size byte_windows.append((input_bytes[start:end], 1)) cursor = end if cursor < len(input_bytes): byte_windows.append((input_bytes[cursor:], 0)) return byte_windows except (base64.binascii.Error, IndexError): return [] def packed_bytes_to_pseudo(b: bytes) -> list[int]: out, acc, bits = [], 0, 0 for byte in b: acc |= byte << bits bits += 8 while bits >= 9: out.append(acc & 0x1FF) acc >>= 9 bits -= 9 return out # decompress strs to list of (bytes, int) tuples def decompress_windows_starts_lens(b64_stream: str) -> tuple[list[int], list[int]]: try: buf = base64.b64decode(b64_stream) i = 0 cursor= 0 starts, lens = [], [] while i < len(buf): gap, i = vread(buf, i) size, i = vread(buf, i) start = cursor + gap length = size starts.append(start) lens.append(length) cursor = start + length return starts, lens except (base64.binascii.Error, IndexError): # decode failed return [], [] def parse_parameters_from_path(file_path: str) -> dict: """parse parameters""" params = {} base_name = os.path.basename(os.path.normpath(file_path)) parts = base_name.split('_') for part in parts: if '-' in part: # key-value , e.g., "iterative-true" key, value = part.split('-', 1) params[key.lower()] = value.lower() else: # _keyvalue , e.g., "ow20" # 使用 re.match 确保只从开头匹配 match = re.match(r'([a-zA-Z]+)(\d+)', part) if match: key, value = match.groups() params[key.lower()] = value params['bits_per_compressed'] = 10 print(f"From path '{base_name}' to parse params: {params}") return params def construct_compression_key(params: dict) -> str: """construct compressed data key。""" ow = params.get('ow', 20) escape_fb = 'True' if params.get('escapefb', 'false') == 'true' else 'False' iterative = 'True' if params.get('iterative', 'true') == 'true' else 'False' force_padding = 'True' if params.get('forcepadding', 'false') == 'true' else 'False' key = f"m1_ac_ow{ow}_escapefb-{escape_fb}_iterative-{iterative}_forcepadding-{force_padding}" print(f"Compress data Key is: '{key}'") return key def analyze_token_collisions_in_directory(input_dir: str, output_dir: str, compression_offset: int = 256, max_files: int = -1, max_lines: int = -1): """for dir token-level collusion """ if not os.path.isdir(input_dir): print(f"❌ Error: input file is not valid '{input_dir}'"); return # from dir get params params = parse_parameters_from_path(input_dir) if 'ow' not in params: print(f"❌ Error: can not from '{input_dir}' get 'ow'"); return compression_bit_threshold = params['ow'] bits_per_compressed = params['bits_per_compressed'] compression_key = construct_compression_key(params) print(f"compress Key is: '{compression_key}'") print(f"params: compression_bit_threshold={compression_bit_threshold}, bits_per_compressed={bits_per_compressed}") # get all .jsonl jsonl_files = [] for root, _, files in os.walk(input_dir): for file in files: if file.endswith('.jsonl') and file.startswith('ocp'): jsonl_files.append(os.path.join(root, file)) if not jsonl_files: print(f"❌ Error: no more'{input_dir}' .jsonl"); return print(f"🔍 Find {len(jsonl_files)} .jsonl files to address") # small batch debug if max_files > 0: jsonl_files = jsonl_files[:max_files] # 只取前max_files个文件 print(f"🔍 小批次模式:仅处理 {len(jsonl_files)} 个文件,每个文件最多 {max_lines} 行") # global mapping #token_to_raw_map = defaultdict(list) sequence_to_raw_map = defaultdict(list) # tuple to list of raw chunks total_lines = 0 total_mismatches = 0 key_not_found_count = 0 decode_errors = 0 total_failed = 0 print("🚀 Start addressing all, build global token -> raw_chunk_list map...") for file_path in tqdm(jsonl_files, desc="Processing files"): with open(file_path, 'r', errors='ignore') as f: line_nums = 0 for line in f: line_nums += 1 if max_lines > 0 and line_nums > max_lines: print(f"📌 文件 {os.path.basename(file_path)} 已处理 {max_lines} 行,停止读取") break total_lines += 1 try: data = json.loads(line) if compression_key not in data or not data[compression_key] or \ 'windows_starts_lens_b64' not in data or not data['windows_starts_lens_b64']: continue required_keys = [compression_key, 'text', 'windows_starts_lens_b64', 'pseudo_lens_per_segment'] if not all(k in data and data[k] for k in required_keys): print(f"some key is not exist") continue if compression_key not in data: if key_not_found_count == 0: print(f"\n\n--- 调试信息:Key 不匹配 ---") print(f"构建的 Key: '{compression_key}'") print(f"JSON中的可用 Keys: {list(data.keys())}") print("---------------------------------") key_not_found_count += 1 continue # 1. parse windows to get mixed data b64_decoded_bytes = base64.b64decode(data[compression_key]) mixed_pseudo_bytes = packed_bytes_to_pseudo(b64_decoded_bytes) # 2.unpack_window to split original texts raw_text_bytes = data['text'].encode('utf-8') all_segments = unpack_windows(raw_text_bytes, data['windows_starts_lens_b64']) pseudo_lens = data["pseudo_lens_per_segment"] if len(pseudo_lens) != len(all_segments): raise ValueError("Metadata length mismatch between pseudo_lens and all_segments") # compressed bytes: 1 2 3 355 356 1 2 3 # raw bytes: 1 2 3 17 18 19 1 2 3 # 3.use ptr to find each compressed position pseudo_ptr = 0 for i in range(len(all_segments)): raw_chunk, indicator = all_segments[i] segment_pseudo_len = pseudo_lens[i] if pseudo_ptr >= len(mixed_pseudo_bytes): raise ValueError("Pseudo bytes stream exhausted prematurely.") if indicator == 0: # skip bytes pseudo_ptr += segment_pseudo_len elif indicator == 1: # compressed windows # get the compressed token # token = mixed_pseudo_bytes[pseudo_ptr] # if token < 256: # raise ValueError(f"Expected a compressed token (>=256), but got {token}") # # get mapping # token_to_raw_map[token].append(raw_chunk) # pseudo_ptr += 1 chunk_len = len(raw_chunk) if pseudo_ptr + segment_pseudo_len > len(mixed_pseudo_bytes): raise ValueError("Pseudo bytes stream exhausted for a window.") # extract common lists # raw bytes current_raw_bytes = list(raw_chunk) # mixed bytes current_pseudo_sequence = mixed_pseudo_bytes[pseudo_ptr : pseudo_ptr + segment_pseudo_len] # delete common end bytes which is less than 256 while current_pseudo_sequence and current_pseudo_sequence[-1] < 256: last_pseudo = current_pseudo_sequence.pop() last_raw = current_raw_bytes.pop() # verify last raw bytes must be same assert last_pseudo == last_raw, "Mismatch in raw tail" # after clearify: set map if current_pseudo_sequence: pure_token_sequence = tuple(current_pseudo_sequence) pure_raw_chunk = bytes(current_raw_bytes) sequence_to_raw_map[pure_token_sequence].append(pure_raw_chunk) # 更新指针 pseudo_ptr += segment_pseudo_len total_processed += 1 except Exception: total_failed += 1 continue # can not only use start and length to get all window because we only record the compressed window # so we must use start_pos and cursor to skip the raw bytes # pseudo_tokens = [t for t in mixed_pseudo_bytes if t >= 256] # # checksum # if len(starts) != len(pseudo_tokens): # total_mismatches += 1 # continue # # mapping # raw_text_bytes = data['text'].encode('utf-8') # for i, token in enumerate(pseudo_tokens): # start, length = starts[i], lens[i] # if start + length <= len(raw_text_bytes): # raw_chunk = raw_text_bytes[start : start + length] # token_to_raw_map[token].append(raw_chunk) # except (json.JSONDecodeError, TypeError, KeyError, base64.binascii.Error, struct.error): # if total_lines % 100000 == 0: # print(f"⚠️ 处理第{total_lines}行时出错: {e}") # continue # skip questions print(f"✅ All file addressed. Total address{total_lines:,} lines") if key_not_found_count > 0: print(f" {key_not_found_count:,} 行因 key 不匹配被跳过。") if total_mismatches > 0: print(f" {total_mismatches:,} 行因窗口与token数不匹配被跳过。") if decode_errors > 0: print(f" {decode_errors:,} 行因解码错误被跳过。") print(f" Global mapping Finished, all find {len(sequence_to_raw_map):,} unique token。") print("\n🔍 Start analysis token-level collusion...") analysis_results = [] all_distances = [] # process for token_sequence, raw_chunks_list in tqdm(sequence_to_raw_map.items(), desc="Analyzing collisions"): unique_raw_chunks = list(set(raw_chunks_list)) if len(unique_raw_chunks) > 1: raw_strings = [c.decode('utf-8', 'replace') for c in unique_raw_chunks] distances, pair_details = [], [] for str1, str2 in combinations(raw_strings, 2): dist = Levenshtein.distance(str1, str2) distances.append(dist) all_distances.extend(distances) analysis_results.append({ "colliding_token_sequence": list(token_sequence), "num_raw_variants": len(unique_raw_chunks), "raw_chunk_variants": raw_strings, "levenshtein_analysis": { "distances": distances, "average_distance": np.mean(distances) if distances else 0, "max_distance": max(distances) if distances else 0, "min_distance": min(distances) if distances else 0, } }) print(f"✅ Finshed analysising. Total {len(analysis_results):,} token collusion.") if not analysis_results: print("🎉 Congradulations! no find token-level collusions."); return os.makedirs(output_dir, exist_ok=True) output_json_path = os.path.join(output_dir, "token_collision_report.json") analysis_results.sort(key=lambda x: x['levenshtein_analysis']['average_distance'], reverse=True) with open(output_json_path, 'w', encoding='utf-8') as f: json.dump(analysis_results, f, indent=2, ensure_ascii=False) print(f"\n💾 Analysis is saved to: {output_json_path}") print("\n📋 Sampel(Avg order):") for i, result in enumerate(analysis_results[:5]): print("-" * 20) print(f"Sample {i+1}:") print(f" Collusion Token: {result['colliding_token_sequence']}") print(f" To {result['num_raw_variants']} diff raw bytes") print(f" Avg Distance: {result['levenshtein_analysis']['average_distance']:.2f}") print(f" Raw 1: {repr(result['raw_chunk_variants'][0][:80])}") print(f" Raw 2: {repr(result['raw_chunk_variants'][1][:80])}") output_plot_path = os.path.join(output_dir, "token_collision_levenshtein_distribution.png") plt.style.use('seaborn-v0_8-whitegrid') fig, ax = plt.subplots(figsize=(12, 7)) if all_distances: sns.histplot(all_distances, bins=max(50, min(len(set(all_distances)), 100)), kde=False, ax=ax) stats_text = (f"Total Colliding Pairs: {len(all_distances):,}\n" f"Mean Distance: {np.mean(all_distances):.2f}\n" f"Median Distance: {np.median(all_distances):.2f}\n" f"Max Distance: {np.max(all_distances):,}") ax.text(0.95, 0.95, stats_text, transform=ax.transAxes, fontsize=10, verticalalignment='top', horizontalalignment='right', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5)) else: ax.text(0.5, 0.5, "No collisions found.", transform=ax.transAxes, fontsize=15, verticalalignment='center', horizontalalignment='center') ax.set_title('Levenshtein Distance between Raw Chunks with Same Compressed Token (Entire Dataset)', fontsize=14) ax.set_xlabel('Levenshtein Distance', fontsize=12) ax.set_ylabel('Frequency (Number of Pairs)', fontsize=12) ax.set_yscale('log') plt.tight_layout() plt.savefig(output_plot_path) print(f"📊 Lev distance is saved to: {output_plot_path}") ## python analysis_dynamic_dis.py /mnt/hdfs/linzheng/data/ocpython_subsampled_50G_entropy90_splits_chunk512_ow20_iterative-true_forcepadding-true_merged_ac if __name__ == "__main__": try: from tqdm import tqdm except ImportError: print("pip install tqdm") def tqdm(iterable, *args, **kwargs): return iterable parser = argparse.ArgumentParser( description="check all token-level Compression collusion", formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument("input_dir", type=str, help="including .jsonl data input die。") parser.add_argument("-o", "--output_dir", type=str, default="analysis_output_token_collision", help="store output") parser.add_argument("--max_files", type=int, default=-1, help="set max addressing files") parser.add_argument("--max_lines", type=int, default=-1, help="the most addressing line") args = parser.parse_args() analyze_token_collisions_in_directory( args.input_dir, args.output_dir, max_files=args.max_files, max_lines=args.max_lines)