#!/opt/catchpoint/bin/python
##
# @file cp_admin_svc
#
# @brief Server that can execute on demand sudoers secured scripts
#
# This is Catchpoint's service that executes only a few commands requiring 
# escalation by launching separate hardened scripts that utilize the sudoers 
# functionality along with NOPASSWD. The scripts are hardened by not accepting 
# arguments from the command line, instead using standard input. This means we 
# don't need to allow wildcards in the sudoers list for these commands, 
# to prevent shell escaping techniques. The scripts are also 'hashed' with an 
# SHA512 sudoer's Digest_Spec which will prevent them from running if modified.
# It only accepts command requests via a unix socket from the Agent user 
# (typically serveruser unless customized). 

from logging.handlers import TimedRotatingFileHandler
import logging
import pwd, grp
import socket
import os
import struct
import json
import subprocess
import sys
import re

##
# @brief namespace like constant class
class Constants:
    ##
    # @brief maximum length for a received json 
    MAX_JSON_LENGTH = 100000   # bytes
    ##
    # @brief path to the unix socket used for receiving "connections"
    SOCK_PATH = "/var/run/catchpoint/cp_admin.sock"
    ##
    # @brief magic number used for troubleshooting
    #
    # "json" written in hex bytes according to ascii table, useful for testing with nc -U <SOCK_PATH>
    JSON_IN_BYTES = 0x6A736F6E 
    ##
    # @brief catchpoint bin folder 
    CATCHPOINT_BIN_FOLDER = "/opt/catchpoint/bin"
    ##
    # @brief location for the logs
    LOG_PATH="/var/log/catchpoint/cp_admin_svc.log"
    ##
    # @brief default catchpoint username (same as the default in the rpm)
    DEFAULT_CATCHPOINT_USER = "serveruser"
    ##
    # @brief default catchpoint group (same as the default in the rpm)
    DEFAULT_CATCHPOINT_GROUP = "cp"
    ##
    # @brief location of the config file that can override DEFAULT_ values
    CATCHPOINT_CONFIG_FILE = "/etc/catchpoint.d/agent.conf"
    ##
    # @brief override key for DEFAULT_CATCHPOINT_USER in CATCHPOINT_CONFIG_FILE
    CATCHPOINT_CONFIG_USER_ENTRY = "catchpoint_user"
    ##
    # @brief override key for DEFAULT_CATCHPOINT_GROUP in CATCHPOINT_CONFIG_FILE
    CATCHPOINT_CONFIG_GROUP_ENTRY = "catchpoint_group"
    ##
    # @brief number of connection allowed on SOCK_PATH
    MAX_ACCEPTED_CONNECTIONS = 4
    ##
    # @brief default permission on SOCK_PATH 
    OWNER_ONLY_READ_WRITE_PERMISSION = 0o600
    ##
    # @brief option socket to retrieve peering credential see man socket
    SO_PEERCRED = 17
    ##
    # @brief maximum timeout in second before timing out communications
    REQUEST_TIMEOUT_SECONDS = 5 
    ##
    # @brief default encoding for communications
    CHARSET_ECODING = "utf-8"
    ##
    # @brief log rotation interval @see Logger
    LOG_INTERVAL_ROTATION_IN_HOURS = 60 # XXX isn't better to rotate at midnight?

