jeuko commited on
Commit
7638cbd
·
verified ·
1 Parent(s): 94a0f4c

Sync from GitHub (main)

Browse files
src/sentinel/risk_models/base.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  import re
4
  from abc import ABC, abstractmethod
5
- from typing import Any
6
 
7
  from pydantic import TypeAdapter, ValidationError
8
 
@@ -13,7 +13,7 @@ from sentinel.user_input import UserInput
13
  class RiskModel(ABC):
14
  """Base class for risk models."""
15
 
16
- REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {}
17
 
18
  def __init__(self, name: str):
19
  self.name = name
@@ -120,14 +120,54 @@ class RiskModel(ABC):
120
  # Validate against type and constraints if value present
121
  if value is not None:
122
  try:
123
- adapter = TypeAdapter(field_type)
124
  adapter.validate_python(value)
125
  except ValidationError as e:
126
- error_msg = e.errors()[0]["msg"]
 
 
 
 
 
 
 
 
127
  errors.append(f"Field '{field_path}': {error_msg}")
128
 
129
  return (len(errors) == 0, errors)
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def run(self, user: UserInput) -> RiskScore:
132
  """Compute all public fields and return a `RiskScore` summary.
133
 
 
2
 
3
  import re
4
  from abc import ABC, abstractmethod
5
+ from typing import Any, Literal, get_args, get_origin
6
 
7
  from pydantic import TypeAdapter, ValidationError
8
 
 
13
  class RiskModel(ABC):
14
  """Base class for risk models."""
15
 
16
+ REQUIRED_INPUTS: dict[str, tuple[Any, bool]] = {}
17
 
18
  def __init__(self, name: str):
19
  self.name = name
 
120
  # Validate against type and constraints if value present
121
  if value is not None:
122
  try:
123
+ adapter: TypeAdapter[Any] = TypeAdapter(field_type)
124
  adapter.validate_python(value)
125
  except ValidationError as e:
126
+ # Check if this is a Literal type with enum values for better error messages
127
+ if self._is_literal_enum_type(field_type):
128
+ allowed_values = self._extract_literal_enum_values(field_type)
129
+ if len(allowed_values) == 1:
130
+ error_msg = f"must be {allowed_values[0]}"
131
+ else:
132
+ error_msg = f"must be one of {', '.join(allowed_values)}"
133
+ else:
134
+ error_msg = e.errors()[0]["msg"]
135
  errors.append(f"Field '{field_path}': {error_msg}")
136
 
137
  return (len(errors) == 0, errors)
138
 
139
+ def _is_literal_enum_type(self, field_type: type) -> bool:
140
+ """Check if a field type is a Literal containing enum values.
141
+
142
+ Args:
143
+ field_type: The type to check
144
+
145
+ Returns:
146
+ True if the type is a Literal containing enum values
147
+ """
148
+ origin = get_origin(field_type)
149
+ if origin is not Literal:
150
+ return False
151
+
152
+ args = get_args(field_type)
153
+ # Check if all arguments are enum values (have a __class__ attribute and are instances of Enum)
154
+ return all(
155
+ hasattr(arg, "__class__") and hasattr(arg, "name") and hasattr(arg, "value")
156
+ for arg in args
157
+ )
158
+
159
+ def _extract_literal_enum_values(self, field_type: type) -> list[str]:
160
+ """Extract enum value names from a Literal type.
161
+
162
+ Args:
163
+ field_type: The Literal type containing enum values
164
+
165
+ Returns:
166
+ List of enum value names (e.g., ['FEMALE'] or ['WHITE', 'BLACK', 'ASIAN'])
167
+ """
168
+ args = get_args(field_type)
169
+ return [arg.name for arg in args]
170
+
171
  def run(self, user: UserInput) -> RiskScore:
172
  """Compute all public fields and return a `RiskScore` summary.
173
 
src/sentinel/risk_models/boadicea.py CHANGED
@@ -7,7 +7,7 @@ The model is specifically designed for women with genetic predispositions,
7
  particularly BRCA1 and BRCA2 mutation carriers.
