Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gemma/flash_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ float HWY_INLINE SingleFlashAttentionRowVector(DF df, VF& x, float& old_max,
template <class DF, typename T = hn::TFromD<DF>,
class DF4 = hn::CappedTag<T, 4>, class VF4 = hn::Vec<DF4>,
class VF = hn::Vec<DF>, typename F>
static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
static HWY_INLINE VF4 Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,
F reducer) {
const DF4 df4;
constexpr size_t kMaxLanes = hn::MaxLanes(df);
Expand All @@ -469,7 +469,7 @@ static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3,

// Handles Up to 4 Q rows by NF*2 timesteps of flash attention.
template <int kNumQueries, class DF, class VF = hn::Vec<DF>>
static void HWY_INLINE FlashAttentionTileStepAndApplySoftCap(
static HWY_INLINE void FlashAttentionTileStepAndApplySoftCap(
DF df, float att_cap, float one_over_att_cap, VF& x_0_p0, VF& x_0_p1,
VF& x_1_p0, VF& x_1_p1, VF& x_2_p0, VF& x_2_p1, VF& x_3_p0, VF& x_3_p1,
float* HWY_RESTRICT old_max, float* HWY_RESTRICT old_d,
Expand Down
Loading