/*
    Copyright (C) 1998  Andrey V. Savochkin <saw@msu.ru>

    This library is free software; you can redistribute it and/or
    modify it under the terms of the GNU Library General Public
    License as published by the Free Software Foundation; either
    version 2 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Library General Public License for more details.

    You should have received a copy of the GNU Library General Public
    License along with this library; if not, write to the Free
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 */

#include <config.h>

#ifdef linux
# define _GNU_SOURCE
# define _BSD_SOURCE
# include <features.h>
#endif
#include <sys/types.h>
#include <malloc.h>
#include <sys/time.h>
#include <string.h>
#include <syslog.h>
#include <stdarg.h>
#include <unistd.h>
#include <stdio.h>
#include <sys/stat.h>
#include <pwd.h>
#ifdef HAVE_SYS_FSUID_H
# include <sys/fsuid.h>
#endif
#include "md5.h"

#define PAM_SM_AUTH

#include <security/pam_modules.h>
#include <security/_pam_macros.h>
#include <security/pam_client.h>

#define PLUGIN_NAME       "sh_secret"
#define PLUGIN_LEN        (sizeof(PLUGIN_NAME)-1)
#define ALLOW_SECRET_FILE ".ssh/allow_secret"
#define ALLOW_SECRET_LEN  (sizeof(ALLOW_SECRET_FILE)-1)

static int debug = 0;

/* Logging facilities */
static void report_no_mem(void);
static void dbg_head(const char *file, int line);
static void dbg_printf(const char *fmt, ...);
#define Dprintf(x) \
    (debug ? (dbg_head(__FILE__, __LINE__), dbg_printf x) : (void)0)
static void do_log(int pri, const char *fmt, ...);

/*
   Make a binary prompt.
   "len" is the message size counting from the control.
 */
static int make_msg(unsigned char **msgdat, unsigned int len,
    struct pam_message *msg[], struct pam_message place[])
{
    *msgdat = NULL;

    *msgdat = calloc(1, 4+len);
    if (*msgdat == NULL) {
        report_no_mem();
        return PAM_AUTH_ERR;
    }
    pamc_write__u32(*msgdat, len);

    place->msg_style = PAM_BINARY_PROMPT;
    place->msg = (const char *)*msgdat;
    *msg = place;

    return PAM_SUCCESS;
}

/*
   The function pass id + challenge enveloped in PAM message
   It expects a single binary reply which is returned in
   a newly allocated memory (*dat, *dlen).
   Old data is cleared and free'd.

   The message is composed according to PAMC policy.
   The first entry in the message is PAMC_CONTROL_SELECT for
   the client plug-in, and the second is PAMC_CONTROL_EXCHANGE
   with the folowing data:
       [ 4 octets placed via pamc_write__u32 ] id length
       id
       [ 4 octets placed via pamc_write__u32 ] challenge length
       challenge
 */
