kernel/sync/spin/
rwlock.rs

1use core::{
2    cell::UnsafeCell,
3    fmt,
4    marker::PhantomData,
5    ops::{Deref, DerefMut},
6    ptr::NonNull,
7    sync::atomic::{AtomicI64, Ordering},
8};
9
10use crate::cpu;
11
12use super::lock;
13
14pub struct RwLock<T: ?Sized> {
15    lock: lock::Lock,
16    owner_cpu: AtomicI64,
17    data: UnsafeCell<T>,
18}
19
20unsafe impl<T: ?Sized + Send> Send for RwLock<T> {}
21unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {}
22
23impl<T> fmt::Debug for RwLock<T>
24where
25    T: fmt::Debug,
26{
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        let mut s = f.debug_struct("RwLock");
29        s.field("owner_cpu", &self.owner_cpu);
30        if let Some(data) = self.try_read() {
31            s.field("data", &data);
32        } else {
33            s.field("data", &"[write locked]");
34        }
35        s.finish()
36    }
37}
38
39pub struct RwLockReadGuard<'a, T: ?Sized + 'a> {
40    // NB: we use a pointer instead of `&'a T` to avoid `noalias` violations, because a
41    // `Ref` argument doesn't hold immutability for its whole scope, only until it drops.
42    // `NonNull` is also covariant over `T`, just like we would have with `&T`. `NonNull`
43    // is preferable over `const* T` to allow for niche optimization.
44    data: NonNull<T>,
45    inner_lock: &'a lock::Lock,
46    marker: PhantomData<*const ()>, // !Send
47}
48
49unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, T> {}
50
51pub struct RwLockWriteGuard<'a, T: ?Sized + 'a> {
52    lock: &'a RwLock<T>,
53    marker: PhantomData<*const ()>, // !Send
54}
55
56unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, T> {}
57
58impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        (**self).fmt(f)
61    }
62}
63
64impl<T: ?Sized + fmt::Display> fmt::Display for RwLockReadGuard<'_, T> {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        (**self).fmt(f)
67    }
68}
69
70impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
71    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
72        (**self).fmt(f)
73    }
74}
75
76impl<T: ?Sized + fmt::Display> fmt::Display for RwLockWriteGuard<'_, T> {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        (**self).fmt(f)
79    }
80}
81
82#[allow(dead_code)]
83impl<T> RwLock<T> {
84    pub const fn new(data: T) -> Self {
85        Self {
86            lock: lock::Lock::new(),
87            owner_cpu: AtomicI64::new(-1),
88            data: UnsafeCell::new(data),
89        }
90    }
91}
92
93#[allow(dead_code)]
94impl<T: ?Sized> RwLock<T> {
95    pub fn read(&self) -> RwLockReadGuard<'_, T> {
96        self.lock.read_lock();
97        // must be -1, i.e. no owner
98        self.owner_cpu.store(-1, Ordering::Relaxed);
99        RwLockReadGuard {
100            data: unsafe { NonNull::new_unchecked(self.data.get()) },
101            inner_lock: &self.lock,
102            marker: PhantomData,
103        }
104    }
105
106    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
107        if self.lock.try_read_lock() {
108            // must be -1, i.e. no owner
109            self.owner_cpu.store(-1, Ordering::Relaxed);
110            Some(RwLockReadGuard {
111                data: unsafe { NonNull::new_unchecked(self.data.get()) },
112                inner_lock: &self.lock,
113                marker: PhantomData,
114            })
115        } else {
116            None
117        }
118    }
119
120    pub fn write(&self) -> RwLockWriteGuard<'_, T> {
121        let cpu = cpu::cpu();
122        cpu.push_cli(); // disable interrupts to avoid deadlock
123        let cpu_id = cpu.id as i64;
124
125        if self.owner_cpu.load(Ordering::Relaxed) == cpu_id {
126            panic!("Mutex already locked by this CPU");
127        } else {
128            self.lock.write_lock();
129            self.owner_cpu.store(cpu_id, Ordering::Relaxed);
130            RwLockWriteGuard {
131                lock: self,
132                marker: PhantomData,
133            }
134        }
135    }
136
137    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
138        let cpu = cpu::cpu();
139        cpu.push_cli(); // disable interrupts to avoid deadlock
140        let cpu_id = cpu.id as i64;
141
142        if self.owner_cpu.load(Ordering::Relaxed) == cpu_id {
143            // we will not throw here, since the CPU might want to try to lock it again, at least its not a deadlock
144            cpu.pop_cli();
145            None
146        } else if self.lock.try_write_lock() {
147            self.owner_cpu.store(cpu_id, Ordering::Relaxed);
148            Some(RwLockWriteGuard {
149                lock: self,
150                marker: PhantomData,
151            })
152        } else {
153            cpu.pop_cli();
154            None
155        }
156    }
157
158    /// We know statically that no one else is accessing the lock, so we can
159    /// just return a reference to the data without acquiring the lock.
160    #[allow(dead_code)]
161    pub fn get_mut(&mut self) -> &mut T {
162        self.data.get_mut()
163    }
164}
165
166impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
167    type Target = T;
168
169    fn deref(&self) -> &T {
170        // SAFETY: the mutex is locked, we may not be the only accessors, but we know,
171        //         that no one will change the value, thus we can get multiple references at the same time
172        unsafe { self.data.as_ref() }
173    }
174}
175
176impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
177    type Target = T;
178
179    fn deref(&self) -> &T {
180        // SAFETY: the mutex is locked, we are the only accessors,
181        //         and the pointer is valid, since it was generated for a valid T
182        unsafe { &*self.lock.data.get() }
183    }
184}
185
186impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
187    fn deref_mut(&mut self) -> &mut T {
188        // SAFETY: the mutex is locked, we are the only accessors,
189        //         and the pointer is valid, since it was generated for a valid T
190        unsafe { &mut *self.lock.data.get() }
191    }
192}
193
194impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
195    fn drop(&mut self) {
196        // SAFETY: the mutex is locked, we are the only accessor
197        unsafe { self.inner_lock.read_unlock() };
198    }
199}
200
201impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
202    fn drop(&mut self) {
203        assert_ne!(self.lock.owner_cpu.load(Ordering::Relaxed), -1);
204        self.lock.owner_cpu.store(-1, Ordering::Relaxed);
205        // SAFETY: the mutex is locked, we are the only accessor
206        unsafe { self.lock.lock.write_unlock() };
207        cpu::cpu().pop_cli(); // re-enable interrupts
208    }
209}