##
# @brief responsible to retrieve configurable options
#
# it handle in a transparent way if is using default option or the override
class Configuration:
    ##
    # @brief default constructor
    #
    # @param self class pointer
    # @param logger logger for the application @see Logger
    def __init__(self, logger):
        self._logger = logger
        self._data = {}
        self._data[Constants.CATCHPOINT_CONFIG_USER_ENTRY] = Constants.DEFAULT_CATCHPOINT_USER
        self._data[Constants.CATCHPOINT_CONFIG_GROUP_ENTRY] = Constants.DEFAULT_CATCHPOINT_GROUP
        self._catchpoint_uid = 0
        self._catchpoint_gid = 0

    ##
    # @brief fetch override from file
    #
    # @param self class pointer
    #
    # @return self
    def fetch(self):
        try:
            with open(Constants.CATCHPOINT_CONFIG_FILE) as config_file:
                self._fetch_from_file(config_file)
        except (OSError) as exception:
            self._logger.warning("unable to read %s because: %s. Using default Catchpoint Agent settings.", Constants.CATCHPOINT_CONFIG_FILE, exception)
        except Exception as e:
            self._logger.error("error while reading configuration %s", e)

        catchpoint_user=self._data[Constants.CATCHPOINT_CONFIG_USER_ENTRY]
        catchpoint_group=self._data[Constants.CATCHPOINT_CONFIG_GROUP_ENTRY]
        self._logger.info("using catchpoint_user=%s and  catchpoint_user=%s", catchpoint_user, catchpoint_group)
        self._catchpoint_uid = pwd.getpwnam(catchpoint_user).pw_uid
        self._catchpoint_gid = grp.getgrnam(catchpoint_group).gr_gid
        return self

    ##
    # @brief parse information from the config file and fetch it into an internal structure
    #
    # @param self class pointer
    # @param config_file path to the file 
    def _fetch_from_file(self, config_file):
        lines = [line for line in config_file]

        # XXX invalid user with invalid caracters will be truncated to the first valid set of chars (garbage in garbage out)
        for key in self._data.keys():
            prog = re.compile("^\\s*{key}=(\"[a-zA-Z0-9_-]+\"|[a-zA-Z0-9_-]+|'[a-zA-Z0-9_-]+')\\s*(#.*)?$".format(key=key))
            for line in lines:
                match = prog.match(line)
                if match:
                    value = match.group(1).strip("\"").strip("'")
                    if len(value) > 0:
                        self._data[key] = value

    ##
    # @brief return the user id relative to the catchpoint user
    #
    # @return a intger representing the uid
    @property
    def catchpoint_uid(self):
        return self._catchpoint_uid

    ##
    # @brief return the group id relative to the catchpoint group
    #
    # @return a intger representing the guid
    @property
    def catchpoint_gid(self):
        return self._catchpoint_gid

##
# @brief handle the connection part of the communication with other sources
class Server:
    ##
    # @brief default constructor for the server class
    #
    # @param self class pointer
    # @param logger logger for the application @see Logger
    # @param configuration configuration for the application @see Configuration
    def __init__(self, logger, configuration):
        self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        self._logger = logger
        self._configuration = configuration
        if os.path.exists(Constants.SOCK_PATH):
            try:
                os.unlink(Constants.SOCK_PATH)
            except Exception as e:
                self._logger.warning("unable to unlink %s, reason %s", Constants.SOCK_PATH, e)
        self._socket.bind(Constants.SOCK_PATH)
        os.chown(Constants.SOCK_PATH, self._configuration.catchpoint_uid, self._configuration.catchpoint_gid)
        os.chmod(Constants.SOCK_PATH, Constants.OWNER_ONLY_READ_WRITE_PERMISSION)
        self._socket.listen(Constants.MAX_ACCEPTED_CONNECTIONS)
    
    ##
    # @brief support for with:
    #
    # @param self class pointer
    #
    # @return self
    def __enter__(self):
        return self
    
    ##
    # @brief support for with:. "gracefully" close the socket
    #
    # @param self class pointer
    # @param type not used
    # @param value not used
    # @param traceback not used
    def __exit__(self, type, value, traceback):
        self._socket.close()
        try:
            os.unlink(Constants.SOCK_PATH)
        except Exception as e:
            self._logger.info("unable to unlink %s, reason %s", Constants.SOCK_PATH, e)

     
    ##
    # @brief get a socket for communicating with a client
    #
    # basically is equivalent to an accept but with some additional protection
    #
    # @param self class pointer
    #
    # @return a socket representing a connection with a client
    def get_socket(self):
        while True:
            conn, _ = self._socket.accept()
            credentials = conn.getsockopt(socket.SOL_SOCKET, Constants.SO_PEERCRED, struct.calcsize("3i"))
            pid, uid, gid = struct.unpack("3i", credentials)
            # we do not allow communication from other users
            if uid == self._configuration.catchpoint_uid:
                conn.settimeout(Constants.REQUEST_TIMEOUT_SECONDS)
                return  conn
     
