/* Copyright (C) 2021,2022 fef <owo@fef.moe>.  All rights reserved. */

#include <gay/clist.h>
#include <gay/linker.h>
#include <gay/mm.h>
#include <gay/systm.h>
#include <gay/util.h>

#include <limits.h>

static struct _bmem_area _bmem_area_cache[128];
static CLIST(bmem_area_freelist);

#ifdef DEBUG
#define debug_free_bmem_area(area) ({ (area)->start = ~(vm_paddr_t)0; })
#define debug_get_bmem_area(area) KASSERT((area)->start == ~(vm_paddr_t)0)
#else
#define debug_free_bmem_area(area) ({})
#define debug_get_bmem_area(area) ({})
#endif

static struct _bmem_area *get_bmem_area(void)
{
	/* XXX this should pretty much never happen, but it would still be nice to
	 *     have at least some sort of error recovery rather than giving up */
	if (clist_is_empty(&bmem_area_freelist))
		panic("Boot memory allocator has run out of areas");

	struct _bmem_area *area = clist_del_first_entry(&bmem_area_freelist, typeof(*area), link);
	debug_get_bmem_area(area);
	return area;
}

static void free_bmem_area(struct _bmem_area *area)
{
	debug_free_bmem_area(area);
	clist_add(&bmem_area_freelist, &area->link);
}

/* insert an area when we already know there are no intersections with reserved memory */
static void insert_area_unsafe(vm_paddr_t start, vm_paddr_t end, enum mm_zone_type zone_type)
{
	KASSERT((start % PAGE_SIZE) == 0);
	KASSERT((end % PAGE_SIZE) == 0);

	struct _bmem_area *area = get_bmem_area();
	area->start = start;
	area->end = end;

	struct mm_zone *zone = &mm_zones[zone_type];
	struct _bmem_area *cursor;
	clist_foreach_entry(&zone->_bmem_areas, cursor, link) {
		if (cursor->start > area->start)
			break;
	}
	clist_insert_before(&cursor->link, &area->link);
}

void __boot_pmalloc_init(void)
{
	for (int i = 0; i < ARRAY_SIZE(_bmem_area_cache); i++) {
		struct _bmem_area *area = &_bmem_area_cache[i];
		debug_free_bmem_area(area);
		clist_add(&bmem_area_freelist, &area->link);
	}

	for (int i = 0; i < MM_NR_ZONES; i++)
		clist_init(&mm_zones[i]._bmem_areas);
}

void __boot_register_mem_area(vm_paddr_t start, vm_paddr_t end, enum mm_zone_type zone_type)
{
	KASSERT(start < end);

	start = align_ceil(start, PAGE_SIZE);
	end = align_floor(end, PAGE_SIZE);
	if (start == end)
		return;

	/* check for any overlaps with the kernel image and avoid those regions */
	if (start <= image_start_phys && end >= image_end_phys) {
		/*
		 * 0x8000 ---------------------- end (-> high_end)
		 * 0x7000   <free real estate>
		 * 0x6000 ---------------------- image_end_phys (-> high_start)
		 * 0x5000  <kernel code & data>
		 * 0x4000 ---------------------- image_start_phys (-> low_end)
		 * 0x3000   <free real estate>
		 * 0x2000 ---------------------- start (-> low_start)
		 */
		vm_paddr_t low_start = start;
		vm_paddr_t low_end = align_floor(image_start_phys, PAGE_SIZE);
		if (low_start < low_end)
			insert_area_unsafe(low_start, low_end, zone_type);

		vm_paddr_t high_start = align_ceil(image_end_phys, PAGE_SIZE);
		vm_paddr_t high_end = end;
		if (high_start < high_end)
			insert_area_unsafe(high_start, high_end, zone_type);
	} else if (start >= image_start_phys && start <= image_end_phys) {
		/*
		 * 0x8000 ---------------------- end (-> high_end)
		 * 0x7000   <free real estate>
		 * 0x6000 ---------------------- image_end_phys (-> high_start)
		 * 0x5000  <kernel code & data>
		 * 0x4000 ---------------------- start
		 * 0x3000   <not part of area>
		 * 0x2000 ---------------------- image_start_phys
		 */
		vm_paddr_t high_start = align_ceil(image_end_phys, PAGE_SIZE);
		vm_paddr_t high_end = end;
		if (high_start < high_end)
			insert_area_unsafe(high_start, high_end, zone_type);
	} else if (end >= image_start_phys && end <= image_end_phys) {
		/*
		 * 0x8000 ---------------------- image_end_phys
		 * 0x7000   <not part of area>
		 * 0x6000 ---------------------- end
		 * 0x5000  <kernel code & data>
		 * 0x4000 ---------------------- image_start_phys (-> low_end)
		 * 0x3000   <free real estate>
		 * 0x2000 ---------------------- start (-> low_start)
		 */
		vm_paddr_t low_start = start;
		vm_paddr_t low_end = align_floor(image_start_phys, PAGE_SIZE);
		if (low_start < low_end)
			insert_area_unsafe(low_start, low_end, zone_type);
	} else {
		insert_area_unsafe(start, end, zone_type);
	}
}

vm_paddr_t __boot_pmalloc(u_int log2, enum mm_zone_type zone_type)
{
	/* never hand out less than a full page */
	KASSERT(log2 >= PAGE_SHIFT);
	/* this might fail if someone accidentally gives us a size rather than shift */
	KASSERT(log2 < sizeof(vm_paddr_t) * CHAR_BIT);

	const vm_size_t alloc_size = (vm_size_t)1 << log2;
	struct mm_zone *zone = &mm_zones[zone_type];

	struct _bmem_area *cursor;
	clist_foreach_entry_rev(&zone->_bmem_areas, cursor, link) {
		vm_paddr_t area_start = cursor->start;
		vm_paddr_t area_end = cursor->end;
		KASSERT(area_start < area_end);

		/* XXX we should really use a best-fit algorithm for this */
		vm_paddr_t alloc_start = align_ceil(area_start, alloc_size);
		vm_paddr_t alloc_end = alloc_start + alloc_size;

		if (alloc_start >= area_start && alloc_end <= area_end) {
			/*
			 * Example with log2 == 18 (alloc_size == 0x4000):
			 *
			 * 0x8000 ------------------- area_end
			 * 0x7000     <high_rest>
			 * 0x8000 ------------------- alloc_end (aligned to 0x4000)
			 *   :    <allocated block>
			 * 0x4000 ------------------- alloc_start (aligned to 0x4000)
			 * 0x3000     <low_rest>
			 * 0x2000 ------------------- area_start
			 */

			if (alloc_start > area_start) {
				struct _bmem_area *low_rest = get_bmem_area();
				low_rest->start = area_start;
				low_rest->end = alloc_start;
				clist_insert_before(&cursor->link, &low_rest->link);
			}

			if (alloc_end < area_end) {
				struct _bmem_area *high_rest = get_bmem_area();
				high_rest->start = alloc_end;
				high_rest->end = area_end;
				clist_insert_after(&cursor->link, &high_rest->link);
			}

			clist_del(&cursor->link);
			free_bmem_area(cursor);
			return alloc_start;
		}
	}

	return BOOT_PMALLOC_ERR;
}