我认为我的算法是正确的,但是当值增加到 10 6或更多时,我超出了允许的 MEMORY 或 TIMELIMIT。起初我尝试将元素推送到向量,然后我更改了方法以重用变量,并通过了更多测试。

公式:A i = (A i-1 + 2 * A i-2 + 3 * A i-3 ) mod M,其中 M = 10 9 + 7。1
<= n <= 10 12 时间限制:1 秒,内存:256mb

代码:

#include<iostream>
#include<cmath>

using namespace std;
using ull = unsigned long long;

ull func(ull n){
    ull a = 1;
    ull b = 1;
    ull c = 2;
    if (n < 2) return a;
    if (n == 3) return c;
    ull res = 0;
    for (ull i = 0; i < n - 3; i++){
        res = (3 * a + 2 * b + c) % (ull)(pow(10, 9) + 7);
        a = b;
        b = c;
        c = res;
    }
    return c;
}

int main() {
    int x; 
    cin >> x;
    cout << func(x);
}

现在我有一个通过了 3 个初始测试的算法(然后失败了 63 个测试,我认为值 > 10^6)

测试 1 输入:6
输出:
34

测试 2 输入:10
输出:
1096

测试 3 输入:500
输出:
340736120

我是否需要改变算法或者通过任何方法来加速?

18

  • 4
    一般而言,只要挑战包含mod 100000007,就会有一个公式,不需要进行计算。您必须先进行数学运算。


    – 


  • 2
    10^9 + 7在循环之前计算一次(您可以简单地将其存储100000007在 const 变量中)。一般不要将其用于pow整数幂运算。


    – 


  • 5
    有一个 O(log n) 算法可以解决此问题。我不知道是否应该直接告诉您它是什么(有点破坏了挑战),但知道它存在可能会帮助您找到它。


    – 

  • 1
    @lilof但是当值增加到 10^6 或更多时,我就有 MEMORY 或 TIMELIMIT 了——而且您已经遇到了这些“竞争性编码”网站使用的技巧之一。这个技巧是,这些网站上提出的问题几乎总是有一个易于编码的简单解决方案,如果数据集很大,它永远不会起作用。除非您尝试回答的问题被标记为“初学者”,否则简单的单行循环不会起作用,也不会做任何特殊的事情。目标是找出算法、数据结构或数学,以获得适用于大量数据集的解决方案。


    – 


  • 1


    – 


最佳答案
1

您当前的解决方案是,当其大小达到 10 12O(n)时,它的速度太慢了n

我们可以找到一个矩阵,M使得我们可以通过乘法从一个状态转换到下一个状态。M满足

[Ai , Ai -1 , Ai -2 ] T =M*[Ai -1 , Ai -2 , Ai -3 ] T

显然,的最后一行M只是[0, 1, 0]为了得到 A i-2

同理,第二行也是[1, 0, 0]

第一行是[1, 2, 3],它直接来自递归关系。

现在,对于n > 3,我们可以n通过将初始条件 [A 3 , A 2 , A 1 ] =[2, 1, 1]乘以M总次数(左) n-3,然后从第一行读出答案,来找到序列的第 个元素。这相当于乘以 M n-3。矩阵指数运算可以在 O(S 3 log(N)) 中执行,其中 S 是矩阵的维度(在本例中为常数3),N 是二进制指数运算的指数。

这导致了以下解决方案:

#include <iostream>
#include <vector>
#include <span>
#include <initializer_list>
#include <stdexcept>
#include <cstddef>
constexpr int MOD = 1e9 + 7;
template<typename T>
class Matrix {
    std::size_t rows, cols;
    std::vector<std::vector<T>> values;

public:
    Matrix(std::size_t rows, std::size_t cols) : rows{rows}, cols{cols}, values(rows, std::vector<T>(cols)) {}
    Matrix(std::initializer_list<std::initializer_list<T>> initVals) : rows{initVals.size()} {
        values.reserve(rows);
        for (auto& row : initVals) {
            values.emplace_back(row);
            if ((cols = row.size()) != values[0].size()) throw std::domain_error("Not a matrix: rows have unequal size");
        }
    }
    std::span<T> operator[](std::size_t r) {
        return values[r];
    }
    std::span<const T> operator[](std::size_t r) const {
        return values[r];
    }
    static Matrix identity(std::size_t size) {
        Matrix id(size, size);
        for (std::size_t i = 0; i < size; ++i) id.values[i][i] = 1;
        return id;
    }
    Matrix operator*(const Matrix& m) const {
        if (cols != m.rows) throw std::domain_error("Matrix dimensions do not match");
        Matrix res(rows, m.cols);
        for (std::size_t r = 0; r < rows; ++r)
            for (std::size_t c = 0; c < m.cols; ++c)
                for (std::size_t i = 0; i < cols; ++i)
                    res.values[r][c] += values[r][i] * m.values[i][c];
        return res;
    }
    Matrix operator%(T mod) const {
        auto res = *this;
        for (std::size_t r = 0; r < rows; ++r)
            for (std::size_t c = 0; c < cols; ++c)
                res.values[r][c] %= mod;
        return res;
    }
    Matrix modPow(std::size_t exp, T mod) const {
        if (rows != cols) throw std::domain_error("Matrix is not square");
        auto res = identity(rows), sq = *this;
        for (; exp; exp >>= 1) {
            if (exp & 1) res = res * sq % mod;
            sq = sq * sq % mod;
        }
        return res;
    }
};
const Matrix<unsigned long long> transition{{1, 2, 3}, {1, 0, 0}, {0, 1, 0}}, 
                                 initialConditions{{2}, {1}, {1}};
unsigned long long nthValue(unsigned long long n){
    if (n < 3) return 1;
    return (transition.modPow(n - 3, MOD) * initialConditions % MOD)[0][0];
}

int main() {
    unsigned long long n; 
    std::cin >> n;
    std::cout << nthValue(n) << '\n';
}