# Copyright(C) [2025] Advanced Micro Devices, Inc. All rights reserved. import numpy as np from .base import BaseMetric class PassK(BaseMetric): """ Pass@k metric. This metric checks if the correct answer is among the top k predictions. """ def __init__(self,): super().__init__(name="Pass@k") def compute(self, n: int, c: int, k: int) -> float: """ Compute the Pass@k metric. Args: n (int): The number of parallel instances. c (list): The number of correct instances. k (int): The number of instances to consider. Returns: float: The Pass@k score. """ if n -c < k: return 1.0 return 1 - np.prod( 1 - k/ np.arange( n-c+1, n+1 ) )