"""Prompt injection filter and input sanitization.

Strips known jailbreak patterns from user queries (reject) and
corpus chunks at ingestion time (strip). Config-driven.
"""

import re
from typing import List

import yaml

from .errors import PromptInjectionError


def load_filter_config(config_path: str = "config.yaml") -> dict:
    with open(config_path) as f:
        cfg = yaml.safe_load(f)
    return cfg.get("policy", {}).get("prompt_filter", {})


def check_input(text: str, config_path: str = "config.yaml") -> str:
    """Validate user input. Returns text if clean, raises on violation."""
    filter_cfg = load_filter_config(config_path)

    if not filter_cfg.get("enabled", False):
        return text

    max_chars = filter_cfg.get("max_input_chars", 4000)
    if len(text) > max_chars:
        raise PromptInjectionError(
            f"Input exceeds maximum length ({len(text)} > {max_chars})",
            details={"length": len(text), "max": max_chars}
        )

    text_lower = text.lower()
    for pattern in filter_cfg.get("banned_patterns", []):
        if pattern.lower() in text_lower:
            raise PromptInjectionError(
                "Input contains a blocked pattern",
                details={"pattern_hint": pattern[:20] + "..."}
            )

    return text


def sanitize_chunk(text: str, config_path: str = "config.yaml") -> str:
    """Sanitize corpus chunk at ingestion time (strip, don't reject)."""
    filter_cfg = load_filter_config(config_path)

    if not filter_cfg.get("enabled", False):
        return text

    for pattern in filter_cfg.get("banned_patterns", []):
        text = re.sub(re.escape(pattern), "[FILTERED]", text, flags=re.IGNORECASE)

    return text


def sanitize_corpus_batch(chunks: List[str], config_path: str = "config.yaml") -> List[str]:
    """Sanitize a batch of chunks."""
    return [sanitize_chunk(c, config_path) for c in chunks]
