EventCount.h
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H
11 #define EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H
12 
13 #include "./InternalHeaderCheck.h"
14 
15 namespace Eigen {
16 
17 // EventCount allows to wait for arbitrary predicates in non-blocking
18 // algorithms. Think of condition variable, but wait predicate does not need to
19 // be protected by a mutex. Usage:
20 // Waiting thread does:
21 //
22 // if (predicate)
23 // return act();
24 // EventCount::Waiter& w = waiters[my_index];
25 // ec.Prewait(&w);
26 // if (predicate) {
27 // ec.CancelWait(&w);
28 // return act();
29 // }
30 // ec.CommitWait(&w);
31 //
32 // Notifying thread does:
33 //
34 // predicate = true;
35 // ec.Notify(true);
36 //
37 // Notify is cheap if there are no waiting threads. Prewait/CommitWait are not
38 // cheap, but they are executed only if the preceding predicate check has
39 // failed.
40 //
41 // Algorithm outline:
42 // There are two main variables: predicate (managed by user) and state_.
43 // Operation closely resembles Dekker mutual algorithm:
44 // https://en.wikipedia.org/wiki/Dekker%27s_algorithm
45 // Waiting thread sets state_ then checks predicate, Notifying thread sets
46 // predicate then checks state_. Due to seq_cst fences in between these
47 // operations it is guaranteed than either waiter will see predicate change
48 // and won't block, or notifying thread will see state_ change and will unblock
49 // the waiter, or both. But it can't happen that both threads don't see each
50 // other changes, which would lead to deadlock.
51 class EventCount {
52  public:
53  class Waiter;
54 
56  : state_(kStackMask), waiters_(waiters) {
57  eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
58  }
59 
61  // Ensure there are no waiters.
63  }
64 
65  // Prewait prepares for waiting.
66  // After calling Prewait, the thread must re-check the wait predicate
67  // and then call either CancelWait or CommitWait.
68  void Prewait() {
69  uint64_t state = state_.load(std::memory_order_relaxed);
70  for (;;) {
71  CheckState(state);
72  uint64_t newstate = state + kWaiterInc;
73  CheckState(newstate);
74  if (state_.compare_exchange_weak(state, newstate,
75  std::memory_order_seq_cst))
76  return;
77  }
78  }
79 
80  // CommitWait commits waiting after Prewait.
81  void CommitWait(Waiter* w) {
82  eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
83  w->state = Waiter::kNotSignaled;
84  const uint64_t me = (w - &waiters_[0]) | w->epoch;
85  uint64_t state = state_.load(std::memory_order_seq_cst);
86  for (;;) {
87  CheckState(state, true);
88  uint64_t newstate;
89  if ((state & kSignalMask) != 0) {
90  // Consume the signal and return immediately.
91  newstate = state - kWaiterInc - kSignalInc;
92  } else {
93  // Remove this thread from pre-wait counter and add to the waiter stack.
94  newstate = ((state & kWaiterMask) - kWaiterInc) | me;
95  w->next.store(state & (kStackMask | kEpochMask),
96  std::memory_order_relaxed);
97  }
98  CheckState(newstate);
99  if (state_.compare_exchange_weak(state, newstate,
100  std::memory_order_acq_rel)) {
101  if ((state & kSignalMask) == 0) {
102  w->epoch += kEpochInc;
103  Park(w);
104  }
105  return;
106  }
107  }
108  }
109 
110  // CancelWait cancels effects of the previous Prewait call.
111  void CancelWait() {
112  uint64_t state = state_.load(std::memory_order_relaxed);
113  for (;;) {
114  CheckState(state, true);
115  uint64_t newstate = state - kWaiterInc;
116  // We don't know if the thread was also notified or not,
117  // so we should not consume a signal unconditionally.
118  // Only if number of waiters is equal to number of signals,
119  // we know that the thread was notified and we must take away the signal.
120  if (((state & kWaiterMask) >> kWaiterShift) ==
121  ((state & kSignalMask) >> kSignalShift))
122  newstate -= kSignalInc;
123  CheckState(newstate);
124  if (state_.compare_exchange_weak(state, newstate,
125  std::memory_order_acq_rel))
126  return;
127  }
128  }
129 
130  // Notify wakes one or all waiting threads.
131  // Must be called after changing the associated wait predicate.
132  void Notify(bool notifyAll) {
133  std::atomic_thread_fence(std::memory_order_seq_cst);
134  uint64_t state = state_.load(std::memory_order_acquire);
135  for (;;) {
136  CheckState(state);
137  const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
138  const uint64_t signals = (state & kSignalMask) >> kSignalShift;
139  // Easy case: no waiters.
140  if ((state & kStackMask) == kStackMask && waiters == signals) return;
141  uint64_t newstate;
142  if (notifyAll) {
143  // Empty wait stack and set signal to number of pre-wait threads.
144  newstate =
145  (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
146  } else if (signals < waiters) {
147  // There is a thread in pre-wait state, unblock it.
148  newstate = state + kSignalInc;
149  } else {
150  // Pop a waiter from list and unpark it.
151  Waiter* w = &waiters_[state & kStackMask];
152  uint64_t next = w->next.load(std::memory_order_relaxed);
153  newstate = (state & (kWaiterMask | kSignalMask)) | next;
154  }
155  CheckState(newstate);
156  if (state_.compare_exchange_weak(state, newstate,
157  std::memory_order_acq_rel)) {
158  if (!notifyAll && (signals < waiters))
159  return; // unblocked pre-wait thread
160  if ((state & kStackMask) == kStackMask) return;
161  Waiter* w = &waiters_[state & kStackMask];
162  if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
163  Unpark(w);
164  return;
165  }
166  }
167  }
168 
169  class Waiter {
170  friend class EventCount;
171  // Align to 128 byte boundary to prevent false sharing with other Waiter
172  // objects in the same vector.
177  unsigned state = kNotSignaled;
178  enum {
182  };
183  };
184 
185  private:
186  // State_ layout:
187  // - low kWaiterBits is a stack of waiters committed wait
188  // (indexes in waiters_ array are used as stack elements,
189  // kStackMask means empty stack).
190  // - next kWaiterBits is count of waiters in prewait state.
191  // - next kWaiterBits is count of pending signals.
192  // - remaining bits are ABA counter for the stack.
193  // (stored in Waiter node and incremented on push).
194  static const uint64_t kWaiterBits = 14;
195  static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
197  static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
198  << kWaiterShift;
199  static const uint64_t kWaiterInc = 1ull << kWaiterShift;
200  static const uint64_t kSignalShift = 2 * kWaiterBits;
201  static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
202  << kSignalShift;
203  static const uint64_t kSignalInc = 1ull << kSignalShift;
204  static const uint64_t kEpochShift = 3 * kWaiterBits;
205  static const uint64_t kEpochBits = 64 - kEpochShift;
206  static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
207  static const uint64_t kEpochInc = 1ull << kEpochShift;
208  std::atomic<uint64_t> state_;
210 
211  static void CheckState(uint64_t state, bool waiter = false) {
212  static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
213  const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
214  const uint64_t signals = (state & kSignalMask) >> kSignalShift;
215  eigen_plain_assert(waiters >= signals);
216  eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
217  eigen_plain_assert(!waiter || waiters > 0);
218  (void)waiters;
219  (void)signals;
220  }
221 
222  void Park(Waiter* w) {
223  EIGEN_MUTEX_LOCK lock(w->mu);
224  while (w->state != Waiter::kSignaled) {
225  w->state = Waiter::kWaiting;
226  w->cv.wait(lock);
227  }
228  }
229 
230  void Unpark(Waiter* w) {
231  for (Waiter* next; w; w = next) {
232  uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
233  next = wnext == kStackMask ? nullptr : &waiters_[internal::convert_index<size_t>(wnext)];
234  unsigned state;
235  {
236  EIGEN_MUTEX_LOCK lock(w->mu);
237  state = w->state;
238  w->state = Waiter::kSignaled;
239  }
240  // Avoid notifying if it wasn't waiting.
241  if (state == Waiter::kWaiting) w->cv.notify_one();
242  }
243  }
244 
245  EventCount(const EventCount&) = delete;
246  void operator=(const EventCount&) = delete;
247 };
248 
249 } // namespace Eigen
250 
251 #endif // EIGEN_CXX11_THREADPOOL_EVENTCOUNT_H
#define eigen_plain_assert(condition)
Definition: Assert.h:156
#define EIGEN_ALIGN_TO_BOUNDARY(n)
RowVector3d w
#define EIGEN_MUTEX
Definition: ThreadPool:55
#define EIGEN_MUTEX_LOCK
Definition: ThreadPool:58
#define EIGEN_CONDVAR
Definition: ThreadPool:61
std::atomic< uint64_t > next
Definition: EventCount.h:173
MaxSizeVector< Waiter > & waiters_
Definition: EventCount.h:209
EventCount(MaxSizeVector< Waiter > &waiters)
Definition: EventCount.h:55
std::atomic< uint64_t > state_
Definition: EventCount.h:208
static const uint64_t kEpochMask
Definition: EventCount.h:206
static const uint64_t kStackMask
Definition: EventCount.h:195
static const uint64_t kSignalInc
Definition: EventCount.h:203
void CommitWait(Waiter *w)
Definition: EventCount.h:81
static const uint64_t kEpochShift
Definition: EventCount.h:204
static const uint64_t kWaiterMask
Definition: EventCount.h:197
void operator=(const EventCount &)=delete
void Notify(bool notifyAll)
Definition: EventCount.h:132
void Unpark(Waiter *w)
Definition: EventCount.h:230
static const uint64_t kWaiterInc
Definition: EventCount.h:199
static const uint64_t kEpochBits
Definition: EventCount.h:205
static const uint64_t kWaiterShift
Definition: EventCount.h:196
static void CheckState(uint64_t state, bool waiter=false)
Definition: EventCount.h:211
EventCount(const EventCount &)=delete
static const uint64_t kSignalMask
Definition: EventCount.h:201
static const uint64_t kEpochInc
Definition: EventCount.h:207
void Park(Waiter *w)
Definition: EventCount.h:222
static const uint64_t kWaiterBits
Definition: EventCount.h:194
static const uint64_t kSignalShift
Definition: EventCount.h:200
The MaxSizeVector class.
Definition: MaxSizeVector.h:31
size_t size() const
std::uint64_t uint64_t
Definition: Meta.h:41
: InteropHeaders
Definition: Core:139
Definition: BFloat16.h:222