#!/usr/bin/env python3
"""
Cambium Fiber OLT SSH JSON getter.

Author: Joshaven Potter
Version: 0.1.0
Date: 2025-11-24
"""

import argparse
import difflib
import json
import os
import re
import subprocess
import sys
import tempfile
import time
from urllib.parse import urlparse
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union


@dataclass(frozen=True)
class OLTRequest:
    host: str
    password: str


@dataclass(frozen=True)
class CachePolicy:
    path: str
    ttl_seconds: int
    enabled: bool


class DebugLog:
    def __init__(self, enabled: bool):
        self.enabled = enabled

    def emit(self, message: str) -> None:
        if self.enabled:
            print(message, file=sys.stderr)


class OLTOutput:
    prompt_prefix = re.compile(r"<.*?#")
    ansi = re.compile(r"\x1b\[[0-9;?]*[ -/]*[@-~]")
    json_block = re.compile(r"\{.*\}", re.S)

    def to_json_text(self, raw: str) -> str:
        cleaned = self._strip_noise(self._strip_prompts(self._strip_ansi(raw)))
        match = self.json_block.search(cleaned)
        if not match:
            raise ValueError("no JSON found in OLT output")
        return match.group(0)

    def _strip_ansi(self, text: str) -> str:
        return self.ansi.sub("", text)

    def _strip_prompts(self, text: str) -> str:
        return "\n".join(self.prompt_prefix.sub("", ln) for ln in text.splitlines())

    def _strip_noise(self, text: str) -> str:
        lines = [ln.rstrip() for ln in text.splitlines()]
        lines = [ln for ln in lines if ln.strip() and "Warning: Input is not a terminal" not in ln]
        return "\n".join(self._drop_leading_banner(lines)).strip()

    def _drop_leading_banner(self, lines: Sequence[str]) -> Sequence[str]:
        for i, ln in enumerate(lines):
            if ln.lstrip().startswith("{") or ln.lstrip().startswith("["):
                return lines[i:]
        return lines


class OLTTransport:
    def __init__(self, output: OLTOutput, debug: DebugLog):
        self.output = output
        self.debug = debug

    def fetch_all(self, request: OLTRequest) -> Any:
        stdin_script = self._stdin_script()
        self.debug.emit(f"olt: stdin_script={stdin_script!r}")
        raw = self._run_sshpass(request.host, request.password, stdin_script)
        json_text = self.output.to_json_text(raw)
        return json.loads(json_text)

    def _stdin_script(self) -> str:
        return "info\nshow all\n"

    def _run_sshpass(self, host: str, password: str, stdin_script: str) -> str:
        redacted = self._redact_password(password)
        cmd = [
            "sshpass", "-p", password,
            "ssh",
            "-o", "PreferredAuthentications=password",
            "-o", "PubkeyAuthentication=no",
            "-T",
            f"admin@{host}",
        ]
        cmd_for_log = self._redact_cmd(cmd, password)
        self.debug.emit(f"olt: cmd={' '.join(self._shell_quote(x) for x in cmd_for_log)}")

        completed = subprocess.run(
            cmd,
            input=stdin_script,
            text=True,
            capture_output=True,
            timeout=30
        )
        combined = (completed.stdout or "") + (completed.stderr or "")
        self.debug.emit(f"olt: fetched bytes={len(combined)} returncode={completed.returncode}")
        return combined

    def _redact_password(self, password: str) -> str:
        if not password:
            return "<empty>"
        if len(password) <= 2:
            return "*" * len(password)
        return password[0] + "*" * (len(password) - 2) + password[-1:]

    def _redact_cmd(self, cmd: Sequence[str], password: str) -> List[str]:
        redacted = self._redact_password(password)
        out: List[str] = []
        i = 0
        while i < len(cmd):
            if cmd[i] == "-p" and i + 1 < len(cmd):
                out.extend([cmd[i], redacted])
                i += 2
            else:
                out.append(cmd[i])
                i += 1
        return out


    def _shell_quote(self, s: str) -> str:
        if re.fullmatch(r"[A-Za-z0-9_./:@=-]+", s):
            return s
        return "'" + s.replace("'", "'\"'\"'") + "'"


