{"id":109902,"date":"2024-06-14T07:00:00","date_gmt":"2024-06-14T14:00:00","guid":{"rendered":"https:\/\/devblogs.microsoft.com\/oldnewthing\/?p=109902"},"modified":"2024-06-14T11:09:06","modified_gmt":"2024-06-14T18:09:06","slug":"20240614-00","status":"publish","type":"post","link":"https:\/\/devblogs.microsoft.com\/oldnewthing\/20240614-00\/?p=109902","title":{"rendered":"Lock-free reference-counting a TLS slot using atomics, part 3"},"content":{"rendered":"<p>Last time, we tried to <a title=\"Lock-free reference-counting a TLS slot using atomics, part 2\" href=\"https:\/\/devblogs.microsoft.com\/oldnewthing\/20240613-00\/?p=109892\"> remove the mutex bottleneck<\/a> from our class that allocates a TLS on demand and frees the TLS slot when the last client disconnected. We figured out how to that allocate the TLS on demand, but freeing the TLS on the disconnection of the last client was a problem because of a race that can occur if an <code>Acquire()<\/code> occurs while the last reference is being <code>Release()<\/code>d.<\/p>\n<p>This conflict between <code>Acquire()<\/code> and <code>Release()<\/code> arises because we are manipulating two separate atomic variables, when we really want to treat the two variables as an atomic unit.<\/p>\n<p>So let&#8217;s make them an atomic unit.<\/p>\n<pre>struct TlsManager\r\n{\r\n    struct State\r\n    {\r\n        DWORD count = 0;\r\n        DWORD tls; \/\/ valid if count != 0\r\n    };\r\n\r\n    std::atomic&lt;State&gt; m_state;\r\n\r\n    DWORD Acquire()\r\n    {\r\n        auto previous = m_state.load();\r\n        while (true) {\r\n            auto state = previous;\r\n            if (++state.count != 1) {\r\n                if (m_state.compare_exchange_weak(previous, state)) {\r\n                    return state.tls;\r\n                }\r\n            } else {\r\n                state.tls = TlsAlloc();\r\n                THROW_LAST_ERROR_IF(state.tls == TLS_OUT_OF_INDEXES);\r\n                if (m_state.compare_exchange_weak(previous, state)) {\r\n                    return state.tls;\r\n                } else {\r\n                    TlsFree(state.tls);\r\n                }\r\n            }\r\n        }\r\n    }\r\n};\r\n<\/pre>\n<p>We capture the initial state and calculate what the desired new state is. We increment the reference count, and if we didn&#8217;t increment to 1, then the increment is all we needed to do. Try to save this as the new state and return if successful. Otherwise, another thread won the race against us, so we restart the loop to try again. (When writing these types of lock-free algorithms, don&#8217;t forget to loop back and try again if you want the operation to eventually succeed.)<\/p>\n<p>If we incremented to 1, then we are also responsible for allocating the TLS slot. Allocate it and try to save the TLS slot and the incremented reference count as an atomic unit. If this succeeds, then return. Otherwise, clean up the TLS slot we mistakenly allocated and try again.<\/p>\n<p>It&#8217;s possible to optimize this loop a tiny bit more by caching the result of <code>TlsAlloc()<\/code> in case we go a second time through the <code>else<\/code> branch inside the loop. However, I don&#8217;t think this is likely, because it means that we have to lose <i>two<\/i> races: While we are calling <code>TlsAlloc()<\/code>, another thread successfully performed an <code>Acquire()<\/code>, and then when we go back and try to increment the reference count, we find that another thread also successfully performed a <code>Release()<\/code>, forcing us to into the <code>TlsAlloc()<\/code> branch again.<\/p>\n<p>This race would occur if another thread exactly interleaves an <code>Acquire()<\/code>\/<code>Release()<\/code> pair inside our <code>Acquire()<\/code>. Some instrumentation would tell us whether this race is likely in practice.<\/p>\n<p>I could imagine it being a possibility if <code>TlsAlloc()<\/code> and <code>TlsFree()<\/code> are slow enough that they open the necessary race window for the other thread to sneak in.<\/p>\n<p>So let&#8217;s add the caching, just to see how it looks.<\/p>\n<pre>    DWORD Acquire()\r\n    {\r\n        wil::unique_tls tls;\r\n\r\n        auto previous = m_state.load();\r\n        while (true) {\r\n            auto state = previous;\r\n            if (++state.count != 1) {\r\n                if (m_state.compare_exchange_weak(previous, state)) {\r\n                    return state.tls;\r\n                }\r\n            } else {\r\n                if (!tls) {\r\n                    tls.reset(TlsAlloc());\r\n                    THROW_LAST_ERROR_IF(!tls);\r\n                }\r\n                state.tls = tls.get();\r\n                if (m_state.compare_exchange_weak(previous, state)) {\r\n                    tls.release(); \/\/ owned by the TlsManager now\r\n                    return state.tls;\r\n                }\r\n            }\r\n        }\r\n    }\r\n<\/pre>\n<p>We take advantage of the <code>wil::<wbr \/>unique_tls<\/code> RAII type which manages a TLS slot.<\/p>\n<p>With this combined state, we can now <code>Release()<\/code> atomically.<\/p>\n<pre>    void Release()\r\n    {\r\n        auto previous = m_state.load();\r\n        while (true) {\r\n            auto state = previous;\r\n            --state.count;\r\n            if (m_state.compare_exchange_weak(previous, state)) {\r\n                if (state.count == 0) {\r\n                    TlsFree(state.tls);\r\n                }\r\n                return;\r\n            }\r\n        }\r\n    }\r\n<\/pre>\n<p><b>Bonus chatter<\/b>: I didn&#8217;t talk about memory ordering, but the <code>.load()<\/code> calls can be weaked to acquire, and the <code>compare_exchange_weak()<\/code> calls can be weakened to release.<\/p>\n<p><b>Bonus bonus chatter<\/b>: Instead of using a structure, we can pack the values manually into a <code>uint64_t<\/code>. If we continue to assume that the 32-bit reference count won&#8217;t overflow, we can increment and decrement the entire <code>uint64_t<\/code> rather than having to take it apart into two 32-bit integers.<\/p>\n<pre>\/\/ Version where the count is kept in the low-order bits.\r\nstruct TlsManager\r\n{\r\n    <span style=\"border: solid 1px currentcolor;\">std::atomic&lt;uint64_t&gt; m_state;<\/span>\r\n\r\n    DWORD Acquire()\r\n    {\r\n        auto previous = m_state.load();\r\n        while (true) {\r\n            auto state = <span style=\"border: solid 1px currentcolor;\">previous + 1<\/span>;\r\n            if (<span style=\"border: solid 1px currentcolor;\">static_cast&lt;uint64_t&gt;(state)<\/span> != 1) {\r\n                if (m_state.compare_exchange_weak(previous, state)) {\r\n                    return <span style=\"border: solid 1px currentcolor;\">static_cast&lt;uint32_t&gt;(state &gt;&gt; 32);<\/span>\r\n                }\r\n            } else {\r\n                auto tls = TlsAlloc();\r\n                THROW_LAST_ERROR_IF(tls == TLS_OUT_OF_INDEXES);\r\n                <span style=\"border: solid 1px currentcolor;\">state = (static_cast&lt;uint64_t&gt;(tls) &lt;&lt; 32) + 1;<\/span>\r\n                if (m_state.compare_exchange_weak(previous, state)) {\r\n                    return <span style=\"border: solid 1px currentcolor;\">static_cast&lt;uint32_t&gt;(state &gt;&gt; 32);<\/span>\r\n                } else {\r\n                    TlsFree(tls);\r\n                }\r\n            }\r\n        }\r\n    }\r\n\r\n    void Release()\r\n    {\r\n        auto previous = m_state.load();\r\n        while (true) {\r\n            auto state = <span style=\"border: solid 1px currentcolor;\">previous - 1<\/span>;\r\n            if (m_state.compare_exchange_weak(previous, state)) {\r\n                if (<span style=\"border: solid 1px currentcolor;\">static_cast&lt;uint32_t&gt;(state)<\/span> == 0) {\r\n                    TlsFree(<span style=\"border: solid 1px currentcolor;\">static_cast&lt;uint32_t&gt;(state &gt;&gt; 32)<\/span>);\r\n                }\r\n                return;\r\n            }\r\n        }\r\n    }\r\n};\r\n<\/pre>\n<p>Or maybe put the count in the high-order 32 bits.<\/p>\n<pre>\/\/ Version where the count is kept in the high-order bits.\r\nstruct TlsManager\r\n{\r\n    <span style=\"border: solid 1px currentcolor; border-bottom: none;\">std::atomic&lt;uint64_t&gt; m_state;                                  <\/span>\r\n    <span style=\"border: solid 1px currentcolor; border-top: none;\">static constexpr uint64_t unit = static_cast&lt;uint64_t&gt;(1) &lt;&lt; 32;<\/span>\r\n\r\n    DWORD Acquire()\r\n    {\r\n        auto previous = m_state.load();\r\n        while (true) {\r\n            auto state = <span style=\"border: solid 1px currentcolor;\">previous + unit<\/span>;\r\n            if (<span style=\"border: solid 1px currentcolor;\">(state &gt;&gt; 32)<\/span> != 1) {\r\n                if (m_state.compare_exchange_weak(previous, state)) {\r\n                    return <span style=\"border: solid 1px currentcolor;\">static_cast&lt;uint32_t&gt;(state);<\/span>\r\n                }\r\n            } else {\r\n                auto tls = TlsAlloc();\r\n                THROW_LAST_ERROR_IF(tls == TLS_OUT_OF_INDEXES);\r\n                <span style=\"border: solid 1px currentcolor;\">state = tls + unit;<\/span>\r\n                if (m_state.compare_exchange_weak(previous, state)) {\r\n                    return <span style=\"border: solid 1px currentcolor;\">static_cast&lt;uint32_t&gt;(state);<\/span>\r\n                } else {\r\n                    TlsFree(tls);\r\n                }\r\n            }\r\n        }\r\n    }\r\n\r\n    void Release()\r\n    {\r\n        auto previous = m_state.load();\r\n        while (true) {\r\n            auto state = <span style=\"border: solid 1px currentcolor;\">previous - unit<\/span>;\r\n            if (m_state.compare_exchange_weak(previous, state)) {\r\n                if (<span style=\"border: solid 1px currentcolor;\">(state &gt;&gt; 32)<\/span> == 0) {\r\n                    TlsFree(<span style=\"border: solid 1px currentcolor;\">static_cast&lt;uint32_t&gt;(state)<\/span>);\r\n                }\r\n                return;\r\n            }\r\n        }\r\n    }\r\n};\r\n<\/pre>\n<p>Different compilers and different target architectures may produce better results for one formulation over another.<\/p>\n","protected":false},"excerpt":{"rendered":"<p>Keeping track of two things at once.<\/p>\n","protected":false},"author":1069,"featured_media":111744,"comment_status":"open","ping_status":"closed","sticky":false,"template":"","format":"standard","meta":{"_acf_changed":false,"footnotes":""},"categories":[1],"tags":[25],"class_list":["post-109902","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-oldnewthing","tag-code"],"acf":[],"blog_post_summary":"<p>Keeping track of two things at once.<\/p>\n","_links":{"self":[{"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/posts\/109902","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/users\/1069"}],"replies":[{"embeddable":true,"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/comments?post=109902"}],"version-history":[{"count":0,"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/posts\/109902\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/media\/111744"}],"wp:attachment":[{"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/media?parent=109902"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/categories?post=109902"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/devblogs.microsoft.com\/oldnewthing\/wp-json\/wp\/v2\/tags?post=109902"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}