Skip to content

Commit

Permalink
inference: remove unused _validate_args function (#5505)
Browse files Browse the repository at this point in the history
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent f2cc809 commit c7f3032
Showing 1 changed file with 0 additions and 24 deletions.
24 changes: 0 additions & 24 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(self, model, config):
self.mp_group = config.tensor_parallel.tp_group
self.mpu = config.tensor_parallel.mpu

#self._validate_args(self.mpu, config.replace_with_kernel_inject)
self.quantize_merge_count = 1
self.quantization_scales = None

Expand Down Expand Up @@ -300,29 +299,6 @@ def _init_quantization_setting(self, quantization_setting):
f"mlp_extra_grouping = {self.mlp_extra_grouping}, "
f"quantize_groups = {self.quantize_groups}", [0])

# TODO: remove this function and add this functionality to pydantic config checking
def _validate_args(self, mpu, replace_with_kernel_inject):
# TODO: to support SD pipeline we need to avoid this check for now
if replace_with_kernel_inject and not isinstance(self.module, Module):
raise ValueError(f"model must be a torch.nn.Module, got {type(self.module)}")
if not isinstance(self._config.tensor_parallel.tp_size, int) or self._config.tensor_parallel.tp_size < 1:
raise ValueError(f"mp_size must be an int >= 1, got {self._config.tensor_parallel.tp_size}")

if mpu:
methods = ["get_model_parallel_group", "get_data_parallel_group"]
for method in methods:
if not hasattr(mpu, method):
raise ValueError(f"mpu is missing {method}")
if self._config.checkpoint is not None and not isinstance(self._config.checkpoint, (str, dict)):
raise ValueError(f"checkpoint must be None, str or dict, got {type(self._config.checkpoint)}")

supported_dtypes = [None, torch.half, torch.int8, torch.float, torch.bfloat16]
if self._config.dtype not in supported_dtypes:
raise ValueError(f"{self._config.dtype} not supported, valid dtype: {supported_dtypes}")

if self.injection_dict is not None and not isinstance(self.injection_dict, dict):
raise ValueError(f"injection_dict must be None or a dict, got: {self.injection_dict}")

def load_model_with_checkpoint(self, r_module):
self.mp_replace = ReplaceWithTensorSlicing(
mp_group=self.mp_group, mp_size=self._config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1)
Expand Down

0 comments on commit c7f3032

Please sign in to comment.