"""SafeClaw LangGraph Controller — Topology = Enforcement.

Graph flow (matching SafeClaw final diagram):
  retrieve -> route_by_score -> local_llm (high score)
                              -> user_gate (low score)
                                  -> grok_fallback (confirmed + hybrid)
                                  -> offline_best_effort (denied or offline)
  ALL paths -> audit_logger -> END

Invariants enforced by graph edges, not prompts:
1. Every query passes through retrieval first.
2. No LLM is called before score gate.
3. No Grok without explicit user confirmation AND hybrid mode.
4. Every response passes through audit logging.
"""

from typing import List, Optional, Literal, TypedDict

from langgraph.graph import StateGraph, END

from retrieval.hybrid_search import HybridRetriever, SearchResult
from llm.client import LocalLLMClient, GrokClient
from utils.logger import audit_log, hash_query
from utils.errors import RAGError, LLMServiceError, GrokServiceError


# =============================================================================
# State Definition
# =============================================================================

class RetrievedDoc(TypedDict):
    text: str
    score: float
    source: str
    chunk_id: int
    stem_tags: List[str]
    mode: str


class GraphState(TypedDict, total=False):
    # Inputs
    query: str

    # Retrieval outputs
    retrieved_docs: List[RetrievedDoc]
    top_score: float
    retrieval_mode: str  # "semantic" | "keyword" | "hybrid" | "none"

    # Control flags
    needs_user_confirm: bool
    user_confirmed_online: Optional[bool]

    # Model outputs
    answer: str
    answer_model: str  # "local" | "grok" | "offline-best-effort"
    answer_sources: List[RetrievedDoc]

    # Audit
    audit_event: dict

    # Error
    error: Optional[str]


# =============================================================================
# Node Functions
# =============================================================================

def retrieve_node(state: GraphState, retriever: HybridRetriever, cfg: dict) -> dict:
    """Node 1: Always runs first. Executes hybrid retrieval."""
    query = state["query"]

    try:
        results = retriever.hybrid_search(query)
    except RAGError as e:
        return {
            "retrieved_docs": [],
            "top_score": 0.0,
            "retrieval_mode": "none",
            "error": f"{e.code}: {e.message}"
        }

    docs = [
        RetrievedDoc(
            text=r.text,
            score=r.score,
            source=r.source,
            chunk_id=r.chunk_id,
            stem_tags=r.stem_tags[:5],
            mode=r.retrieval_mode
        )
        for r in results
    ]

    return {
        "retrieved_docs": docs,
        "top_score": docs[0]["score"] if docs else 0.0,
        "retrieval_mode": docs[0]["mode"] if docs else "none"
    }


def route_by_score_node(state: GraphState, cfg: dict) -> dict:
    """Node 2: Compare top_score to threshold. Sets routing flag."""
    threshold = cfg["retrieval"]["min_score"]
    top_score = state.get("top_score", 0.0)

    if top_score >= threshold:
        return {"needs_user_confirm": False}
    else:
        return {"needs_user_confirm": True}


def local_llm_node(state: GraphState, llm: LocalLLMClient, cfg: dict) -> dict:
    """Node 3: Build prompt from retrieved docs + query, call LM Studio."""
    query = state["query"]
    docs = state.get("retrieved_docs", [])

    context_chunks = "\n\n---\n\n".join([
        f"[Source: {d['source']}, Score: {d['score']:.3f}]\n{d['text']}"
        for d in docs[:5]
    ])

    prompt = f"""You are a helpful assistant with access to a local knowledge base.

USER QUERY: {query}

RETRIEVED CONTEXT:
{context_chunks}

Answer based STRICTLY on the retrieved context above. If the context is insufficient, say so explicitly. Do not fabricate information."""

    try:
        answer = llm.generate(prompt)
    except LLMServiceError as e:
        answer = f"[LLM Error: {e.message}]"

    return {
        "answer": answer,
        "answer_model": "local",
        "answer_sources": docs[:5]
    }


def user_gate_node(state: GraphState, cfg: dict) -> dict:
    """Node 4: User confirmation gate for Grok fallback.

    If user_confirmed_online is None (first pass), signal needs_confirm.
    If True/False, pass through for downstream routing.
    """
    confirmed = state.get("user_confirmed_online")
    top_score = state.get("top_score", 0.0)
    threshold = cfg["retrieval"]["min_score"]

    if confirmed is None:
        # First pass: tell gateway to prompt user
        return {
            "answer": "",
            "answer_model": "",
            "needs_user_confirm": True
        }

    # User has responded — routing handled by conditional edge
    return {}


