Phương pháp đào tạo "đa mã thông báo" của Meta: tốc độ suy luận tăng gấp 3 lần, hiệu suất tăng hơn 10%

Các nhà nghiên cứu đã đề xuất một phương pháp đào tạo mô hình ngôn ngữ quy mô lớn mới nhằm cải thiện hiệu quả mẫu và hiệu suất mô hình bằng cách dự đoán nhiều mã thông báo (multi-token prediction) trong tương lai cùng một lúc, cho thấy những lợi thế đáng kể trong cả nhiệm vụ tạo mã và ngôn ngữ tự nhiên. thời gian và tốc độ suy luận có thể tăng lên gấp ba lần.

Hiện tại, các mô hình ngôn ngữ lớn như GPT và Llama chủ yếu được đào tạo bằng cách dự đoán "mã thông báo tiếp theo" dựa trên "chuỗi từ trước đó". Nhưng bạn đã bao giờ nghĩ đến một câu hỏi, tại sao không cùng lúc dự đoán các token tiếp theo?

Gần đây, các nhà nghiên cứu từ Meta, Đại học ParisTech và Đại học Paris-Saclay đã cùng đề xuất một phương pháp đào tạo mới có thể dự đoán nhiều mã thông báo trong tương lai cùng một lúc, có thể cải thiện hiệu quả mẫu của mô hình.

Link file pdf: https://arxiv.org/pdf/2404.19737

Cách cổ điển để đào tạo LLM được gọi là “dự đoán mã thông báo tiếp theo”, một kỹ thuật học tập tự giám sát trong đó mô hình được cung cấp một chuỗi mã thông báo và phải dự đoán mã tiếp theo.

Sau đó, nó thêm mã thông báo được dự đoán vào đầu vào và lặp lại quy trình, mỗi lần một mã thông báo. Bằng cách thực hiện việc này nhiều lần trên một lượng lớn văn bản, mô hình sẽ học các mẫu chung cho phép nó tạo ra các đoạn văn bản mạch lạc.

Các nhà nghiên cứu đã nghiên cứu và ghi lại những hạn chế của dự đoán mã thông báo tiếp theo trong việc tiếp thu ngôn ngữ, kiến thức thế giới và khả năng suy luận .

Ví dụ: bằng cách chỉ tập trung vào một mã thông báo, mô hình sẽ trở nên quá nhạy cảm với các mẫu cục bộ và bỏ qua các dự đoán đòi hỏi phải suy luận trong phạm vi dài hơn. Các mô hình được đào tạo về dự đoán mã thông báo tiếp theo cũng yêu cầu lượng dữ liệu khổng lồ để đạt được mức độ trôi chảy mà con người có được với ít văn bản hơn nhiều.

1717377689239.png
Ví dụ: mô hình có tham số 13B có khả năng giải quyết vấn đề trên điểm chuẩn HumanEval tốt hơn 12% so với mô hình một mã thông báo có cùng kích thước và tốt hơn 17% trên điểm chuẩn MBPP. Ngoài ra, thông qua thử nghiệm trên các nhiệm vụ thuật toán nhỏ, các nhà nghiên cứu nhận thấy rằng dự đoán nhiều mã thông báo có lợi cho việc cải thiện phần đầu cảm ứng của mô hình và khả năng suy luận thuật toán.

Hơn nữa, các mô hình được đào tạo bằng cách sử dụng dự đoán nhiều mã thông báo sẽ suy luận nhanh hơn tới ba lần, ngay cả khi xử lý các lô dữ liệu lớn.

Dự đoán nhiều mã thông báo​

Dự đoán nhiều mã thông báo hướng dẫn LLM dự đoán một số mã thông báo trong tương lai từ mỗi vị trí trong tập đoàn đào tạo cùng một lúc. Các nhà nghiên cứu đề xuất một kiến trúc dự đoán nhiều mã thông báo đơn giản không yêu cầu thêm thời gian đào tạo hoặc chi phí bộ nhớ.

Các mô hình ngôn ngữ tiêu chuẩn học từ kho văn bản lớn bằng cách thực hiện nhiệm vụ "dự đoán mã thông báo tiếp theo". Mục tiêu của nhiệm vụ là giảm thiểu mất entropy chéo, trong đó mô hình cần tối đa hóa "dự đoán mã thông báo tiếp theo dựa trên lịch sử của chuỗi mã thông báo trước đó”.