class CacheStore:
    def __init__(self, debug: DebugLog):
        self.debug = debug

    def load_if_fresh(self, policy: CachePolicy) -> Optional[Any]:
        self.debug.emit(f"cache: enabled={policy.enabled} path={policy.path} ttl={policy.ttl_seconds}s")
        if not policy.enabled:
            self.debug.emit("cache: disabled (--no-cache)")
            return None
        if not os.path.isfile(policy.path):
            self.debug.emit("cache: miss (no file)")
            return None
        age = time.time() - os.path.getmtime(policy.path)
        if age > policy.ttl_seconds:
            self.debug.emit(f"cache: stale age={age:.1f}s > ttl")
            return None
        with open(policy.path, "r", encoding="utf-8") as f:
            data = json.load(f)
        self.debug.emit(f"cache: hit age={age:.1f}s")
        return data

    def save(self, policy: CachePolicy, data: Any) -> None:
        if not policy.enabled:
            return
        directory = os.path.dirname(policy.path) or "."
        os.makedirs(directory, exist_ok=True)
        tmp_path = policy.path + ".tmp"
        with open(tmp_path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2)
        os.replace(tmp_path, policy.path)
        self.debug.emit(f"cache: wrote {policy.path}")


@dataclass(frozen=True)
class PathToken:
    kind: str
    value: Union[str, int, None] = None


class JsonPath:
    token_re = re.compile(
        r"""
        (?:
            \.?
            (?P<key>[A-Za-z_][A-Za-z0-9_-]*)
        )
        |
        (?:
            \.?
            "(?P<qkey>[^"]+)"
        )
        |
        (?:
            \[(?P<index>\d+|\*)\]
        )
        """,
        re.X,
    )

    def parse(self, path: str) -> List[PathToken]:
        tokens: List[PathToken] = []
        pos = 0
        for m in self.token_re.finditer(path):
            if m.start() != pos:
                raise ValueError(f"invalid path near: {path[pos:]}")
            pos = m.end()
            if m.group("key"):
                tokens.append(PathToken("key", m.group("key")))
            elif m.group("qkey"):
                tokens.append(PathToken("key", m.group("qkey")))
            elif m.group("index"):
                idx = m.group("index")
                tokens.append(PathToken("wildcard" if idx == "*" else "index", None if idx == "*" else int(idx)))
        if pos != len(path):
            raise ValueError(f"invalid path near: {path[pos:]}")
        return tokens

    def select(self, data: Any, path: Optional[str]) -> Any:
        if not path:
            return data
        tokens = self.parse(path)
        return self._apply_tokens(data, tokens)

    def _apply_tokens(self, data: Any, tokens: List[PathToken]) -> Any:
        current = data
        for token in tokens:
            current = self._apply_one(current, token)
        return current

    def _apply_one(self, current: Any, token: PathToken) -> Any:
        if token.kind == "key":
            if not isinstance(current, dict):
                raise KeyError(f"cannot access key {token.value} on non-object")
            key = str(token.value)
            if key not in current:
                available = sorted(current.keys())
                hint = self._closest_key_hint(key, available)
                raise KeyError(f"missing key {key}. available: {available}{hint}")
            return current[key]
        if token.kind == "index":
            if not isinstance(current, list):
                raise KeyError(f"cannot index non-array with {token.value}")
            idx = int(token.value)
            if idx < 0 or idx >= len(current):
                raise IndexError(f"index {idx} out of range")
            return current[idx]
        if token.kind == "wildcard":
            if not isinstance(current, list):
                raise KeyError("cannot wildcard non-array")
            return current[:]
        raise ValueError(f"unknown token kind {token.kind}")

    def _closest_key_hint(self, requested: str, available: Iterable[str]) -> str:
        matches = difflib.get_close_matches(requested, list(available), n=1, cutoff=0.6)
        return f" is '{matches[0]}' what you're looking for?" if matches else ""


class OLTClient:
    def __init__(self, transport: OLTTransport, cache: CacheStore):
        self.transport = transport
        self.cache = cache

    def get_all(self, request: OLTRequest, policy: CachePolicy) -> Any:
        cached = self.cache.load_if_fresh(policy)
        if cached is not None:
            return cached
        data = self.transport.fetch_all(request)
        self.cache.save(policy, data)
        return data


class PathProjector:
    def __init__(self, selector: JsonPath, debug: DebugLog):
        self.selector = selector
        self.debug = debug

    def project(self, data: Any, paths: Sequence[str]) -> Any:
        if not paths:
            self.debug.emit("path: <none> -> full json")
            return data
        if len(paths) == 1:
            selection = self.selector.select(data, paths[0])
            self.debug.emit(f"path: {paths[0]} -> {self._shape(selection)}")
            return selection
        projected = {p: self.selector.select(data, p) for p in paths}
        self.debug.emit(f"path: {len(paths)} paths -> object")
        return projected

    def _shape(self, value: Any) -> str:
        if isinstance(value, dict):
            return "object"
        if isinstance(value, list):
            return "array"
        return "scalar"


