Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support split qkv linear and sp overlap comm #415

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
12 changes: 9 additions & 3 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def validate_args(args, defaults={}):

if args.ds_sequence_parallel_size > 1:
assert version.parse(deepspeed.__version__) >= version.parse("0.10.2"), "sequence parallelism requires DeepSpeed version 0.10.2+"

if args.ds_sequence_parallel_overlap_comm:
assert args.split_qkv_linear, \
"ds_sequence_parallel_overlap_comm requires split_qkv_linear is True"
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size * \
args.ds_sequence_parallel_size
Expand Down Expand Up @@ -924,6 +926,9 @@ def _add_training_args(parser):
group.add_argument('--disable-moe-top2-2nd-expert-sampling', action='store_false',
help='Disable MoE top2 sampling of the 2nd expert. Instead of sampling, use argmax.',
dest='moe_top2_2nd_expert_sampling')
group.add_argument('--split-qkv-linear', action='store_true',
help='Separate linear computations for query, key, and value.',
dest='split_qkv_linear')
group.add_argument('--use-flash-attn', '--use-flash-attn-v1', dest='use_flash_attn_v1', action='store_true',
help='use first version FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135')
Expand Down Expand Up @@ -975,14 +980,15 @@ def _add_training_args(parser):
help='Enable DeepSpeed\'s sequence parallel. Cannot be combined with "--sequence-parallel", which enables Megatron-LM\'s sequence parallel.')
group.add_argument('--force-ds-sequence-parallel', action='store_true',
help='use DeepSpeed sequence parallelism regardless of sequence parallel size.')

group.add_argument('--ds-sequence-parallel-overlap-comm', action='store_true',
help='overlap comm for ds-sequence-parallel',
dest='ds_sequence_parallel_overlap_comm')
group.add_argument('--ds-sequence-parallel-fpdt', action='store_true',
help='use DeepSpeed sequence parallelism with FPDT.')
group.add_argument('--ds-sequence-parallel-fpdt-chunk-size', type=int, default=65536,
help='Chunk size used in FPDT attention.')
group.add_argument('--ds-sequence-parallel-fpdt-offloading', action='store_true',
help='use DeepSpeed sequence parallelism FPDT with offloading.')

group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
Expand Down
34 changes: 24 additions & 10 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from megatron import get_args

from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore

from megatron.core.parallel_state import (
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -248,13 +249,14 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel):
async_grad_allreduce, sequence_parallel, bwd_stream=None):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel = sequence_parallel

ctx.bwd_stream = bwd_stream

if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
Expand Down Expand Up @@ -314,6 +316,7 @@ def backward(ctx, grad_output):
total_input = all_gather_buffer
else:
total_input = input

grad_input = grad_output.matmul(weight)

if ctx.sequence_parallel:
Expand Down Expand Up @@ -368,23 +371,30 @@ def backward(ctx, grad_output):
# grad_weight = None
# else:
# grad_weight = grad_output.t().matmul(total_input)
if args.enable_zbh1_pipeline:
from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore

if ctx.bwd_stream is not None:
# for sp overlap communication
ctx.bwd_stream.wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(ctx.bwd_stream):
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
grad_weight = None
elif args.enable_zbh1_pipeline:
WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)

grad_bias = grad_output.sum(dim=0) if use_bias else None

if ctx.bwd_stream is not None:
total_input.record_stream(ctx.bwd_stream)
grad_output.record_stream(ctx.bwd_stream)
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None

if ctx.async_grad_allreduce:
handle.wait()

return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None

def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
Expand All @@ -393,6 +403,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
async_sp_all2all_stream=None
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
Expand Down Expand Up @@ -453,6 +464,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel,
async_sp_all2all_stream
]

if not linear_with_grad_accumulation_and_async_allreduce.warned:
Expand Down Expand Up @@ -607,7 +619,6 @@ def __init__(self, input_size, output_size, *,
"cannot be enabled at the same time."
)


def forward(self,
input_: torch.Tensor,
weight: Optional[torch.Tensor] = None):
Expand Down Expand Up @@ -706,9 +717,10 @@ def __init__(self, input_size: int, output_size: int, *,
stride: int = 1,
keep_master_weight_for_test: bool = False,
skip_bias_add: bool = False,
moe=False, enable_expert_tensor_parallelism=False):
moe=False, enable_expert_tensor_parallelism=False, ds_sp_async_stream=None):
torch.nn.Module.__init__(self)

self.ds_sp_async_stream = ds_sp_async_stream

# Keep input parameters
self.input_size = input_size
self.output_size = output_size
Expand Down Expand Up @@ -784,13 +796,15 @@ def forward(self, input_):
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.

output_parallel = linear_with_grad_accumulation_and_async_allreduce(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel=False,
async_sp_all2all_stream=self.ds_sp_async_stream
)

# All-reduce across all the partitions.
Expand Down
106 changes: 79 additions & 27 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,14 @@ class ParallelAttention(MegatronModule):
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""

sp_stream=None