##
# @brief Logger wrapper around the standard logger for IOC purposes
class Logger:
    ##
    # @brief init the standard logger
    #
    # @param self class pointer
    def __init__(self):
        log_format = "%(asctime)s - %(levelname)s - %(message)s"
        formatter = logging.Formatter(log_format)
        logging.basicConfig(level=logging.DEBUG, format=log_format)
        try:
            fileHandler = TimedRotatingFileHandler(filename=Constants.LOG_PATH, when="h", interval=Constants.LOG_INTERVAL_ROTATION_IN_HOURS)
            fileHandler.setFormatter(formatter)
            self._log = logging.getLogger()
            self._log.addHandler(fileHandler) 
        except Exception as e:
            sys.stderr.write("error while setting the logger. reason: " + str(e))
            self._log = logging.getLogger()

    # exposed methods
    
    ##
    # @brief debug message
    # 
    # @param self class pointer
    # @param msg The msg is the message format string
    # @param args used to compile the format string positional placeholders
    # @param kwargs used to compile the format string with named placeholders
    def debug(self, msg, *args, **kwargs):
        self._log.debug(msg, *args, **kwargs)

    ##
    # @brief info message
    # 
    # @param self class pointer
    # @param msg The msg is the message format string
    # @param args used to compile the format string positional placeholders
    # @param kwargs used to compile the format string with named placeholders
    def info(self, msg, *args, **kwargs):
        self._log.info(msg, *args, **kwargs)
    
    ##
    # @brief warning message
    # 
    # @param self class pointer
    # @param msg The msg is the message format string
    # @param args used to compile the format string positional placeholders
    # @param kwargs used to compile the format string with named placeholders
    def warning(self, msg, *args, **kwargs):
        self._log.warning(msg, *args, **kwargs)

    ##
    # @brief error message
    # 
    # @param self class pointer
    # @param msg The msg is the message format string
    # @param args used to compile the format string positional placeholders
    # @param kwargs used to compile the format string with named placeholders
    def error(self, msg, *args, **kwargs):
        self._log.error(msg, *args, **kwargs)

    ##
    # @brief critical message
    # 
    # @param self class pointer
    # @param msg The msg is the message format string
    # @param args used to compile the format string positional placeholders
    # @param kwargs used to compile the format string with named placeholders
    def critical(self, msg, *args, **kwargs):
        self._log.critical(msg, *args, **kwargs)

##
# @brief handler the interaction between client and server over a socket from server
class Communicator:
    ##
    # @brief size of interger from c#. Used to receive an integer from socket
    INTEGER_SIZE = 4
    ##
    # @brief init the standard Communicator
    #
    # @param self class pointer
    # @param logger logger for the application @see Logger
    # @param server Server for the application @see Server
    def __init__(self, logger, server):
        self._logger = logger
        self._socket = server.get_socket()
    
    ##
    # @brief read infomation from the socket and parse it to a pair action, is_manual
    #
    # @param self class pointer
    #
    # @return a pair (action_str, was_manual_json)
    #   action_str: is a string that contains action to perform @see Executor
    #   was_manual_json: if the action_str was done manually (because require a different respose)
    def read(self):
        try:
            was_manual_json = False
            data_bytes = self._socket.recv(Communicator.INTEGER_SIZE)  # "json" or  int32 length followed by len bytes of json string
            if data_bytes:
                length = struct.unpack("I", data_bytes)[0]
                if length == Constants.JSON_IN_BYTES: # "json"
                    length = Constants.MAX_JSON_LENGTH
                    was_manual_json = True
                else:
                    length = min(max(0, length), Constants.MAX_JSON_LENGTH) # sanity
            else:
                return (None, False)
            actions = self._socket.recv(length).decode(Constants.CHARSET_ECODING)
            return (actions, was_manual_json)
        except Exception as e:
            self._logger.warning("error while reading commands from socket. reason: %s", e)
            return (None, False)

    ##
    # @brief write the message back to the socket
    #
    # @param self class pointer
    # @param message message to send back
    # @param json_instead_of_len boolean to decide if send back the length or Constants.JSON_IN_BYTES
    def write(self, message, json_instead_of_len):
        try:
            message = message.encode(Constants.CHARSET_ECODING)
        except UnicodeDecodeError: pass  # msg is already utf-8

        data = bytearray()
        if not json_instead_of_len:
            message_length = struct.pack("I", len(message))
            data.extend(message_length)
        data.extend(message)
        self._socket.sendall(data)

