balanced_quicksort.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 
00003 // Copyright (C) 2007, 2008 Free Software Foundation, Inc.
00004 //
00005 // This file is part of the GNU ISO C++ Library.  This library is free
00006 // software; you can redistribute it and/or modify it under the terms
00007 // of the GNU General Public License as published by the Free Software
00008 // Foundation; either version 2, or (at your option) any later
00009 // version.
00010 
00011 // This library is distributed in the hope that it will be useful, but
00012 // WITHOUT ANY WARRANTY; without even the implied warranty of
00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00014 // General Public License for more details.
00015 
00016 // You should have received a copy of the GNU General Public License
00017 // along with this library; see the file COPYING.  If not, write to
00018 // the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
00019 // MA 02111-1307, USA.
00020 
00021 // As a special exception, you may use this file as part of a free
00022 // software library without restriction.  Specifically, if other files
00023 // instantiate templates or use macros or inline functions from this
00024 // file, or you compile this file and link it with other files to
00025 // produce an executable, this file does not by itself cause the
00026 // resulting executable to be covered by the GNU General Public
00027 // License.  This exception does not however invalidate any other
00028 // reasons why the executable file might be covered by the GNU General
00029 // Public License.
00030 
00031 /** @file parallel/balanced_quicksort.h
00032  *  @brief Implementation of a dynamically load-balanced parallel quicksort.
00033  *
00034  *  It works in-place and needs only logarithmic extra memory.
00035  *  The algorithm is similar to the one proposed in
00036  *
00037  *  P. Tsigas and Y. Zhang.
00038  *  A simple, fast parallel implementation of quicksort and
00039  *  its performance evaluation on SUN enterprise 10000.
00040  *  In 11th Euromicro Conference on Parallel, Distributed and
00041  *  Network-Based Processing, page 372, 2003.
00042  *
00043  *  This file is a GNU parallel extension to the Standard C++ Library.
00044  */
00045 
00046 // Written by Johannes Singler.
00047 
00048 #ifndef _GLIBCXX_PARALLEL_BAL_QUICKSORT_H
00049 #define _GLIBCXX_PARALLEL_BAL_QUICKSORT_H 1
00050 
00051 #include <parallel/basic_iterator.h>
00052 #include <bits/stl_algo.h>
00053 
00054 #include <parallel/settings.h>
00055 #include <parallel/partition.h>
00056 #include <parallel/random_number.h>
00057 #include <parallel/queue.h>
00058 #include <functional>
00059 
00060 #if _GLIBCXX_ASSERTIONS
00061 #include <parallel/checkers.h>
00062 #endif
00063 
00064 namespace __gnu_parallel
00065 {
00066 /** @brief Information local to one thread in the parallel quicksort run. */
00067 template<typename RandomAccessIterator>
00068   struct QSBThreadLocal
00069   {
00070     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00071     typedef typename traits_type::difference_type difference_type;
00072 
00073     /** @brief Continuous part of the sequence, described by an
00074     iterator pair. */
00075     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00076 
00077     /** @brief Initial piece to work on. */
00078     Piece initial;
00079 
00080     /** @brief Work-stealing queue. */
00081     RestrictedBoundedConcurrentQueue<Piece> leftover_parts;
00082 
00083     /** @brief Number of threads involved in this algorithm. */
00084     thread_index_t num_threads;
00085 
00086     /** @brief Pointer to a counter of elements left over to sort. */
00087     volatile difference_type* elements_leftover;
00088 
00089     /** @brief The complete sequence to sort. */
00090     Piece global;
00091 
00092     /** @brief Constructor.
00093      *  @param queue_size Size of the work-stealing queue. */
00094     QSBThreadLocal(int queue_size) : leftover_parts(queue_size) { }
00095   };
00096 
00097 /** @brief Balanced quicksort divide step.
00098   *  @param begin Begin iterator of subsequence.
00099   *  @param end End iterator of subsequence.
00100   *  @param comp Comparator.
00101   *  @param num_threads Number of threads that are allowed to work on
00102   *  this part.
00103   *  @pre @c (end-begin)>=1 */
00104 template<typename RandomAccessIterator, typename Comparator>
00105   typename std::iterator_traits<RandomAccessIterator>::difference_type
00106   qsb_divide(RandomAccessIterator begin, RandomAccessIterator end,
00107              Comparator comp, thread_index_t num_threads)
00108   {
00109     _GLIBCXX_PARALLEL_ASSERT(num_threads > 0);
00110 
00111     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00112     typedef typename traits_type::value_type value_type;
00113     typedef typename traits_type::difference_type difference_type;
00114 
00115     RandomAccessIterator pivot_pos =
00116       median_of_three_iterators(begin, begin + (end - begin) / 2,
00117                 end  - 1, comp);
00118 
00119 #if defined(_GLIBCXX_ASSERTIONS)
00120     // Must be in between somewhere.
00121     difference_type n = end - begin;
00122 
00123     _GLIBCXX_PARALLEL_ASSERT(
00124            (!comp(*pivot_pos, *begin) && !comp(*(begin + n / 2), *pivot_pos))
00125         || (!comp(*pivot_pos, *begin) && !comp(*(end - 1), *pivot_pos))
00126         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*begin, *pivot_pos))
00127         || (!comp(*pivot_pos, *(begin + n / 2)) && !comp(*(end - 1), *pivot_pos))
00128         || (!comp(*pivot_pos, *(end - 1)) && !comp(*begin, *pivot_pos))
00129         || (!comp(*pivot_pos, *(end - 1)) && !comp(*(begin + n / 2), *pivot_pos)));
00130 #endif
00131 
00132     // Swap pivot value to end.
00133     if (pivot_pos != (end - 1))
00134       std::swap(*pivot_pos, *(end - 1));
00135     pivot_pos = end - 1;
00136 
00137     __gnu_parallel::binder2nd<Comparator, value_type, value_type, bool>
00138         pred(comp, *pivot_pos);
00139 
00140     // Divide, returning end - begin - 1 in the worst case.
00141     difference_type split_pos = parallel_partition(
00142         begin, end - 1, pred, num_threads);
00143 
00144     // Swap back pivot to middle.
00145     std::swap(*(begin + split_pos), *pivot_pos);
00146     pivot_pos = begin + split_pos;
00147 
00148 #if _GLIBCXX_ASSERTIONS
00149     RandomAccessIterator r;
00150     for (r = begin; r != pivot_pos; ++r)
00151       _GLIBCXX_PARALLEL_ASSERT(comp(*r, *pivot_pos));
00152     for (; r != end; ++r)
00153       _GLIBCXX_PARALLEL_ASSERT(!comp(*r, *pivot_pos));
00154 #endif
00155 
00156     return split_pos;
00157   }
00158 
00159 /** @brief Quicksort conquer step.
00160   *  @param tls Array of thread-local storages.
00161   *  @param begin Begin iterator of subsequence.
00162   *  @param end End iterator of subsequence.
00163   *  @param comp Comparator.
00164   *  @param iam Number of the thread processing this function.
00165   *  @param num_threads
00166   *          Number of threads that are allowed to work on this part. */
00167 template<typename RandomAccessIterator, typename Comparator>
00168   void
00169   qsb_conquer(QSBThreadLocal<RandomAccessIterator>** tls,
00170               RandomAccessIterator begin, RandomAccessIterator end,
00171               Comparator comp,
00172               thread_index_t iam, thread_index_t num_threads,
00173               bool parent_wait)
00174   {
00175     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00176     typedef typename traits_type::value_type value_type;
00177     typedef typename traits_type::difference_type difference_type;
00178 
00179     difference_type n = end - begin;
00180 
00181     if (num_threads <= 1 || n <= 1)
00182       {
00183         tls[iam]->initial.first  = begin;
00184         tls[iam]->initial.second = end;
00185 
00186         qsb_local_sort_with_helping(tls, comp, iam, parent_wait);
00187 
00188         return;
00189       }
00190 
00191     // Divide step.
00192     difference_type split_pos = qsb_divide(begin, end, comp, num_threads);
00193 
00194 #if _GLIBCXX_ASSERTIONS
00195     _GLIBCXX_PARALLEL_ASSERT(0 <= split_pos && split_pos < (end - begin));
00196 #endif
00197 
00198     thread_index_t num_threads_leftside =
00199         std::max<thread_index_t>(1, std::min<thread_index_t>(
00200                           num_threads - 1, split_pos * num_threads / n));
00201 
00202 #   pragma omp atomic
00203     *tls[iam]->elements_leftover -= (difference_type)1;
00204 
00205     // Conquer step.
00206 #   pragma omp parallel num_threads(2)
00207     {
00208       bool wait;
00209       if(omp_get_num_threads() < 2)
00210         wait = false;
00211       else
00212         wait = parent_wait;
00213 
00214 #     pragma omp sections
00215         {
00216 #         pragma omp section
00217             {
00218               qsb_conquer(tls, begin, begin + split_pos, comp,
00219                           iam,
00220                           num_threads_leftside,
00221                           wait);
00222               wait = parent_wait;
00223             }
00224           // The pivot_pos is left in place, to ensure termination.
00225 #         pragma omp section
00226             {
00227               qsb_conquer(tls, begin + split_pos + 1, end, comp,
00228                           iam + num_threads_leftside,
00229                           num_threads - num_threads_leftside,
00230                           wait);
00231               wait = parent_wait;
00232             }
00233         }
00234     }
00235   }
00236 
00237 /**
00238   *  @brief Quicksort step doing load-balanced local sort.
00239   *  @param tls Array of thread-local storages.
00240   *  @param comp Comparator.
00241   *  @param iam Number of the thread processing this function.
00242   */
00243 template<typename RandomAccessIterator, typename Comparator>
00244   void
00245   qsb_local_sort_with_helping(QSBThreadLocal<RandomAccessIterator>** tls,
00246                               Comparator& comp, int iam, bool wait)
00247   {
00248     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00249     typedef typename traits_type::value_type value_type;
00250     typedef typename traits_type::difference_type difference_type;
00251     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00252 
00253     QSBThreadLocal<RandomAccessIterator>& tl = *tls[iam];
00254 
00255     difference_type base_case_n = _Settings::get().sort_qsb_base_case_maximal_n;
00256     if (base_case_n < 2)
00257       base_case_n = 2;
00258     thread_index_t num_threads = tl.num_threads;
00259 
00260     // Every thread has its own random number generator.
00261     random_number rng(iam + 1);
00262 
00263     Piece current = tl.initial;
00264 
00265     difference_type elements_done = 0;
00266 #if _GLIBCXX_ASSERTIONS
00267     difference_type total_elements_done = 0;
00268 #endif
00269 
00270     for (;;)
00271       {
00272         // Invariant: current must be a valid (maybe empty) range.
00273         RandomAccessIterator begin = current.first, end = current.second;
00274         difference_type n = end - begin;
00275 
00276         if (n > base_case_n)
00277           {
00278             // Divide.
00279             RandomAccessIterator pivot_pos = begin +  rng(n);
00280 
00281             // Swap pivot_pos value to end.
00282             if (pivot_pos != (end - 1))
00283               std::swap(*pivot_pos, *(end - 1));
00284             pivot_pos = end - 1;
00285 
00286             __gnu_parallel::binder2nd
00287                 <Comparator, value_type, value_type, bool>
00288                 pred(comp, *pivot_pos);
00289 
00290             // Divide, leave pivot unchanged in last place.
00291             RandomAccessIterator split_pos1, split_pos2;
00292             split_pos1 = __gnu_sequential::partition(begin, end - 1, pred);
00293 
00294             // Left side: < pivot_pos; right side: >= pivot_pos.
00295 #if _GLIBCXX_ASSERTIONS
00296             _GLIBCXX_PARALLEL_ASSERT(begin <= split_pos1 && split_pos1 < end);
00297 #endif
00298             // Swap pivot back to middle.
00299             if (split_pos1 != pivot_pos)
00300               std::swap(*split_pos1, *pivot_pos);
00301             pivot_pos = split_pos1;
00302 
00303             // In case all elements are equal, split_pos1 == 0.
00304             if ((split_pos1 + 1 - begin) < (n >> 7)
00305             || (end - split_pos1) < (n >> 7))
00306               {
00307                 // Very unequal split, one part smaller than one 128th
00308                 // elements not strictly larger than the pivot.
00309                 __gnu_parallel::unary_negate<__gnu_parallel::binder1st
00310           <Comparator, value_type, value_type, bool>, value_type>
00311           pred(__gnu_parallel::binder1st
00312                <Comparator, value_type, value_type, bool>(comp,
00313                                   *pivot_pos));
00314 
00315                 // Find other end of pivot-equal range.
00316                 split_pos2 = __gnu_sequential::partition(split_pos1 + 1,
00317                              end, pred);
00318               }
00319             else
00320               // Only skip the pivot.
00321               split_pos2 = split_pos1 + 1;
00322 
00323             // Elements equal to pivot are done.
00324             elements_done += (split_pos2 - split_pos1);
00325 #if _GLIBCXX_ASSERTIONS
00326             total_elements_done += (split_pos2 - split_pos1);
00327 #endif
00328             // Always push larger part onto stack.
00329             if (((split_pos1 + 1) - begin) < (end - (split_pos2)))
00330               {
00331                 // Right side larger.
00332                 if ((split_pos2) != end)
00333                   tl.leftover_parts.push_front(std::make_pair(split_pos2,
00334                                   end));
00335 
00336                 //current.first = begin;    //already set anyway
00337                 current.second = split_pos1;
00338                 continue;
00339               }
00340             else
00341               {
00342                 // Left side larger.
00343                 if (begin != split_pos1)
00344                   tl.leftover_parts.push_front(std::make_pair(begin,
00345                                   split_pos1));
00346 
00347                 current.first = split_pos2;
00348                 //current.second = end; //already set anyway
00349                 continue;
00350               }
00351           }
00352         else
00353           {
00354             __gnu_sequential::sort(begin, end, comp);
00355             elements_done += n;
00356 #if _GLIBCXX_ASSERTIONS
00357             total_elements_done += n;
00358 #endif
00359 
00360             // Prefer own stack, small pieces.
00361             if (tl.leftover_parts.pop_front(current))
00362               continue;
00363 
00364 #           pragma omp atomic
00365             *tl.elements_leftover -= elements_done;
00366 
00367             elements_done = 0;
00368 
00369 #if _GLIBCXX_ASSERTIONS
00370             double search_start = omp_get_wtime();
00371 #endif
00372 
00373             // Look for new work.
00374             bool successfully_stolen = false;
00375             while (wait && *tl.elements_leftover > 0 && !successfully_stolen
00376 #if _GLIBCXX_ASSERTIONS
00377               // Possible dead-lock.
00378               && (omp_get_wtime() < (search_start + 1.0))
00379 #endif
00380               )
00381               {
00382                 thread_index_t victim;
00383                 victim = rng(num_threads);
00384 
00385                 // Large pieces.
00386                 successfully_stolen = (victim != iam)
00387                     && tls[victim]->leftover_parts.pop_back(current);
00388                 if (!successfully_stolen)
00389                   yield();
00390 #if !defined(__ICC) && !defined(__ECC)
00391 #               pragma omp flush
00392 #endif
00393               }
00394 
00395 #if _GLIBCXX_ASSERTIONS
00396             if (omp_get_wtime() >= (search_start + 1.0))
00397               {
00398                 sleep(1);
00399                 _GLIBCXX_PARALLEL_ASSERT(omp_get_wtime()
00400                      < (search_start + 1.0));
00401               }
00402 #endif
00403             if (!successfully_stolen)
00404               {
00405 #if _GLIBCXX_ASSERTIONS
00406                 _GLIBCXX_PARALLEL_ASSERT(*tl.elements_leftover == 0);
00407 #endif
00408                 return;
00409               }
00410           }
00411       }
00412   }
00413 
00414 /** @brief Top-level quicksort routine.
00415   *  @param begin Begin iterator of sequence.
00416   *  @param end End iterator of sequence.
00417   *  @param comp Comparator.
00418   *  @param n Length of the sequence to sort.
00419   *  @param num_threads Number of threads that are allowed to work on
00420   *  this part.
00421   */
00422 template<typename RandomAccessIterator, typename Comparator>
00423   void
00424   parallel_sort_qsb(RandomAccessIterator begin, RandomAccessIterator end,
00425                     Comparator comp,
00426                     typename std::iterator_traits<RandomAccessIterator>
00427                         ::difference_type n,
00428                     thread_index_t num_threads)
00429   {
00430     _GLIBCXX_CALL(end - begin)
00431 
00432     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00433     typedef typename traits_type::value_type value_type;
00434     typedef typename traits_type::difference_type difference_type;
00435     typedef std::pair<RandomAccessIterator, RandomAccessIterator> Piece;
00436 
00437     typedef QSBThreadLocal<RandomAccessIterator> tls_type;
00438 
00439     if (n <= 1)
00440       return;
00441 
00442     // At least one element per processor.
00443     if (num_threads > n)
00444       num_threads = static_cast<thread_index_t>(n);
00445 
00446     // Initialize thread local storage
00447     tls_type** tls = new tls_type*[num_threads];
00448     difference_type queue_size = num_threads * (thread_index_t)(log2(n) + 1);
00449     for (thread_index_t t = 0; t < num_threads; ++t)
00450       tls[t] = new QSBThreadLocal<RandomAccessIterator>(queue_size);
00451 
00452     // There can never be more than ceil(log2(n)) ranges on the stack, because
00453     // 1. Only one processor pushes onto the stack
00454     // 2. The largest range has at most length n
00455     // 3. Each range is larger than half of the range remaining
00456     volatile difference_type elements_leftover = n;
00457     for (int i = 0; i < num_threads; ++i)
00458       {
00459         tls[i]->elements_leftover = &elements_leftover;
00460         tls[i]->num_threads = num_threads;
00461         tls[i]->global = std::make_pair(begin, end);
00462 
00463         // Just in case nothing is left to assign.
00464         tls[i]->initial = std::make_pair(end, end);
00465       }
00466 
00467     // Main recursion call.
00468     qsb_conquer(tls, begin, begin + n, comp, 0, num_threads, true);
00469 
00470 #if _GLIBCXX_ASSERTIONS
00471     // All stack must be empty.
00472     Piece dummy;
00473     for (int i = 1; i < num_threads; ++i)
00474       _GLIBCXX_PARALLEL_ASSERT(!tls[i]->leftover_parts.pop_back(dummy));
00475 #endif
00476 
00477     for (int i = 0; i < num_threads; ++i)
00478       delete tls[i];
00479     delete[] tls;
00480   }
00481 } // namespace __gnu_parallel
00482 
00483 #endif

Generated on Sat Oct 25 05:08:58 2008 for libstdc++ by  doxygen 1.5.6