def get_sp_stream(self):
if not self.ds_sp_overlap:
return None
if ParallelAttention.sp_stream is None:
ParallelAttention.sp_stream=get_accelerator().Stream()
return ParallelAttention.sp_stream
def __init__(self, config, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
Expand All @@ -524,7 +531,8 @@ def __init__(self, config, layer_number,
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.use_gqa = (self.num_attention_heads != self.num_key_value_heads)

self.split_qkv = args.split_qkv_linear
self.ds_sp_overlap = args.ds_sequence_parallel_overlap_comm
self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2 or \
args.use_flash_attn_builder) \
and attention_type == AttnType.self_attn \
Expand Down Expand Up @@ -577,13 +585,31 @@ def __init__(self, config, layer_number,

# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False)
if not self.split_qkv:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False)

else:
linear_configs = [
("query_linear", projection_size),
("key_linear", kv_projection_size),
("value_linear", kv_projection_size),
]

for attr_name, output_size in linear_configs:
setattr(self, attr_name, tensor_parallel.ColumnParallelLinear(
config.hidden_size,
output_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False
))
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
Expand Down Expand Up @@ -614,12 +640,14 @@ def __init__(self, config, layer_number,
self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel
if self.enable_ds_sequence_parallel:

assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0

self.dist_attn = DistributedAttention(
local_attn,
parallel_state.get_sequence_parallel_group(),
gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0)
gather_idx=1 if args.use_flash_attn_v1 or args.use_flash_attn_v2 else 0,sp_stream=self.get_sp_stream())
# flash_attn_cuda assumes [b, s, nh, hd] layout, we need to make sure all2all gathers into the correct sequence dimension.
else:
if self.use_flash_attn:
Expand All @@ -636,7 +664,9 @@ def __init__(self, config, layer_number,
init_method=config.output_layer_init_method,
bias=args.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True)
skip_bias_add=True,
ds_sp_async_stream=self.get_sp_stream()
)


def _checkpointed_attention_forward(self, query_layer, key_layer,
Expand Down Expand Up @@ -722,22 +752,41 @@ def forward(self, hidden_states, attention_mask,
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

if self.enable_ds_sequence_parallel:
assert self.projection_size == self.kv_projection_size
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
query_layer = mixed_x_layer[:, :, :self.projection_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, self.projection_size:self.projection_size+self.kv_projection_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, self.projection_size+self.kv_projection_size:].reshape(seq_len, bs, -1, self.head_dim)
if self.sequence_parallel or not self.enable_ds_sequence_parallel:
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
each_hidden_size = mixed_x_layer.shape[-1] // 3
query_layer = mixed_x_layer[:, :, :each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, each_hidden_size:each_hidden_size+each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, each_hidden_size+each_hidden_size:].reshape(seq_len, bs, -1, self.head_dim)

if not self.split_qkv:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

if self.enable_ds_sequence_parallel:
assert self.projection_size == self.kv_projection_size
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
query_layer = mixed_x_layer[:, :, :self.projection_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, self.projection_size:self.projection_size+self.kv_projection_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, self.projection_size+self.kv_projection_size:].reshape(seq_len, bs, -1, self.head_dim)
if self.sequence_parallel or not self.enable_ds_sequence_parallel:
seq_len, bs = mixed_x_layer.shape[0], mixed_x_layer.shape[1]
each_hidden_size = mixed_x_layer.shape[-1] // 3
query_layer = mixed_x_layer[:, :, :each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
key_layer = mixed_x_layer[:, :, each_hidden_size:each_hidden_size+each_hidden_size].reshape(seq_len, bs, -1, self.head_dim)
value_layer = mixed_x_layer[:, :, each_hidden_size+each_hidden_size:].reshape(seq_len, bs, -1, self.head_dim)
else:
assert self.ds_sp_overlap, """
Currently, the split_qkv operation is only applicable
when ds_sp_overlap is enabled.
"""
self.get_sp_stream().wait_stream(get_accelerator().current_stream())
with get_accelerator().stream(self.get_sp_stream()):
query_layer,_ = self.query_linear(hidden_states)
query_layer=query_layer.reshape(query_layer.shape[0],query_layer.shape[1],self.num_attention_heads,-1)
fwd_query_layer_done_event = get_accelerator().Event()
fwd_query_layer_done_event.record(self.get_sp_stream())
key_layer,_ = self.key_linear(hidden_states)
key_layer=key_layer.reshape(key_layer.shape[0],key_layer.shape[1],self.num_attention_heads,-1)

fwd_key_layer_done_event = get_accelerator().Event()
fwd_key_layer_done_event.record(self.get_sp_stream())
value_layer,_ = self.value_linear(hidden_states)
value_layer=value_layer.reshape(value_layer.shape[0],value_layer.shape[1],self.num_attention_heads,-1)

# Repeat kv
if self.use_gqa:
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
Expand Down Expand Up @@ -833,6 +882,9 @@ def forward(self, hidden_states, attention_mask,
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

if self.enable_ds_sequence_parallel:
if self.ds_sp_overlap:
key_layer.done_event=fwd_key_layer_done_event
query_layer.done_event=fwd_query_layer_done_event
batch_dim_idx = 1
if self.use_flash_attn:
if not self.use_flash_attn_triton:
Expand Down