static int ask_for_data(pam_handle_t *pamh, const struct pam_conv *conv,
    unsigned char **dat, unsigned int *dlen,
    const unsigned char *challenge, unsigned int chlen)
{
    int retval;
    unsigned char *msgdat[2];
    struct pam_message *msg[2], msgcont[2];
    struct pam_response *resp;
    unsigned int idlen;
    unsigned int control; /* PAMC control of the reply */

    retval = make_msg(&msgdat[0], 4+PLUGIN_LEN, &msg[0], &msgcont[0]);
    if (retval != PAM_SUCCESS) {
        memset(*dat, 0, *dlen);
        free(*dat);
        *dat = NULL;
        return PAM_AUTH_ERR;
    }
    pamc_write__u32(pamc_packet_data(msgdat[0]), PAMC_CONTROL_SELECT);
    memcpy(pamc_packet_data(msgdat[0])+4, PLUGIN_NAME, PLUGIN_LEN);

    idlen = *dlen;
    retval = make_msg(&msgdat[1], 4+4+idlen+4+chlen, &msg[1], &msgcont[1]);
    if (retval != PAM_SUCCESS) {
        memset(*dat, 0, idlen);
        free(*dat);
        *dat = NULL;
        free(msgdat[0]); /* not sensitive */
        return PAM_AUTH_ERR;
    }
    pamc_write__u32(pamc_packet_data(msgdat[1]), PAMC_CONTROL_EXCHANGE);
    pamc_write__u32(pamc_packet_data(msgdat[1])+4, idlen);
    memcpy(pamc_packet_data(msgdat[1])+4+4, *dat, idlen);
    pamc_write__u32(pamc_packet_data(msgdat[1])+4+4+idlen, chlen);
    memcpy(pamc_packet_data(msgdat[1])+4+4+idlen+4, challenge, chlen);
    memset(*dat, 0, idlen);
    free(*dat);
    *dat = NULL;

    resp = NULL;
    retval = (*conv->conv)(2, (const struct pam_message **)msg,
                           &resp, conv->appdata_ptr);
    Dprintf(("conversation returns %d", retval));
    memset(pamc_packet_data(msgdat[1]), 0, pamc_packet_length(msgdat[1]));
    free(msgdat[1]);
    free(msgdat[0]); /* not sensitive */
    memset(msgdat, 0, sizeof(msgdat));
    if (retval != PAM_SUCCESS) return retval;

    if (resp == NULL) return PAM_AUTH_ERR;
    msgdat[0] = (unsigned char *)resp[0].resp;
    if (msgdat[0] == NULL) {
        free(resp);
        return PAM_AUTH_ERR;
    }
    msgdat[1] = (unsigned char *)resp[1].resp;
    if (msgdat[1] == NULL) {
        free(resp[0].resp);
        free(resp);
        return PAM_AUTH_ERR;
    }
    retval = PAM_AUTH_ERR;
    do {
        if (pamc_packet_length(msgdat[0]) != 4) {
            do_log(LOG_ERR, "Bad length(%d) in conversation reply #0",
                    pamc_packet_length(msgdat[0]));
            break;
        }
        *dlen = pamc_packet_length(msgdat[1]);
        if (*dlen < 4) {
            do_log(LOG_ERR, "Bad length(%d) in conversation reply #1", *dlen);
            break;
        }
        (*dlen) -= 4;

        control = pamc_read__u32(pamc_packet_data(msgdat[0]));
        Dprintf(("First control is %u", control));
        if (control != PAMC_CONTROL_EXCHANGE) break;
        control = pamc_read__u32(pamc_packet_data(msgdat[1]));
        Dprintf(("Second control is %u", control));
        if (control != PAMC_CONTROL_DONE) break;

        *dat = malloc(*dlen);
        if (*dat == NULL) {
            report_no_mem();
            break;
        }
        memcpy(*dat, pamc_packet_data(msgdat[1])+4, *dlen);
        Dprintf(("The data is: %02x %02x %02x %02x %02x %02x %02x %02x",
            (*dat)[0], (*dat)[1], (*dat)[2], (*dat)[3],
            (*dat)[4], (*dat)[5], (*dat)[6], (*dat)[7]));
        retval = PAM_SUCCESS;
    }while(0);

    memset(pamc_packet_data(msgdat[1]), 0, pamc_packet_length(msgdat[1]));

    free(resp[1].resp); /* the space is cleared (accessed via msgdat) */
    free(resp[0].resp); /* I don't bother to clear it */
    free(resp);

    return retval;
}


#define gen_init() \
    userlen = strlen(username);                                             \
    if (gethostname(hostname, sizeof(hostname)) == -1) {                    \
        do_log(LOG_ERR, "gethostname() failed.");                           \
        return PAM_AUTH_ERR;                                                \
    }                                                                       \
    for (hostlen = 0;                                                       \
         hostlen < sizeof(hostname) && hostname[hostlen]; hostlen++);       \
    if (hostlen >= sizeof(hostname)) {                                      \
        do_log(LOG_ERR,                                                     \
                "unterminated hostname is returned from gethostname()!");   \
        return PAM_AUTH_ERR;                                                \
    }                                                                       \
