我正在尝试优化一种计算具有特定约束的排列的算法。给定整数 n、t、a、b,其中:
- n 是排列的长度(1 到 n)
- t 是所需的固定点数(原始位置上的数字)
- a 是所需元素的数量小于其位置
- b 是所需的元素数量大于其位置
例如,n=3、t=1、a=1、b=1:
该排列[1,3,2]
是有效的,因为:
- 1 处于其原始位置(计入 t=1)
- 2 小于位置 3(计入 a=1)
- 3 大于位置 2(计入 b=1)
当前问题:
- 内存使用量为 O(2^n * n),当 n > 20 时就会出现问题
- 运行时间约为 O(n * 2^n)
- DP 转换中存在许多冗余计算
问题:
- 对于这种类型的组合计数是否有任何已知的优化方法?
- 如果不使用位掩码 DP 可以解决这个问题吗?
- 我可以利用哪些数学特性来降低复杂性?
我考虑过:
- 使用容斥原理
- 尝试寻找一个数学公式
- 使用不同的DP状态表示
限制如下:
- 1≤n≤100
- 0 ≤ t,a,b ≤ n
- t + a + b 必须等于 n
我只是无法编写出足够快的代码。最重要的是,这是代码:
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
using ll = long long;
const int MAXN = 21;
ll combination[MAXN][MAXN];
void init_combination() {
for (int n = 0; n < MAXN; ++n) {
combination[n][0] = combination[n][n] = 1;
for (int k = 1; k < n; ++k)
combination[n][k] = combination[n - 1][k - 1] +
combination[n - 1][k];
}
}
int main() {
int n, t, a, b;
std::cin >> n >> t >> a >> b;
if (t + a + b != n || t < 0 || a < 0 || b < 0 || t > n || a > n || b > n) {
std::cout << 0 << std::endl;
return 0;
}
init_combination();
int n_prime = n - t;
ll fixed_points_ways = combination[n][t];
int total_states = 1 << n_prime;
std::vector<std::vector<ll>> dp(total_states, std::vector<ll>(n_prime + 1, 0));
dp[0][0] = 1;
for (int mask = 0; mask < total_states; ++mask) {
int pos = __builtin_popcount(mask);
if (pos == n_prime) continue;
for (int num = 0; num < n_prime; ++num) {
if (mask & (1 << num)) continue;
if (num == pos) continue;
int next_excedances = 0;
if (num > pos) {
next_excedances = 1;
}
int total_excedances = 0;
for (int k = 0; k <= n_prime; ++k) {
int new_excedances = k + next_excedances;
dp[mask | (1 << num)][new_excedances] += dp[mask][k];
}
}
}
ll total_permutations = dp[total_states - 1][b];
ll result = fixed_points_ways * total_permutations;
std::cout << result << std::endl;
return 0;
}
4
最佳答案
1
我将逐步解释我的解决方案,以便您能够跟上。
步骤1:循环排列
我们将每个排列定义为一个映射:f: list(1 to N) => list(P)
,其中P
是排列本身。
举个例子,对于排列4, 1, 6, 2, 5, 3
,它是一个映射:
[1, 2, 3, 4, 5, 6] => [4, 1, 6, 2, 5, 3]
我们从中画一条边i => f(i)
,我们将得到一个图。
对于我们的例子,我们有1 => 4, 2 => 1, 3 => 6, 4 => 1, 5 => 5, 6 => 3
,然后我们有:
因此我们会发现有 3 个独立的图表。
再比如,对于排列4, 1, 2, 5, 6, 3
,我们有:
我们可以发现只有 1 个图,并且所有节点都与该排列相连。
当图中所有节点都连通时,我们将该排列定义为循环排列。
我们引入循环符号,即:
(a b c .. x y z): a=>b, b=>c, .. x=>y, y=>z, z=>a
例如,排列的循环符号4, 1, 2, 5, 6, 3
是(1, 4, 5, 6, 3, 2)
。
我们可以发现:
(a b c .. x y z) = (b c .. x y z a) = (c .. x y z a b) = (z a b c .. x y)
另外,排列的循环符号4, 1, 6, 2, 5, 3
是(1 4 2)(3 6)(5)
。
因此,循环排列的循环符号中只有 1 对括号。
步骤2:欧拉数
我们先处理一个简单的任务:
对于长度为 N 的循环排列,有多少个排列恰好具有A
小于其位置的数字?
回到我们的循环符号,为了消除重复计数,我们将节点 1 固定在符号的开头。然后其他数字将填充到我们的符号中:
(1 ? ? ... ?) # N-1 '?' in total
对于任意两个相邻的?
,如果第一个小于第二个,则映射后第一个数小于其位置。
我们很容易看出,1
映射后一定小于其位置。所以问题变成了:
在所有可能性中,有多少种组合的确切A-1
数字大于之前的数字?
该问题的答案是:
在组合学中,欧拉数是数字 1 到 的排列数,其中元素恰好大于前一个元素(具有“上升”的排列)。
欧拉数的前几行:
我们以 为例E(3, 1)
,有 4 种组合匹配:
1 3 2, 2 1 3, 2 3 1, 3 1 2
我们将更多细节留给维基百科并继续。
Step3:原始问题与较小的问题
现在我们回到最初的问题,考虑我们排列中最大的数字,这个数字只有 2 种情况:
Case1: in its original position.
Case2: in a length k(2 to n) cycle permutation
我们以 6 长度排列为例:
Case1: ? ? ? ? ? 6
Case2: ? ? ? 6 ? 4(a length 2 cycle), ? ? 6 3 ? 4 (a length 3 cycle) ...
请注意,在 中Case2
,我们遇到的问题较小,因为只剩下n-k
?
残骸。
我们定义:
sol(n, t, a) is the answer of length n permutation with t fixed point and "a" numbers less than their position
在我们的案例中,有:
sol(n, t, a) is the sum of
sol(n-1, t-1, a) # case 1
for 2<=i<=n and 0<=j<=i-1, C(n-1, i-1)* eulerian_number(i-1, j)*sol(n-i, t, a-j-1) # case 2
解释C(n-1, i-1) * eulerian_number(i-1, j) * sol(n-i, t, a-j-1)
:
对于长度为i的循环,由于数字n是固定的,我们可以i-1
从数字中选择位置n-1
,所以我们有C(n-1, i-1)
,这C
就是组合数字。
并且我们将枚举小于位置的数量j
,因此我们有eulerian_number(i-1, j)
(j
小于其位置的数字,并且(i-1)!
由于它是i
长度循环排列,因此存在组合)
另外,请注意,将有一个数字映射到最大数字的位置,因此我们有sol(n-i, t, a-j-1)
。
总结以上所有内容,并通过记忆,我们得出了O(N^3)
空间和O(N^5)
时间的解决方案。
您可以查看我的代码以了解详细信息。
附录:代码
import itertools
import math
ed = {}
dp = {}
def eulerian_number(n, k):
if ed.get((n, k)) is not None:
return ed[(n, k)]
if n == 0:
if k == 0:
ans = 1
else:
ans = 0
else:
ans = (n - k) * eulerian_number(n - 1, k - 1) + (k + 1) * eulerian_number(n - 1, k)
ed[(n, k)] = ans
return ans
def sol(n, t, a):
p = (n, t, a)
if dp.get(p) is not None:
return dp[p]
if n == 0:
if a == 0 and t == 0:
return 1
else:
return 0
ans = sol(n - 1, t - 1, a)
for i in range(2, n + 1):
for j in range(0, i - 1):
ans += sol(n - i, t, a - j - 1) * eulerian_number(i - 1, j) * math.comb(n - 1, i - 1)
dp[p] = ans
return and
我写了一些测试:
def test():
def _test(permutation_length):
d = {}
for numbers in itertools.permutations(list(range(permutation_length))):
t = 0
a = 0
b = 0
for index, n in enumerate(numbers):
if n == index:
t += 1
if n < index:
a += 1
if n > index:
b += 1
if d.get((t, a, b)) is None:
d[(t, a, b)] = 0
d[(t, a, b)] += 1
test_pass = True
for k in d:
t, a, b = k
ans = sol(permutation_length, t, a)
if ans != d[k]:
test_pass = False
print('Error on {}: {}, answer: {}, expected: {}'.format(permutation_length, k, ans, d[k]))
if test_pass:
print('Passed on Permutation Length: {}'.format(permutation_length))
for i in range(1, 8):
_test(i)
test()
输出:
Passed on Permutation Length: 1
Passed on Permutation Length: 2
Passed on Permutation Length: 3
Passed on Permutation Length: 4
Passed on Permutation Length: 5
Passed on Permutation Length: 6
Passed on Permutation Length: 7
更好的解决方案
我们可以先选择t
固定点,这样Case1
我们的公式中就没有了。
因此我们得出这个O(N^4)
(可能的O(N^3)
)时间解和O(N^2)
空间解。
def sol2(n, t, a):
def _sol2(n, a):
p = (n, a)
if dp2.get(p) is not None:
return dp2[p]
if n == 0:
if a == 0:
return 1
else:
return 0
ans = 0
for i in range(2, n + 1):
for j in range(0, i - 1):
ans += _sol2(n - i, a - j - 1) * eulerian_number(i - 1, j) * math.comb(n - 1, i - 1)
dp2[p] = ans
return ans
return math.comb(n, t) * _sol2(n - t, a)
|
–
–
–
–
|