class OLTCLI:
    def __init__(self):
        self.parser = self._build_parser()

    def run(self, argv: Sequence[str]) -> int:
        args = self.parser.parse_args(argv)
        if args.cache_file is None:
            args.cache_file = self._default_cache_path(args.host)
        debug = DebugLog(args.debug)
        request = OLTRequest(args.host, args.password)
        policy = CachePolicy(args.cache_file, args.cache_ttl, not args.no_cache)
        client = OLTClient(OLTTransport(OLTOutput(), debug), CacheStore(debug))
        projector = PathProjector(JsonPath(), debug)
        try:
            data = client.get_all(request, policy)
            selection = projector.project(data, args.paths)
        except KeyError as e:
            self._emit_key_error(e)
            return 2
        except Exception as e:
            self._emit_error(str(e))
            return 2
        self._emit_value(selection)
        return 0

    def _build_parser(self) -> argparse.ArgumentParser:
        class FriendlyParser(argparse.ArgumentParser):
            def error(self, message):
                self.print_help(sys.stderr)
                self.exit(2, f"\nerror: {message}\n")


        p = FriendlyParser(
            prog="olt_get.py",
            description="Fetch full OLT JSON, cache it, optionally project one or more JSON paths.",
            formatter_class=argparse.RawTextHelpFormatter,
            epilog=(
                "Examples:\n"
                "  olt_get.py 192.168.50.10 FiberDemo!\n"
                "  olt_get.py 192.168.50.10 FiberDemo! Ethernet\n"
                "  olt_get.py 192.168.50.10 FiberDemo! Ethernet[0] Ethernet[0].MAC\n"
                "  olt_get.py 192.168.50.10 FiberDemo! --no-cache Ethernet[0].RxMulticastPackets\n"
                "  olt_get.py 192.168.50.10 FiberDemo! --debug Ethernet[0].TxBytes\n"
            )
        )
        p.add_argument("host", help="OLT management IP/hostname")
        p.add_argument("password", help="Password for admin@<host> SSH login")
        p.add_argument("paths", nargs="*", help="Optional JSON path(s)")
        p.add_argument("--cache-file", default=None, help="Cache file path")
        p.add_argument("--cache-ttl", type=int, default=60, help="Cache TTL seconds")
        p.add_argument("--no-cache", action="store_true", help="Disable cache read/write")
        p.add_argument("--debug", action="store_true", help="Print debug info to stderr")
        return p

    @staticmethod
    def _sanatize_host(host: str) -> str:
        h = (host or "").strip()
        if not h:
            return "unknown"

        if h.startswith(("http://", "https://")):
            p = urlparse(h)
            h = p.netloc or p.path

        h = h.split("/", 1)[0]

        if h.startswith("[") and "]" in h:
            h = h[1:h.index("]")]

        if ":" in h and h.count(":") == 1 and "[" not in h and "]" not in h:
            h = h.split(":", 1)[0]

        h = h.replace(":", "-")

        safe_chars = []
        for ch in h:
            if ch.isalnum() or ch in "._-":
                safe_chars.append(ch)
            else:
                safe_chars.append("_")

        h = "".join(safe_chars).strip("._-")
        return h or "unknown"

    def _default_cache_path(self, host: str) -> str:
        name = self._sanatize_host(host)
        filename = f"{name}.stats.json"
        if os.path.isdir("/tmp"):
            return os.path.join("/tmp", filename)
        base = tempfile.gettempdir()
        return os.path.join(base, filename)

    def _emit_value(self, value: Any) -> None:
        if self._is_scalar(value):
            sys.stdout.write("" if value is None else str(value))
            if value is not None:
                sys.stdout.write("\n")
            return
        print(json.dumps(value, indent=2))

    def _emit_error(self, message: str) -> None:
        sys.stderr.write(f"error: {message}\n")

    def _emit_key_error(self, err: KeyError) -> None:
        msg = err.args[0] if err.args else str(err)
        sys.stderr.write("error: " + msg.replace(". available:", ".\navailable:") + "\n")


    def _is_scalar(self, value: Any) -> bool:
        return isinstance(value, (str, int, float, bool)) or value is None


class Program:
    def main(self, argv: Sequence[str]) -> int:
        return OLTCLI().run(argv)


if __name__ == "__main__":
    raise SystemExit(Program().main(sys.argv[1:]))

