diff --git a/usr.bin/mdo/mdo.c b/usr.bin/mdo/mdo.c --- a/usr.bin/mdo/mdo.c +++ b/usr.bin/mdo/mdo.c @@ -5,6 +5,7 @@ */ #include +#include #include #include @@ -27,6 +28,8 @@ { struct passwd *pw; const char *username = "root"; + struct setcred wcred = SETCRED_INITIALIZER; + u_int setcred_flags = 0; bool uidonly = false; int ch; @@ -50,20 +53,45 @@ const char *errp = NULL; uid_t uid = strtonum(username, 0, UID_MAX, &errp); if (errp != NULL) - err(EXIT_FAILURE, "%s", errp); + err(EXIT_FAILURE, "invalid user ID '%s'", + username); pw = getpwuid(uid); } if (pw == NULL) err(EXIT_FAILURE, "invalid username '%s'", username); } + + wcred.sc_uid = wcred.sc_ruid = wcred.sc_svuid = pw->pw_uid; + setcred_flags |= SETCREDF_UID | SETCREDF_RUID | SETCREDF_SVUID; + if (!uidonly) { - if (initgroups(pw->pw_name, pw->pw_gid) == -1) - err(EXIT_FAILURE, "failed to call initgroups"); - if (setgid(pw->pw_gid) == -1) - err(EXIT_FAILURE, "failed to call setgid"); + /* + * If there are too many groups specified for some UID, setting + * the groups will fail. We preserve this condition by + * allocating one more group slot than allowed, as + * getgrouplist() itself is just some getter function and thus + * doesn't (and shouldn't) check the limit, and to allow + * setcred() to actually check for overflow. + */ + const long ngroups_alloc = sysconf(_SC_NGROUPS_MAX) + 2; + gid_t *const groups = malloc(sizeof(*groups) * ngroups_alloc); + int ngroups = ngroups_alloc; + + if (groups == NULL) + err(EXIT_FAILURE, "cannot allocate memory for groups"); + + getgrouplist(pw->pw_name, pw->pw_gid, groups, &ngroups); + + wcred.sc_gid = wcred.sc_rgid = wcred.sc_svgid = pw->pw_gid; + wcred.sc_supp_groups = groups + 1; + wcred.sc_supp_groups_nb = ngroups - 1; + setcred_flags |= SETCREDF_GID | SETCREDF_RGID | SETCREDF_SVGID | + SETCREDF_SUPP_GROUPS; } - if (setuid(pw->pw_uid) == -1) - err(EXIT_FAILURE, "failed to call setuid"); + + if (setcred(setcred_flags, &wcred, sizeof(wcred)) != 0) + err(EXIT_FAILURE, "calling setcred() failed"); + if (*argv == NULL) { const char *sh = getenv("SHELL"); if (sh == NULL)