(void)0

/*
   Generate user@host message.
   The space will be allocated inside.
 */
static int gen_id(const char *username,
    unsigned char **dat, unsigned int *dlen)
{
    int userlen; /* :-) */
    char hostname[256];
    int hostlen;
    unsigned char *p;

    gen_init();

    *dlen = userlen+1+hostlen;
    *dat = malloc(*dlen);
    if (*dat == NULL) {
        report_no_mem();
        return PAM_AUTH_ERR;
    }

    p = *dat;
    memcpy(p, username, userlen); p += userlen;
    *p++ = '@';
    memcpy(p, hostname, hostlen); p += hostlen;

    return PAM_SUCCESS;
}

/*
   The function generates a challenge.
   The space is allocated inside.
 */
static int gen_challenge(const char *username,
    unsigned char **challenge, unsigned int *chlen)
{
    int userlen; /* :-) */
    char hostname[256];
    int hostlen;
    struct timeval timeval;
    pid_t pid;
    unsigned char *p;

    gen_init();
    gettimeofday(&timeval, NULL);

    *chlen = hostlen+sizeof(timeval)+sizeof(pid);
    *challenge = malloc(*chlen);
    if (*challenge == NULL) {
        report_no_mem();
        return PAM_AUTH_ERR;
    }

    p = *challenge;
    memcpy(p, username, userlen); p += userlen;
    *p++ = '@';
    memcpy(p, hostname, hostlen); p += hostlen;
    memcpy(p, &timeval, sizeof(timeval)); p += sizeof(timeval);
    pid = getpid();
    memcpy(p, &pid, sizeof(pid)); p += sizeof(pid);

    return PAM_SUCCESS;
}

#undef gen_init


static void restore_cred(uid_t uid, gid_t gid)
{
    setfsuid(uid);
    if (setfsuid(uid) != uid)
        do_log(LOG_ALERT, "Can\'t restore fsuid!");
    setfsgid(gid);
    if (setfsgid(gid) != gid)
        do_log(LOG_ALERT, "Can\'t restore fsgid!");
}

/*
   Lookup for a secret in the user's file with allowed
   secrets. The space for the result is allocated inside.
 */
static int lookup_secret(FILE *file,
    const unsigned char *id, unsigned int id_len,
    unsigned char **secret, unsigned int *seclen)
{
    char buf[1024], *p;
    int linelen;

    while (fgets(buf, sizeof(buf), file) != NULL) {
        linelen = strlen(buf);
        if (linelen > 0 && buf[linelen-1] == '\n') linelen--;
        if (linelen < id_len+1) continue;
        if (strncmp(buf, (const char *)id, id_len) || buf[id_len] != ' ')
	    continue;
        for (p = buf+id_len+1, linelen -= id_len+1;
            linelen > 0 && *p == ' ';
            p++, linelen--);
        if (linelen < 16) { /* magic minimal secret length */
            memset(buf, 0, sizeof(buf));
            return PAM_AUTH_ERR;
        }
        *secret = malloc(linelen);
        if (*secret == NULL) {
            memset(buf, 0, sizeof(buf));
            report_no_mem();
            return PAM_AUTH_ERR;
        }
        *seclen = linelen;
        memcpy(*secret, p, linelen);
        memset(buf, 0, sizeof(buf));
        return PAM_SUCCESS;
    }

    return PAM_AUTH_ERR;
}

