|
- using Qdrant.Client.Grpc;
- using System.Linq;
- using System.Text;
- using System.Text.Json;
- using System.Text.Json.Serialization;
- using System.Text.RegularExpressions;
-
- namespace Reranker
- {
- public class Tokenizer
- {
- private readonly List<VocabEntry>? _vocab2;
- private readonly List<List<object>>? _vocab;
- private readonly Dictionary<string, int> _tokenToId;
-
- private readonly Dictionary<(string, string), int>? _merges;
- private readonly string _cls = "[CLS]";
- private readonly string _sep = "[SEP]";
- private readonly string _pad = "[PAD]";
- private readonly string _unk = "[UNK]";
-
- public Tokenizer(string tokenizerJsonPath)
- {
- var json = File.ReadAllText(tokenizerJsonPath);
- var root = JsonSerializer.Deserialize<TokenizerJson>(json);
- if (root?.Model?.Vocab == null)
- throw new Exception("Le JSON ne contient pas le vocabulaire attendu");
-
- var model = root.Model;
- _vocab2 = model.Vocab;
- _vocab = model.VocabRaw;
-
- // Crée un dictionnaire token -> index
- _tokenToId = _vocab2
- .Select((entry, idx) => new { entry.Token, Id = idx })
- .ToDictionary(x => x.Token, x => x.Id);
-
- // Gestion des tokens spéciaux
- _cls = "<s>";
- _sep = "</s>";
- _pad = "<pad>";
- _unk = "<unk>";
- }
-
- private Dictionary<(string, string), int> LoadMerges(List<string> merges)
- {
- var dict = new Dictionary<(string, string), int>();
- for (int i = 0; i < merges.Count; i++)
- {
- var parts = merges[i].Split(' ');
- dict[(parts[0], parts[1])] = i;
- }
- return dict;
- }
-
- // ---- PUBLIC ------------------------------------------------------
- public (long[] inputIds, long[] attentionMask) EncodePair(string query, string document, int maxLen)
- {
- var tokens = new List<string>();
-
- tokens.Add(_cls);
- tokens.AddRange(EncodeText(query));
- tokens.Add(_sep);
- tokens.AddRange(EncodeText(document));
- tokens.Add(_sep);
-
- // Convert tokens → ids
- var ids = tokens
- .Select(t => _tokenToId.ContainsKey(t) ? _tokenToId[t] : _tokenToId[_unk])
- .ToList();
-
- // Troncature
- if (ids.Count > maxLen)
- ids = ids.Take(maxLen).ToList();
-
- // Padding
- while (ids.Count < maxLen)
- ids.Add(_tokenToId[_pad]);
-
- long[] inputIds = ids.Select(i => (long)i).ToArray();
- long[] attentionMask = inputIds.Select(x => x == _tokenToId[_pad] ? 0L : 1L).ToArray();
-
- return (inputIds, attentionMask);
- }
-
- // ---- TOKENIZATION BPE -------------------------------------------
-
- private List<string> EncodeText(string text)
- {
- text = text.ToLower().Trim();
- var words = Regex.Split(text, @"\s+");
-
- var tokens = new List<string>();
- foreach (var word in words)
- tokens.AddRange(TokenizeWordBpe(word));
-
- return tokens;
- }
-
-
- private List<string> TokenizeWordBpe(string word)
- {
- var tokens = new List<string>();
- int i = 0;
-
- while (i < word.Length)
- {
- string match = null;
-
- // On cherche le token le plus long qui correspond
- for (int j = word.Length; j > i; j--)
- {
- string sub = word[i..j];
- if (_tokenToId.ContainsKey(sub))
- {
- match = sub;
- break;
- }
- }
-
- if (match != null)
- {
- tokens.Add(match);
- i += match.Length;
- }
- else
- {
- tokens.Add(_unk);
- i += 1;
- }
- }
-
- return tokens;
- }
-
-
- private List<(string, string)> GetPairs(List<string> symbols)
- {
- var pairs = new List<(string, string)>();
-
- for (int i = 0; i < symbols.Count - 1; i++)
- pairs.Add((symbols[i], symbols[i + 1]));
-
- return pairs;
- }
-
- private List<string> Merge(List<string> symbols, (string, string) pair)
- {
- var merged = new List<string>();
-
- int i = 0;
- while (i < symbols.Count)
- {
- if (i < symbols.Count - 1 &&
- symbols[i] == pair.Item1 &&
- symbols[i + 1] == pair.Item2)
- {
- merged.Add(pair.Item1 + pair.Item2);
- i += 2;
- }
- else
- {
- merged.Add(symbols[i]);
- i++;
- }
- }
-
- return merged;
- }
-
- // ---- JSON STRUCTS -----------------------------------------------
-
- private class TokenizerJson
- {
- [JsonPropertyName("model")]
- public TokenizerModel Model { get; set; }
-
- [JsonPropertyName("added_tokens")]
- public List<AddedToken> AddedTokens { get; set; }
-
- [JsonPropertyName("decoder")]
- public Decoder Decoder { get; set; }
-
- [JsonPropertyName("normalizer")]
- public Normalizer Normalizer { get; set; }
- }
-
- private class TokenizerModel
- {
- [JsonPropertyName("type")]
- public string Type { get; set; }
-
- [JsonPropertyName("unk_id")]
- public int UnkId { get; set; }
-
- // Vocabulaire : chaque élément est une paire [string, double]
- [JsonPropertyName("vocab")]
- public List<List<object>> VocabRaw { get; set; }
-
- // Transformation en liste typée
- [JsonIgnore]
- public List<VocabEntry> Vocab => VocabRaw
- .Select(x => new VocabEntry
- {
- Token = x[0].ToString(), // JsonElement → string
- Score = Convert.ToDouble(((JsonElement)x[1]).GetDouble())
- })
- .ToList();
- }
-
- public class VocabEntry
- {
- public string Token { get; set; }
- public double Score { get; set; }
- }
- // Exemple pour added_tokens
- public class AddedToken
- {
- [JsonPropertyName("id")]
- public int Id { get; set; }
-
- [JsonPropertyName("content")]
- public string Content { get; set; }
-
- [JsonPropertyName("single_word")]
- public bool SingleWord { get; set; }
-
- [JsonPropertyName("lstrip")]
- public bool LStrip { get; set; }
-
- [JsonPropertyName("rstrip")]
- public bool RStrip { get; set; }
-
- [JsonPropertyName("normalized")]
- public bool Normalized { get; set; }
-
- [JsonPropertyName("special")]
- public bool Special { get; set; }
- }
-
- // Decoder minimal
- public class Decoder
- {
- [JsonPropertyName("type")]
- public string Type { get; set; }
-
- [JsonPropertyName("replacement")]
- public string Replacement { get; set; }
-
- [JsonPropertyName("add_prefix_space")]
- public bool AddPrefixSpace { get; set; }
- }
-
- // Normalizer minimal
- public class Normalizer
- {
- [JsonPropertyName("type")]
- public string Type { get; set; }
- }
- }
- }
|