/* Copyright (C) 2021,2022 fef <owo@fef.moe>.  All rights reserved. */

#include <arch/atom.h>
#include <arch/cpufunc.h>
#include <arch/page.h>

#include <gay/cdefs.h>
#include <gay/clist.h>
#include <gay/config.h>
#include <gay/kprintf.h>
#include <gay/ktrace.h>
#include <gay/mm.h>
#include <gay/poison.h>
#include <gay/systm.h>
#include <gay/types.h>
#include <gay/vm/page.h>

/*
 * XXX this implementation is still missing object caches
 */

#if CFG_POISON_SLABS
struct slab_poison {
	void *_pad __unused;	/**< @brief That's where the freelist pointer is stored */
	void *alloc_source;	/**< @brief Code address that made the alloc call */
	u_long exact_size;
	u_long low_poison;
	u8 data[0];
	u_long high_poison[1];
};

static void poison_after_alloc(struct slab_poison *poison, u_int exact_size, void *alloc_source);
static void poison_after_free(struct slab_poison *poison);
#endif

#if CFG_DEBUG_SLAB_ALLOCS
#	define slab_debug(msg, ...) kprintf("[slab] " msg, ##__VA_ARGS__)
#	define SLAB_DEBUG_BLOCK
#	define SLAB_ASSERT KASSERT
#	if CFG_DEBUG_SLAB_ALLOCS_NOISY
#		define slab_debug_noisy(msg, ...) kprintf("[slab] " msg, ##__VA_ARGS__)
#	else
#		define slab_debug_noisy(msg, ...) ({})
#	endif
#else
#	define SLAB_DEBUG_BLOCK if (0)
#	define SLAB_ASSERT(x) ({})
#	define slab_debug(msg, ...) ({})
#	define slab_debug_noisy(msg, ...) ({})
#endif

struct slab_pool {
	const u_int entry_size;		/**< @brief Size of one entry in bytes */
	const u_int entries_per_slab;	/**< @brief Max number of entries per slab */
	atom_t total_used;		/**< @brief Total allocated entries */
	const u_int page_order;		/**< @brief Order passed to `get_pages()` */
	struct clist empty_list;	/* -> struct vm_page::link */
	struct clist partial_list;	/* -> struct vm_page::link */
	struct clist full_list;		/* -> struct vm_page::link */
	spin_t empty_lock;		/**< @brief Lock for `empty_list` */
	spin_t partial_lock;		/**< @brief Lock for `partial_list` */
	spin_t full_lock;		/**< @brief Lock for `full_list` */
	atom_t empty_count;		/**< @brief Number of empty slabs */
	atom_t partial_count;		/**< @brief Number of partially empty slabs */
	atom_t full_count;		/**< @brief Number of full slabs */
};

/*
 * Fun size calculations because the slab header takes up some overhead at the
 * beginning of each page.  We should ideally try to cram all the info we need
 * into struct vm_page, because the individual slab entry sizes could be even
 * powers of two and perfectly aligned then.
 */
#define _MIN1(x) ((x) < 1 ? 1 : (x))
#define POOL_ENTRIES_PER_TABLE(sz) _MIN1(PAGE_SIZE / (sz))

#define POOL_DEFINE(sz) {					\
	.entry_size		= (sz),				\
	.entries_per_slab	= POOL_ENTRIES_PER_TABLE(sz),	\
        .total_used		= ATOM_DEFINE(0),		\
	.page_order		= ((sz) - 1) / PAGE_SIZE,	\
        .empty_lock		= SPIN_DEFINE,			\
	.partial_lock		= SPIN_DEFINE,			\
	.full_lock		= SPIN_DEFINE,			\
	.empty_count		= ATOM_DEFINE(0),		\
	.partial_count		= ATOM_DEFINE(0),		\
	.full_count		= ATOM_DEFINE(0),		\
}

static struct slab_pool slab_pools_normal[] = {
	POOL_DEFINE(32),
	POOL_DEFINE(64),
	POOL_DEFINE(128),
	POOL_DEFINE(256),
	POOL_DEFINE(512),
	POOL_DEFINE(1024),
	POOL_DEFINE(2048),
	POOL_DEFINE(4096),
	POOL_DEFINE(8192),
	POOL_DEFINE(16384),
	POOL_DEFINE(32768),
	{ /* terminator */ }
};
static struct slab_pool slab_pools_dma[] = {
	POOL_DEFINE(32),
	POOL_DEFINE(64),
	POOL_DEFINE(128),
	POOL_DEFINE(256),
	POOL_DEFINE(512),
	POOL_DEFINE(1024),
	{ /* terminator */ }
};
#undef _MIN1 /* we don't wanna end up using this in actual code, do we? */

