En proyectos generados por usuarios, a menudo tenemos que lidiar con duplicados, lo cual es especialmente importante para nosotros, ya que el contenido principal de la aplicación móvil iFunny son imágenes que se publican por decenas de miles todos los días. Hemos escrito un sistema separado para buscar repeticiones para facilitar el proceso y ahorrar mucho tiempo.
Debajo del corte, consideraremos las herramientas utilizadas y luego pasaremos a una implementación de ejemplo.
Red neuronal convolucional (CNN)
Hay una gran cantidad de algoritmos de búsqueda duplicados diferentes, cada uno con sus pros y sus contras. Uno de ellos es la búsqueda de los vectores más similares (más cercanos) obtenidos usando las redes CNN .
Una vez que la imagen se clasifica a través de la red CNN, la salida es un vector de "lo que la red vio en la imagen". En teoría, este método debería ser menos sensible al recorte, pero habrá más imágenes similares falsas en comparación con métodos más precisos.
También hay otro inconveniente. En la salida de la clasificación, se obtiene un vector grande (2048 float para resnet152), que debe almacenarse en algún lugar y poder encontrar todos los N vectores similares para el buscado en un período de tiempo razonable, lo que en sí mismo no es fácil.
FAISS
Encontrar los vectores más cercanos es una tarea común para la que ya existen excelentes herramientas. Aquí, la biblioteca FAISS de Facebook se considera líder . Utiliza una agrupación de vectores eficiente, lo que le permite organizar búsquedas incluso para vectores que no caben en la RAM.
Pero trabajar con FAISS directamente no es muy conveniente. Esta no es una base de datos, no puede simplemente guardar un vector allí y consultar uno similar (además, después de crear un índice, solo puede volver a crearlo). Por lo tanto, para la operación industrial, necesita construir su arnés alrededor del sistema de indexación.
Milvus
Para ello existe un proyecto muy prometedor, Milvus , que es muy similar en diseño a Elasticsearch. La única diferencia es que Elasticsearch se basa en el índice lucene, mientras que en Milvus toda la arquitectura se basa en el índice FAISS.
La estructura de las colecciones también es similar:
para cada colección, puede crear varias particiones, por lo que la búsqueda se limitará más adelante. Una partición consta de segmentos, que son un conjunto simple de archivos con id, índices iniciales y datos de servicio.
La información sobre colecciones, particiones y segmentos se almacena en una base de datos SQL separada. SQLite incorporado se usa para el lanzamiento independiente, y también es posible usar una base de datos MySQL externa.
El proyecto Milvus está en desarrollo activo (la versión actual es 0.11.0). Hasta ahora, no hay replicación de datos, así como la capacidad de usar otras bases de datos SQL (o NoSQL) como almacenamiento de metainformación. Por lo tanto, por ahora, para las soluciones de alta disponibilidad, solo puede usar un esquema con dos instancias con almacenamiento compartido: una se ejecutará y la otra estará "inactiva". Mishards se pueden usar para escalar , pero en 0.11.0 está roto.
Además, en 0.11.0, fue posible guardar datos adicionales en la colección junto con el vector y la identificación en sí. Es cierto, hasta ahora sin índices adicionales para ellos, pero con la capacidad de buscar.
Desde el punto de vista del uso, Milvus parece una base de datos externa normal. Existe una API (cliente gRPC y un conjunto de métodos http) para almacenar y buscar un vector, administrar colecciones e índices, así como obtener información sobre todas las entidades.
Al crear una colección, puede especificar el número máximo de entradas por segmento (segmento_row_limit). Si se excede este límite, Milvus comenzará a construir el índice FAISS. Esto está relacionado con una de las características de Milvus: para todos los vectores agregados, para los que aún no se ha creado un índice, la búsqueda funcionará sobre la base de una búsqueda completa. Por lo tanto, para valores grandes de segmento_row_limit, habrá muchos registros para los que aún no se ha creado el índice (también afecta la cantidad de segmentos que se crearán para la colección). Para encontrar vectores similares en la colección, debe hacer una búsqueda en cada segmento, y cuanto más haya, más larga será la búsqueda.
Tenga en cuenta que el segmento recién creado no se llena hasta el límite al agregar nuevos registros. En cambio, hay una fusión gradual de segmentos basada en el principio de un juego. 2048 (hasta que el tamaño supere el límite). Por lo tanto, si especifica un valor mayor para el segmento_row_limit, puede haber muchos segmentos más pequeños para los que no hay índice, lo que significa que la búsqueda en ellos será lenta.
A pesar de todas las peculiaridades, la búsqueda vectorial es rápida. La arquitectura de los índices FAISS y Milvus en sí le permite buscar simultáneamente valores a lo largo de varios vectores a la vez. Y en la práctica, una búsqueda secuencial de dos vectores será significativamente más lenta que la búsqueda de ambos vectores a la vez .
Implementación de búsqueda duplicada
Milvus se puede ejecutar tanto en la versión de CPU como en la GPU. El primero se utiliza mejor en procesadores que admiten la instrucción AVX512 . Para hacer esto, simplemente inicie el contenedor:
docker run -d --rm --name milvusdb -p 19530:19530 -p 19121:19121 \ milvusdb/milvus:0.11.0-cpu-d101620-4c44c0
En este caso, 19530 será el puerto para el cliente gRPC y 19121 para la API http .
Como una cadena de CNN, puede tomar cualquiera de los previamente entrenados (o aprender usted mismo); los resultados pueden diferir ligeramente. En este ejemplo, usaremos el resnet152 previamente entrenado:
model = models.resnet152(pretrained=True)
El vector se eliminará de la capa `avgpool`:
layer = model._modules.get('avgpool')
Y obtenga el vector en sí usando un gancho:
vector = torch.zeros(2048)
def copy_data(m, i, o):
vector.copy_(torch.reshape(o.data, (1, 2048))[0])
hook = layer.register_forward_hook(copy_data)
model(prepared_image)
hook.remove()
El código completo para obtener un vector se ve así:
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image
model = models.resnet152(pretrained=True)
layer = model._modules.get('avgpool')
model.eval()
pipeline = [
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]
def _prepare_Image(img: Image) -> Variable:
raw = img
for action in pipeline:
raw = action(raw)
return Variable(raw.unsqueeze(0))
def image_vectorization(image_path: str) -> np.ndarray:
img = Image.open(image_path)
prepared_image = _prepare_Image(img)
vector = torch.zeros(2048)
def copy_data(m, i, o):
vector.copy_(torch.reshape(o.data, (1, 2048))[0])
hook = layer.register_forward_hook(copy_data)
model(prepared_image)
hook.remove()
# vector normalization
norm_vector = vector / torch.norm(vector)
return np.array(norm_vector)
Ahora necesitas un cliente para trabajar con Milvus. Puede tomar cualquiera de los compatibles (Python, Java, Go, Rest, C ++). Tomemos un cliente Java y escribamos un ejemplo en Kotlin. ¿Por qué? Por qué no.
Conectamos Milvus SDK:
implementation("io.milvus:milvus-sdk-java:0.9.0")
Crear conexión con Milvus:
val connectParam = ConnectParam.Builder()
.withHost("localhost")
.withPort(19530)
.build()
val client = MilvusGrpcClient(connectParam)
Crear una colección para el vector 2048:
val collectionMapping = CollectionMapping.create(collectionName)
.addVectorField("float_vec", DataType.VECTOR_FLOAT, 2048)
// id
.setParamsInJson(JsonBuilder()
.param("auto_id", false)
.param("segment_row_limit", segmentRowLimit)
.build()
)
client.createCollection(collectionMapping)
Crear índice IVF_SQ8:
Index.create(collectionName, "float_vec")
.setIndexType(IndexType.IVF_SQ8)
.setMetricType(MetricType.L2)
.setParamsInJson(JsonBuilder()
.param("nlist", 16384)
.build()
)
client.createIndex(index)
Guarde algunos vectores en una colección:
InsertParam.create(collectionName)
.setEntityIds(listOf(1L, 2L))
.addVectorField("float_vec", DataType.VECTOR_FLOAT, listOf(vector1, vector2))
client.insert(insertParam)
client.flush(collectionName) //
Buscamos un vector previamente guardado:
val dsl = JsonBuilder().param(
"bool", mapOf(
"must" to listOf(
mapOf(
"vector" to mapOf(
"float_vec" to
mapOf(
"topk" to 10,
"metric_type" to MetricType.L2,
"type" to "float",
"query" to listOf(vector1),
"params" to mapOf("nprobe" to 50)
)
)
)
)
)
).build()
val searchParam = SearchParam.create(collectionName)
.setDsl(dsl)
val result = client.search(searchParam)
println(result.queryResultsList[0].map { it.entityId to it.distance })
Si todo funciona y está configurado correctamente, se devolverá un resultado similar:
[(1, 0.0), (2, 0.2)]
Para el primer vector, la distancia L2 consigo mismo será 0, y con el otro vector será mayor que 0.
Todo lo anterior, por supuesto, es solo un boceto, pero esto es suficiente para intentar crear un servicio Python para clasificar y obtener un vector. Y termine la API para guardar y buscar vectores para él, o hacerlo en un servicio separado (por ejemplo, en Kotlin), que recibirá el vector y lo guardará en Milvus por sí solo.
Gracias a todos los que leyeron hasta el final, espero que hayan encontrado algo nuevo para ustedes. Y si está interesado en el proyecto Milvus, puede apoyarlo en Github .