LeetCode 698. 划分为k个相等的子集
给定一个整数数组 nums
和一个正整数 k
,找出是否有可能把这个数组分成 k
个非空子集,其总和都相等。
示例 1:
输入: nums = [4, 3, 2, 3, 5, 2, 1], k = 4
输出: True
说明: 有可能将其分成 4 个子集(5),(1,4),(2,3),(2,3)等于总和。
示例 2:
输入: nums = [1,2,3,4], k = 3
输出: false
method 1: 回溯
- 从
bucket
的角度
等价于把这些数都放进k
个bucket
里面,使得每个bucket
都等于target
bool dfs(vector<int>& nums, vector<int>& bucket, int target, int k, int index, vector<bool>& used) {
if (k == -1) return true;
if (bucket[k] == target) {
return dfs(nums, bucket, target, k - 1, 0, used); // 处理下一个bucket,nums要从0开始选
}
for (int i = index; i < nums.size(); i++) {
if (bucket[k] + nums[i] > target) break; // 已经从小到大排序了
if (used[i]) continue;
bucket[k] += nums[i];
used[i] = true;
if (dfs(nums, bucket, target, k, i + 1, used)) return true;
bucket[k] -= nums[i];
used[i] = false;
}
return false;
}
bool canPartitionKSubsets(vector<int>& nums, int k) {
int n = nums.size();
sort(nums.begin(), nums.end());
int sum = 0;
for (auto n : nums) sum += n;
if (sum % k != 0) return false; // 不能整除肯定不行
int target = sum / k;
if (nums.back() > target) return false;
vector<int> bucket(k);
vector<bool> used(n, false);
return dfs(nums, bucket, target, k - 1, 0, used);
}
时间复杂度:$\mathcal{O}(k\cdot2^n)$,对于$k$个bucket
,每个都有$2^n$种选择
空间复杂度:$\mathcal{O}(k)$,递归最大深度
- 从
nums[i]
的角度
每个数都可以选择一个bucket
装进去
剪枝:
- 从大到小排序,让较大的数先选
- 第一个数装哪里都一样
- 如果两个
bucket
一样,装哪个结果都是一样的,可以跳过
bool dfs(vector<int>& nums, vector<int>& bucket, int target, int k, int index) {
if (index == nums.size()) return true;
for (int i = 0; i < k; i++) {
if (i > 0 && index == 0) break;
if (i > 0 && bucket[i] == bucket[i - 1]) continue;
if (bucket[i] + nums[index] > target) continue;
bucket[i] += nums[index];
if (dfs(nums, bucket, target, k, index + 1)) return true;
bucket[i] -= nums[index];
}
return false;
}
bool canPartitionKSubsets(vector<int>& nums, int k) {
sort(nums.rbegin(), nums.rend()); // 这样快一点
int sum = 0;
for (auto n : nums) sum += n;
if (sum % k != 0) return false;
int target = sum / k;
if (nums[0] > target) return false;
vector<int> bucket(k);
return dfs(nums, bucket, target, k, 0);
}
时间复杂度:$\mathcal{O}(k^N)$,每个数都有$k$种选择,$N$个数就是$k^N$
空间复杂度:$\mathcal{O}(N)$,递归最大深度
method 2: 状态压缩dp
由于数组最长只有16,所以可以用一个int
型的state
表示每个数的状态,0表示没用过,1表示用过了
dp[state]
:true
表示该state
是有效的,false
表示该state
无效
$n$个数就有$2^n$种状态
每种state要变成nextState,需要选择一个数nums[j]
来累加,并把选的那个数的位置置1,也就是state | (1 << j)
,如果当前状态的累加和curSum[state]
加上nums[j]
之后没超过target
,那nextState
的累加和就是curSum[state]+nums[j]
,并且dp[nextState]=true
bool canPartitionKSubsets(vector<int>& nums, int k) {
int n = nums.size();
sort(nums.begin(), nums.end());
int sum = 0;
for (auto n : nums) sum += n;
if (sum % k != 0) return false;
int target = sum / k;
int size = (1 << n);
vector<bool> dp(size, false);
dp[0] = true;
vector<int> curSum(size, 0);
for (int state = 0; state < size; state++) {
if (!dp[state]) continue; // 如果当前状态不行,就不用看了
for (int j = 0; j < n; j++) {
if ((state & (1 << j)) != 0) continue; // 已经用过就不能再用了
int nextState = state | (1 << j); // 第j位置1
if (dp[nextState]) continue; // 下一个状态已经是true了就不用管了
if (curSum[state] % target + nums[j] <= target) {
curSum[nextState] = curSum[state] + nums[j];
dp[nextState] = true;
} else break; // 从小到大排序了,后面只会更大,肯定也不行
}
}
return dp[size - 1];
}
时间复杂度:$\mathcal{O}(N \cdot 2^N)$。其中 $N$ 是输入数组nums
的长度。有$2^N$个状态,每个状态对nums
执行$\mathcal{O}(N)$次尝试
空间复杂度:$\mathcal{O}(2^N)$,状态数组的长度、数组curSum
的长度
LeetCode 473. 火柴拼正方形
你将得到一个整数数组 matchsticks
,其中 matchsticks[i]
是第 i
个火柴棒的长度。你要用 所有的火柴棍 拼成一个正方形。你 不能折断 任何一根火柴棒,但你可以把它们连在一起,而且每根火柴棒必须 使用一次 。
如果你能使这个正方形,则返回 true
,否则返回 false
。
示例 1:
输入: matchsticks = [1,1,2,2,2]
输出: true
解释: 能拼成一个边长为2的正方形,每边两根火柴。
method
相当于上一题k=4
的情况
bool dfs(vector<int>& nums, vector<int>& bucket, int target, int k, int index, vector<bool>& used) {
if (k == 4) return true;
if (bucket[k] == target) {
return dfs(nums, bucket, target, k + 1, 0, used);
}
for (int i = index; i < nums.size(); i++) {
if (bucket[k] + nums[i] > target) break;
if (used[i]) continue;
bucket[k] += nums[i];
used[i] = true;
if (dfs(nums, bucket, target, k, i + 1, used)) return true;
bucket[k] -= nums[i];
used[i] = false;
}
return false;
}
bool makesquare(vector<int>& nums) {
sort(nums.begin(), nums.end());
int n = nums.size();
int sum = 0;
for (auto n : nums) sum += n;
if (sum % 4 != 0) return false;
int target = sum / 4;
if (nums.back() > target) return false;
vector<int> bucket(4, 0);
vector<bool> used(n, false);
return dfs(nums, bucket, target, 0, 0, used);
}