static struct slab_pool *slab_zone_pools[MM_NR_ZONES] = {
	[_M_ZONE_DMA]		= slab_pools_dma,
	[_M_ZONE_NORMAL]	= slab_pools_normal,
};

static vm_page_t slab_create(struct slab_pool *pool, enum mflags flags);

void kmalloc_init(void)
{
	for (int i = 0; i < MM_NR_ZONES; i++) {
		struct slab_pool *pool = slab_zone_pools[i];

		while (pool->entry_size != 0) {
			clist_init(&pool->empty_list);
			clist_init(&pool->partial_list);
			clist_init(&pool->full_list);
			pool++;
		}
	}
}

void *kmalloc(usize size, enum mflags flags)
{
	if (size == 0)
		return nil;

#if CFG_POISON_SLABS
	size += sizeof(struct slab_poison);
#endif

	SLAB_DEBUG_BLOCK {
		if (!(flags & _M_NOWAIT) && in_irq()) {
			slab_debug("kmalloc() called from irq without M_NOWAIT "
				   "(caller: %p)\n", ktrace_return_addr());
			flags |= _M_NOWAIT;
		}
	}

	SLAB_ASSERT(_M_ZONE_INDEX(flags) < ARRAY_SIZE(slab_zone_pools));
	struct slab_pool *pool = slab_zone_pools[_M_ZONE_INDEX(flags)];
	while (pool->entry_size != 0) {
		if (pool->entry_size >= size)
			break;
		pool++;
	}

	if (pool->entry_size == 0) {
		slab_debug("Refusing to allocate %zu bytes in zone %d (limit is %u)\n",
			   size, _M_ZONE_INDEX(flags), pool[-1].entry_size);
		return nil;
	}

	slab_debug_noisy("alloc %zu bytes from zone %d, pool size %u\n",
			 size, _M_ZONE_INDEX(flags), pool->entry_size);

	/*
	 * Before locking a slab, we always remove it from its pool.
	 * This is far from optimal, because if multiple CPUs allocate from the
	 * same pool at the same time, we could end up creating several slabs
	 * with one used entry each (not to mention the overhead of the mostly
	 * unnecessary list deletions/insertions).  However, it allows me to be
	 * lazier when freeing unused slabs from a background thread since that
	 * thread knows for sure that once it has removed a slab from free_list,
	 * it can't possibly be used for allocations anymore.
	 * This is probably not worth the overhead, though.
	 */
	vm_page_t page = INVALID_PAGE;

	/* try to use a slab that is already partially used first */
	register_t cpuflags = intr_disable();
	spin_lock(&pool->partial_lock);
	if (!clist_is_empty(&pool->partial_list)) {
		atom_dec(&pool->partial_count);
		page = clist_del_first_entry(&pool->partial_list, typeof(*page), link);
	}
	spin_unlock(&pool->partial_lock);

	if (!page) {
		/* no partially used slab available, see if we have a completely free one */
		spin_lock(&pool->empty_lock);
		if (!clist_is_empty(&pool->empty_list)) {
			atom_dec(&pool->empty_count);
			page = clist_del_first_entry(&pool->empty_list, typeof(*page), link);
		}
		spin_unlock(&pool->empty_lock);

		if (!page) {
			/* we're completely out of usable slabs, allocate a new one */
			intr_restore(cpuflags);
			page = slab_create(pool, flags);
			if (!page) {
				slab_debug("kernel OOM\n");
				return nil;
			}
			intr_disable();
		}
	}

	/* if we've made it to here, we have a slab and interrupts are disabled */
	page_lock(page);
	void *ret = page->slab.freelist;
	SLAB(page)->freelist = *SLAB(page)->freelist;
	if (--page->slab.free_count == 0) {
		spin_lock(&pool->full_lock);
		clist_add(&pool->full_list, &page->link);
		spin_unlock(&pool->full_lock);
		atom_inc(&pool->full_count);
	} else {
		spin_lock(&pool->partial_lock);
		clist_add(&pool->partial_list, &page->link);
		spin_unlock(&pool->partial_lock);
		atom_inc(&pool->partial_count);
	}
	page_unlock(page);
	intr_restore(cpuflags);

	atom_inc(&pool->total_used);

#if CFG_POISON_SLABS
	struct slab_poison *poison = ret;
	poison_after_alloc(poison, size - sizeof(*poison), ktrace_return_addr());
	ret = poison->data;
#endif
	return ret;
}

