/* Copyright (C) 2021,2022 fef <owo@fef.moe>.  All rights reserved. */

#include <gay/bits.h>
#include <gay/types.h>

#include <limits.h>

void bit_set_range(unsigned long *bitfield, usize first, usize count)
{
	/* skip ahead to the longword containing the first bit we need to set */
	bitfield += first / LONG_BIT;
	unsigned int bit = first % LONG_BIT;

	/* test if the entire bit range is contained within this longword */
	if (bit + count < LONG_BIT) {
		unsigned long low_mask = (1lu << bit) - 1;			/* 0b000..011 */
		unsigned long high_mask = ~( (1lu << (bit + count)) - 1 );	/* 0b110..000 */
		*bitfield |= ~(low_mask | high_mask);
	} else {
		/* if the first bit isn't longword aligned, manually set the upper
		 * bits in that longword, starting from the first bit's position */
		if (bit != 0) {
			unsigned long mask = (1lu << bit) - 1;
			*bitfield++ |= ~mask;
			count -= LONG_BIT - bit;
		}

		/* write out full longwords while we can */
		while (count >= LONG_BIT) {
			*bitfield++ = ~0lu;
			count -= LONG_BIT;
		}

		/* set the remaining lower bits in the last longword, if any */
		if (count != 0) {
			unsigned long mask = (1lu << count) - 1;
			*bitfield |= mask;
		}
	}
}

/* this works the same way as bit_set_range, it's just the inverse */
void bit_clr_range(unsigned long *bitfield, usize first, usize count)
{
	bitfield += first / LONG_BIT;
	unsigned int bit = first % LONG_BIT;

	if (bit + count < LONG_BIT) {
		unsigned long low_mask = (1lu << bit) - 1;
		unsigned long high_mask = ~( (1lu << (bit + count)) - 1 );
		*bitfield &= low_mask | high_mask;
	} else {
		if (bit != 0) {
			unsigned long mask = (1lu << bit) - 1;
			*bitfield++ &= mask;
			count -= LONG_BIT - bit;
		}

		while (count >= LONG_BIT) {
			*bitfield++ = 0;
			count -= LONG_BIT;
		}

		if (count != 0) {
			unsigned long mask = (1lu << count) - 1;
			*bitfield &= ~mask;
		}
	}
}

bool bit_tst(const unsigned long *bitfield, usize pos)
{
	usize index = pos / LONG_BIT;
	unsigned long mask = 1 << (pos % LONG_BIT);
	return (bitfield[index] & mask) != 0;
}

void bit_set(unsigned long *bitfield, usize pos)
{
	usize index = pos / LONG_BIT;
	unsigned long mask = 1 << (pos % LONG_BIT);
	bitfield[index] |= mask;
}

void bit_clr(unsigned long *bitfield, usize pos)
{
	usize index = pos / LONG_BIT;
	unsigned long mask = 1 << (pos % LONG_BIT);
	bitfield[index] &= ~mask;
}