Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

关于 MoE 版本的辅助损失函数. #110

Open
Leiyi-Hu opened this issue Jan 9, 2025 · 6 comments
Open

关于 MoE 版本的辅助损失函数. #110

Leiyi-Hu opened this issue Jan 9, 2025 · 6 comments

Comments

@Leiyi-Hu
Copy link

Leiyi-Hu commented Jan 9, 2025

您好,这里的 aux_loss 看起来并没有被使用?还是通过其他的方式参与了训练呢?
image

@jingyaogong
Copy link
Owner

是的,简单起见这部分loss并没有加入训练😊

@Leiyi-Hu
Copy link
Author

Leiyi-Hu commented Jan 9, 2025

是的,简单起见这部分loss并没有加入训练😊

谢谢!如果需要加入训练,其实现是不是应该将每层的 loss 都存下来和最后的 ce loss 一起进行梯度计算呢?

@jingyaogong
Copy link
Owner

是的,简单起见这部分loss并没有加入训练😊

谢谢!如果需要加入训练,其实现是不是应该将每层的 loss 都存下来和最后的 ce loss 一起进行梯度计算呢?

是的,只需要把每一层的 aux_loss 累加,最后和 logits分类交叉熵loss 相加即可。

@Leiyi-Hu
Copy link
Author

Leiyi-Hu commented Jan 9, 2025

是的,简单起见这部分loss并没有加入训练😊

谢谢!如果需要加入训练,其实现是不是应该将每层的 loss 都存下来和最后的 ce loss 一起进行梯度计算呢?

是的,只需要把每一层的 aux_loss 累加,最后和 logits分类交叉熵loss 相加即可。

明白了,谢谢!另外有一个关于数据预处理问题想请教,
image
这里的 history 为什么截断为 max_length的一半?同时比较疑惑的是 max_length应该是以 token 为单位,这里的 history 看起来是字符串?

@jingyaogong
Copy link
Owner

jingyaogong commented Jan 9, 2025

是的,简单起见这部分loss并没有加入训练😊

谢谢!如果需要加入训练,其实现是不是应该将每层的 loss 都存下来和最后的 ce loss 一起进行梯度计算呢?

是的,只需要把每一层的 aux_loss 累加,最后和 logits分类交叉熵loss 相加即可。

明白了,谢谢!另外有一个关于数据预处理问题想请教, image 这里的 history 为什么截断为 max_length的一半?同时比较疑惑的是 max_length应该是以 token 为单位,这里的 history 看起来是字符串?


也是简单起见
如果在这里用tokenzier去严格统计token数量会增加不必要的时间(当然对后面的input_id做统计也行,只不过更麻烦)
考虑到M的字符串中的token数量一定是 < M 的


为什么是一半?
单轮对话的时候希望Q、A至多各占一半字符,超过部分直接简单粗暴的截断

@Leiyi-Hu
Copy link
Author

Leiyi-Hu commented Jan 9, 2025

是的,简单起见这部分loss并没有加入训练😊

谢谢!如果需要加入训练,其实现是不是应该将每层的 loss 都存下来和最后的 ce loss 一起进行梯度计算呢?

是的,只需要把每一层的 aux_loss 累加,最后和 logits分类交叉熵loss 相加即可。

明白了,谢谢!另外有一个关于数据预处理问题想请教, image 这里的 history 为什么截断为 max_length的一半?同时比较疑惑的是 max_length应该是以 token 为单位,这里的 history 看起来是字符串?

也是简单起见 如果在这里用tokenzier去严格统计token数量会增加不必要的时间(当然对后面的input_id做统计也行,只不过更麻烦) 考虑到M的字符串中的token数量一定是 < M 的

感谢您的解答!🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants