107 lines
2.4 KiB
Python
107 lines
2.4 KiB
Python
import os
|
|
from FlagEmbedding import BGEM3FlagModel
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from typing import List
|
|
|
|
|
|
|
|
# --- ساخت اپلیکیشن ---
|
|
def create_app() -> FastAPI:
|
|
app = FastAPI(
|
|
title="embeding-services Backend",
|
|
version="0.1.0",
|
|
)
|
|
|
|
origins = ["*"]
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
@app.get("/")
|
|
async def simple():
|
|
return "embeding-services OK"
|
|
|
|
@app.get("/ping")
|
|
async def ping():
|
|
return "embeding-services OK"
|
|
|
|
return app
|
|
|
|
|
|
# ---------- Model ----------
|
|
load_dotenv()
|
|
bge_m3_path = os.getenv("EMBED_BGE")
|
|
|
|
bgem3_model = BGEM3FlagModel(
|
|
model_name_or_path=bge_m3_path,
|
|
use_fp16=True,
|
|
devices="cpu",
|
|
cache_dir="./cache",
|
|
normalize_embeddings=True,
|
|
trust_remote_code=False,
|
|
)
|
|
|
|
# ---------- Schema ----------
|
|
class InputForm(BaseModel):
|
|
id: str
|
|
sentence: str
|
|
|
|
class EmbedInput(BaseModel):
|
|
max_length: int = 1024
|
|
batch_size: int = 64
|
|
sentences: List[InputForm]
|
|
|
|
class OutputForm(BaseModel):
|
|
id: str
|
|
sentence: str
|
|
vec: List[float]
|
|
|
|
class EmbedOut(BaseModel):
|
|
result: List[OutputForm]
|
|
|
|
|
|
def chunks(lst, n):
|
|
"""Split list to batches"""
|
|
for i in range(0, len(lst), n):
|
|
yield lst[i:i+n]
|
|
|
|
|
|
# ---------- Endpoint ----------
|
|
# ✅ نمونهسازی نهایی
|
|
app = create_app()
|
|
|
|
@app.post("/embedding/bge-m3", response_model=EmbedOut)
|
|
async def embed_bge(payload: EmbedInput):
|
|
|
|
sentences_text = [s.sentence for s in payload.sentences]
|
|
ids = [s.id for s in payload.sentences]
|
|
|
|
results = []
|
|
|
|
for batch_texts, batch_ids in zip(
|
|
chunks(sentences_text, payload.batch_size),
|
|
chunks(ids, payload.batch_size)):
|
|
|
|
emb = bgem3_model.encode(
|
|
sentences=batch_texts,
|
|
max_length=payload.max_length
|
|
)["dense_vecs"]
|
|
|
|
for i in range(len(batch_texts)):
|
|
results.append(
|
|
OutputForm(
|
|
id=batch_ids[i],
|
|
sentence=batch_texts[i],
|
|
vec=emb[i].tolist() if hasattr(emb[i], "tolist") else emb[i]
|
|
)
|
|
)
|
|
|
|
return EmbedOut(result=results)
|