Các nhà nghiên cứu đã khái quát nhiệm vụ "dự đoán một mã thông báo" thành "dự đoán nhiều mã thông báo". Tại mỗi vị trí trong dự đoán huấn luyện, mô hình cần dự đoán n mã thông báo trong tương lai cùng một lúc.

1717378631507.png

Kiến trúc biến áp với dự đoán đa mã thông báo
Để làm cho vấn đề có thể giải quyết được, giả sử rằng một mô hình ngôn ngữ lớn sử dụng mạng đường trục dùng chung để tạo ra biểu diễn tiềm ẩn z của bối cảnh được quan sát, sau đó đưa biểu diễn này vào n mạng đầu độc lập để dự đoán từng tương lai theo cách song song.

Việc mất entropy chéo của dự đoán nhiều mã thông báo có thể được chia thành hai phần: biểu diễn tiềm ẩn theo một chuỗi mã thông báo nhất định và dự đoán về n mã thông báo trong tương lai theo điều kiện của biểu diễn tiềm ẩn này.

Trong thực tế, kiến trúc bao gồm một mô hình đường trục Transformer dùng chung tạo ra các biểu diễn tiềm ẩn dựa trên các chuỗi từ ngữ cảnh, n đầu ra dựa trên lớp Transformer độc lập và một ma trận không nhúng chung.

Tiết kiệm bộ nhớ​

Khi đào tạo bộ dự đoán nhiều mã thông báo, vấn đề chính là việc sử dụng bộ nhớ GPU quá mức.

Trong các mô hình ngôn ngữ lớn (LLM) hiện tại, kích thước của từ vựng V thường lớn hơn nhiều so với kích thước d của biểu diễn tiềm năng, do đó vectơ logit trở thành nút thắt cổ chai trong việc sử dụng bộ nhớ GPU.

Nếu chúng ta chỉ triển khai một bộ dự đoán nhiều mã thông báo và lưu trữ tất cả các vectơ logit và độ dốc của chúng trong bộ nhớ thì mức sử dụng bộ nhớ sẽ tăng lên nhanh chóng vì mỗi vectơ có hình dạng (n, V), điều này rất tốn kém. mô hình có thể xử lý đồng thời và tăng mức sử dụng bộ nhớ GPU trung bình.

1717378698534.png

Các nhà nghiên cứu đã đề xuất một phương pháp triển khai hiệu quả về bộ nhớ nhằm giảm mức sử dụng bộ nhớ bằng cách điều chỉnh thứ tự của các hoạt động truyền tiến và truyền ngược.

Cụ thể, sau khi hoàn thành quá trình truyền tiến qua mạng đường trục dùng chung fs, mô hình thực hiện truyền tiến và lùi trên từng đầu ra độc lập fi theo trình tự và tích lũy gradient tại mạng đường trục, mỗi đầu ra vector logit và gradient của fi sẽ được giải phóng sau khi tính toán và không cần chiếm bộ nhớ cho đến khi hoàn thành tất cả các phép tính đầu.

Điều này có nghĩa là không cần lưu trữ lâu dài bất kỳ gradient nào ngoài các gradient của mạng đường trục, giúp giảm đáng kể mức sử dụng bộ nhớ GPU.

Bằng cách này, độ phức tạp của bộ nhớ của mô hình giảm từ O(nV+d) xuống O(V+d), giảm đáng kể mức sử dụng bộ nhớ tối đa của GPU mà không làm giảm thời gian chạy.

Giai đoạn suy luận​

Trong quá trình suy luận, cách sử dụng cơ bản nhất của mô hình này là sử dụng "đầu dự đoán mã thông báo tiếp theo" cho "dự đoán tự động hồi quy mã thông báo tiếp theo cơ bản" trong khi loại bỏ tất cả các mạng đầu khác.

Bạn cũng có thể sử dụng các mạng đầu ra bổ sung để giải mã tự suy luận nhằm tăng tốc độ giải mã từ mạng đầu dự đoán mã thông báo tiếp theo:

1. Giải mã song song theo từng khối, một biến thể của giải mã suy luận có thể dự đoán song song nhiều mã thông báo mà không cần thêm mô hình dự thảo;

2. Sử dụng giải mã suy đoán tương tự như cơ chế chú ý cây Medusa có thể cải thiện tốc độ và hiệu quả giải mã.

Kết quả thực nghiệm

Các nhà nghiên cứu đã tiến hành tổng cộng bảy thử nghiệm quy mô lớn để chứng minh tính hiệu quả của việc mất dự đoán nhiều mã thông báo.