##
# @brief takes a json_like action and execute the action listed in it 
# 
# an action corresponds to a sudoer file
class Executor:
    ##
    # @brief error marker returned when something bad happen
    ERROR_STRING = "error"
    ##
    # @brief marker used in case the command succeeded
    SUCCESS_STRING = "result"

    ##
    # @brief default constructor
    #
    # @param self class pointer
    # @param logger logger for the application @see Logger
    def __init__(self, logger):
        self._logger = logger

    ##
    # @brief execute the the action listed in json_str
    #
    # @param self class pointer
    # @param json_str a json string in the form {"command": <arg>} here <arg> a 
    #   json object itself 
    def execute(self, json_str): 
        try:
            commands = json.loads(json_str)
        except ValueError as e:
            self._logger.error("received invalid json %s", e)
            return json.dumps([Executor.ERROR_STRING, str(e)])

        # For security each of these scripts must be in the sudoers file, along with SHA512 hashes of the contents
        sudoers_mapping = { 
            "SET_NET_ADDR": "", #set_adapter_params
            "CONFIG_WIFI": "cpsudo_configure_wifi",
            "SERVICE_CTL": "cpsudo_service_ctl",
            #"COUNT_FD": "cpsudo_count_file_descriptors", #this is used from crontab but not admin service currently 
            "CHROME_SANDBOX_SECURE": "cpsudo_chrome_sandbox_secure",
        }

        results_dict = {}
        try:
            for cmd, args in commands.items():
                if cmd in sudoers_mapping:
                    relative_path = os.path.realpath(os.path.join(Constants.CATCHPOINT_BIN_FOLDER, sudoers_mapping[cmd]))
                    script = os.path.abspath(relative_path)

                    # basically impossible since we are not checking the actual file
                    # exists but only if the path join succeeded
                    # leaving this for parity the svc2
                    if not script:
                        return json.dumps(['error', "execute_cmd: unsupported cmd received: " + cmd])

                    results_dict[cmd] = self.exec_sudoers(script, args)
                else:
                    msg = "execute_cmd: unrecognized cmd received: " + cmd
                    self._logger.warning(msg)
                    return json.dumps([Executor.ERROR_STRING, msg])  # early exit if any 1 command fails

            return json.dumps([Executor.SUCCESS_STRING, results_dict])
        except Exception as e:
            self._logger.error("error while executing the commands. reason: %s", e)
            return json.dumps([Executor.ERROR_STRING, str(e)])

    ##
    # @brief execute the the action listed in json_str
    #
    # @param self class pointer
    # @param script full path to a script command
    # @param args json to dump into standard input to script
    def exec_sudoers(self, script, args):
        self._logger.debug(script + " " + str(args))
        proc = subprocess.Popen(['sudo', script], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        out, err = proc.communicate(input=json.dumps(args).encode(Constants.CHARSET_ECODING))
        self._logger.debug("STDOUT:" + str(out))
        self._logger.debug("STDERR:" + str(err))
        return json.loads(out.decode(Constants.CHARSET_ECODING))

##
# @brief main program         
def main():
    logger = Logger()
    try:
        configuration = Configuration(logger).fetch()
        with Server(logger, configuration) as server:
            while True:
                communicator = Communicator(logger, server)
                try:
                    action, manual_json = communicator.read()
                    if action:
                        executor = Executor(logger)
                        results = executor.execute(action)
                        communicator.write(results, manual_json)
                except Exception as e:
                    logger.error("recoverable error, skipping current command. reason: %s", e)
                    
        # automatic cleanup with destructors
    except Exception as e:
        logger.critical("unrecoverable error. reason: %s", e)

if __name__ == "__main__":
    main()