def grok_fallback_node(state: GraphState, grok: GrokClient, cfg: dict) -> dict:
    """Node 5: Call Grok API. Only reachable when hybrid + confirmed."""
    query = state["query"]
    send_ctx = cfg["policy"]["fallback"].get("send_local_context_to_grok", False)

    if send_ctx:
        docs = state.get("retrieved_docs", [])
        context = "\n".join([d["text"][:200] for d in docs[:3]])
        prompt = f"Context (local KB, partial):\n{context}\n\nQuery: {query}"
    else:
        prompt = query

    try:
        answer = grok.generate(prompt)
    except GrokServiceError as e:
        answer = f"[Grok Error: {e.message}]"

    return {
        "answer": answer,
        "answer_model": "grok",
        "answer_sources": [{"source": "Grok Fallback", "score": 0.0, "chunk_id": -1, "stem_tags": [], "mode": "online", "text": ""}]
    }


def offline_best_effort_node(state: GraphState, llm: LocalLLMClient, cfg: dict) -> dict:
    """Node 6: Best-effort local answer when user declines Grok or offline mode."""
    query = state["query"]
    docs = state.get("retrieved_docs", [])

    if docs:
        context = "\n\n".join([d["text"][:300] for d in docs[:3]])
        prompt = f"""You are a helpful assistant. The following context may be partially relevant.

PARTIAL CONTEXT:
{context}

USER QUERY: {query}

Provide the best answer you can. Clearly note where you lack sufficient context."""
    else:
        prompt = f"""You are a helpful assistant operating without local knowledge base context.

USER QUERY: {query}

Provide the best general answer you can. Note that your local knowledge base did not have relevant information for this query."""

    try:
        answer = llm.generate(prompt)
    except LLMServiceError as e:
        answer = f"[LLM Error: {e.message}]"

    return {
        "answer": answer,
        "answer_model": "offline-best-effort",
        "answer_sources": docs[:3] if docs else []
    }


def audit_logger_node(state: GraphState, cfg: dict) -> dict:
    """Node 7: Runs for ALL paths. Writes JSONL audit event."""
    query = state.get("query", "")

    event = {
        "event": "rag_query",
        "query": query,  # Will be hashed by audit_log()
        "top_score": state.get("top_score", 0.0),
        "retrieval_mode": state.get("retrieval_mode", "none"),
        "online_escalated": state.get("answer_model") == "grok",
        "model_used": state.get("answer_model", "unknown"),
        "hit_count": len(state.get("retrieved_docs", []))
    }

    audit_log(event)

    return {"audit_event": event}


# =============================================================================
# Graph Builder
# =============================================================================

def build_graph(cfg: dict, retriever: HybridRetriever,
                local_llm: LocalLLMClient, grok: Optional[GrokClient] = None):
    """Build the SafeClaw LangGraph state machine.

    Returns a compiled graph ready for .invoke(initial_state).
    """
    graph = StateGraph(GraphState)

    # Register nodes with closures binding dependencies
    graph.add_node("retrieve", lambda s: retrieve_node(s, retriever, cfg))
    graph.add_node("route_by_score", lambda s: route_by_score_node(s, cfg))
    graph.add_node("local_llm", lambda s: local_llm_node(s, local_llm, cfg))
    graph.add_node("user_gate", lambda s: user_gate_node(s, cfg))
    if grok and grok.is_available():
        graph.add_node("grok_fallback", lambda s: grok_fallback_node(s, grok, cfg))
    graph.add_node("offline_best_effort", lambda s: offline_best_effort_node(s, local_llm, cfg))
    graph.add_node("audit_logger", lambda s: audit_logger_node(s, cfg))

    # Entry point: ALWAYS retrieve first
    graph.set_entry_point("retrieve")

    # retrieve -> route_by_score
    graph.add_edge("retrieve", "route_by_score")

    # Conditional: score gate
    def route_from_score(state: GraphState) -> str:
        if state.get("error"):
            return "offline_best_effort"
        if not state.get("needs_user_confirm", False):
            return "local_llm"
        else:
            return "user_gate"

    graph.add_conditional_edges("route_by_score", route_from_score)

    # Conditional: user gate
    def route_from_user_gate(state: GraphState) -> str:
        confirmed = state.get("user_confirmed_online")

        if confirmed is None:
            # Signal back to gateway — graph pauses here
            return "audit_logger"

        if (confirmed is True and
                cfg["app"]["mode"] == "hybrid" and
                cfg["models"]["grok"].get("enabled", False) and
                grok and grok.is_available()):
            return "grok_fallback"
        else:
            return "offline_best_effort"

    graph.add_conditional_edges("user_gate", route_from_user_gate)

    # All model-producing nodes -> audit_logger
    graph.add_edge("local_llm", "audit_logger")
    if grok and grok.is_available():
        graph.add_edge("grok_fallback", "audit_logger")
    graph.add_edge("offline_best_effort", "audit_logger")

    # audit_logger -> END
    graph.add_edge("audit_logger", END)

    return graph.compile()
