/* Copyright (C) 2021,2022 fef . All rights reserved. */ /* * Ok, i'm gonna be honest, this is one of those functions where i wish i was * writing them in Rust rather than C. Anyway, the tricky part of this entire * implementation is that we can't call kmalloc() because that uses kprintf() * internally for debug messages and the such, so it could potentially end in an * infinite loop. But, since format strings allow arbitrary padding, we need to * resort to our good old friend alloca(). * Also, this code is probably buggy as shit and i only tested some basic format * sequences, so don't push it too far. Be gentle. * I hope i never have to touch this again. */ #include #include #include #include #include #include static struct kprintf_printer *printer = NULL; int kprintf_set_printer(struct kprintf_printer *new) { int ret = 0; if (printer != NULL) ret = printer->flush(printer); printer = new; return ret; } enum length_modifier { LENGTH_DEFAULT = 0, LENGTH_H, LENGTH_HH, LENGTH_L, LENGTH_LL, LENGTH_J, LENGTH_T, LENGTH_Z, }; struct fmt_sequence { isize (*render)(const struct fmt_sequence *sequence, va_list *ap); unsigned int min_width; unsigned int max_precision; enum length_modifier length_modifier; struct { bool hash; bool zero; bool minus; bool space; bool plus; bool apos; } flags __packed; /* save some bytes on the stack :) */ bool uppercase; }; static void parse_fmt_sequence(struct fmt_sequence *sequence, const char **restrict posptr); /** @brief Write a NUL terminated string using the current `printer`. */ static isize write_asciz(const char *s); /** @brief Write a specific amount of bytes using the current `printer`. */ static isize write_bytes(const void *buf, usize len); int kvprintf(const char *fmt, va_list _args) { isize ret = 0; const char *tmp = fmt; va_list args; va_copy(args, _args); while (*tmp != '\0') { if (*tmp++ == '%') { /* write out everything we have so far (minus one char for %) */ isize write_ret = write_bytes(fmt, (usize)tmp - (usize)fmt - 1); if (write_ret < 0) { ret = write_ret; break; } ret += write_ret; isize fmt_ret = 0; struct fmt_sequence sequence; parse_fmt_sequence(&sequence, &tmp); if (sequence.render != NULL) fmt_ret = sequence.render(&sequence, &args); /* * act as if the current position were the beginning in * order to make the first step of this if block easier */ fmt = tmp; if (fmt_ret < 0) { ret = fmt_ret; break; } ret += fmt_ret; } } if (tmp != fmt && ret >= 0) { isize render_ret = write_bytes(fmt, (usize)tmp - (usize)fmt); if (render_ret < 0) ret = render_ret; else ret += render_ret; } isize flush_ret = printer->flush(printer); if (flush_ret < 0) ret = flush_ret; else ret += flush_ret; va_end(args); return (int)ret; } int kprintf(const char *fmt, ...) { int ret; va_list args; va_start(args, fmt); ret = kvprintf(fmt, args); va_end(args); return ret; } static isize render_c(const struct fmt_sequence *sequence, va_list *ap); static isize render_d(const struct fmt_sequence *sequence, va_list *ap); static isize render_o(const struct fmt_sequence *sequence, va_list *ap); static isize render_p(const struct fmt_sequence *sequence, va_list *ap); static isize render_s(const struct fmt_sequence *sequence, va_list *ap); static isize render_u(const struct fmt_sequence *sequence, va_list *ap); static isize render_x(const struct fmt_sequence *sequence, va_list *ap); static isize render_percent(const struct fmt_sequence *sequence, va_list *ap); /* * Oh boi this is gonna be fun. * So, this is basically a step by step implementation of the FreeBSD manpage * for printf(), except that there is no support for all the deprecated * specifiers and the (imho insane) $ directive: * */ void parse_fmt_sequence(struct fmt_sequence *sequence, const char **restrict posptr) { memset(sequence, 0, sizeof(*sequence)); if (**posptr == '%') { /* %% */ sequence->render = render_percent; *posptr += 1; return; } /* * parse optional flags */ bool continue_parse_flags = true; while (continue_parse_flags) { switch (**posptr) { case '#': sequence->flags.hash = true; break; case '0': sequence->flags.zero = true; break; case '-': sequence->flags.minus = true; break; case ' ': /* the FreeBSD manpage says plus overrides space if both are used */ if (!sequence->flags.plus) sequence->flags.space = true; break; case '+': sequence->flags.plus = true; sequence->flags.space = false; break; case '\'': sequence->flags.apos = true; break; default: continue_parse_flags = false; break; } *posptr += 1; } *posptr -= 1; /* * parse optional minimum digits */ while (**posptr >= '0' && **posptr <= '9') { sequence->min_width *= 10; sequence->min_width += **posptr - '0'; if (sequence->max_precision > 128) sequence->max_precision = 128; *posptr += 1; } /* * parse optional maximum precision */ if (**posptr == '.') { *posptr += 1; while (**posptr >= '0' && **posptr <= '9') { sequence->max_precision *= 10; sequence->max_precision += **posptr - '0'; /* sanitize length (prevents stack overflow) */ if (sequence->max_precision > 128) sequence->max_precision = 128; *posptr += 1; } } /* * parse optional length modifier */ switch (**posptr) { case 'h': case 'H': if ((*posptr)[1] == 'h' || (*posptr)[1] == 'H') { sequence->length_modifier = LENGTH_HH; *posptr += 2; } else { sequence->length_modifier = LENGTH_H; *posptr += 1; } break; case 'l': case 'L': if ((*posptr)[1] == 'l' || (*posptr)[1] == 'L') { sequence->length_modifier = LENGTH_LL; *posptr += 2; } else { sequence->length_modifier = LENGTH_L; *posptr += 1; } break; case 'j': case 'J': sequence->length_modifier = LENGTH_J; *posptr += 1; break; case 't': case 'T': sequence->length_modifier = LENGTH_T; *posptr += 1; break; case 'z': case 'Z': sequence->length_modifier = LENGTH_Z; *posptr += 1; break; default: break; } /* * parse type specifier */ switch (**posptr) { case 'C': sequence->length_modifier = LENGTH_L; /* fall through */ case 'c': sequence->render = render_c; break; case 'd': case 'i': sequence->render = render_d; break; case 'o': sequence->render = render_o; break; case 'P': sequence->uppercase = true; /* fall through */ case 'p': sequence->render = render_p; break; case 'S': sequence->length_modifier = LENGTH_L; /* fall through */ case 's': sequence->render = render_s; break; case 'u': sequence->render = render_u; break; case 'X': sequence->uppercase = true; /* fall through */ case 'x': sequence->render = render_x; break; default: sequence->render = NULL; break; } *posptr += 1; } static ssize_t render_c(const struct fmt_sequence *sequence, va_list *ap) { /* we don't support wchars until we have a UTF-8 encoder */ if (sequence->length_modifier != LENGTH_DEFAULT) return -1; char val = (char)va_arg(*ap, int); return printer->write(printer, &val, sizeof(val)); } static ssize_t render_s(const struct fmt_sequence *sequence, va_list *ap) { /* * the string is a wchar_t if LENGTH_L is set, but that would require * a full UTF-8 encoder which i won't write in the near future. Cope. */ if (sequence->length_modifier != LENGTH_DEFAULT) return -1; const char *s = va_arg(*ap, char *); if (s == nil) return write_asciz("(null)"); if (sequence->max_precision) return write_bytes(s, strnlen(s, sequence->max_precision)); else return write_asciz(s); } static inline intmax_t get_arg_signed(const struct fmt_sequence *sequence, va_list *ap) { switch (sequence->length_modifier) { case LENGTH_H: case LENGTH_HH: case LENGTH_DEFAULT: /* short and char will be promoted to int with parameter passing */ return va_arg(*ap, int); case LENGTH_L: return va_arg(*ap, long); case LENGTH_LL: return va_arg(*ap, long long); case LENGTH_Z: return va_arg(*ap, isize); case LENGTH_J: return va_arg(*ap, intmax_t); case LENGTH_T: return va_arg(*ap, intptr_t); } } static inline uintmax_t get_arg_unsigned(const struct fmt_sequence *sequence, va_list *ap) { switch (sequence->length_modifier) { case LENGTH_H: case LENGTH_HH: case LENGTH_DEFAULT: /* short and char will be promoted to int with parameter passing */ return va_arg(*ap, unsigned int); case LENGTH_L: return va_arg(*ap, unsigned long); case LENGTH_LL: return va_arg(*ap, unsigned long long); case LENGTH_Z: return va_arg(*ap, usize); case LENGTH_J: return va_arg(*ap, uintmax_t); case LENGTH_T: return va_arg(*ap, uintptr_t); } } static isize render_d(const struct fmt_sequence *sequence, va_list *ap) { isize ret = 0; intmax_t val = get_arg_signed(sequence, ap); if (val < 0) { val = -val; ret = write_asciz("-"); if (ret < 0) return ret; } else if (sequence->flags.plus) { ret = write_asciz("+"); if (ret < 0) return ret; } else if (sequence->flags.space) { ret = write_asciz(" "); if (ret < 0) return ret; } usize len = 20; /* 2**64 has 20 decimal digits, let's hope intmax_t isn't 128 bits */ if (sequence->min_width > len) len = sequence->min_width; char *buf = alloca(len); char *pos = &buf[len - 1]; do { *pos-- = (char)(val % 10) + '0'; /* NOLINT */ val /= 10; } while (val > 0); char fillchr; if (sequence->flags.zero) fillchr = '0'; else fillchr = ' '; while (sequence->min_width > len - (pos - buf) - 1) *pos-- = fillchr; pos += 1; isize tmp = write_bytes(pos, len - (pos - buf)); if (tmp > 0) ret += tmp; else ret = tmp; return ret; } static isize render_o(const struct fmt_sequence *sequence, va_list *ap) { isize ret = 0; if (sequence->flags.plus) { ret = write_asciz("+"); if (ret < 0) return ret; } else if (sequence->flags.space) { ret = write_asciz(" "); if (ret < 0) return ret; } uintmax_t val = get_arg_unsigned(sequence, ap); usize len = 22; /* 2**64 has 22 octal digits, let's hope intmax_t isn't 128 bits */ if (sequence->min_width > len) len = sequence->min_width; char *buf = alloca(len); char *pos = &buf[len - 1]; do { *pos-- = (char)(val % 010) + '0'; /* NOLINT */ val /= 010; } while (val > 0); char fillchr; if (sequence->flags.zero) fillchr = '0'; else fillchr = ' '; while (sequence->min_width > len - (pos - buf) - 1) *pos-- = fillchr; pos++; isize tmp = write_bytes(pos, len - (pos - buf)); if (tmp > 0) ret += tmp; else ret = tmp; return ret; } static const char *const digit_table_smol = "0123456789abcdef"; static const char *const digit_table_big = "0123456789ABCDEF"; static isize render_p(const struct fmt_sequence *sequence, va_list *ap) { /* 2 hex digits per byte + 2 for 0x prefix */ char buf[sizeof(uintptr_t) * 2 + 2]; char *pos = &buf[sizeof(uintptr_t) * 2 + 1]; const char *digit_table; uintptr_t ptr = va_arg(*ap, uintptr_t); buf[0] = '0'; buf[1] = 'x'; if (sequence->uppercase) digit_table = digit_table_big; else digit_table = digit_table_smol; while (pos > &buf[1]) { *pos-- = digit_table[ptr % 0x10]; ptr /= 0x10; } return write_bytes(buf, sizeof(buf)); } static isize render_u(const struct fmt_sequence *sequence, va_list *ap) { isize ret = 0; if (sequence->flags.plus) { ret = write_asciz("+"); if (ret < 0) return ret; } else if (sequence->flags.space) { ret = write_asciz(" "); if (ret < 0) return ret; } uintmax_t val = get_arg_unsigned(sequence, ap); usize len = 20; /* 2^64 has 20 decimal digits, let's hope intmax_t isn't 128 bits */ if (sequence->min_width > len) len = sequence->min_width; char *buf = alloca(len); char *pos = &buf[len - 1]; do { *pos-- = (char)(val % 10) + '0'; /* NOLINT */ val /= 10; } while (val > 0); char fillchr; if (sequence->flags.zero) fillchr = '0'; else fillchr = ' '; while (sequence->min_width > len - (pos - buf) - 1) *pos-- = fillchr; pos++; isize tmp = write_bytes(pos, len - (pos - buf)); if (tmp > 0) ret += tmp; else ret = tmp; return ret; } static isize render_x(const struct fmt_sequence *sequence, va_list *ap) { char *buf; usize len = sizeof(uintmax_t) * 2; /* 2 hex digits per byte */ if (len < sequence->min_width) len = sequence->min_width; buf = alloca(len); char *pos = &buf[len - 1]; uintmax_t val = get_arg_unsigned(sequence, ap); const char *digit_table; if (sequence->uppercase) digit_table = digit_table_big; else digit_table = digit_table_smol; do { *pos-- = digit_table[val % 0x10]; val /= 0x10; } while (val > 0); char fillchr; if (sequence->flags.zero) fillchr = '0'; else fillchr = ' '; while (sequence->min_width > len - (pos - buf) - 1) *pos-- = fillchr; pos++; return write_bytes(pos, len - (pos - buf)); } static isize render_percent(const struct fmt_sequence *sequence, va_list *ap) { return write_asciz("%"); } static inline isize write_asciz(const char *s) { return printer->write(printer, s, strlen(s)); } static inline isize write_bytes(const void *buf, usize len) { return printer->write(printer, buf, len); }