我正在尝试优化一种计算具有特定约束的排列的算法。给定整数 n、t、a、b,其中:

  • n 是排列的长度(1 到 n)
  • t 是所需的固定点数(原始位置上的数字)
  • a 是所需元素的数量小于其位置
  • b 是所需的元素数量大于其位置

例如,n=3、t=1、a=1、b=1:

该排列[1,3,2]是有效的,因为:

  • 1 处于其原始位置(计入 t=1)
  • 2 小于位置 3(计入 a=1)
  • 3 大于位置 2(计入 b=1)

当前问题:

  1. 内存使用量为 O(2^n * n),当 n > 20 时就会出现问题
  2. 运行时间约为 O(n * 2^n)
  3. DP 转换中存在许多冗余计算

问题:

  1. 对于这种类型的组合计数是否有任何已知的优化方法?
  2. 如果不使用位掩码 DP 可以解决这个问题吗?
  3. 我可以利用哪些数学特性来降低复杂性?

我考虑过:

  • 使用容斥原理
  • 尝试寻找一个数学公式
  • 使用不同的DP状态表示

限制如下:

  • 1≤n≤100
  • 0 ≤ t,a,b ≤ n
  • t + a + b 必须等于 n

我只是无法编写出足够快的代码。最重要的是,这是代码:

    #include <iostream>
    #include <vector>
    #include <algorithm>

    using namespace std;
    using ll = long long;

    const int MAXN = 21;

    ll combination[MAXN][MAXN];

    void init_combination() {
       for (int n = 0; n < MAXN; ++n) {
          combination[n][0] = combination[n][n] = 1;
          for (int k = 1; k < n; ++k)
               combination[n][k] = combination[n - 1][k - 1] + 
 combination[n - 1][k];
      }
    }

 int main() {
int n, t, a, b;
std::cin >> n >> t >> a >> b;

if (t + a + b != n || t < 0 || a < 0 || b < 0 || t > n || a > n || b > n) {
    std::cout << 0 << std::endl;
    return 0;
}

init_combination();

int n_prime = n - t;

ll fixed_points_ways = combination[n][t];

int total_states = 1 << n_prime;
std::vector<std::vector<ll>> dp(total_states, std::vector<ll>(n_prime + 1, 0));
dp[0][0] = 1;


for (int mask = 0; mask < total_states; ++mask) {
    int pos = __builtin_popcount(mask);
    if (pos == n_prime) continue; 

    for (int num = 0; num < n_prime; ++num) {
        if (mask & (1 << num)) continue; 
        if (num == pos) continue; 

        int next_excedances = 0;
        if (num > pos) {
            next_excedances = 1;
        }
        int total_excedances = 0;
        for (int k = 0; k <= n_prime; ++k) {
            int new_excedances = k + next_excedances;
            dp[mask | (1 << num)][new_excedances] += dp[mask][k];
        }
    }
}

ll total_permutations = dp[total_states - 1][b];
ll result = fixed_points_ways * total_permutations;

std::cout << result << std::endl;
return 0;

}

4

  • 您是否在寻求帮助来优化这里没人见过的代码?


    – 

  • @ScottHunter,现在添加了代码,抱歉,有点混乱,我在缩进方面遇到了麻烦。


    – 

  • 您能解释一下这段代码是如何工作的吗?


    – 

  • 您能分享一些示例数据吗?


    – 


最佳答案
1

我将逐步解释我的解决方案,以便您能够跟上。

步骤1:循环排列

我们将每个排列定义为一个映射:f: list(1 to N) => list(P),其中P是排列本身。

举个例子,对于排列4, 1, 6, 2, 5, 3,它是一个映射:

[1, 2, 3, 4, 5, 6] => [4, 1, 6, 2, 5, 3]

我们从中画一条边i => f(i),我们将得到一个图。

对于我们的例子,我们有1 => 4, 2 => 1, 3 => 6, 4 => 1, 5 => 5, 6 => 3,然后我们有:

因此我们会发现有 3 个独立的图表。

再比如,对于排列4, 1, 2, 5, 6, 3,我们有:

我们可以发现只有 1 个图,并且所有节点都与该排列相连。

当图中所有节点都连通时,我们将该排列定义为循环排列。

我们引入循环符号,即:

(a b c .. x y z): a=>b, b=>c, .. x=>y, y=>z, z=>a

例如,排列的循环符号4, 1, 2, 5, 6, 3(1, 4, 5, 6, 3, 2)

我们可以发现:

(a b c .. x y z) = (b c .. x y z a) = (c .. x y z a b) = (z a b c .. x y)

另外,排列的循环符号4, 1, 6, 2, 5, 3(1 4 2)(3 6)(5)

因此,循环排列的循环符号中只有 1 对括号。

步骤2:欧拉数

我们先处理一个简单的任务:

对于长度为 N 的循环排列,有多少个排列恰好具有A小于其位置的数字?

回到我们的循环符号,为了消除重复计数,我们将节点 1 固定在符号的开头。然后其他数字将填充到我们的符号中:

(1 ? ? ... ?) # N-1 '?' in total 

对于任意两个相邻的?,如果第一个小于第二个,则映射后第一个数小于其位置。

我们很容易看出,1映射后一定小于其位置。所以问题变成了:

在所有可能性中,有多少种组合的确切A-1数字大于之前的数字?

该问题的答案是

在组合学中,欧拉数是数字 1 到 的排列数,其中元素恰好大于前一个元素(具有“上升”的排列)。

欧拉数的前几行:

我们以 为例E(3, 1),有 4 种组合匹配:

1 3 2, 2 1 3, 2 3 1, 3 1 2

我们将更多细节留给维基百科并继续。

Step3:原始问题与较小的问题

现在我们回到最初的问题,考虑我们排列中最大的数字,这个数字只有 2 种情况:

Case1: in its original position.
Case2: in a length k(2 to n) cycle permutation

我们以 6 长度排列为例:

Case1: ? ? ? ? ? 6
Case2: ? ? ? 6 ? 4(a length 2 cycle), ? ? 6 3 ? 4 (a length 3 cycle) ...

请注意,在 中Case2,我们遇到的问题较小,因为只剩下n-k ?残骸。

我们定义:

sol(n, t, a) is the answer of length n permutation with t fixed point and "a" numbers less than their position

在我们的案例中,有:

sol(n, t, a) is the sum of 
 sol(n-1, t-1, a) # case 1
 for 2<=i<=n and 0<=j<=i-1, C(n-1, i-1)* eulerian_number(i-1, j)*sol(n-i, t, a-j-1) # case 2

解释C(n-1, i-1) * eulerian_number(i-1, j) * sol(n-i, t, a-j-1)

对于长度为i的循环,由于数字n是固定的,我们可以i-1从数字中选择位置n-1,所以我们有C(n-1, i-1),这C就是组合数字。

并且我们将枚举小于位置的数量j,因此我们有eulerian_number(i-1, j)j小于其位置的数字,并且(i-1)!由于它是i长度循环排列,因此存在组合)

另外,请注意,将有一个数字映射到最大数字的位置,因此我们有sol(n-i, t, a-j-1)

总结以上所有内容,并通过记忆,我们得出了O(N^3)空间和O(N^5)时间的解决方案。

您可以查看我的代码以了解详细信息。

附录:代码

import itertools
import math

ed = {}
dp = {}


def eulerian_number(n, k):
    if ed.get((n, k)) is not None:
        return ed[(n, k)]
    if n == 0:
        if k == 0:
            ans = 1
        else:
            ans = 0
    else:
        ans = (n - k) * eulerian_number(n - 1, k - 1) + (k + 1) * eulerian_number(n - 1, k)
    ed[(n, k)] = ans
    return ans


def sol(n, t, a):
    p = (n, t, a)
    if dp.get(p) is not None:
        return dp[p]
    if n == 0:
        if a == 0 and t == 0:
            return 1
        else:
            return 0
    ans = sol(n - 1, t - 1, a)
    for i in range(2, n + 1):
        for j in range(0, i - 1):
            ans += sol(n - i, t, a - j - 1) * eulerian_number(i - 1, j) * math.comb(n - 1, i - 1)
    dp[p] = ans
    return and

我写了一些测试:

def test():
    def _test(permutation_length):
        d = {}
        for numbers in itertools.permutations(list(range(permutation_length))):
            t = 0
            a = 0
            b = 0
            for index, n in enumerate(numbers):
                if n == index:
                    t += 1
                if n < index:
                    a += 1
                if n > index:
                    b += 1
            if d.get((t, a, b)) is None:
                d[(t, a, b)] = 0
            d[(t, a, b)] += 1
        test_pass = True
        for k in d:
            t, a, b = k
            ans = sol(permutation_length, t, a)
            if ans != d[k]:
                test_pass = False
                print('Error on {}: {}, answer: {}, expected: {}'.format(permutation_length, k, ans, d[k]))
        if test_pass:
            print('Passed on Permutation Length: {}'.format(permutation_length))
    for i in range(1, 8):
        _test(i)


test()

输出:

Passed on Permutation Length: 1
Passed on Permutation Length: 2
Passed on Permutation Length: 3
Passed on Permutation Length: 4
Passed on Permutation Length: 5
Passed on Permutation Length: 6
Passed on Permutation Length: 7

更好的解决方案

我们可以先选择t固定点,这样Case1我们的公式中就没有了。

因此我们得出这个O(N^4)(可能的O(N^3))时间解和O(N^2)空间解。

def sol2(n, t, a):

    def _sol2(n, a):
        p = (n, a)
        if dp2.get(p) is not None:
            return dp2[p]
        if n == 0:
            if a == 0:
                return 1
            else:
                return 0
        ans = 0
        for i in range(2, n + 1):
            for j in range(0, i - 1):
                ans += _sol2(n - i, a - j - 1) * eulerian_number(i - 1, j) * math.comb(n - 1, i - 1)
        dp2[p] = ans
        return ans

    return math.comb(n, t) * _sol2(n - t, a)