Để so sánh công bằng giữa bộ dự đoán mã thông báo tiếp theo và bộ dự đoán n-mã thông báo, số lượng tham số mô hình trong thử nghiệm là như nhau, nghĩa là khi thêm n-1 lớp vào mạng đầu dự đoán trong tương lai, nó cũng sẽ là được chuyển từ đường trục mô hình dùng chung Ngoại trừ lớp n-1.

1717378960413.png
1. Hiệu suất được cải thiện khi kích thước mô hình tăng lên

Để nghiên cứu tác động của kích thước mô hình, các nhà nghiên cứu đã đào tạo "sáu" mô hình từ đầu, với kích thước từ 300M đến 13B tham số, sử dụng ít nhất 91B mã thông báo.
1717378984913.png
Như có thể thấy từ kết quả đánh giá, các thử nghiệm trên MBPP và HumanEval cho thấy rằng việc sử dụng dự đoán nhiều mã thông báo có thể đạt được hiệu suất tốt hơn trên các tập dữ liệu cố định trong cùng một tải tính toán.

Các nhà nghiên cứu tin rằng tính năng này chỉ có thể được phản ánh trong dữ liệu quy mô lớn và các mô hình kích thước lớn. Đây cũng có thể là lý do tại sao dự đoán đa mã thông báo chưa được sử dụng rộng rãi trong đào tạo mô hình ngôn ngữ quy mô lớn.

2. Tốc độ suy luận nhanh hơn

Các nhà nghiên cứu đã triển khai giải mã tự suy đoán tham lam bằng cách sử dụng xFormers với kích thước lô không đồng nhất và đo tốc độ giải mã của mô hình dự đoán 4 mã thông báo tốt nhất (tham số 7B) khi hoàn thành mã và dữ liệu ngôn ngữ tự nhiên.

Có thể thấy, phương pháp này nhanh hơn 3,0 lần trong tác vụ tạo mã, nhanh hơn 2,7 lần trong tác vụ tạo văn bản và nhanh hơn 6,4 lần trên mô hình dự đoán 8 byte.
1717379169087.png
Khi được đào tạo trước với tính năng dự đoán nhiều mã thông báo, mạng đầu bổ sung có thể chính xác hơn so với việc tinh chỉnh một mô hình dự đoán mã thông báo tiếp theo duy nhất, cho phép mô hình nhận ra toàn bộ tiềm năng của việc giải mã tự suy đoán.

3. Sử dụng dự đoán nhiều byte để tìm hiểu các mẫu chung

Để chứng minh rằng nhiệm vụ dự đoán mã thông báo tiếp theo có thể nắm bắt các mẫu cục bộ, các nhà nghiên cứu đã sử dụng trường hợp cực đoan của mã thông báo cấp byte và huấn luyện mô hình Biến áp cấp byte tham số 7B để xử lý 314B byte, tương đương với 116B mã thông báo

Mô hình dự đoán 8 byte đã đạt được những cải tiến hiệu suất đáng kể so với dự đoán byte tiếp theo, giải quyết hơn 67% vấn đề trên MBPP pass@1 và 20% vấn đề trên HumanEval pass@1.

Do đó, dự đoán nhiều byte là một phương pháp rất hứa hẹn giúp cho việc huấn luyện các mô hình cấp byte hiệu quả hơn.

Giải mã tự suy đoán có thể đạt được tốc độ tăng gấp 6 lần của mô hình dự đoán 8 byte, điều này hoàn toàn bù đắp chi phí cho "chuỗi byte dài hơn" trong quá trình suy luận và thậm chí nhanh gần gấp đôi so với mô hình dự đoán mã thông báo tiếp theo.

Mặc dù lượng dữ liệu được sử dụng để huấn luyện ít hơn 1,7 lần nhưng hiệu suất của mô hình dự đoán 8 byte vẫn gần bằng hiệu suất của mô hình dựa trên mã thông báo.

4. Tìm giá trị n tối ưu

Để hiểu rõ hơn về tác động của số lượng mã thông báo được dự đoán, các nhà nghiên cứu đã tiến hành thử nghiệm cắt bỏ toàn diện trên mô hình kích thước 7B (dữ liệu huấn luyện chứa mã thông báo 200B) và thử n = 1, 2, 4 trong các cài đặt thử nghiệm khác nhau.

Kết quả thử nghiệm cho thấy khi đào tạo với 4 mã thông báo trong tương lai, nó vượt qua các mô hình so sánh khác ở tất cả các chỉ số 1, 10 và 100 của HumanEval và MBPP: mức cải thiện của MBPP lần lượt là +3,8%, +2,1% và +3,2%, mức cải thiện của HumanEval lần lượt là +1,2%, +3,7% và +4,1%.