static int find_secret(pam_handle_t *pamh,
    const unsigned char *id, unsigned int idlen,
    unsigned char **secret, unsigned int *seclen)
{
    const char *username;
    const char *envhome;
    struct passwd *pw;
    int retval;
    char filename[1024];
    int r;
    FILE *file;
    struct stat statbuf;
    uid_t uid;
    gid_t gid;

    Dprintf(("find_secret is called for %.*s", idlen, id));

    r = 0;
    filename[0] = '\0';

    retval = pam_get_user(pamh, &username, NULL);
    if (retval != PAM_SUCCESS) return retval;

    envhome = pam_getenv(pamh, "HOME");
    if (envhome != NULL) {
        r = strlen(envhome);
        if (r+1+ALLOW_SECRET_LEN+1 > sizeof(filename)) {
            do_log(LOG_ERR, "HOME too long (%d) for %s", r, username);
            return PAM_AUTH_ERR;
        }
        if (r > 0) memcpy(filename, envhome, r);
    }

    pw = getpwnam(username);
    if (pw == NULL) {
        do_log(LOG_ALERT, "No passwd entry for %s", username);
        return PAM_AUTH_ERR;
    }

    if (r <= 0) {
        r = strlen(pw->pw_dir);
        if (r <= 0 || r+1+ALLOW_SECRET_LEN+1 > sizeof(filename)) {
            do_log(LOG_ALERT, "Bad home dir for %s", username);
            return PAM_AUTH_ERR;
        }
        memcpy(filename, pw->pw_dir, r);
    }
    filename[r] = '/';
    memcpy(filename+r+1, ALLOW_SECRET_FILE, ALLOW_SECRET_LEN+1);

    uid = setfsuid(pw->pw_uid);
    if (setfsuid(pw->pw_uid) != pw->pw_uid) {
        do_log(LOG_ALERT, "Can\'t set fsuid");
        return PAM_AUTH_ERR;
    }
    gid = setfsgid(pw->pw_gid);
    if (setfsgid(pw->pw_gid) != pw->pw_gid) {
        do_log(LOG_ALERT, "Setfsuid succeded but setfsgid failed!");
        restore_cred(uid, gid);
        return PAM_AUTH_ERR;
    }
    file = fopen(filename, "r");
    if (file == NULL) {
        do_log(LOG_ERR, "User %s has no allow_secret file", username);
        restore_cred(uid, gid);
        return PAM_AUTH_ERR;
    }
    if (fstat(fileno(file), &statbuf) == -1) {
        do_log(LOG_ERR, "Can\'t stat %s\'s allow_secret file", username);
        fclose(file);
        restore_cred(uid, gid);
        return PAM_AUTH_ERR;
    }
    if (statbuf.st_uid != pw->pw_uid ||
        (statbuf.st_mode &
         (S_IXUSR|S_IRGRP|S_IWGRP|S_IXGRP|S_IROTH|S_IWOTH|S_IXOTH)))
    {
        do_log(LOG_ERR, "%s\'s allow_secret file has illegal permissions",
                username);
        Dprintf(("Owner %u, perm %u", (unsigned)statbuf.st_uid,
                (unsigned)statbuf.st_mode));
        fclose(file);
        restore_cred(uid, gid);
        return PAM_AUTH_ERR;
    }

    retval = lookup_secret(file, id, idlen, secret, seclen);
    fclose(file);
    restore_cred(uid, gid);
    return retval;
}


