import '@tensorflow/tfjs-backend-cpu';
import '@tensorflow/tfjs-backend-webgl';
import * as tf from '@tensorflow/tfjs-core';
import { BertTokenizer } from './bert_tokenizer';
import { meanPooling } from './pooling';

const scoreDocuments = async (
  queryString: string,
  model: any,
  tokenizer: BertTokenizer,
  documentEmbeddings: tf.Tensor,
): Promise<number[]> => {
  const tokens = tokenizer.tokenize(queryString)
  return tf.tidy('calculate similarity matrix', () => {
    const modelOut = model.model.predict(tokens)
    let queryEmbedding = meanPooling(modelOut[0], tokens['attention_mask'])
    queryEmbedding = tf.div(queryEmbedding, tf.norm(queryEmbedding, 2))

    const cosineSimilarityMatrix = tf.matMul(documentEmbeddings, queryEmbedding, false, true)
    const { indices } = tf.topk(tf.reshape(cosineSimilarityMatrix, [-1]), documentEmbeddings.shape[0], true);

    return (indices.arraySync()) as number[]
  })
}

export { scoreDocuments }