Điều thú vị là trên APPS/Intro, mức cải thiện hiệu suất khi n = 6 lần lượt là + 0,7%, +3,0% và +5,3%.

Kích thước cửa sổ tối ưu có thể phụ thuộc vào việc phân phối dữ liệu đầu vào. Đối với mô hình cấp byte, kích thước cửa sổ tối ưu nhất quán hơn trên các điểm chuẩn (8 byte).

5. Đào tạo nhiều token

Khi huấn luyện các mô hình machine learning, phương pháp huấn luyện nhiều token vẫn thể hiện ưu điểm cho nhiệm vụ dự đoán token tiếp theo khi xử lý nhiều chu kỳ huấn luyện của cùng một tập dữ liệu.

Mặc dù lợi thế giảm nhẹ khi thời gian đào tạo tăng lên, nhưng vẫn thấy mức cải thiện 2,4% ở chỉ báo pass@1 trên tập dữ liệu MBPP trên chỉ báo pass@ 100 trên tập dữ liệu HumanEval, mức cải thiện đạt 3,2%;
1717379210775.png
Kết quả cho thấy ngay cả sau nhiều lần đào tạo, phương pháp đào tạo đa mã thông báo vẫn có thể mang lại những cải thiện hiệu suất nhất định.

Nhưng đối với tập dữ liệu APPS/Intro, khi số lượng mã thông báo đào tạo đạt 200B, sử dụng phương pháp đào tạo với kích thước cửa sổ là 4 không còn là lựa chọn tối ưu. Kích thước cửa sổ có thể cần phải được điều chỉnh hoặc có thể sử dụng các chiến lược khác để cải thiện hơn nữa hiệu suất của mô hình.

6. Tinh chỉnh dự đoán nhiều mã thông báo

Trong lĩnh vực học máy, mô hình đào tạo trước được đào tạo thông qua chức năng dự đoán mất nhiều mã thông báo. So với mô hình dự đoán một mã thông báo truyền thống, phương pháp này cho thấy hiệu suất tốt hơn trong giai đoạn tinh chỉnh tiếp theo.

Các nhà nghiên cứu đã tinh chỉnh mô hình với tham số 7B trên tập dữ liệu CodeContests, so sánh mô hình có thể dự đoán 4 mã thông báo tiếp theo với mô hình dự đoán một mã thông báo cơ bản và thử mô hình dự đoán 4 mã thông báo sau khi loại bỏ tiêu đề dự đoán bổ sung các cài đặt được tinh chỉnh theo mục tiêu dự đoán một mã thông báo truyền thống.

1717379231983.png

Kết quả thử nghiệm cho thấy trên chỉ báo pass@k, dù sử dụng phương pháp tinh chỉnh nào thì hiệu suất của mô hình dự đoán 4 mã thông báo đều vượt trội so với mô hình dự đoán 4 mã thông báo. Nó cũng cho thấy mô hình dự đoán 4 mã thông báo. tốt hơn trong việc hiểu nhiệm vụ, giải quyết vấn đề và tạo ra sự đa dạng hóa. Câu trả lời là tốt hơn về mặt khả năng.

Các kết quả thử nghiệm cũng cho thấy rằng việc tinh chỉnh dự đoán một mã thông báo dựa trên quá trình đào tạo trước dự đoán 4 mã thông báo có thể là một chiến lược có hiệu suất tổng thể tốt nhất, so với mô hình học máy cổ điển là trước tiên sử dụng các tác vụ phụ trợ cho quá trình đào tạo trước và sau đó.

7. Dự đoán nhiều token bằng ngôn ngữ tự nhiên

Các nhà nghiên cứu đã đào tạo một mô hình với các tham số 7B và sử dụng ba phương pháp dự đoán mất mát khác nhau: dự đoán 4 mã thông báo, 2 mã thông báo và mã thông báo đơn và thực hiện trên 6 điểm chuẩn Đánh giá xử lý ngôn ngữ tự nhiên (NLP).
1717379250500.png
Trong nhiệm vụ tóm tắt, 8 điểm chuẩn khác nhau đã được sử dụng và chất lượng của văn bản được tạo được tự động đánh giá thông qua chỉ số ROUGE. Kết quả cho thấy cả 2 mã thông báo và 4 mã thông báo đều hoạt động tốt hơn so với đường cơ sở dự đoán một mã thông báo.
1717379269606.png
 


Đăng nhập một lần thảo luận tẹt ga
Top