C++多线程实现快速排序算法

一个实现C++快排的算法

speed

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
#include <iostream>
#include <thread>
#include <ctime>

using namespace std;

#define N 10000000

void rand_nums(int* nums, int n)
{
    srand((unsigned int)time(NULL));
    for (int i = 0; i < n; ++i)
    {
        nums[i] = rand() % N;
    }
    
}

void quick_sort_sigle_thread(int* nums, int left, int right)
{
    if (left >= right) return;
    int i = left, j = right, base = nums[left];

    while (i < j)
    {
        while (i < j && nums[j] >= base) --j; //找一个比base小的
        while (i < j && nums[i] <= base) ++i; //找一个比base大的
        if (i < j) swap(nums[i], nums[j]);
    }

    swap(nums[left], nums[i]);
    
    quick_sort_sigle_thread(nums, left, i - 1);
    quick_sort_sigle_thread(nums, i + 1, right);
}

void quick_sort_multi_thread(int* nums, int left, int right)
{
    if (left >= right) return;
    int i = left, j = right, base = nums[left];

    while (i < j)
    {
        while (i < j && nums[j] >= base) --j; //找一个比base小的
        while (i < j && nums[i] <= base) ++i; //找一个比base大的
        if (i < j) swap(nums[i], nums[j]);
    }

    swap(nums[left], nums[i]);
  
    thread threads[2];

    if (right - left <= 100000)
    {   //小于10w用递归
        threads[0] = thread(quick_sort_sigle_thread, nums, left, i - 1);
        threads[1] = thread(quick_sort_sigle_thread, nums, i + 1, right);
    }
    else
    {
        //大于10w继续用多线程分割
        threads[0] = thread(quick_sort_multi_thread, nums, left, i - 1);
        threads[1] = thread(quick_sort_multi_thread, nums, i + 1, right);
    }
    
    for (int i = 0; i < 2; ++ i) threads[i].join(); 
}


void print_nums(int* nums, int end, int start = 0)
{
    for (int i = start; i < start +  end; ++i)
    {
        cout << nums[i] << ' ';
    }
    cout << endl;
}

int main()
{
    int *nums = new int[N];

    rand_nums(nums, N);
    // print_nums(nums, N);
    
    cout << "数组大小:" << N << endl;
    int start_time, end_time;
    start_time = clock();
    quick_sort_sigle_thread(nums, 0, N - 1);
    end_time = clock();

    cout << "单线程用时:" << (end_time - start_time) / 1000.0 << 's' << endl;

    // print_nums(nums, N);

    rand_nums(nums, N);

    start_time = clock();
    quick_sort_multi_thread(nums, 0, N - 1);
    end_time = clock();

    cout << "多线程用时:" << (end_time - start_time) / 1000.0 << 's' << endl;

    // print_nums(nums, N);

    delete[] nums;
    return 0;
}