forked from microsoft/presidio-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
presidio_recognizer_wrapper.py
75 lines (68 loc) · 2.56 KB
/
presidio_recognizer_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import List
from presidio_analyzer import EntityRecognizer
from presidio_analyzer.nlp_engine import NlpEngine
from presidio_evaluator import InputSample
from presidio_evaluator.models import BaseModel
from presidio_evaluator.span_to_tag import span_to_tag
class PresidioRecognizerWrapper(BaseModel):
def __init__(
self,
recognizer: EntityRecognizer,
nlp_engine: NlpEngine,
entities_to_keep: List[str] = None,
labeling_scheme: str = "BILUO",
with_nlp_artifacts: bool = False,
verbose: bool = False,
):
"""
Evaluator for one specific PII recognizer
To evaluate the entire set of recognizers, refer to PresidioAnaylzerWrapper
:param recognizer: An object of type EntityRecognizer (in presidio-analyzer)
:param nlp_engine: An object of type NlpEngine, e.g. SpacyNlpEngine (in presidio-analyzer)
:param entities_to_keep: List of entity types to focus on while ignoring all the rest.
Default=None would look at all entity types
:param with_nlp_artifacts: Whether NLP artifacts should be obtained
(faster if not, but some recognizers need it)
"""
super().__init__(
entities_to_keep=entities_to_keep,
verbose=verbose,
labeling_scheme=labeling_scheme,
)
self.with_nlp_artifacts = with_nlp_artifacts
self.recognizer = recognizer
self.nlp_engine = nlp_engine
#
def __make_nlp_artifacts(self, text: str):
return self.nlp_engine.process_text(text, "en")
#
def predict(self, sample: InputSample) -> List[str]:
nlp_artifacts = None
if self.with_nlp_artifacts:
nlp_artifacts = self.__make_nlp_artifacts(sample.full_text)
results = self.recognizer.analyze(
sample.full_text, self.entities, nlp_artifacts
)
starts = []
ends = []
tags = []
scores = []
for res in results:
if not res.start:
res.start = 0
starts.append(res.start)
ends.append(res.end)
tags.append(res.entity_type)
scores.append(res.score)
response_tags = span_to_tag(
scheme=self.labeling_scheme,
text=sample.full_text,
start=starts,
end=ends,
tag=tags,
tokens=sample.tokens,
scores=scores,
)
if len(sample.tags) == 0:
sample.tags = ["0" for _ in response_tags]
return response_tags