PAM_EXTERN
int pam_sm_authenticate(pam_handle_t *pamh, int flags,
    int argc, const char **argv)
{
    int i;
    int retval;
    const struct pam_conv *conv;
    const char *username;
    unsigned char *dat, *challenge, *secret, *id;
    unsigned int dlen, chlen, seclen, idlen;
    struct MD5Context context;
    unsigned char digest[16];

    openlog("PAM_sh_secret", LOG_CONS | LOG_PID, LOG_AUTHPRIV);
    Dprintf(("called."));

#define do_return(x) return(closelog(),(x))

    for (i = 0; i < argc; i++) {
        if (!strcmp(argv[i], "debug")) debug = 1;
        else {
            do_log(LOG_ERR, "illegal option \"%s\"", argv[i]);
            break;
        }
    }

    retval = pam_get_user(pamh, &username, NULL);
    if (retval != PAM_SUCCESS) do_return(retval);

    retval = pam_get_item(pamh, PAM_CONV, (const void **)&conv);
    if (retval != PAM_SUCCESS) do_return(retval);

    dat = NULL; /* for safety */
    challenge = NULL;
    secret = NULL;
    id = NULL;
    retval = gen_id(username, &dat, &dlen);
    Dprintf(("gen_id returns %d", retval));
    if (retval != PAM_SUCCESS) do_return(retval);
    retval = gen_challenge(username, &challenge, &chlen);
    Dprintf(("gen_challenge returns %d", retval));
    if (retval != PAM_SUCCESS) do_return(retval);

    /*
       Ask for checksum of id+challenge.
       Result is placed in `dat'.
     */
    retval = ask_for_data(pamh, conv, &dat, &dlen, challenge, chlen);
    Dprintf(("ask_for_data returns %d", retval));
    if (retval != PAM_SUCCESS) {
        free(challenge); /* nothing sensitive */
        challenge = NULL;
        do_return(retval);
    }

    idlen = 0;
    if (dlen < 4+sizeof(digest) ||
            (idlen = pamc_read__u32(dat), dlen != 4+idlen+4+sizeof(digest))) {
        Dprintf(("reply has illegal dlen=%u with idlen=%u", dlen, idlen));
        Dprintf(("The data was: %02x %02x %02x %02x %02x %02x %02x %02x",
            dat[0], dat[1], dat[2], dat[3],
            dat[4], dat[5], dat[6], dat[7]));
        memset(dat, 0, dlen);
        free(dat);
        dat = NULL;
        free(challenge); /* nothing sensitive */
        challenge = NULL;
        do_return(PAM_AUTH_ERR);
    }
    id = dat+4;

    /* Read user's file with allowed secrets. */
    retval = find_secret(pamh, id, idlen, &secret, &seclen);
    Dprintf(("find_secret returns %d", retval));
    if (retval != PAM_SUCCESS) {
        memset(id, 0, idlen);
        free(dat); /* the rest isn't sensitive */
        dat = NULL;
        free(challenge); /* nothing sensitive */
        challenge = NULL;
        do_return(retval);
    }

    /* Verify MD5sum */
    MD5Init(&context);
    MD5Update(&context, challenge, chlen);
    free(challenge);            /* nothing sensitive */
    challenge = NULL;
    MD5Update(&context, secret, seclen);
    memset(secret, 0, seclen);  /* it's sensitive information */
    free(secret);
    secret = NULL;
    MD5Final(digest, &context);
    retval = (memcmp(digest, dat+4+idlen+4, sizeof(digest)) ?
                 PAM_AUTH_ERR : PAM_SUCCESS);
    memset(id, 0, idlen);
    free(dat); /* the rest isn't sensitive */
    dat = NULL;

    Dprintf(("%d is for return", retval));
    closelog();
    do_return(retval);

#undef do_return
}


PAM_EXTERN
int pam_sm_setcred(pam_handle_t *pamh, int flags,
    int argc, const char **argv)
{
    return PAM_SUCCESS;
}


/* Logging implementation */
static void report_no_mem()
{
    syslog(LOG_CRIT, "no memory");
    closelog();
}

static void dbg_head(const char *file, int line)
{
    syslog(LOG_DEBUG, "%s:%d wants to report:", file, line);
}

static void dbg_printf(const char *fmt, ...)
{
    va_list args;
    va_start(args, fmt);
    vsyslog(LOG_DEBUG, fmt, args);
    va_end(args);
}

static void do_log(int pri, const char *fmt, ...)
{
    va_list args;
    va_start(args, fmt);
    vsyslog(pri, fmt, args);
    va_end(args);
}


#ifdef PAM_STATIC

/* static module data */

struct pam_module _pam_sh_secret_modstruct = {
    "pam_sh_secret",
    pam_sm_authenticate,
    pam_sm_setcred,
    NULL,
    NULL,
    NULL,
    NULL,
};

#endif
