multiway_mergesort.h

Go to the documentation of this file.
00001 // -*- C++ -*-
00002 
00003 // Copyright (C) 2007, 2008 Free Software Foundation, Inc.
00004 //
00005 // This file is part of the GNU ISO C++ Library.  This library is free
00006 // software; you can redistribute it and/or modify it under the terms
00007 // of the GNU General Public License as published by the Free Software
00008 // Foundation; either version 2, or (at your option) any later
00009 // version.
00010 
00011 // This library is distributed in the hope that it will be useful, but
00012 // WITHOUT ANY WARRANTY; without even the implied warranty of
00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00014 // General Public License for more details.
00015 
00016 // You should have received a copy of the GNU General Public License
00017 // along with this library; see the file COPYING.  If not, write to
00018 // the Free Software Foundation, 59 Temple Place - Suite 330, Boston,
00019 // MA 02111-1307, USA.
00020 
00021 // As a special exception, you may use this file as part of a free
00022 // software library without restriction.  Specifically, if other files
00023 // instantiate templates or use macros or inline functions from this
00024 // file, or you compile this file and link it with other files to
00025 // produce an executable, this file does not by itself cause the
00026 // resulting executable to be covered by the GNU General Public
00027 // License.  This exception does not however invalidate any other
00028 // reasons why the executable file might be covered by the GNU General
00029 // Public License.
00030 
00031 /** @file parallel/multiway_mergesort.h
00032  *  @brief Parallel multiway merge sort.
00033  *  This file is a GNU parallel extension to the Standard C++ Library.
00034  */
00035 
00036 // Written by Johannes Singler.
00037 
00038 #ifndef _GLIBCXX_PARALLEL_MERGESORT_H
00039 #define _GLIBCXX_PARALLEL_MERGESORT_H 1
00040 
00041 #include <vector>
00042 
00043 #include <parallel/basic_iterator.h>
00044 #include <bits/stl_algo.h>
00045 #include <parallel/parallel.h>
00046 #include <parallel/multiway_merge.h>
00047 
00048 namespace __gnu_parallel
00049 {
00050 
00051 /** @brief Subsequence description. */
00052 template<typename _DifferenceTp>
00053   struct Piece
00054   {
00055     typedef _DifferenceTp difference_type;
00056 
00057     /** @brief Begin of subsequence. */
00058     difference_type begin;
00059 
00060     /** @brief End of subsequence. */
00061     difference_type end;
00062   };
00063 
00064 /** @brief Data accessed by all threads.
00065   *
00066   *  PMWMS = parallel multiway mergesort */
00067 template<typename RandomAccessIterator>
00068   struct PMWMSSortingData
00069   {
00070     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00071     typedef typename traits_type::value_type value_type;
00072     typedef typename traits_type::difference_type difference_type;
00073 
00074     /** @brief Number of threads involved. */
00075     thread_index_t num_threads;
00076 
00077     /** @brief Input begin. */
00078     RandomAccessIterator source;
00079 
00080     /** @brief Start indices, per thread. */
00081     difference_type* starts;
00082 
00083     /** @brief Storage in which to sort. */
00084     value_type** temporary;
00085 
00086     /** @brief Samples. */
00087     value_type* samples;
00088 
00089     /** @brief Offsets to add to the found positions. */
00090     difference_type* offsets;
00091 
00092     /** @brief Pieces of data to merge @c [thread][sequence] */
00093     std::vector<Piece<difference_type> >* pieces;
00094 };
00095 
00096 /**
00097   *  @brief Select samples from a sequence.
00098   *  @param sd Pointer to algorithm data. Result will be placed in
00099   *  @c sd->samples.
00100   *  @param num_samples Number of samples to select.
00101   */
00102 template<typename RandomAccessIterator, typename _DifferenceTp>
00103   void 
00104   determine_samples(PMWMSSortingData<RandomAccessIterator>* sd,
00105                     _DifferenceTp num_samples)
00106   {
00107     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00108     typedef typename traits_type::value_type value_type;
00109     typedef _DifferenceTp difference_type;
00110 
00111     thread_index_t iam = omp_get_thread_num();
00112 
00113     difference_type* es = new difference_type[num_samples + 2];
00114 
00115     equally_split(sd->starts[iam + 1] - sd->starts[iam], 
00116                   num_samples + 1, es);
00117 
00118     for (difference_type i = 0; i < num_samples; ++i)
00119       ::new(&(sd->samples[iam * num_samples + i]))
00120       value_type(sd->source[sd->starts[iam] + es[i + 1]]);
00121 
00122     delete[] es;
00123   }
00124 
00125 /** @brief Split consistently. */
00126 template<bool exact, typename RandomAccessIterator,
00127           typename Comparator, typename SortingPlacesIterator>
00128   struct split_consistently
00129   {
00130   };
00131 
00132 /** @brief Split by exact splitting. */
00133 template<typename RandomAccessIterator, typename Comparator,
00134           typename SortingPlacesIterator>
00135   struct split_consistently
00136     <true, RandomAccessIterator, Comparator, SortingPlacesIterator>
00137   {
00138     void operator()(
00139       const thread_index_t iam,
00140       PMWMSSortingData<RandomAccessIterator>* sd,
00141       Comparator& comp,
00142       const typename
00143         std::iterator_traits<RandomAccessIterator>::difference_type
00144           num_samples)
00145       const
00146   {
00147 #   pragma omp barrier
00148 
00149     std::vector<std::pair<SortingPlacesIterator, SortingPlacesIterator> >
00150         seqs(sd->num_threads);
00151     for (thread_index_t s = 0; s < sd->num_threads; s++)
00152       seqs[s] = std::make_pair(sd->temporary[s],
00153                                 sd->temporary[s]
00154                                     + (sd->starts[s + 1] - sd->starts[s]));
00155 
00156     std::vector<SortingPlacesIterator> offsets(sd->num_threads);
00157 
00158     // if not last thread
00159     if (iam < sd->num_threads - 1)
00160       multiseq_partition(seqs.begin(), seqs.end(),
00161                           sd->starts[iam + 1], offsets.begin(), comp);
00162 
00163     for (int seq = 0; seq < sd->num_threads; seq++)
00164       {
00165         // for each sequence
00166         if (iam < (sd->num_threads - 1))
00167           sd->pieces[iam][seq].end = offsets[seq] - seqs[seq].first;
00168         else
00169           // very end of this sequence
00170           sd->pieces[iam][seq].end =
00171               sd->starts[seq + 1] - sd->starts[seq];
00172       }
00173 
00174 #   pragma omp barrier
00175 
00176     for (thread_index_t seq = 0; seq < sd->num_threads; seq++)
00177       {
00178         // For each sequence.
00179         if (iam > 0)
00180           sd->pieces[iam][seq].begin = sd->pieces[iam - 1][seq].end;
00181         else
00182           // Absolute beginning.
00183           sd->pieces[iam][seq].begin = 0;
00184       }
00185   }   
00186   };
00187 
00188 /** @brief Split by sampling. */ 
00189 template<typename RandomAccessIterator, typename Comparator,
00190           typename SortingPlacesIterator>
00191   struct split_consistently<false, RandomAccessIterator, Comparator,
00192                              SortingPlacesIterator>
00193   {
00194     void operator()(
00195         const thread_index_t iam,
00196         PMWMSSortingData<RandomAccessIterator>* sd,
00197         Comparator& comp,
00198         const typename
00199           std::iterator_traits<RandomAccessIterator>::difference_type
00200             num_samples)
00201         const
00202     {
00203       typedef std::iterator_traits<RandomAccessIterator> traits_type;
00204       typedef typename traits_type::value_type value_type;
00205       typedef typename traits_type::difference_type difference_type;
00206 
00207       determine_samples(sd, num_samples);
00208 
00209 #     pragma omp barrier
00210 
00211 #     pragma omp single
00212       __gnu_sequential::sort(sd->samples,
00213                              sd->samples + (num_samples * sd->num_threads),
00214                              comp);
00215 
00216 #     pragma omp barrier
00217 
00218       for (thread_index_t s = 0; s < sd->num_threads; ++s)
00219         {
00220           // For each sequence.
00221           if (num_samples * iam > 0)
00222             sd->pieces[iam][s].begin =
00223                 std::lower_bound(sd->temporary[s],
00224                     sd->temporary[s]
00225                         + (sd->starts[s + 1] - sd->starts[s]),
00226                     sd->samples[num_samples * iam],
00227                     comp)
00228                 - sd->temporary[s];
00229           else
00230             // Absolute beginning.
00231             sd->pieces[iam][s].begin = 0;
00232 
00233           if ((num_samples * (iam + 1)) < (num_samples * sd->num_threads))
00234             sd->pieces[iam][s].end =
00235                 std::lower_bound(sd->temporary[s],
00236                         sd->temporary[s]
00237                             + (sd->starts[s + 1] - sd->starts[s]),
00238                         sd->samples[num_samples * (iam + 1)],
00239                         comp)
00240                 - sd->temporary[s];
00241           else
00242             // Absolute end.
00243             sd->pieces[iam][s].end = sd->starts[s + 1] - sd->starts[s];
00244         }
00245     }
00246   };
00247   
00248 template<bool stable, typename RandomAccessIterator, typename Comparator>
00249   struct possibly_stable_sort
00250   {
00251   };
00252 
00253 template<typename RandomAccessIterator, typename Comparator>
00254   struct possibly_stable_sort<true, RandomAccessIterator, Comparator>
00255   {
00256     void operator()(const RandomAccessIterator& begin,
00257                      const RandomAccessIterator& end, Comparator& comp) const
00258     {
00259       __gnu_sequential::stable_sort(begin, end, comp); 
00260     }
00261   };
00262 
00263 template<typename RandomAccessIterator, typename Comparator>
00264   struct possibly_stable_sort<false, RandomAccessIterator, Comparator>
00265   {
00266     void operator()(const RandomAccessIterator begin,
00267                      const RandomAccessIterator end, Comparator& comp) const
00268     {
00269       __gnu_sequential::sort(begin, end, comp); 
00270     }
00271   };
00272 
00273 template<bool stable, typename SeqRandomAccessIterator,
00274           typename RandomAccessIterator, typename Comparator,
00275           typename DiffType>
00276   struct possibly_stable_multiway_merge
00277   {
00278   };
00279 
00280 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
00281           typename Comparator, typename DiffType>
00282   struct possibly_stable_multiway_merge
00283     <true, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
00284     DiffType>
00285   {
00286     void operator()(const SeqRandomAccessIterator& seqs_begin,
00287                       const SeqRandomAccessIterator& seqs_end,
00288                       const RandomAccessIterator& target,
00289                       Comparator& comp,
00290                       DiffType length_am) const
00291     {
00292       stable_multiway_merge(seqs_begin, seqs_end, target, comp,
00293                        length_am, sequential_tag());
00294     }
00295   };
00296 
00297 template<typename SeqRandomAccessIterator, typename RandomAccessIterator,
00298           typename Comparator, typename DiffType>
00299   struct possibly_stable_multiway_merge
00300     <false, SeqRandomAccessIterator, RandomAccessIterator, Comparator,
00301     DiffType>
00302   {
00303     void operator()(const SeqRandomAccessIterator& seqs_begin,
00304                       const SeqRandomAccessIterator& seqs_end,
00305                       const RandomAccessIterator& target,
00306                       Comparator& comp,
00307                       DiffType length_am) const
00308     {
00309       multiway_merge(seqs_begin, seqs_end, target, comp,
00310                        length_am, sequential_tag());
00311     }
00312   };
00313 
00314 /** @brief PMWMS code executed by each thread.
00315   *  @param sd Pointer to algorithm data.
00316   *  @param comp Comparator.
00317   */
00318 template<bool stable, bool exact, typename RandomAccessIterator,
00319           typename Comparator>
00320   void 
00321   parallel_sort_mwms_pu(PMWMSSortingData<RandomAccessIterator>* sd,
00322                         Comparator& comp)
00323   {
00324     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00325     typedef typename traits_type::value_type value_type;
00326     typedef typename traits_type::difference_type difference_type;
00327 
00328     thread_index_t iam = omp_get_thread_num();
00329 
00330     // Length of this thread's chunk, before merging.
00331     difference_type length_local = sd->starts[iam + 1] - sd->starts[iam];
00332 
00333     // Sort in temporary storage, leave space for sentinel.
00334 
00335     typedef value_type* SortingPlacesIterator;
00336 
00337     sd->temporary[iam] =
00338         static_cast<value_type*>(
00339         ::operator new(sizeof(value_type) * (length_local + 1)));
00340 
00341     // Copy there.
00342     std::uninitialized_copy(sd->source + sd->starts[iam],
00343                             sd->source + sd->starts[iam] + length_local,
00344                             sd->temporary[iam]);
00345 
00346     possibly_stable_sort<stable, SortingPlacesIterator, Comparator>()
00347         (sd->temporary[iam], sd->temporary[iam] + length_local, comp);
00348 
00349     // Invariant: locally sorted subsequence in sd->temporary[iam],
00350     // sd->temporary[iam] + length_local.
00351 
00352     // No barrier here: Synchronization is done by the splitting routine.
00353 
00354     difference_type num_samples =
00355         _Settings::get().sort_mwms_oversampling * sd->num_threads - 1;
00356     split_consistently
00357       <exact, RandomAccessIterator, Comparator, SortingPlacesIterator>()
00358         (iam, sd, comp, num_samples);
00359 
00360     // Offset from target begin, length after merging.
00361     difference_type offset = 0, length_am = 0;
00362     for (thread_index_t s = 0; s < sd->num_threads; s++)
00363       {
00364         length_am += sd->pieces[iam][s].end - sd->pieces[iam][s].begin;
00365         offset += sd->pieces[iam][s].begin;
00366       }
00367 
00368     typedef std::vector<
00369       std::pair<SortingPlacesIterator, SortingPlacesIterator> >
00370         seq_vector_type;
00371     seq_vector_type seqs(sd->num_threads);
00372 
00373     for (int s = 0; s < sd->num_threads; ++s)
00374       {
00375         seqs[s] =
00376           std::make_pair(sd->temporary[s] + sd->pieces[iam][s].begin,
00377         sd->temporary[s] + sd->pieces[iam][s].end);
00378       }
00379 
00380     possibly_stable_multiway_merge<
00381         stable,
00382         typename seq_vector_type::iterator,
00383         RandomAccessIterator,
00384         Comparator, difference_type>()
00385           (seqs.begin(), seqs.end(),
00386            sd->source + offset, comp,
00387            length_am);
00388 
00389 #   pragma omp barrier
00390 
00391     ::operator delete(sd->temporary[iam]);
00392   }
00393 
00394 /** @brief PMWMS main call.
00395   *  @param begin Begin iterator of sequence.
00396   *  @param end End iterator of sequence.
00397   *  @param comp Comparator.
00398   *  @param n Length of sequence.
00399   *  @param num_threads Number of threads to use.
00400   */
00401 template<bool stable, bool exact, typename RandomAccessIterator,
00402            typename Comparator>
00403   void
00404   parallel_sort_mwms(RandomAccessIterator begin, RandomAccessIterator end,
00405                      Comparator comp,
00406                      thread_index_t num_threads)
00407   {
00408     _GLIBCXX_CALL(end - begin)
00409 
00410     typedef std::iterator_traits<RandomAccessIterator> traits_type;
00411     typedef typename traits_type::value_type value_type;
00412     typedef typename traits_type::difference_type difference_type;
00413 
00414     difference_type n = end - begin;
00415 
00416     if (n <= 1)
00417       return;
00418 
00419     // at least one element per thread
00420     if (num_threads > n)
00421       num_threads = static_cast<thread_index_t>(n);
00422 
00423     // shared variables
00424     PMWMSSortingData<RandomAccessIterator> sd;
00425     difference_type* starts;
00426 
00427 #   pragma omp parallel num_threads(num_threads)
00428       {
00429         num_threads = omp_get_num_threads();  //no more threads than requested
00430 
00431 #       pragma omp single
00432           {
00433             sd.num_threads = num_threads;
00434             sd.source = begin;
00435 
00436             sd.temporary = new value_type*[num_threads];
00437 
00438             if (!exact)
00439               {
00440                 difference_type size =
00441                     (_Settings::get().sort_mwms_oversampling * num_threads - 1)
00442                         * num_threads;
00443                 sd.samples = static_cast<value_type*>(
00444                               ::operator new(size * sizeof(value_type)));
00445               }
00446             else
00447               sd.samples = NULL;
00448 
00449             sd.offsets = new difference_type[num_threads - 1];
00450             sd.pieces = new std::vector<Piece<difference_type> >[num_threads];
00451             for (int s = 0; s < num_threads; ++s)
00452               sd.pieces[s].resize(num_threads);
00453             starts = sd.starts = new difference_type[num_threads + 1];
00454 
00455             difference_type chunk_length = n / num_threads;
00456             difference_type split = n % num_threads;
00457             difference_type pos = 0;
00458             for (int i = 0; i < num_threads; ++i)
00459               {
00460                 starts[i] = pos;
00461                 pos += (i < split) ? (chunk_length + 1) : chunk_length;
00462               }
00463             starts[num_threads] = pos;
00464           } //single
00465 
00466         // Now sort in parallel.
00467         parallel_sort_mwms_pu<stable, exact>(&sd, comp);
00468       } //parallel
00469 
00470     delete[] starts;
00471     delete[] sd.temporary;
00472 
00473     if (!exact)
00474       ::operator delete(sd.samples);
00475 
00476     delete[] sd.offsets;
00477     delete[] sd.pieces;
00478   }
00479 } //namespace __gnu_parallel
00480 
00481 #endif

Generated on Fri Jan 23 20:12:15 2009 for libstdc++ by  doxygen 1.5.6