/* Copyright (C) 2021,2022 fef <owo@fef.moe>.  All rights reserved. */

/*
 * slabbing slabs onto the slab for slabs slab slab slahsdf ashklfghdsla
 */

#include <arch/atom.h>
#include <arch/cpufunc.h>
#include <arch/page.h>

#include <gay/bits.h>
#include <gay/cdefs.h>
#include <gay/clist.h>
#include <gay/config.h>
#include <gay/kprintf.h>
#include <gay/ktrace.h>
#include <gay/mm.h>
#include <gay/poison.h>
#include <gay/systm.h>
#include <gay/types.h>
#include <gay/vm/page.h>

#include <strings.h>

#if CFG_POISON_SLABS
struct slab_poison {
	void *_pad __unused;	/**< @brief That's where the freelist pointer is stored */
	void *alloc_source;	/**< @brief Code address that made the alloc call */
	u_long exact_size;
	u_long low_poison;
	u8 data[0];
	u_long high_poison[1];
};

static void poison_on_alloc(struct slab_poison *poison, u_long exact_size, void *alloc_source);
static void poison_on_free(struct slab_poison *poison);
#endif

#if CFG_DEBUG_SLAB_ALLOCS
#	define slab_debug(msg, ...) kprintf("[slab] " msg, ##__VA_ARGS__)
#	define SLAB_DEBUG_BLOCK
#	define SLAB_ASSERT KASSERT
#	if CFG_DEBUG_SLAB_ALLOCS_NOISY
#		define slab_debug_noisy(msg, ...) kprintf("[slab] " msg, ##__VA_ARGS__)
#	else
#		define slab_debug_noisy(msg, ...) ({})
#	endif
#else
#	define SLAB_DEBUG_BLOCK if (0)
#	define SLAB_ASSERT(x) ({})
#	define slab_debug(msg, ...) ({})
#	define slab_debug_noisy(msg, ...) ({})
#endif

/**
 * @brief Single node in the object cache system.
 * Each node owns a page
 */
struct kmem_cache_node {
	struct clist link;		/* -> struct kmem_cache_pool::list */
	void **freelist;		/**< @brief Stack of free objects */
	struct kmem_cache *cache;	/**< @brief Object cache this node belongs to */
	spin_t lock;			/**< @brief Lock for `freelist` */
	u_int free_count;
	vm_page_t page;			/**< @brief Physical page this node manages */
};

struct kmem_cache_pool {
	struct clist list;	/* -> struct kmem_cache_node::link */
	spin_t lock;
	atom_t count;
};

/**
 * @brief Cache for one particular object type.
 * A pool holds multiple nodes, each of which hold the same number of slabs.
 */
struct kmem_cache {
	u_int object_size;		/**< @brief Object size in bytes */
	u_int page_order;		/**< @brief Order passed to `get_pages()` */
	enum slab_flags flags;		/**< @brief Flags for how to allocate */
	u_int slabs_per_node;		/**< @brief Max number of slabs per cache node */
	latom_t total_used;		/**< @brief Total allocated entries */
	const char *name;		/**< @brief Unique name for this object type */
	void (*ctor)(void *ptr, kmem_cache_t cache);
	void (*dtor)(void *ptr, kmem_cache_t cache);
	struct kmem_cache_pool empty;
	struct kmem_cache_pool partial;
	struct kmem_cache_pool full;
	struct clist link;		/**< @brief List of all kmem caches */
};

/* values for struct kmem_cache::flags */

/** @brief Zone to request pages from (using `page_alloc()`) */
#define SLAB_ZONE(flags)	((flags) & 3)

/** @brief List of all currently registered `struct kmem_cache`s. */
static CLIST(kmem_cache_list);

#define _MIN1(x) ((x) < 1 ? 1 : (x))
#define SLABS_PER_NODE(sz) _MIN1(PAGE_SIZE / (sz))

