二分查找
整数二分
通用模板
check()
判断 mid
是否满足某种性质。
模板一
区间
C++
bool check(int x) {/* ... */} // 检查x是否满足某种性质
int bsearch(int l, int r) {
while (l < r) {
int mid = l + (r - l >> 1);
if (check(mid)) r = mid;
else l = mid + 1;
}
return l;
}
bool check(int x) {/* ... */} // 检查x是否满足某种性质
int bsearch(int l, int r) {
while (l < r) {
int mid = l + (r - l >> 1);
if (check(mid)) r = mid;
else l = mid + 1;
}
return l;
}
模板二
C++
bool check(int x) {/* ... */} // 检查x是否满足某种性质
int bsearch(int l, int r) {
while (l < r) {
int mid = l + (r - l + 1 >> 1);
if (check(mid)) l = mid;
else r = mid - 1;
}
return l;
}
bool check(int x) {/* ... */} // 检查x是否满足某种性质
int bsearch(int l, int r) {
while (l < r) {
int mid = l + (r - l + 1 >> 1);
if (check(mid)) l = mid;
else r = mid - 1;
}
return l;
}
应用
根据上面给出的模板一、二,可以按照如下方法如下搜寻边界:
C++
// 返回数组中最后一个小于等于x的元素的下标 不存在则返回l
int last_lt(int a[], int l, int r, int x) {
while (l < r) {
int m = l + (r - l + 1 >> 1);
if (a[m] < x) l = m;
else r = m - 1;
}
return r;
}
// 返回数组中最后一个小于等于x的元素的下标 不存在则返回l
int last_lt(int a[], int l, int r, int x) {
while (l < r) {
int m = l + (r - l + 1 >> 1);
if (a[m] < x) l = m;
else r = m - 1;
}
return r;
}
C++
// 返回数组中最后一个小于等于x的元素的下标 不存在则返回l
int last_le(int a[], int l, int r, int x) {
while (l < r) {
int m = l + (r - l + 1 >> 1);
if (a[m] <= x) l = m;
else r = m - 1;
}
return r;
}
// 返回数组中最后一个小于等于x的元素的下标 不存在则返回l
int last_le(int a[], int l, int r, int x) {
while (l < r) {
int m = l + (r - l + 1 >> 1);
if (a[m] <= x) l = m;
else r = m - 1;
}
return r;
}
C++
// 返回数组中第一个大于x的元素的下标 不存在则返回r
int first_gt(int a[], int l, int r, int x) {
while (l < r) {
int m = l + (r - l >> 1);
if (a[m] > x) r = m;
else l = m + 1;
}
return r;
}
// 返回数组中第一个大于x的元素的下标 不存在则返回r
int first_gt(int a[], int l, int r, int x) {
while (l < r) {
int m = l + (r - l >> 1);
if (a[m] > x) r = m;
else l = m + 1;
}
return r;
}
C++
// 返回数组中第一个大于等于x的元素的下标 不存在则返回r
int first_ge(int a[], int l, int r, int x) {
while (l < r) {
int m = l + (r - l >> 1);
if (a[m] >= x) r = m;
else l = m + 1;
}
return r;
}
// 返回数组中第一个大于等于x的元素的下标 不存在则返回r
int first_ge(int a[], int l, int r, int x) {
while (l < r) {
int m = l + (r - l >> 1);
if (a[m] >= x) r = m;
else l = m + 1;
}
return r;
}
相等
Java
int binarySearch(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + ((right - left) >> 1);
if (nums[mid] == target) return mid;
else if (nums[mid] < target) left = mid + 1;
else right = mid - 1;
}
return -1;
}
int binarySearch(int[] nums, int target) {
int left = 0, right = nums.length - 1;
while (left <= right) {
int mid = left + ((right - left) >> 1);
if (nums[mid] == target) return mid;
else if (nums[mid] < target) left = mid + 1;
else right = mid - 1;
}
return -1;
}
C++
int binary_search(vector<int>& nums, int target) {
int left = 0, right = nums.size() - 1;
while (left <= right) {
int mid = left + ((right - left) >> 1);
if (nums[mid] == target) return mid;
else if (nums[mid] < target) left = mid + 1;
else right = mid - 1;
}
return -1;
}
int binary_search(vector<int>& nums, int target) {
int left = 0, right = nums.size() - 1;
while (left <= right) {
int mid = left + ((right - left) >> 1);
if (nums[mid] == target) return mid;
else if (nums[mid] < target) left = mid + 1;
else right = mid - 1;
}
return -1;
}
浮点数二分
eps
表示精度,取决于题目对精度的要求(一般来说取题目精度要求小两个数量级,比如题目要求 eps
按经验会取
C++
const double eps = 1e-6;
double bsearch(double l, double r) {
while (r - l > eps) {
double mid = (l + r) / 2;
if (check(mid)) r = mid;
else l = mid;
}
return l;
}
const double eps = 1e-6;
double bsearch(double l, double r) {
while (r - l > eps) {
double mid = (l + r) / 2;
if (check(mid)) r = mid;
else l = mid;
}
return l;
}
n
的 m
次方根
模板题:数的三次方根。
Java
double root(double n, int m, double eps) {
double l = 0.0, r = Math.max(n, 1.0);
while (r - l > eps) {
double c = (l + r) / 2;
double prod = 1.0;
for (int i = 0; i < m; ++i) prod *= c;
if (prod >= n) r = c;
else l = c;
}
return l;
}
double root(double n, int m, double eps) {
double l = 0.0, r = Math.max(n, 1.0);
while (r - l > eps) {
double c = (l + r) / 2;
double prod = 1.0;
for (int i = 0; i < m; ++i) prod *= c;
if (prod >= n) r = c;
else l = c;
}
return l;
}