increasing_heap_allocator/
allocator.rs

1use core::mem;
2
3use crate::{is_aligned, HeapStats, PageAllocatorProvider};
4
5use super::align_up;
6
7const HEAP_MAGIC: u32 = 0xF0B0CAFE;
8
9#[repr(C, align(16))]
10struct AllocatedHeapBlockInfo {
11    magic: u32,
12    size: usize,
13    pre_padding: usize,
14}
15
16const KERNEL_HEAP_BLOCK_INFO_SIZE: usize = mem::size_of::<AllocatedHeapBlockInfo>();
17
18#[derive(Debug)]
19struct HeapFreeBlock {
20    prev: *mut HeapFreeBlock,
21    next: *mut HeapFreeBlock,
22    // including this header
23    size: usize,
24}
25
26pub struct HeapAllocator<const PAGE_SIZE: usize, T: PageAllocatorProvider<PAGE_SIZE>> {
27    heap_start: usize,
28    total_heap_size: usize,
29    free_list_addr: *mut HeapFreeBlock,
30    free_size: usize,
31    used_size: usize,
32    page_allocator: T,
33}
34
35unsafe impl<const PAGE_SIZE: usize, T: PageAllocatorProvider<PAGE_SIZE>> Send
36    for HeapAllocator<PAGE_SIZE, T>
37{
38}
39
40impl<const PAGE_SIZE: usize, T> HeapAllocator<PAGE_SIZE, T>
41where
42    T: PageAllocatorProvider<PAGE_SIZE>,
43{
44    fn is_free_blocks_in_cycle(&self) -> bool {
45        // use floyd algorithm to detect if we are in cycle
46        let mut slow = self.free_list_addr;
47        let mut fast = self.free_list_addr;
48
49        // advance fast first
50        if fast.is_null() {
51            return false;
52        } else {
53            fast = unsafe { (*fast).next };
54        }
55
56        while fast != slow {
57            if fast.is_null() {
58                return false;
59            } else {
60                fast = unsafe { (*fast).next };
61            }
62            if fast.is_null() {
63                return false;
64            } else {
65                fast = unsafe { (*fast).next };
66            }
67
68            if slow.is_null() {
69                return false;
70            } else {
71                slow = unsafe { (*slow).next };
72            }
73        }
74
75        true
76    }
77
78    fn check_free_blocks(&self) -> bool {
79        let mut forward_count = 0;
80        let mut last: *mut HeapFreeBlock = core::ptr::null_mut();
81        for block in self.iter_free_blocks() {
82            forward_count += 1;
83            last = block as _;
84        }
85
86        let mut backward_count = 0;
87        if !last.is_null() {
88            // go back to the first block
89            while !last.is_null() {
90                backward_count += 1;
91                last = unsafe { (*last).prev };
92            }
93        }
94
95        forward_count != backward_count
96    }
97
98    fn check_issues(&self) -> bool {
99        self.is_free_blocks_in_cycle() || self.check_free_blocks()
100    }
101
102    fn get_free_block(&mut self, size: usize) -> *mut HeapFreeBlock {
103        if self.total_heap_size == 0 {
104            let size = align_up(size, PAGE_SIZE);
105            self.allocate_more_pages(size / PAGE_SIZE);
106            // call recursively
107            return self.get_free_block(size);
108        }
109        // find best block
110        let mut best_block: *mut HeapFreeBlock = core::ptr::null_mut();
111        for block in self.iter_free_blocks() {
112            if block.size >= size
113                && (best_block.is_null() || block.size < unsafe { (*best_block).size })
114            {
115                best_block = block as _;
116            }
117        }
118
119        if best_block.is_null() {
120            // no block found, allocate more pages
121            let size = align_up(size, PAGE_SIZE);
122            self.allocate_more_pages(size / PAGE_SIZE);
123            // call recursively
124            return self.get_free_block(size);
125        }
126
127        best_block
128    }
129
130    fn iter_free_blocks(&self) -> impl Iterator<Item = &mut HeapFreeBlock> {
131        let mut current_block = self.free_list_addr;
132        core::iter::from_fn(move || {
133            if current_block.is_null() {
134                None
135            } else {
136                let block = current_block;
137                current_block = unsafe { (*current_block).next };
138                Some(unsafe { &mut *block })
139            }
140        })
141    }
142
143    /// Allocates more pages and add them to the free list
144    fn allocate_more_pages(&mut self, pages: usize) {
145        assert!(pages > 0);
146
147        let new_heap_start = if self.total_heap_size == 0 {
148            // first allocation
149            self.heap_start = self.page_allocator.allocate_pages(pages).unwrap() as usize;
150            self.heap_start
151        } else {
152            // allocate more pages
153            self.page_allocator.allocate_pages(pages).unwrap() as usize
154        };
155
156        self.total_heap_size += pages * PAGE_SIZE;
157
158        // add to the free list (fast path)
159        if self.free_list_addr.is_null() {
160            // no free list for now, add this as the very first free entry
161            let free_block = new_heap_start as *mut HeapFreeBlock;
162
163            unsafe {
164                (*free_block).prev = core::ptr::null_mut();
165                (*free_block).next = core::ptr::null_mut();
166                (*free_block).size = pages * PAGE_SIZE;
167            }
168
169            self.free_list_addr = free_block;
170        } else {
171            unsafe {
172                self.free_block(new_heap_start as _, pages * PAGE_SIZE);
173            }
174        }
175        self.free_size += pages * PAGE_SIZE;
176    }
177
178    unsafe fn free_block(&mut self, freeing_block: usize, size: usize) {
179        assert!(freeing_block <= self.heap_start + self.total_heap_size);
180        assert!(freeing_block + size <= self.heap_start + self.total_heap_size);
181
182        let freeing_block = freeing_block as *mut HeapFreeBlock;
183        let freeing_block_start = freeing_block as usize;
184        let freeing_block_end = freeing_block_start + size;
185
186        // find blocks that are either before or after this block
187        let mut prev_block: *mut HeapFreeBlock = core::ptr::null_mut();
188        let mut next_block: *mut HeapFreeBlock = core::ptr::null_mut();
189        let mut closest_prev_block: *mut HeapFreeBlock = core::ptr::null_mut();
190        for block in self.iter_free_blocks() {
191            let block_addr = block as *mut _ as usize;
192            let block_end = block_addr + block.size;
193
194            if block_addr == freeing_block_start {
195                // our block should not be in the free list
196                panic!("double free");
197            }
198
199            // assert that we are not in the middle of a block
200            assert!(
201                (freeing_block_end <= block_addr) || (freeing_block_start >= block_end),
202                "Free block at {freeing_block_start:x}..{freeing_block_end:x} is in the middle of another block at {block_addr:x}..{block_end:x}",
203            );
204
205            if block_end == freeing_block_start {
206                // this block is before the freeing block
207                prev_block = block as _;
208            } else if freeing_block_end == block_addr {
209                // this block is after the freeing block
210                next_block = block as _;
211            }
212
213            if block_addr < freeing_block_start {
214                // this block is before the freeing block
215                if closest_prev_block.is_null() || block_addr > (closest_prev_block as usize) {
216                    closest_prev_block = block as _;
217                }
218            }
219        }
220
221        if !prev_block.is_null() && !next_block.is_null() {
222            let new_block = prev_block;
223            // both are not null, so we are in the middle
224            // merge the blocks
225            (*new_block).size += size + (*next_block).size;
226
227            // update the previous block to point to this new subblock instead
228            if !(*next_block).next.is_null() {
229                (*(*next_block).next).prev = new_block;
230            }
231
232            if !(*next_block).prev.is_null() {
233                (*(*next_block).prev).next = new_block;
234            } else {
235                // this is the first block
236                self.free_list_addr = new_block;
237            }
238
239            (*new_block).next = (*next_block).next;
240        } else if !prev_block.is_null() {
241            // no blocks after this
242            // merge the blocks easily, we only need to change the size
243            (*prev_block).size += size;
244        } else if !next_block.is_null() {
245            let new_block = freeing_block;
246
247            // replace next with a new size
248            (*new_block).size = (*next_block).size + size;
249            (*new_block).prev = (*next_block).prev;
250            (*new_block).next = (*next_block).next;
251
252            // update references
253            // update the next block to point to this new subblock instead
254            if !(*next_block).next.is_null() {
255                (*(*next_block).next).prev = new_block;
256            }
257            // update the previous block to point to this new subblock instead
258            if !(*next_block).prev.is_null() {
259                (*(*next_block).prev).next = new_block;
260            } else {
261                // this is the first block
262                self.free_list_addr = new_block;
263            }
264        } else {
265            // no blocks around this
266            // add this to the free list in the correct order
267            if closest_prev_block.is_null() {
268                // this is the first block
269                (*freeing_block).prev = core::ptr::null_mut();
270                (*freeing_block).next = self.free_list_addr;
271                (*freeing_block).size = size;
272
273                // update the next block to point to this new subblock instead
274                if !(*freeing_block).next.is_null() {
275                    (*(*freeing_block).next).prev = freeing_block;
276                }
277
278                self.free_list_addr = freeing_block;
279            } else {
280                // put this after the closest previous block
281                let closest_next_block = (*closest_prev_block).next;
282                (*freeing_block).prev = closest_prev_block;
283                (*freeing_block).next = closest_next_block;
284                (*freeing_block).size = size;
285
286                (*closest_prev_block).next = freeing_block;
287                if !closest_next_block.is_null() {
288                    (*closest_next_block).prev = freeing_block;
289                }
290            }
291        }
292    }
293}
294
295// public interface
296impl<const PAGE_SIZE: usize, T> HeapAllocator<PAGE_SIZE, T>
297where
298    T: PageAllocatorProvider<PAGE_SIZE>,
299{
300    pub fn new(page_allocator: T) -> Self {
301        Self {
302            heap_start: 0,
303            free_list_addr: core::ptr::null_mut(),
304            total_heap_size: 0,
305            free_size: 0,
306            used_size: 0,
307            page_allocator,
308        }
309    }
310
311    pub fn stats(&self) -> HeapStats {
312        HeapStats {
313            allocated: self.used_size,
314            free_size: self.free_size,
315            heap_size: self.total_heap_size,
316        }
317    }
318
319    pub fn debug_free_blocks(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
320        self.iter_free_blocks()
321            .map(|block| (block as *mut _ as usize, block.size))
322    }
323
324    /// # Safety
325    /// Check [`core::alloc::GlobalAlloc::alloc`] for more info
326    pub unsafe fn alloc(&mut self, layout: core::alloc::Layout) -> *mut u8 {
327        // info header
328        let block_info_layout = core::alloc::Layout::new::<AllocatedHeapBlockInfo>();
329
330        // use minimum alignment AllocatedHeapBlockInfo
331        // whole_layout here is the layout of the requested block + the info header
332        // whole_block_offset is the offset of the block after the info header
333        let (whole_layout, block_offset_from_header) = block_info_layout
334            .extend(layout.align_to(block_info_layout.align()).unwrap())
335            .unwrap();
336        // at least align to AllocatedHeapBlockInfo (see above)
337        // `allocation_size` is the size of the block we are going to allocate as a whole
338        // this block include the info header and the requested block and maybe some padding
339        let mut allocation_size = whole_layout.pad_to_align().size();
340
341        let free_block = self.get_free_block(allocation_size);
342
343        if free_block.is_null() {
344            return core::ptr::null_mut();
345        }
346
347        // work on the pointer and add the info of the block before it, and handle alignment
348        // so, we can use it to deallocate later
349        let base = free_block as usize;
350        // this should never fail, we are allocating of `block_info_layout.align()` alignment always
351        assert!(is_aligned(base, block_info_layout.align()));
352        let possible_next_offset = align_up(base, layout.align()) - base;
353        let allocated_block_offset = if possible_next_offset < KERNEL_HEAP_BLOCK_INFO_SIZE {
354            // if we can't fit the info header, we need to add to the offset
355            possible_next_offset + KERNEL_HEAP_BLOCK_INFO_SIZE.max(layout.align())
356        } else {
357            possible_next_offset
358        };
359        assert!(allocated_block_offset >= KERNEL_HEAP_BLOCK_INFO_SIZE);
360        if allocated_block_offset > block_offset_from_header {
361            // we can exceed the calculated block sizes from the layout above, if that happens
362            // we must increase the allocation size to account for that
363            // this can happen when the alignment of the requested block is more than the info block
364            //
365            // example:
366            //   requested layout: size=512, align=64
367            //   info layout: size=32, align=16
368            //   the above calculation `block_offset_from_header` will be 64
369            //   the allocator, i.e. `free_block` will always be aligned to 16 (the info block)
370            //   then, if the `possible_next_offset` happens to be 16, i.e. we are 48 bytes into a 64 bytes block
371            //
372            //       [ 16 bytes ][ 16 bytes ][ 16 bytes ][ 16 bytes ]
373            //       ^ <64 byte alignment>               ^ free_block
374            //
375            //   since 16 is less than 32, we need to add more offset, but `layout.size()` is 64. So we are going to
376            //   add 80 (64 + 16) as the `allocated_block_offset`, but that already exceed `64`.
377            //   the `allocation_size` before this fix would have been 512+64=576,
378            //   but the actual size we need 512+80=592. That's why we need this fix.
379            //   (as you might have expected, these numbers are from an actual bug I found and debugged -_-)
380            allocation_size += allocated_block_offset - block_offset_from_header;
381        }
382        let allocated_ptr = (free_block as *mut u8).add(allocated_block_offset);
383        let allocated_block_info =
384            allocated_ptr.sub(KERNEL_HEAP_BLOCK_INFO_SIZE) as *mut AllocatedHeapBlockInfo;
385
386        let free_block_size = (*free_block).size;
387        // for now, we hope we get enough size
388        // FIXME: get a new block if this is not enough
389        assert!(free_block_size >= allocation_size);
390        let free_block_end = free_block as usize + allocation_size;
391        let new_free_block = free_block_end as *mut HeapFreeBlock;
392
393        // we have to make sure that the block after us has enough space to write the metadata,
394        // and we won't corrupt the block that comes after (if there is any)
395        let required_safe_size = allocation_size + mem::size_of::<HeapFreeBlock>();
396
397        // store the actual size of the block
398        // if we needed to extend (since the next free block is to small)
399        // this will include the whole size and not just the size that
400        // we were asked to allocate
401        let mut this_allocation_size = allocation_size;
402
403        // do we have empty space left?
404        if free_block_size > required_safe_size {
405            // update the previous block to point to this new subblock instead
406            (*new_free_block).prev = (*free_block).prev;
407            (*new_free_block).next = (*free_block).next;
408            (*new_free_block).size = free_block_size - allocation_size;
409
410            // update the next block to point to this new subblock instead
411            if !(*new_free_block).next.is_null() {
412                (*(*new_free_block).next).prev = new_free_block;
413            }
414
415            // update the previous block to point to this new subblock instead
416            if !(*new_free_block).prev.is_null() {
417                (*(*new_free_block).prev).next = new_free_block;
418            } else {
419                // this is the first block
420                self.free_list_addr = new_free_block;
421            }
422        } else {
423            // exact size
424            this_allocation_size = free_block_size;
425
426            // update the previous block to point to the next block instead
427            if !(*free_block).prev.is_null() {
428                (*(*free_block).prev).next = (*free_block).next;
429            } else {
430                // this is the first block
431                self.free_list_addr = (*free_block).next;
432            }
433            if !(*free_block).next.is_null() {
434                (*(*free_block).next).prev = (*free_block).prev;
435            }
436        }
437        self.free_size -= this_allocation_size;
438        self.used_size += this_allocation_size;
439
440        // TODO: add flag to control when to enable this runtime checking
441        if self.check_issues() {
442            panic!("Found issues in `alloc`");
443        }
444
445        // make sure we are aligned
446        assert!(is_aligned(allocated_ptr as _, layout.align()),
447            "base_block={allocated_block_info:p}, offset={allocated_block_offset}, ptr={allocated_ptr:?}, layout={layout:?}, should_be_addr={:x}",
448            align_up(allocated_block_info as usize, layout.align()));
449
450        // write the info header
451        (*allocated_block_info).magic = HEAP_MAGIC;
452        (*allocated_block_info).size = this_allocation_size;
453        (*allocated_block_info).pre_padding = allocated_block_offset;
454
455        allocated_ptr
456    }
457
458    /// # Safety
459    /// Check [`core::alloc::GlobalAlloc::dealloc`] for more info
460    pub unsafe fn dealloc(&mut self, ptr: *mut u8, layout: core::alloc::Layout) {
461        assert!(!ptr.is_null());
462
463        // info header
464        let base_layout = core::alloc::Layout::new::<AllocatedHeapBlockInfo>();
465
466        let (whole_layout, _) = base_layout.extend(layout.align_to(16).unwrap()).unwrap();
467        let size_to_free_from_layout = whole_layout.pad_to_align().size();
468
469        let allocated_block_info =
470            ptr.sub(KERNEL_HEAP_BLOCK_INFO_SIZE) as *mut AllocatedHeapBlockInfo;
471
472        assert_eq!((*allocated_block_info).magic, HEAP_MAGIC);
473        // This could be more than the layout size, because
474        // we might increase the size of the block a bit to not leave
475        // free blocks that are too small (see `alloc``)
476        assert!((*allocated_block_info).size >= size_to_free_from_layout);
477        assert!((*allocated_block_info).pre_padding >= KERNEL_HEAP_BLOCK_INFO_SIZE);
478        let this_allocation_size = (*allocated_block_info).size;
479
480        let freeing_block = ptr.sub((*allocated_block_info).pre_padding) as usize;
481
482        self.free_block(freeing_block, this_allocation_size);
483        self.used_size -= this_allocation_size;
484        self.free_size += this_allocation_size;
485
486        // TODO: add flag to control when to enable this runtime checking
487        if self.check_issues() {
488            panic!("Found issues in `dealloc`");
489        }
490    }
491}