#define CACHE_DEFINE(sz, _name, _flags) {			\
	.object_size		= (sz),				\
	.page_order		= ((sz) - 1) / PAGE_SIZE,	\
        .flags			= (_flags),			\
	.slabs_per_node		= SLABS_PER_NODE(sz),		\
        .total_used		= ATOM_DEFINE(0),		\
	.name			= (_name),			\
}

static struct kmem_cache kmem_caches[] = {
	CACHE_DEFINE(32,	"kmem_32",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(64,	"kmem_64",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(128,	"kmem_128",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(256,	"kmem_256",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(512,	"kmem_512",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(1024,	"kmem_1024",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(2048,	"kmem_2048",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(4096,	"kmem_4096",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(8192,	"kmem_8192",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(16384,	"kmem_16384",	_M_ZONE_NORMAL | SLAB_POISON),
	CACHE_DEFINE(32768,	"kmem_32768",	_M_ZONE_NORMAL | SLAB_POISON),
	{ /* terminator */ }
};
static struct kmem_cache kmem_dma_caches[] = {
	CACHE_DEFINE(32,	"kmem_dma_32",		_M_ZONE_DMA | SLAB_POISON),
	CACHE_DEFINE(64,	"kmem_dma_64",		_M_ZONE_DMA | SLAB_POISON),
	CACHE_DEFINE(128,	"kmem_dma_128",		_M_ZONE_DMA | SLAB_POISON),
	CACHE_DEFINE(256,	"kmem_dma_256",		_M_ZONE_DMA | SLAB_POISON),
	CACHE_DEFINE(512,	"kmem_dma_512",		_M_ZONE_DMA | SLAB_POISON),
	CACHE_DEFINE(1024,	"kmem_dma_1024",	_M_ZONE_DMA | SLAB_POISON),
	{ /* terminator */ }
};

/**
 * This is a little fucked.
 *
 * So, every `vm_page_t` in use by the slab allocator gets a corresponding
 * `struct kmem_cache_node` that keeps track of everything we need to know to
 * make allocations.  However, the memory for those structs themselves doesn't
 * magically grow on trees.  In other words, we need to allocate memory in
 * order to be able to allocate memory.
 *
 * So what we have here is a separate object cache for `struct kmem_cache_node`
 * that works slightly differently than all the other ones:  Instead of making
 * an extra allocation for the cache node, that node sits at the beginning of
 * the page that we allocate from itself.  Other caches don't do this because
 * it destroys the perfect page alignment of the allocated area itself, but that
 * doesn't matter here.
 */
static struct kmem_cache kmem_cache_node_caches =
	CACHE_DEFINE(sizeof(struct kmem_cache_node), "kmem_cache_node", _M_ZONE_NORMAL | SLAB_DMAP);

#undef _MIN1 /* we don't wanna end up using this in actual code, do we? */

static struct kmem_cache *kmem_cache_zones[MM_NR_ZONES] = {
	[_M_ZONE_DMA]		= kmem_dma_caches,
	[_M_ZONE_NORMAL]	= kmem_caches,
};

static void cache_pool_init(struct kmem_cache_pool *pool)
{
	clist_init(&pool->list);
	atom_init(&pool->count, 0);
	spin_init(&pool->lock);
}

void kmalloc_init(void)
{
	cache_pool_init(&kmem_cache_node_caches.empty);
	cache_pool_init(&kmem_cache_node_caches.partial);
	cache_pool_init(&kmem_cache_node_caches.full);
	/* for the management node at the beginning of the page */
	kmem_cache_node_caches.slabs_per_node--;
	clist_add(&kmem_cache_list, &kmem_cache_node_caches.link);

	for (int i = 0; i < MM_NR_ZONES; i++) {
		struct kmem_cache *cache = kmem_cache_zones[i];

		while (cache->object_size != 0) {
			clist_init(&cache->empty.list);
			clist_init(&cache->partial.list);
			clist_init(&cache->full.list);
			clist_add(&kmem_cache_list, &cache->link);
			cache++;
		}
	}
}

kmem_cache_t kmem_cache_register(const char *name, u_int obj_size, enum slab_flags flags,
				 void (*ctor)(void *ptr, kmem_cache_t cache),
				 void (*dtor)(void *ptr, kmem_cache_t cache))
{
	obj_size = align_ceil(obj_size, sizeof(long));
	/* we only support objects up to PAGE_SIZE for now */
	if (obj_size > PAGE_SIZE || obj_size == 0)
		return nil;

	struct kmem_cache *cache = kmalloc(sizeof(*cache), M_KERN);

	if (cache) {
		cache->name = name;
		cache->object_size = obj_size;
		cache->flags = flags;
		cache->ctor = ctor;
		cache->dtor = dtor;
		cache_pool_init(&cache->empty);
		cache_pool_init(&cache->partial);
		cache_pool_init(&cache->full);

		/* XXX this is pretty wasteful for larger obj_sizes */
		cache->slabs_per_node = PAGE_SIZE / obj_size;
		cache->page_order = 0;

		clist_add(&kmem_cache_list, &cache->link);
	}

	return cache;
}

static inline void **freelist_init(vm_page_t page, struct kmem_cache *cache)
{
	void *prev = nil;
	void *start = __v(pg2paddr(page));
	void *end = start + align_floor(1 << (cache->page_order + PAGE_SHIFT), cache->object_size);
	void *pos = end;

	do {
		pos -= cache->object_size;
		if (cache->ctor)
			cache->ctor(pos, cache);
		*(void **)pos = prev;
		prev = pos;
	} while (pos >= start + cache->object_size);

	return (void **)pos;
}

/** Attempt to remove a cache node from the partial/empty lists in a cache node and return it */
/* call with interrupts disabled */
static inline struct kmem_cache_node *pool_del_first_node(struct kmem_cache_pool *pool)
{
	struct kmem_cache_node *node = nil;

	spin_lock(&pool->lock);
	if (!clist_is_empty(&pool->list)) {
		atom_dec(&pool->count);
		node = clist_del_first_entry(&pool->list, typeof(*node), link);
	}
	spin_unlock(&pool->lock);

	return node;
}

/* call with interrupts disabled */
static inline void pool_del_node(struct kmem_cache_pool *pool, struct kmem_cache_node *node)
{
	atom_dec(&pool->count);
	spin_lock(&pool->lock);
	clist_del(&node->link);
	spin_unlock(&pool->lock);
}

/* call with interrupts disabled */
static inline void pool_add_node(struct kmem_cache_pool *pool, struct kmem_cache_node *node)
{
	spin_lock(&pool->lock);
	clist_add(&pool->list, &node->link);
	spin_unlock(&pool->lock);
	atom_inc(&pool->count);
}

/* call with interrupts disabled */
static inline void *pop_freelist_and_insert(struct kmem_cache *cache, struct kmem_cache_node *node)
{
	spin_lock(&node->lock);
	void *ret = node->freelist;
	node->freelist = *node->freelist;
	u_int free_count = --node->free_count;
	spin_unlock(&node->lock);

	latom_inc(&cache->total_used);
	if (free_count == 0)
		pool_add_node(&cache->full, node);
	else
		pool_add_node(&cache->partial, node);

	return ret;
}

/* call with interrupts disabled */
static struct kmem_cache_node *node_alloc(void)
{
	/*
	 * This is really the same basic procedure as kmem_cache_alloc(),
	 * except that we allocate everything manually if we run out of caches
	 * and interrupts are disabled.
	 * It definitely needs a cleanup at some point, most of the stuff here
	 * can probably be eliminated if kmem_cache_alloc() is split up.
	 */
	struct kmem_cache_node *mgmt_node = pool_del_first_node(&kmem_cache_node_caches.partial);
	if (!mgmt_node) {
		mgmt_node = pool_del_first_node(&kmem_cache_node_caches.empty);
		if (!mgmt_node) {
			vm_page_t page = page_alloc(0, M_ATOMIC);
			if (!page)
				return nil;

			void **freelist = freelist_init(page, &kmem_cache_node_caches);
			mgmt_node = (struct kmem_cache_node *)freelist;
			mgmt_node->freelist = *freelist;

			mgmt_node = __v(pg2paddr(page));
			spin_init(&mgmt_node->lock);
			mgmt_node->free_count = kmem_cache_node_caches.slabs_per_node;
			mgmt_node->cache = &kmem_cache_node_caches;
			mgmt_node->page = page;
		}
	}

	struct kmem_cache_node *new_node = pop_freelist_and_insert(&kmem_cache_node_caches,
								   mgmt_node);
	return new_node;
}

/* call with interrupts disabled */
static inline struct kmem_cache_node *node_create(struct kmem_cache *cache, enum mflags flags,
						  register_t cpuflags)
{
	struct kmem_cache_node *node = node_alloc();

	if (node) {
		intr_restore(cpuflags);
		vm_page_t page = page_alloc(cache->page_order, flags | M_ZERO);
		if (page) {
			pga_set_slab(page, true);
			page->slab = node;

			node->freelist = freelist_init(page, cache);
			spin_init(&node->lock);
			node->free_count = cache->slabs_per_node;
			node->cache = cache;
			node->page = page;
		} else {
			kfree(node);
			node = nil;
		}
		intr_disable();
	}

	return node;
}

void *kmem_cache_alloc(kmem_cache_t cache, enum mflags flags)
{
	SLAB_DEBUG_BLOCK {
		if (!(flags & _M_NOWAIT) && in_irq()) {
			slab_debug("kmem_cache_alloc() called from irq %p w/o M_NOWAIT\n",
				   ktrace_return_addr());
			flags |= _M_NOWAIT;
		}
	}

	SLAB_ASSERT(_M_ZONE_INDEX(flags) < ARRAY_SIZE(slab_zone_pools));
	slab_debug_noisy("alloc %zu bytes from zone %d, cache %s\n",
			 size, _M_ZONE_INDEX(flags), cache->name);

	/*
	 * Before locking a node, we always remove it from its cache pool.
	 * This is far from optimal, because if multiple CPUs allocate from the
	 * same pool at the same time, we could end up creating several slabs
	 * with one used entry each (not to mention the overhead of the mostly
	 * unnecessary list deletions/insertions).  However, it allows me to be
	 * lazier when freeing unused slabs from a background thread since that
	 * thread knows for sure that once it has removed a slab from free_list,
	 * it can't possibly be used for allocations anymore.
	 * This is probably not worth the overhead, though.
	 */
	struct kmem_cache_node *node = nil;

	/* try to use a slab that is already partially used first */
	register_t cpuflags = intr_disable();

	node = pool_del_first_node(&cache->partial);
	if (!node) {
		/* no partially used node available, see if we have a completely free one */
		node = pool_del_first_node(&cache->empty);
		if (!node) {
			/* we're completely out of usable nodes, allocate a new one */
			node = node_create(cache, flags, cpuflags);
			if (!node) {
				slab_debug("kernel OOM\n");
				return nil;
			}
		}
	}

	/* if we've made it to here, we have a cache node and interrupts are disabled */
	void *ret = pop_freelist_and_insert(cache, node);
	intr_restore(cpuflags);

	return ret;
}

void *kmalloc(usize size, enum mflags flags)
{
	if (size == 0)
		return nil;

#if CFG_POISON_SLABS
	size += sizeof(struct slab_poison);
#endif

	SLAB_ASSERT(_M_ZONE_INDEX(flags) < ARRAY_SIZE(slab_zone_pools));
	struct kmem_cache *cache = kmem_cache_zones[_M_ZONE_INDEX(flags)];
	while (cache->object_size != 0) {
		if (cache->object_size >= size)
			break;
		cache++;
	}

	if (cache->object_size == 0) {
		slab_debug("Refusing to allocate %zu bytes in zone %d (limit is %u)\n",
			   size, _M_ZONE_INDEX(flags), cache[-1].object_size);
		return nil;
	}

	void *ret = kmem_cache_alloc(cache, flags);

#if CFG_POISON_SLABS
	if (ret) {
		struct slab_poison *poison = ret;
		poison_on_alloc(poison, size - sizeof(*poison), ktrace_return_addr());
		ret = poison->data;
	}
#endif
	return ret;
}

void kfree(void *ptr)
{
	if (ptr == nil)
		return;

	SLAB_ASSERT(ptr >= DMAP_START && ptr < DMAP_END);

	vm_page_t page = vaddr2pg(ptr);
	SLAB_ASSERT(pga_slab(page));
	struct kmem_cache_node *node = page_slab(page);
	struct kmem_cache *cache = node->cache;
#if CFG_POISON_SLABS
	if (cache->flags & SLAB_POISON) {
		struct slab_poison *poison = container_of(ptr, typeof(*poison), data);
		poison_on_free(poison);
		ptr = poison;
	}
#endif

	register_t cpuflags = intr_disable();

	spin_lock(&node->lock);
	*(void **)ptr = node->freelist;
	SLAB(page)->freelist = (void **)ptr;
	u_int free_count = ++node->free_count;
	spin_unlock(&node->lock);

	if (free_count == cache->slabs_per_node) {
		pool_del_node(&cache->partial, node);
		pool_add_node(&cache->empty, node);
	}

	latom_dec(&cache->total_used);
	intr_restore(cpuflags);
}

#if CFG_POISON_SLABS

static inline void poison_on_alloc(struct slab_poison *poison, u_long exact_size,
				   void *alloc_source)
{
	u_long offset = align_ceil(poison->exact_size, sizeof(long)) / sizeof(long);
	u_long *poison_start = &poison->low_poison;

	/*
	 * page_alloc() always initializes the allocated page to zeroes.
	 * Therefore, if exact_size is 0, we know this particular slab entry has
	 * never been used before, and we can skip the check.
	 */
	if (poison->exact_size != 0) {
		for (u_long *pos = poison_start; pos < &poison->high_poison[offset]; pos++) {
			if (*pos != SLAB_POISON_FREE) {
				kprintf("Use-after-free in %p (alloc by %p)\n",
					poison->data, poison->alloc_source);
				break;
			}
		}
	}

	/* update offset to the new size */
	offset = align_ceil(exact_size, sizeof(long)) / sizeof(long);

	poison->alloc_source = alloc_source;
	poison->exact_size = exact_size;
	for (u_long *pos = &poison->low_poison; pos <= &poison->high_poison[offset]; pos++)
		*pos = SLAB_POISON_ALLOC;
}

static inline void poison_on_free(struct slab_poison *poison)
{
	u_long offset = align_ceil(poison->exact_size, sizeof(long)) / sizeof(long);

	if (poison->low_poison != SLAB_POISON_ALLOC) {
		kprintf("Low out-of-bounds write to %p (alloc by %p)\n",
			poison->data, poison->alloc_source);
	}

	if (poison->high_poison[offset] != SLAB_POISON_ALLOC) {
		kprintf("High out-of-bounds write to %p (alloc by %p)\n",
			poison->data, poison->alloc_source);
	}

	for (u_long *pos = &poison->low_poison; pos <= &poison->high_poison[offset]; pos++)
		*pos = SLAB_POISON_FREE;
}

#endif /* CFG_POISON_SLABS */

/*
 * for certain libc routines
 */

__weak void *malloc(usize size)
{
	return kmalloc(size, M_KERN);
}

__weak void free(void *ptr)
{
	kfree(ptr);
}