"""Local embedding service using sentence-transformers.

CPU-only. No Ollama dependency. Model is loaded once and cached.
Thread-safe for concurrent indexing via ThreadPoolExecutor.
"""

import threading
from typing import List, Optional

import numpy as np
import yaml

from utils.errors import EmbeddingServiceError

_model_lock = threading.Lock()
_model_instance = None
_model_name: Optional[str] = None


def _get_model(model_name: str, cache_dir: Optional[str] = None):
    """Lazy-load and cache the sentence-transformers model."""
    global _model_instance, _model_name

    if _model_instance is not None and _model_name == model_name:
        return _model_instance

    with _model_lock:
        # Double-check after acquiring lock
        if _model_instance is not None and _model_name == model_name:
            return _model_instance

        try:
            from sentence_transformers import SentenceTransformer
            _model_instance = SentenceTransformer(
                model_name,
                cache_folder=cache_dir,
                device="cpu"
            )
            _model_name = model_name
            return _model_instance
        except Exception as e:
            raise EmbeddingServiceError(
                f"Failed to load embedding model '{model_name}': {e}",
                details={"model": model_name}
            )


def get_embedding(text: str, config_path: str = "config.yaml") -> List[float]:
    """Compute embedding for a single text string.

    Returns a list of floats (embedding vector).
    """
    with open(config_path) as f:
        cfg = yaml.safe_load(f)

    emb_cfg = cfg["models"]["embeddings"]
    model_name = emb_cfg["model"]
    cache_dir = emb_cfg.get("cache_dir")

    model = _get_model(model_name, cache_dir)

    try:
        embedding = model.encode(text, convert_to_numpy=True, show_progress_bar=False)
        return embedding.tolist()
    except Exception as e:
        raise EmbeddingServiceError(
            f"Embedding generation failed: {e}",
            details={"model": model_name, "text_length": len(text)}
        )


def get_embeddings_batch(texts: List[str], config_path: str = "config.yaml") -> List[List[float]]:
    """Compute embeddings for a batch of texts.

    More efficient than calling get_embedding() in a loop —
    sentence-transformers batches internally.
    """
    with open(config_path) as f:
        cfg = yaml.safe_load(f)

    emb_cfg = cfg["models"]["embeddings"]
    model_name = emb_cfg["model"]
    cache_dir = emb_cfg.get("cache_dir")

    model = _get_model(model_name, cache_dir)

    try:
        embeddings = model.encode(
            texts,
            convert_to_numpy=True,
            show_progress_bar=len(texts) > 50,
            batch_size=32
        )
        return embeddings.tolist()
    except Exception as e:
        raise EmbeddingServiceError(
            f"Batch embedding failed: {e}",
            details={"model": model_name, "batch_size": len(texts)}
        )


def reset_model() -> None:
    """Reset cached model (for tests)."""
    global _model_instance, _model_name
    with _model_lock:
        _model_instance = None
        _model_name = None
