import faiss
import numpy as np
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import cv2
import os
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

# 📌 Configuración
DIMENSION = 2048  # ResNet50 genera vectores de 2048 dimensiones
INDEX_PATH = "image_index.faiss"
IMAGE_FOLDER = "/home/shopin86/public_html/public/files/demo"  # Carpeta donde están las imágenes

# 🔹 Cargar modelo ResNet50 pre-entrenado
resnet = models.resnet18(pretrained=False)
resnet.fc = torch.nn.Identity()  # Quitar la última capa (usamos el embedding)
resnet.eval()

# 🔹 Transformaciones para la imagen
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def extract_features(image_path):
    """ Extrae el embedding de una imagen con ResNet50 """
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convertir de BGR a RGB
    img = transform(img).unsqueeze(0)  # Convertir a tensor

    with torch.no_grad():
        features = resnet(img)  # Extraer el vector de características
    return features.numpy().flatten()  # Convertir a numpy array

def create_faiss_index():
    """ Crea un índice FAISS con IndexIVFFlat """
    quantizer = faiss.IndexFlatL2(DIMENSION)  # Cuantizador básico
    index = faiss.IndexIVFFlat(quantizer, DIMENSION, 100, faiss.METRIC_L2)  # IVF con 100 clusters
    return index

def load_or_create_index():
    """ Carga el índice FAISS desde disco o crea uno nuevo """
    if os.path.exists(INDEX_PATH):
        return faiss.read_index(INDEX_PATH)  # Cargar desde disco
    else:
        index = create_faiss_index()
        print("[INFO] Nuevo índice creado.")
        return index

def add_images_to_index(index):
    """ Extrae características de todas las imágenes y las guarda en FAISS """
    image_paths = [os.path.join(IMAGE_FOLDER, img) for img in os.listdir(IMAGE_FOLDER) if img.endswith((".jpg", ".png"))]


    print(f"[INFO] Imágenes detectadas: {image_paths}")  # 🔹 Agregar este print


    if not image_paths:
        print("[ERROR] No hay imágenes en la carpeta.")
        return

    feature_list = []
    for img_path in image_paths:
        features = extract_features(img_path)
        feature_list.append(features)

    feature_array = np.array(feature_list).astype("float32")

    if not index.is_trained:
        print("[INFO] Entrenando índice FAISS...")
        index.train(feature_array)  # Entrenar antes de agregar datos

    index.add(feature_array)  # Agregar datos a FAISS
    faiss.write_index(index, INDEX_PATH)  # Guardar en disco

    print(f"[INFO] Se añadieron {len(image_paths)} imágenes al índice.")
    return image_paths

def search_similar_images(image_path, index, image_paths, k=5):
    """ Busca imágenes similares y muestra los resultados """
    query_features = extract_features(image_path).reshape(1, -1)
    distances, indices = index.search(query_features, k)  # Buscar en FAISS

    print("[INFO] Resultados de búsqueda:")
    for i, idx in enumerate(indices[0]):
        if idx == -1:
            continue
        print(f"{i+1}. {image_paths[idx]} (Distancia: {distances[0][i]:.2f})")

    # 🔹 Mostrar imágenes de resultado
    fig, axes = plt.subplots(1, k+1, figsize=(15, 5))

    # Imagen de consulta
    query_img = cv2.imread(image_path)
    query_img = cv2.cvtColor(query_img, cv2.COLOR_BGR2RGB)
    axes[0].imshow(query_img)
    axes[0].set_title("Imagen de Consulta")
    axes[0].axis("off")

    # Imágenes más similares
    for i, idx in enumerate(indices[0]):
        if idx == -1:
            continue
        result_img = cv2.imread(image_paths[idx])
        result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
        axes[i+1].imshow(result_img)
        axes[i+1].set_title(f"Similar {i+1}")
        axes[i+1].axis("off")

    plt.show()

# 🚀 EJECUCIÓN PRINCIPAL
if __name__ == "__main__":
    index = load_or_create_index()
    image_paths = add_images_to_index(index)  # Agregar imágenes a FAISS

    consulta = "query.jpg"  # Cambia esta ruta por la imagen de prueba
    search_similar_images(consulta, index, image_paths, k=5)