void kfree(void *ptr)
{
	if (ptr == nil)
		return;

	SLAB_ASSERT(ptr >= DMAP_START && ptr < DMAP_END);

	vm_page_t page = vaddr2pg(ptr);
	SLAB_ASSERT(pga_slab(page));
	struct slab_pool *pool = SLAB(page)->pool;
#if CFG_POISON_SLABS
	struct slab_poison *poison = container_of(ptr, typeof(*poison), data);
	poison_after_free(poison);
	ptr = poison;
#endif

	register_t cpuflags = intr_disable();
	page_lock(page);
	*(void **)ptr = SLAB(page)->freelist;
	SLAB(page)->freelist = (void **)ptr;
	if (++SLAB(page)->free_count == pool->entries_per_slab) {
		spin_lock(&pool->partial_lock);
		clist_del(&page->link);
		spin_unlock(&pool->partial_lock);
		atom_dec(&pool->partial_count);

		spin_lock(&pool->empty_lock);
		clist_add(&pool->empty_list, &page->link);
		spin_unlock(&pool->empty_lock);
		atom_inc(&pool->empty_count);
	}
	page_unlock(page);
	atom_dec(&pool->total_used);
	intr_restore(cpuflags);
}

static vm_page_t slab_create(struct slab_pool *pool, enum mflags flags)
{
	slab_debug_noisy("Creating new cache for entry_size %u\n", pool->entry_size);
	vm_page_t page = page_alloc(pool->page_order, flags);

	if (page) {
		pga_set_slab(page, true);
		SLAB(page)->pool = pool;
		SLAB(page)->free_count = pool->entries_per_slab;
		void *prev = nil;
		/* XXX this should not rely on a direct map */
		void *start = pfn2vaddr(pg2pfn(page));
		void *end = start + (1 << (pool->page_order + PAGE_SHIFT));
		void *pos = end;
		do {
			pos -= pool->entry_size;
			*(void **)pos = prev;
			prev = pos;
		} while (pos > start);
		SLAB(page)->freelist = pos;
	}

	return page;
}

#if CFG_POISON_SLABS

static inline void poison_after_alloc(struct slab_poison *poison, u_int exact_size,
				      void *alloc_source)
{
	u_int offset = align_ceil(poison->exact_size, sizeof(long)) / sizeof(long);
	u_long *poison_start = &poison->low_poison;

	/*
	 * page_alloc() always initializes the allocated page to zeroes.
	 * Therefore, if exact_size is 0, we know this particular slab entry has
	 * never been used before, and we can skip the check.
	 */
	if (poison->exact_size != 0) {
		for (u_long *pos = poison_start; pos < &poison->high_poison[offset]; pos++) {
			if (*pos != SLAB_POISON_FREE) {
				kprintf("Use-after-free in %p (alloc by %p)\n",
					poison->data, poison->alloc_source);
				break;
			}
		}
	}

	/* update offset to the new size */
	offset = align_ceil(exact_size, sizeof(long)) / sizeof(long);

	poison->alloc_source = alloc_source;
	poison->exact_size = exact_size;
	for (u_long *pos = &poison->low_poison; pos <= &poison->high_poison[offset]; pos++)
		*pos = SLAB_POISON_ALLOC;
}

static inline void poison_after_free(struct slab_poison *poison)
{
	u_int offset = align_ceil(poison->exact_size, sizeof(long)) / sizeof(long);

	if (poison->low_poison != SLAB_POISON_ALLOC) {
		kprintf("Low out-of-bounds write to %p (alloc by %p)\n",
			poison->data, poison->alloc_source);
	}

	if (poison->high_poison[offset] != SLAB_POISON_ALLOC) {
		kprintf("High out-of-bounds write to %p (alloc by %p)\n",
			poison->data, poison->alloc_source);
	}

	for (u_long *pos = &poison->low_poison; pos <= &poison->high_poison[offset]; pos++)
		*pos = SLAB_POISON_FREE;
}

#endif /* CFG_POISON_SLABS */

/*
 * for certain libc routines
 */

__weak void *malloc(usize size)
{
	return kmalloc(size, M_KERN);
}

__weak void free(void *ptr)
{
	kfree(ptr);
}