您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

262 行
7.7KB

  1. using Qdrant.Client.Grpc;
  2. using System.Linq;
  3. using System.Text;
  4. using System.Text.Json;
  5. using System.Text.Json.Serialization;
  6. using System.Text.RegularExpressions;
  7. namespace Reranker
  8. {
  9. public class Tokenizer
  10. {
  11. private readonly List<VocabEntry>? _vocab2;
  12. private readonly List<List<object>>? _vocab;
  13. private readonly Dictionary<string, int> _tokenToId;
  14. private readonly Dictionary<(string, string), int>? _merges;
  15. private readonly string _cls = "[CLS]";
  16. private readonly string _sep = "[SEP]";
  17. private readonly string _pad = "[PAD]";
  18. private readonly string _unk = "[UNK]";
  19. public Tokenizer(string tokenizerJsonPath)
  20. {
  21. var json = File.ReadAllText(tokenizerJsonPath);
  22. var root = JsonSerializer.Deserialize<TokenizerJson>(json);
  23. if (root?.Model?.Vocab == null)
  24. throw new Exception("Le JSON ne contient pas le vocabulaire attendu");
  25. var model = root.Model;
  26. _vocab2 = model.Vocab;
  27. _vocab = model.VocabRaw;
  28. // Crée un dictionnaire token -> index
  29. _tokenToId = _vocab2
  30. .Select((entry, idx) => new { entry.Token, Id = idx })
  31. .ToDictionary(x => x.Token, x => x.Id);
  32. // Gestion des tokens spéciaux
  33. _cls = "<s>";
  34. _sep = "</s>";
  35. _pad = "<pad>";
  36. _unk = "<unk>";
  37. }
  38. private Dictionary<(string, string), int> LoadMerges(List<string> merges)
  39. {
  40. var dict = new Dictionary<(string, string), int>();
  41. for (int i = 0; i < merges.Count; i++)
  42. {
  43. var parts = merges[i].Split(' ');
  44. dict[(parts[0], parts[1])] = i;
  45. }
  46. return dict;
  47. }
  48. // ---- PUBLIC ------------------------------------------------------
  49. public (long[] inputIds, long[] attentionMask) EncodePair(string query, string document, int maxLen)
  50. {
  51. var tokens = new List<string>();
  52. tokens.Add(_cls);
  53. tokens.AddRange(EncodeText(query));
  54. tokens.Add(_sep);
  55. tokens.AddRange(EncodeText(document));
  56. tokens.Add(_sep);
  57. // Convert tokens → ids
  58. var ids = tokens
  59. .Select(t => _tokenToId.ContainsKey(t) ? _tokenToId[t] : _tokenToId[_unk])
  60. .ToList();
  61. // Troncature
  62. if (ids.Count > maxLen)
  63. ids = ids.Take(maxLen).ToList();
  64. // Padding
  65. while (ids.Count < maxLen)
  66. ids.Add(_tokenToId[_pad]);
  67. long[] inputIds = ids.Select(i => (long)i).ToArray();
  68. long[] attentionMask = inputIds.Select(x => x == _tokenToId[_pad] ? 0L : 1L).ToArray();
  69. return (inputIds, attentionMask);
  70. }
  71. // ---- TOKENIZATION BPE -------------------------------------------
  72. private List<string> EncodeText(string text)
  73. {
  74. text = text.ToLower().Trim();
  75. var words = Regex.Split(text, @"\s+");
  76. var tokens = new List<string>();
  77. foreach (var word in words)
  78. tokens.AddRange(TokenizeWordBpe(word));
  79. return tokens;
  80. }
  81. private List<string> TokenizeWordBpe(string word)
  82. {
  83. var tokens = new List<string>();
  84. int i = 0;
  85. while (i < word.Length)
  86. {
  87. string match = null;
  88. // On cherche le token le plus long qui correspond
  89. for (int j = word.Length; j > i; j--)
  90. {
  91. string sub = word[i..j];
  92. if (_tokenToId.ContainsKey(sub))
  93. {
  94. match = sub;
  95. break;
  96. }
  97. }
  98. if (match != null)
  99. {
  100. tokens.Add(match);
  101. i += match.Length;
  102. }
  103. else
  104. {
  105. tokens.Add(_unk);
  106. i += 1;
  107. }
  108. }
  109. return tokens;
  110. }
  111. private List<(string, string)> GetPairs(List<string> symbols)
  112. {
  113. var pairs = new List<(string, string)>();
  114. for (int i = 0; i < symbols.Count - 1; i++)
  115. pairs.Add((symbols[i], symbols[i + 1]));
  116. return pairs;
  117. }
  118. private List<string> Merge(List<string> symbols, (string, string) pair)
  119. {
  120. var merged = new List<string>();
  121. int i = 0;
  122. while (i < symbols.Count)
  123. {
  124. if (i < symbols.Count - 1 &&
  125. symbols[i] == pair.Item1 &&
  126. symbols[i + 1] == pair.Item2)
  127. {
  128. merged.Add(pair.Item1 + pair.Item2);
  129. i += 2;
  130. }
  131. else
  132. {
  133. merged.Add(symbols[i]);
  134. i++;
  135. }
  136. }
  137. return merged;
  138. }
  139. // ---- JSON STRUCTS -----------------------------------------------
  140. private class TokenizerJson
  141. {
  142. [JsonPropertyName("model")]
  143. public TokenizerModel Model { get; set; }
  144. [JsonPropertyName("added_tokens")]
  145. public List<AddedToken> AddedTokens { get; set; }
  146. [JsonPropertyName("decoder")]
  147. public Decoder Decoder { get; set; }
  148. [JsonPropertyName("normalizer")]
  149. public Normalizer Normalizer { get; set; }
  150. }
  151. private class TokenizerModel
  152. {
  153. [JsonPropertyName("type")]
  154. public string Type { get; set; }
  155. [JsonPropertyName("unk_id")]
  156. public int UnkId { get; set; }
  157. // Vocabulaire : chaque élément est une paire [string, double]
  158. [JsonPropertyName("vocab")]
  159. public List<List<object>> VocabRaw { get; set; }
  160. // Transformation en liste typée
  161. [JsonIgnore]
  162. public List<VocabEntry> Vocab => VocabRaw
  163. .Select(x => new VocabEntry
  164. {
  165. Token = x[0].ToString(), // JsonElement → string
  166. Score = Convert.ToDouble(((JsonElement)x[1]).GetDouble())
  167. })
  168. .ToList();
  169. }
  170. public class VocabEntry
  171. {
  172. public string Token { get; set; }
  173. public double Score { get; set; }
  174. }
  175. // Exemple pour added_tokens
  176. public class AddedToken
  177. {
  178. [JsonPropertyName("id")]
  179. public int Id { get; set; }
  180. [JsonPropertyName("content")]
  181. public string Content { get; set; }
  182. [JsonPropertyName("single_word")]
  183. public bool SingleWord { get; set; }
  184. [JsonPropertyName("lstrip")]
  185. public bool LStrip { get; set; }
  186. [JsonPropertyName("rstrip")]
  187. public bool RStrip { get; set; }
  188. [JsonPropertyName("normalized")]
  189. public bool Normalized { get; set; }
  190. [JsonPropertyName("special")]
  191. public bool Special { get; set; }
  192. }
  193. // Decoder minimal
  194. public class Decoder
  195. {
  196. [JsonPropertyName("type")]
  197. public string Type { get; set; }
  198. [JsonPropertyName("replacement")]
  199. public string Replacement { get; set; }
  200. [JsonPropertyName("add_prefix_space")]
  201. public bool AddPrefixSpace { get; set; }
  202. }
  203. // Normalizer minimal
  204. public class Normalizer
  205. {
  206. [JsonPropertyName("type")]
  207. public string Type { get; set; }
  208. }
  209. }
  210. }