You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 line
2.9KB

  1. using Microsoft.ML.OnnxRuntime;
  2. using Microsoft.ML.OnnxRuntime.Tensors;
  3. using Models;
  4. using ToolsServices;
  5. namespace Reranker;
  6. public static class BgeReranker
  7. {
  8. private static InferenceSession? _session;
  9. private static Tokenizer? _tokenizer;
  10. private static bool LoadModeles()
  11. {
  12. LoggerService.LogDebug("BgeReranker.LoadModeles");
  13. try
  14. {
  15. LoggerService.LogDebug(typeof(InferenceSession).Assembly.Location);
  16. var pathModele = Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location);
  17. var modeleFullPath = Path.Combine(pathModele!, "onnx");
  18. _tokenizer = new Tokenizer(Path.Combine(modeleFullPath, "tokenizer.json"));
  19. /*
  20. var opts = new SessionOptions();
  21. opts.AppendExecutionProvider_CPU(); // forcé
  22. var session = new InferenceSession("onnx/model.onnx", opts);
  23. */
  24. _session = new InferenceSession(Path.Combine(modeleFullPath, "model.onnx"));
  25. return true;
  26. }
  27. catch (Exception ex)
  28. {
  29. LoggerService.LogError($"BgeReranker.LoadModeles : {ex.Message}");
  30. throw;
  31. }
  32. }
  33. public static List<RankedDocument> Rerank(string query, List<SearchResult> docs, int topK = 5)
  34. {
  35. LoggerService.LogDebug("BgeReranker.Rerank");
  36. if (_tokenizer == null || _session == null)
  37. {
  38. var b = LoadModeles();
  39. if(!b)
  40. {
  41. return new List<RankedDocument>();
  42. }
  43. }
  44. var scored = new List<RankedDocument>();
  45. foreach (var doc in docs)
  46. {
  47. float score = ScorePair(query, doc.Text);
  48. scored.Add(new RankedDocument
  49. {
  50. Text = doc.Text,
  51. Nom_Fichier = doc.Nom_Fichier,
  52. Score = score
  53. });
  54. }
  55. return scored
  56. .OrderByDescending(x => x.Score)
  57. .Take(topK)
  58. .ToList();
  59. }
  60. private static float ScorePair(string query, string document)
  61. {
  62. LoggerService.LogDebug("BgeReranker.ScorePair");
  63. var encoded = _tokenizer!.EncodePair(query, document, 100);
  64. var inputIds = new DenseTensor<long>(new[] { 1, encoded.inputIds.Length });
  65. var attentionMask = new DenseTensor<long>(new[] { 1, encoded.attentionMask.Length });
  66. for (int i = 0; i < encoded.inputIds.Length; i++)
  67. {
  68. inputIds[0, i] = encoded.inputIds[i];
  69. attentionMask[0, i] = encoded.attentionMask[i];
  70. }
  71. var inputs = new List<NamedOnnxValue>
  72. {
  73. NamedOnnxValue.CreateFromTensor("input_ids", inputIds),
  74. NamedOnnxValue.CreateFromTensor("attention_mask", attentionMask)
  75. };
  76. using var result = _session!.Run(inputs);
  77. var output = result.First().AsEnumerable<float>().First();
  78. return output;
  79. }
  80. }