8
  """
9
 
10
- from typing import Annotated
11
 
12
  from pydantic import Field
13
 
@@ -25,7 +25,7 @@ class BOADICEARiskModel(RiskModel):
25
 
26
  REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
27
  "demographics.age_years": (Annotated[int, Field(ge=18, le=100)], True),
28
- "demographics.sex": (Sex, True),
29
  "demographics.ethnicity": (Ethnicity | None, False),
30
  "demographics.anthropometrics.height_cm": (
31
  Annotated[float, Field(gt=0)],
 
7
  particularly BRCA1 and BRCA2 mutation carriers.
8
  """
9
 
10
+ from typing import Annotated, Literal
11
 
12
  from pydantic import Field
13
 
 
25
 
26
  REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
27
  "demographics.age_years": (Annotated[int, Field(ge=18, le=100)], True),
28
+ "demographics.sex": (Literal[Sex.FEMALE], True),
29
  "demographics.ethnicity": (Ethnicity | None, False),
30
  "demographics.anthropometrics.height_cm": (
31
  Annotated[float, Field(gt=0)],
src/sentinel/risk_models/extended_pbcg.py CHANGED
@@ -4,7 +4,7 @@ import json
4
  from functools import lru_cache
5
  from math import exp, log
6
  from pathlib import Path
7
- from typing import Annotated
8
 
9
  from pydantic import Field
10
 
@@ -65,7 +65,7 @@ class ExtendedPBCGRiskModel(RiskModel):
65
 
66
  REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
67
  "demographics.age_years": (Annotated[int, Field(ge=40, le=90)], True),
68
- "demographics.sex": (Sex, True),
69
  "demographics.ethnicity": (Ethnicity | None, False),
70
  "clinical_tests.psa": (PSATest, True),
71
  "clinical_tests.prostate_volume": (ProstateVolumeTest, False),
 
4
  from functools import lru_cache
5
  from math import exp, log
6
  from pathlib import Path
7
+ from typing import Annotated, Literal
8
 
9
  from pydantic import Field
10
 
 
65
 
66
  REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
67
  "demographics.age_years": (Annotated[int, Field(ge=40, le=90)], True),
68
+ "demographics.sex": (Literal[Sex.MALE], True),
69
  "demographics.ethnicity": (Ethnicity | None, False),
70
  "clinical_tests.psa": (PSATest, True),
71
  "clinical_tests.prostate_volume": (ProstateVolumeTest, False),
src/sentinel/risk_models/gail.py CHANGED
@@ -17,7 +17,7 @@ https://dceg.cancer.gov/tools/risk-assessment/bcra.
17
  """
18
 
19
  from math import ceil, exp, log
20
- from typing import Annotated
21
 
22
  from pydantic import Field
23
 
@@ -503,10 +503,20 @@ class GailRiskModel(RiskModel):
503
  def __init__(self):
504
  super().__init__("gail")
505
 
506
- REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
507
  "demographics.age_years": (Annotated[int, Field(ge=35, le=85)], True),
508
- "demographics.sex": (Sex, True),
509
- "demographics.ethnicity": (Ethnicity | None, False),
 
 
 
 
 
 
 
 
 
 
510
  "female_specific.menstrual.age_at_menarche": (
511
  Annotated[int, Field(ge=7, le=20)],
512
  False,
@@ -732,10 +742,6 @@ class GailRiskModel(RiskModel):
732
  if not is_valid:
733
  raise ValueError(f"Invalid inputs for Gail: {'; '.join(errors)}")
734
 
735
- # Check sex
736
- if user.demographics.sex != Sex.FEMALE:
737
- return "N/A: Gail model is only applicable to female patients."
738
-
739
  # Check female-specific data
740
  if user.female_specific is None:
741
  return "N/A: Missing female-specific information required for Gail."
 
17
  """
18
 
19
  from math import ceil, exp, log
20
+ from typing import Annotated, Any, Literal
21
 
22
  from pydantic import Field
23
 
 
503
  def __init__(self):
504
  super().__init__("gail")
505
 
506
+ REQUIRED_INPUTS: dict[str, tuple[Any, bool]] = {
507
  "demographics.age_years": (Annotated[int, Field(ge=35, le=85)], True),
508
+ "demographics.sex": (Literal[Sex.FEMALE], True),
509
+ "demographics.ethnicity": (
510
+ Literal[
511
+ Ethnicity.WHITE,
512
+ Ethnicity.BLACK,
513
+ Ethnicity.ASIAN,
514
+ Ethnicity.PACIFIC_ISLANDER,
515
+ Ethnicity.HISPANIC,
516
+ ]
517
+ | None,
518
+ False,
519
+ ),
520
  "female_specific.menstrual.age_at_menarche": (
521
  Annotated[int, Field(ge=7, le=20)],
522
  False,
 
742
  if not is_valid:
743
  raise ValueError(f"Invalid inputs for Gail: {'; '.join(errors)}")
744
 
 
 
 
 
745
  # Check female-specific data
746
  if user.female_specific is None:
747
  return "N/A: Missing female-specific information required for Gail."
src/sentinel/risk_models/prostate_mortality.py CHANGED
@@ -7,7 +7,7 @@ methodology to predict mortality outcomes at specified time horizons.
7
  """
8
 
9
  from math import exp, log
10
- from typing import Annotated
11
 
12
  from pydantic import Field
13
 
@@ -29,7 +29,7 @@ class ProstateMortalityRiskModel(RiskModel):
29
 
30
  REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
31
  "demographics.age_years": (Annotated[int, Field(ge=35, le=95)], True),
32
- "demographics.sex": (Sex, True),
33
  "clinical_tests.psa": (PSATest, True),
34
  "personal_medical_history.prostate_cancer_grade_group": (
35
  Annotated[int, Field(ge=1, le=5)],
 
7
  """
8
 
9
  from math import exp, log
10
+ from typing import Annotated, Literal
11
 
12
  from pydantic import Field
13
 
 
29
 
30
  REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
31
  "demographics.age_years": (Annotated[int, Field(ge=35, le=95)], True),
32
+ "demographics.sex": (Literal[Sex.MALE], True),
33
  "clinical_tests.psa": (PSATest, True),
34
  "personal_medical_history.prostate_cancer_grade_group": (
35
  Annotated[int, Field(ge=1, le=5)],
src/sentinel/risk_models/tyrer_cuzick.py CHANGED
@@ -1109,7 +1109,7 @@ class TyrerCuzickRiskModel(RiskModel):
1109
 
1110
  REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
1111
  "demographics.age_years": (Annotated[int, Field(ge=18, le=100)], True),
1112
- "demographics.sex": (Sex, True),
1113
  "demographics.anthropometrics.height_cm": (
1114
  Annotated[float, Field(gt=0)],
1115
  False,
 
1109
 
1110
  REQUIRED_INPUTS: dict[str, tuple[type, bool]] = {
1111
  "demographics.age_years": (Annotated[int, Field(ge=18, le=100)], True),
1112
+ "demographics.sex": (Literal[Sex.FEMALE], True),
1113
  "demographics.anthropometrics.height_cm": (
1114
  Annotated[float, Field(gt=0)],
1115
  False,
tests/test_risk_models/test_base_model_enum_validation.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for enum subset validation in REQUIRED_INPUTS.
2
+
3
+ This module tests the enhanced validation logic that supports restricting
4
+ enum fields to specific subsets using Literal types.
5
+ """
6
+
7
+ from typing import Any, Literal
8
+
9
+ from sentinel.risk_models.base import RiskModel
10
+ from sentinel.user_input import (
11
+ Anthropometrics,
12
+ Demographics,
13
+ Ethnicity,
14
+ Lifestyle,
15
+ PersonalMedicalHistory,
16
+ Sex,
17
+ SmokingHistory,
18
+ SmokingStatus,
19
+ UserInput,
20
+ )
21
+
22
+
23
+ class EnumValidationTestModel(RiskModel):
24
+ """Test risk model with various enum restrictions for validation testing."""
25
+
26
+ def __init__(self):
27
+ super().__init__("test_enum_validation")
28
+
29
+ # Test cases for different enum restriction patterns
30
+ REQUIRED_INPUTS: dict[str, tuple[type | Any, bool]] = {
31
+ # Single enum value restriction
32
+ "demographics.sex": (Literal[Sex.FEMALE], True),
33
+ # Multiple enum value restriction (subset)
34
+ "demographics.ethnicity": (
35
+ Literal[Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN] | None,
36
+ False,
37
+ ),
38
+ }
39
+
40
+ def compute_score(self, user: UserInput) -> str:
41
+ """Test implementation.
42
+
43
+ Args:
44
+ user: The user profile to score.
45
+
46
+ Returns:
47
+ A test score string.
48
+ """
49
+ return "test_score"
50
+
51
+ def cancer_type(self) -> str:
52
+ return "test"
53
+
54
+ def description(self) -> str:
55
+ return "Test model"
56
+
57
+ def interpretation(self) -> str:
58
+ return "Test interpretation"
59
+
60
+ def references(self) -> list[str]:
61
+ return ["Test reference"]
62
+
63
+ def time_horizon_years(self) -> float | None:
64
+ return None
65
+
66
+
67
+ class TestEnumSubsetValidation:
68
+ """Test enum subset validation functionality."""
69
+
70
+ def setup_method(self):
71
+ """Set up test model."""
72
+ self.model = EnumValidationTestModel()
73
+
74
+ def _create_user_input(
75
+ self, sex: Sex, ethnicity: Ethnicity | None = None
76
+ ) -> UserInput:
77
+ """Create a valid UserInput instance for testing.
78
+
79
+ Args:
80
+ sex: The biological sex for the user.
81
+ ethnicity: The ethnicity for the user (optional).
82
+
83
+ Returns:
84
+ A valid UserInput instance for testing.
85
+ """
86
+ return UserInput(
87
+ demographics=Demographics(
88
+ age_years=40,
89
+ sex=sex,
90
+ ethnicity=ethnicity,
91
+ anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0),
92
+ ),
93
+ lifestyle=Lifestyle(
94
+ smoking=SmokingHistory(status=SmokingStatus.NEVER),
95
+ ),
96
+ personal_medical_history=PersonalMedicalHistory(),
97
+ )
98
+
99
+ def test_single_enum_value_restriction_valid(self):
100
+ """Test that valid single enum value passes validation."""
101
+ user = self._create_user_input(Sex.FEMALE, Ethnicity.WHITE)
102
+
103
+ is_valid, errors = self.model.validate_inputs(user)
104
+ assert is_valid
105
+ assert len(errors) == 0
106
+
107
+ def test_single_enum_value_restriction_invalid(self):
108
+ """Test that invalid single enum value fails validation with clear message."""
109
+ user = self._create_user_input(Sex.MALE, Ethnicity.WHITE) # Should be FEMALE
110
+
111
+ is_valid, errors = self.model.validate_inputs(user)
112
+ assert not is_valid
113
+ assert len(errors) == 1
114
+ assert "Field 'demographics.sex': must be FEMALE" in errors[0]
115
+
116
+ def test_multiple_enum_value_restriction_valid(self):
117
+ """Test that valid enum values from subset pass validation."""
118
+ valid_ethnicities = [Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN]
119
+
120
+ for ethnicity in valid_ethnicities:
121
+ user = self._create_user_input(Sex.FEMALE, ethnicity)
122
+
123
+ is_valid, errors = self.model.validate_inputs(user)
124
+ assert is_valid, f"Failed for ethnicity: {ethnicity}"
125
+ assert len(errors) == 0
126
+
127
+ def test_multiple_enum_value_restriction_invalid(self):
128
+ """Test that invalid enum values fail validation with clear message."""
129
+ invalid_ethnicities = [
130
+ Ethnicity.HISPANIC,
131
+ Ethnicity.ASHKENAZI_JEWISH,
132
+ Ethnicity.NATIVE_AMERICAN,
133
+ Ethnicity.PACIFIC_ISLANDER,
134
+ Ethnicity.OTHER,
135
+ Ethnicity.UNKNOWN,
136
+ ]
137
+
138
+ for ethnicity in invalid_ethnicities:
139
+ user = self._create_user_input(Sex.FEMALE, ethnicity)
140
+
141
+ is_valid, errors = self.model.validate_inputs(user)
142
+ assert not is_valid, f"Should have failed for ethnicity: {ethnicity}"
143
+ assert len(errors) == 1
144
+ assert "Field 'demographics.ethnicity': Input should be" in errors[0]
145
+ assert (
146
+ "WHITE" in errors[0] and "BLACK" in errors[0] and "ASIAN" in errors[0]
147
+ )
148
+
149
+ def test_optional_enum_field_with_none(self):
150
+ """Test that None values are handled correctly for optional enum fields."""
151
+ user = self._create_user_input(Sex.FEMALE, None) # Optional field
152
+
153
+ is_valid, errors = self.model.validate_inputs(user)
154
+ assert is_valid
155
+ assert len(errors) == 0
156
+
157
+ def test_missing_required_enum_field(self):
158
+ """Test that missing required enum fields are caught."""
159
+
160
+ # Create a model that requires a field that's not in the user input
161
+ class MissingFieldModel(RiskModel):
162
+ """Test model for missing field validation."""
163
+
164
+ def __init__(self):
165
+ super().__init__("missing_field_test")
166
+
167
+ REQUIRED_INPUTS: dict[str, tuple[Any, bool]] = {
168
+ "demographics.sex": (Literal[Sex.FEMALE], True),
169
+ "demographics.ethnicity": (Ethnicity | None, False),
170
+ "demographics.nonexistent_field": (
171
+ str,
172
+ True,
173
+ ), # This field doesn't exist
174
+ }
175
+
176
+ def compute_score(self, user: UserInput) -> str:
177
+ return "test"
178
+
179
+ def cancer_type(self) -> str:
180
+ return "test"
181
+
182
+ def description(self) -> str:
183
+ return "test"
184
+
185
+ def interpretation(self) -> str:
186
+ return "test"
187
+
188
+ def references(self) -> list[str]:
189
+ return ["test"]
190
+
191
+ def time_horizon_years(self) -> float | None:
192
+ return None
193
+
194
+ model = MissingFieldModel()
195
+ user = self._create_user_input(Sex.FEMALE, Ethnicity.WHITE)
196
+
197
+ is_valid, errors = model.validate_inputs(user)
198
+ assert not is_valid
199
+ assert len(errors) == 1
200
+ assert "Required field 'demographics.nonexistent_field' is missing" in errors[0]
201
+
202
+ def test_multiple_validation_errors(self):
203
+ """Test that multiple validation errors are reported."""
204
+ user = self._create_user_input(Sex.MALE, Ethnicity.HISPANIC) # Both wrong
205
+
206
+ is_valid, errors = self.model.validate_inputs(user)
207
+ assert not is_valid
208
+ assert len(errors) == 2
209
+
210
+ # Check that both errors are present
211
+ error_messages = " ".join(errors)
212
+ assert "must be FEMALE" in error_messages
213
+ assert "Input should be" in error_messages
214
+ assert (
215
+ "WHITE" in error_messages
216
+ and "BLACK" in error_messages
217
+ and "ASIAN" in error_messages
218
+ )
219
+
220
+ def test_literal_enum_type_detection(self):
221
+ """Test the _is_literal_enum_type helper method."""
222
+ # Test Literal with single enum value
223
+ single_literal = Literal[Sex.FEMALE]
224
+ assert self.model._is_literal_enum_type(single_literal)
225
+
226
+ # Test Literal with multiple enum values
227
+ multi_literal = Literal[Ethnicity.WHITE, Ethnicity.BLACK]
228
+ assert self.model._is_literal_enum_type(multi_literal)
229
+
230
+ # Test non-Literal types
231
+ assert not self.model._is_literal_enum_type(Sex)
232
+ assert not self.model._is_literal_enum_type(int)
233
+ assert not self.model._is_literal_enum_type(str)
234
+
235
+ def test_extract_literal_enum_values(self):
236
+ """Test the _extract_literal_enum_values helper method."""
237
+ # Test single enum value
238
+ single_literal = Literal[Sex.FEMALE]
239
+ values = self.model._extract_literal_enum_values(single_literal)
240
+ assert values == ["FEMALE"]
241
+
242
+ # Test multiple enum values
243
+ multi_literal = Literal[Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN]
244
+ values = self.model._extract_literal_enum_values(multi_literal)
245
+ assert set(values) == {"WHITE", "BLACK", "ASIAN"}
246
+
247
+ def test_backward_compatibility_unrestricted_enum(self):
248
+ """Test that unrestricted enum types still work (backward compatibility)."""
249
+
250
+ # Create a model with unrestricted enum
251
+ class UnrestrictedModel(RiskModel):
252
+ """Test model for backward compatibility with unrestricted enums."""
253
+
254
+ def __init__(self):
255
+ super().__init__("unrestricted_test")
256
+
257
+ REQUIRED_INPUTS: dict[str, tuple[type | Any, bool]] = {
258
+ "demographics.sex": (Sex, True),
259
+ "demographics.ethnicity": (Ethnicity | None, False),
260
+ }
261
+
262
+ def compute_score(self, user: UserInput) -> str:
263
+ return "test"
264
+
265
+ def cancer_type(self) -> str:
266
+ return "test"
267
+
268
+ def description(self) -> str:
269
+ return "test"
270
+
271
+ def interpretation(self) -> str:
272
+ return "test"
273
+
274
+ def references(self) -> list[str]:
275
+ return ["test"]
276
+
277
+ def time_horizon_years(self) -> float | None:
278
+ return None
279
+
280
+ model = UnrestrictedModel()
281
+
282
+ # Test with any valid enum values
283
+ user = self._create_user_input(
284
+ Sex.MALE, Ethnicity.HISPANIC
285
+ ) # Any values should work
286
+
287
+ is_valid, errors = model.validate_inputs(user)
288
+ assert is_valid
289
+ assert len(errors) == 0
tests/test_risk_models/test_boadicea_model.py CHANGED
@@ -154,7 +154,14 @@ def test_model_metadata(boadicea_model: BOADICEARiskModel) -> None:
154
  def test_ineligible_patients_return_messages(
155
  boadicea_model: BOADICEARiskModel, user: UserInput, expected: str
156
  ) -> None:
157
- assert boadicea_model.compute_score(user) == expected
 
 
 
 
 
 
 
158
 
159
 
160
  @pytest.mark.parametrize(
 
154
  def test_ineligible_patients_return_messages(
155
  boadicea_model: BOADICEARiskModel, user: UserInput, expected: str
156
  ) -> None:
157
+ # For male patients, validation now raises ValueError instead of returning N/A
158
+ if user.demographics.sex == Sex.MALE:
159
+ with pytest.raises(ValueError) as exc_info:
160
+ boadicea_model.compute_score(user)
161
+ assert "Invalid inputs for BOADICEA" in str(exc_info.value)
162
+ assert "must be FEMALE" in str(exc_info.value)
163
+ else:
164
+ assert boadicea_model.compute_score(user) == expected
165
 
166
 
167
  @pytest.mark.parametrize(
tests/test_risk_models/test_extended_pbcg_model.py CHANGED
@@ -408,10 +408,9 @@ class TestExtendedPBCGRiskModel:
408
  dre=DRETest(result=DREResult.NORMAL),
409
  ),
410
  )
411
- assert (
412
- self.model.compute_score(user)
413
- == "N/A: Extended PBCG is validated for male patients only."
414
- )
415
 
416
  def test_compute_score_invalid_age(self) -> None:
417
  user = UserInput(
 
408
  dre=DRETest(result=DREResult.NORMAL),
409
  ),
410
  )
411
+ # Validation now returns N/A message instead of raising ValueError
412
+ result = self.model.compute_score(user)
413
+ assert result == "N/A: Invalid inputs - Field 'demographics.sex': must be MALE"
 
414
 
415
  def test_compute_score_invalid_age(self) -> None:
416
  user = UserInput(
tests/test_risk_models/test_gail_model.py CHANGED
@@ -286,7 +286,7 @@ class TestGailModel:
286
  assert float(score) > 0
287
 
288
  def test_male_patient_handling(self):
289
- """Test that male patients receive N/A response."""
290
  male_user = UserInput(
291
  demographics=Demographics(
292
  age_years=45,
@@ -299,8 +299,12 @@ class TestGailModel:
299
  personal_medical_history=PersonalMedicalHistory(),
300
  )
301
 
302
- score = self.model.compute_score(male_user)
303
- assert score == "N/A: Gail model is only applicable to female patients."
 
 
 
 
304
 
305
  def test_age_validation(self):
306
  """Test age validation (35-85 range)."""
@@ -365,3 +369,97 @@ class TestGailModel:
365
  assert "1.67" in self.model.interpretation()
366
  assert isinstance(self.model.references(), list)
367
  assert len(self.model.references()) > 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  assert float(score) > 0
287
 
288
  def test_male_patient_handling(self):
289
+ """Test that male patients raise ValueError due to validation failure."""
290
  male_user = UserInput(
291
  demographics=Demographics(
292
  age_years=45,
 
299
  personal_medical_history=PersonalMedicalHistory(),
300
  )
301
 
302
+ # Male patients should now raise ValueError due to validation failure
303
+ with pytest.raises(ValueError) as exc_info:
304
+ self.model.compute_score(male_user)
305
+
306
+ assert "Invalid inputs for Gail" in str(exc_info.value)
307
+ assert "must be FEMALE" in str(exc_info.value)
308
 
309
  def test_age_validation(self):
310
  """Test age validation (35-85 range)."""
 
369
  assert "1.67" in self.model.interpretation()
370
  assert isinstance(self.model.references(), list)
371
  assert len(self.model.references()) > 0
372
+
373
+ def test_male_patient_validation_rejection(self):
374
+ """Test that male patients are rejected during validation."""
375
+ male_user = UserInput(
376
+ demographics=Demographics(
377
+ age_years=40,
378
+ sex=Sex.MALE, # Male patient should be rejected
379
+ ethnicity=Ethnicity.WHITE,
380
+ anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=75.0),
381
+ ),
382
+ lifestyle=Lifestyle(
383
+ smoking=SmokingHistory(status=SmokingStatus.NEVER),
384
+ ),
385
+ personal_medical_history=PersonalMedicalHistory(),
386
+ )
387
+
388
+ # Validation should fail before compute_score is called
389
+ is_valid, errors = self.model.validate_inputs(male_user)
390
+ assert not is_valid
391
+ assert len(errors) == 1
392
+ assert "Field 'demographics.sex': must be FEMALE" in errors[0]
393
+
394
+ def test_male_patient_compute_score_raises_error(self):
395
+ """Test that compute_score raises ValueError for male patients."""
396
+ male_user = UserInput(
397
+ demographics=Demographics(
398
+ age_years=40,
399
+ sex=Sex.MALE,
400
+ ethnicity=Ethnicity.WHITE,
401
+ anthropometrics=Anthropometrics(height_cm=175.0, weight_kg=75.0),
402
+ ),
403
+ lifestyle=Lifestyle(
404
+ smoking=SmokingHistory(status=SmokingStatus.NEVER),
405
+ ),
406
+ personal_medical_history=PersonalMedicalHistory(),
407
+ )
408
+
409
+ # compute_score should raise ValueError due to validation failure
410
+ with pytest.raises(ValueError) as exc_info:
411
+ self.model.compute_score(male_user)
412
+
413
+ assert "Invalid inputs for Gail" in str(exc_info.value)
414
+ assert "must be FEMALE" in str(exc_info.value)
415
+
416
+ def test_ethnicity_restriction_validation(self):
417
+ """Test that unsupported ethnicities are rejected during validation."""
418
+ # Test with unsupported ethnicity
419
+ user = UserInput(
420
+ demographics=Demographics(
421
+ age_years=40,
422
+ sex=Sex.FEMALE,
423
+ ethnicity=Ethnicity.ASHKENAZI_JEWISH, # Not in supported list
424
+ anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0),
425
+ ),
426
+ lifestyle=Lifestyle(
427
+ smoking=SmokingHistory(status=SmokingStatus.NEVER),
428
+ ),
429
+ personal_medical_history=PersonalMedicalHistory(),
430
+ )
431
+
432
+ # Validation should fail
433
+ is_valid, errors = self.model.validate_inputs(user)
434
+ assert not is_valid
435
+ assert len(errors) == 1
436
+ assert "Field 'demographics.ethnicity': Input should be" in errors[0]
437
+ assert "WHITE" in errors[0] and "BLACK" in errors[0] and "ASIAN" in errors[0]
438
+
439
+ def test_supported_ethnicities_pass_validation(self):
440
+ """Test that all supported ethnicities pass validation."""
441
+ supported_ethnicities = [
442
+ Ethnicity.WHITE,
443
+ Ethnicity.BLACK,
444
+ Ethnicity.ASIAN,
445
+ Ethnicity.PACIFIC_ISLANDER,
446
+ Ethnicity.HISPANIC,
447
+ ]
448
+
449
+ for ethnicity in supported_ethnicities:
450
+ user = UserInput(
451
+ demographics=Demographics(
452
+ age_years=40,
453
+ sex=Sex.FEMALE,
454
+ ethnicity=ethnicity,
455
+ anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0),
456
+ ),
457
+ lifestyle=Lifestyle(
458
+ smoking=SmokingHistory(status=SmokingStatus.NEVER),
459
+ ),
460
+ personal_medical_history=PersonalMedicalHistory(),
461
+ )
462
+
463
+ is_valid, errors = self.model.validate_inputs(user)
464
+ assert is_valid, f"Failed for ethnicity: {ethnicity}"
465
+ assert len(errors) == 0