稀疏矩阵是绝大多数元素为 0 的矩阵,只有少数非零元素。现在我必须填充我的矩阵Sparsematrix class,使得矩阵可以做到addsubtractmultiply。我使用 COO 来存储我的矩阵。

template <class T>
class VecList{
    private:
        int capacity;
        int length;
        T* arr;
        void doubleListSize(){
            T * oldArr = arr;
            arr = new T[2*capacity];
            capacity = 2 * capacity;
            for(int i=0;i<length;i++){
                arr[i] = oldArr[i];
            }
            delete [] oldArr;
        }
    public:
        VecList(){
            length = 0;
            capacity = 100;
            arr = new T[capacity];
        }
        VecList(T* a, int n){
            length = n;
            capacity = 100 + 2*n;
            arr = new T[capacity];
            for(int i=0;i<n;i++){
                arr[i] = a[i];
            }
            for (int i = 0; i < n; i++)
            {
                cout << arr[i] << " ";
            }
            cout << endl;
            printList();
        }
        ~VecList(){
            delete [] arr;
        }
        int getLength(){
            return length;
        }
        bool isEmpty(){
            return length==0;
        }
        void insertEleAtPos(int i, T x){
            if(length==capacity)
                doubleListSize();
            if(i > length || i < 0)
                throw "Illegal position";
            for(int j=length;j>i;j--)
                arr[j] = arr[j-1];
            arr[i] = x;
            length++;
        }
        T deleteEleAtPos(int i){
            if(i >= length || i < 0)
                throw "Illegal position";
            T tmp = arr[i];
            for(int j=i;j<length-1;j++)
                arr[j] = arr[j+1];
            length--;
            return tmp;
        }
        void setEleAtPos(int i, T x){
            if(i >= length || i < 0)
                throw "Illegal position";
            arr[i] = x;
        }
        T getEleAtPos(int i){
            if(i >= length || i < 0)
                throw "Illegal position";
            return arr[i];
        }
        int locateEle(T x){
            for(int i=0;i<length;i++){
                if(arr[i]==x)
                    return i;
            }
            return -1;
        }
        void printList(){
            for(int i=0;i<length;i++)
                cout << arr[i] << " ";
        }
};

COO 使用三个VecList来存储矩阵。

  1. rowIndex:表示行数。
  2. colIndex:表示列数。
  3. values:表示元素的值。以下是我的Sparsematrix class
template <class T>
class SparseMatrix{
    private:
        int rows;
        int cols;
        VecList<int>* rowIndex;
        VecList<int>* colIndex;
        VecList<T>* values;
    public:
        SparseMatrix(){ //Create a 10x10 Sparse matrix
            rows = 10;
            cols = 10;
            rowIndex = new VecList<int>();
            colIndex = new VecList<int>();
            values = new VecList<T>();
        }
        SparseMatrix(int r, int c){ //Create a rxc Sparse matrix
            rows = r;
            cols = c;
            rowIndex = new VecList<int>();
            colIndex = new VecList<int>();
            values = new VecList<T>();
        }
        ~SparseMatrix(){
            delete rowIndex;
            delete colIndex;
            delete values;
        }
};

如你所见,我只需要关注非零元素。例如:一个稀疏矩阵

0 2 0
0 0 1
3 0 0
rows = Exact number of rows, 3
cols = Exact number of columns, 3
rowIndex = 0 1 2
colIndex = 1 2 0
values   = 2 1 3

第一根竖线表示,row[0]col[1]有一个非零元素“2”。第二根竖线表示,row[1]col[2]有一个非零元素“1”。现在我写了几个函数来实现矩阵之间的运算。

int findPos(int a, int b){ //If there is a non-zero element at (a, b), then return its position in            "rowIndex", else return -1.
            for (int i = 0; i < rowIndex->getLength(); i++)
            {
                 if(rowIndex->getEleAtPos(i) == a && colIndex->getEleAtPos(i) == b)return i;
                 else if(rowIndex->getEleAtPos(rowIndex->getLength() - 1 - i) == a && colIndex->getEleAtPos(colIndex->getLength()-1-i) == b)return rowIndex->getLength()-1 - i;
            }
            return -1;
        }
        void setEntry(int rPos, int cPos, T x){ // Set (rPos, cPos) = x
            int pos = findPos(rPos,cPos);
            //Find if there is a non-zero element at (rPos, cPos).
            if(x != 0){
            //If the origin matrix does not have an element at(rPos, cPos),insert x to the matrix.
            if (pos == -1)
            {
                rowIndex->insertEleAtPos(rowIndex->getLength(),rPos);
                colIndex->insertEleAtPos(colIndex->getLength(),cPos);
                values->insertEleAtPos(values->getLength(),x);
            }
            else{
                //If the origin matrix has an element at(rPos, cPos),replace it with x.
                rowIndex->setEleAtPos(pos,rPos);
                colIndex->setEleAtPos(pos,cPos);
                values->setEleAtPos(pos,x);
            }
           }
           else{
            //If x == 0 and the origin matrix has an element at(rPos, cPos), delete the element.
                if(pos != -1){
                    rowIndex->deleteEleAtPos(pos);
                    colIndex->deleteEleAtPos(pos);
                    values->deleteEleAtPos(pos);
                }
            }
        //If x == 0, and the origin matrix does not have an element at(rPos, cPos), nothing changed.
        }
T getEntry(int rPos, int cPos){
        //Get the element at (rPos, cPos)
            return findPos(rPos,cPos) == -1 ? 0 : values->getEleAtPos(findPos(rPos,cPos));
        }
        SparseMatrix<T> * add(SparseMatrix<T> * B){
            if(rows != B->rows || cols != B->cols)throw "Matrices have incompatible sizes";
            SparseMatrix<T> *C = new SparseMatrix<T>(rows,cols);//Create a new matrix C as result.
            for (int i = 0; i < rowIndex->getLength(); i++)
            {
//I call the two input matrices "A" and "B". I put every elements of A into C, and also put every elements of B into C. But I use "C->setEntry", which means when A[i][j] has an element and B[i][j] also has an element, "setEntry" will cover the prior one. So I use C->setEntry(i,j,C->getEntry(i,j) + A[i][j] or B[i][j]), in another word, setEntry with (oldvalue + newvalue).That's what I did.
                C->setEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i),C->getEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i))+values->getEleAtPos(i));
                C->setEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i),C->getEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i))+B->values->getEleAtPos(i));
            }
            return C;
        }
        SparseMatrix<T> * subtract(SparseMatrix<T> * B){
//The same method as add.
            if(rows != B->rows || cols != B->cols)throw "Matrices have incompatible sizes";
            SparseMatrix<T> *C = new SparseMatrix<T>(rows,cols);
            for (int i = 0; i < rowIndex->getLength(); i++)
            {
                C->setEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i),C->getEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i))-values->getEleAtPos(i));
                C->setEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i),C->getEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i))-B->values->getEleAtPos(i));
            }
            return C;
        }

        SparseMatrix<T> * multiply(SparseMatrix<T> * B){
            //perform multiplication if the sizes of the matrices are compatible.
            if(rows != B->cols || cols != B->rows)throw "Matrices have incompatible sizes"; 
            SparseMatrix<T> *C = new SparseMatrix<T>(rows,B->cols);
//I call the two input matrices as "A" and "B".
//My method is take a row of A first, let this row do the arithmetic with each column of B,then I finish a row in C. Then continue to the next row.
            for (int i = 0; i < rowIndex->getLength();i++)
            {
                for (int j = 0; j < B->colIndex->getLength(); j++)
                {
                    if (B->findPos(colIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j)) != -1)
                    {
                        C->setEntry(rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j),C->getEntry(rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j))+(values->getEleAtPos(i)*B->values->getEleAtPos(j)));
                    }
                } 
            }
            return C;
        }

        void printMatrix(){
            for (int i = 0; i < rows; i++)
            {
                for (int j = 0; j < cols; j++)
                {
                    cout << getEntry(i,j) << " ";
                }
                cout << endl;
            }
        }

我已经测试了几种情况,所有情况都表明addsubtractmultiply运行良好。但是有一个 10000×10000 矩阵(称为“X”和“Y”)测试我无法通过,X 和 Y 没有很多非零元素。而且它们只是加、减和乘。
时间限制是 1 秒。(不包括 printMatrix(),但包括 setEntry())我超过了它。我该如何减少程序的运行时间?(我还想知道 COO 存储是否错误,以及findPos()函数是否备用。)谢谢。我的工具是 VSCode2022,使用 C++11,Windows 11。这是测试代码的示例。

#include <iostream>
#include <algorithm>
#include <chrono>
using namespace std;
int main(){
    auto start = std::chrono::high_resolution_clock::now();
    SparseMatrix<int> X,Y;
    X.setEntry(1,3,4);
    X.setEntry(7,8,2);
    Y.setEntry(1,6,4);
    Y.setEntry(1,3,4);
    Y.setEntry(7,7,2);
    X.printMatrix();
    cout << endl;
    Y.printMatrix();
    cout << endl;
    X.add(&Y)->printMatrix();
    cout << endl;
    X.subtract(&Y)->printMatrix();
    cout << endl;
    Y.multiply(&X)->printMatrix();
    cout << "Done" << endl;
    auto stop = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
    cout << "Running Time:" << duration << "ms\n";
    return 0;
}

3

  • 范围检查很好,但不利于速度。当你将两个 10000×10000 矩阵相乘时,代码将调用getEleAtPos几亿次。即使我们知道它们在范围内(因为循环条件)。


    – 


  • 无关:如果你使用一些临时变量,你的代码会更加清晰,例如acol = colIndex->getEleAtPos(i)


    – 


  • 2
    没有任何理由让任何成员SparseMatrix成为指针,因此您可以通过移除指针来立即加快速度。


    – 


最佳答案
2

首先进行一些快速的粗略计算:矩阵中的项数findPos为 O(n) 。乘法调用有两个嵌套的 for 循环。内循环至少调用一次,对于and调用可能调用两次nfindPosgetEntrysetEntry

作为第一个改进,请考虑先按行rowIndex排序colIndex,然后按列排序。这样可以使用一种称为二分搜索的晦涩算法,其复杂度为 O(log n)。此外,它还允许您快速选择属于给定行的所有元素。

这也使您的addsubtract函数变得更加简单:您只需同时遍历两个矩阵,并将元素复制到结果矩阵,如果两个矩阵都填充了该位置,则偶尔进行加法/减法。

第二个优化是另一种常见的优化:在 A * B 乘法中,A 矩阵逐行读取,B 矩阵逐列读取。因此,如果您先转置 B,则可以逐行读取,如上所述,这样现在速度更快。

4

  • “一种名为二分查找的晦涩算法”晦涩难懂?


    – 

  • 这是个笑话;)这似乎是 C++ 或算法/数据结构课程的家庭作业,所以 OP 应该了解二分搜索。


    – 


  • 谢谢你的建议。二分查找很容易理解,但这意味着我必须在rowIndex之后对进行排序?顺便说一句,我实际上不明白 是什么addsubtractmultiplysorted by row first and then by column


    – 

  • 我的意思是,同一行中的单元格应该按递增顺序出现在colIndex数组中。您可以构造加法、乘法、减法运算,使它们只按正确的顺序生成索引。


    – 


  • 经过两天的努力,考试顺利通过了ACCEPTED,两个10000×10000的稀疏矩阵运算在0.01s~0.02s内完成。

  • 而且我必须说,使用三个VecList或任意一个two-dimensional array来存储稀疏矩阵绝对不是一个好主意。因为矩阵中有成千上万个 0。这次我只使用一个VecList来存储矩阵,但在 VecList 中有struct rHead成员。

struct OrthNode{
    /*I apologize for misspelling ‘row’ as ‘rol’.It was only when the program 
    was mostly written that I noticed this.*/
    int rol, col, value;
    struct OrthNode *right;
};
struct rHead{
    int nums;
//This variable 'nums' is useless.I didn't use it in the later program.
    struct OrthNode * right;
};
  • 包含rHead在 中VecList,一个矩阵有多少行,那么rHead中就有多少行VecList。每个行rHead占一整行,当我想访问一行时,我应该rHead先访问它。
  • 就像OrthNode是一张身份证,记录了所有非零元素的信息,它的col,它的rol,它的value。值得注意的是,它有一个*right指向同一行下一个元素的指针。
  • 现在,我以rHead作为每一行的开头。 和rHead->right指向这一行中第一个非零元素OrthNode。 而这个OrthNode的指针OrthNode->right指向这一行的下一个元素。无论rHeadOrthNode,如果右侧没有元素,则它指向NULL
  • 这样,当我要访问的时候A[3][4],就先访问rHead第三行的,也就是现在VecList[3]我在A[3][?],然后找rHead->right,假设是rHead->right = A[3][1],再找A[3][1]->right,直到得到A[3][4]
  • 因此,我使用一个VecList来逐行存储稀疏矩阵。我按列的顺序将元素放在每行中(例如,在row[3]A rHead->A[3][1]->A[3][2]->A[3][4]->A[3][6][3][k] 中,k 是按顺序排列的。)
  • 这是我的代码:
template <class T>
class VecList{
//Only this member funcion has changed, so that it can return a correct thing.
public:
        T* getEleAtPos(int i){
            if(i >= length || i < 0)
                throw "Illegal position";
            return &arr[i];
        }
//Others things in VecList are the same as in question.
};
template <class T>
class SparseMatrix{
    private:
        VecList<rHead> M;//The VecList is the same as in the question.
        int totalrows;
        int totalcols;
    public:
        SparseMatrix(){
            totalrows = 10;
            totalcols = 10;
            for(int i=0; i<10; i++){
                M.insertEleAtPos(M.getLength(), {0, nullptr});
            }
        }
        SparseMatrix(int r, int c){
            totalrows = r;
            totalcols = c;
            for(int i=0; i<r; i++){
                M.insertEleAtPos(M.getLength(), {0, nullptr});
            }
        }
        ~SparseMatrix(){
            for (int i = 0; i < totalrows; i++)
            {
                if(M.getEleAtPos(i)->right != NULL)
                {
                    OrthNode * temp = M.getEleAtPos(i)->right;
                    while (temp != NULL)
                    {
                        M.getEleAtPos(i)->right = temp->right;
                        OrthNode*delNode = temp;
                        temp = temp->right;
                        delete delNode;
                    }
                    delete temp;
                }
            }
        }
        void setEntry(int rPos, int cPos, T x){
            OrthNode* newNode = new OrthNode;
            newNode->rol = rPos;
            newNode->col = cPos;
            newNode->value = x;
            if (x == 0)
            {
                if (M.getEleAtPos(rPos)->right == NULL)
                {
                    delete newNode;
                    return;
                }
                OrthNode*temp = M.getEleAtPos(rPos)->right;
                if (temp->col == cPos)
                {
                    M.getEleAtPos(rPos)->right = temp->right;
                    delete newNode;
                    delete temp;
                    M.getEleAtPos(rPos)->nums --;
                    return;
                }
                
                while(temp->col < cPos && temp != NULL){
                    if (temp->right->col == cPos)
                    {
                        OrthNode*delNode = temp->right;
                        temp = temp->right->right;
                        delete delNode;
                        M.getEleAtPos(rPos)->nums --;
                        return;
                    }
                    temp = temp->right;
                }
                if (temp->right == NULL)
                {
                    delete temp;
                    delete newNode;
                    return;
                }
            }
            else{
                if (M.getEleAtPos(rPos)->right == NULL)
                {
                    newNode->right = NULL;
                    M.getEleAtPos(rPos)->right = newNode;
                    M.getEleAtPos(rPos)->nums ++;
                    return;
                }
                else{
                    OrthNode* temp = M.getEleAtPos(rPos)->right;
                    if (cPos < temp->col)
                    {
                        newNode->right = temp;
                        M.getEleAtPos(rPos)->right = newNode;
                        M.getEleAtPos(rPos)->nums ++;
                        return;
                    }
                    while (temp->col <= cPos && temp != NULL)
                    {
                        if (temp->col == cPos)
                        {
                            temp->value = x;
                            M.getEleAtPos(rPos)->nums ++;
                            delete newNode;
                            return;
                        }
                        if (temp->right == NULL)
                        {
                            newNode->right =NULL;
                            temp->right = newNode;
                            M.getEleAtPos(rPos)->nums ++;
                            return;
                        } 
                        if (temp->col < cPos && temp->right->col > cPos)
                        {
                            newNode->right = temp->right;
                            temp->right = newNode;
                            M.getEleAtPos(rPos)->nums ++;
                            return;
                        }
                        temp = temp->right;
                    }    
            }
        }
    }
        T getEntry(int rPos, int cPos){
            OrthNode* read = new OrthNode;
            read = M.getEleAtPos(rPos)->right;
            while (read != NULL)
            {
                if (read->col == cPos)
                {
                    return read->value;
                }
                read = read->right;
            }
            delete read;
            return 0;
        }
        SparseMatrix<T> * add(SparseMatrix<T> * B){
            if(totalrows != B->totalrows || totalcols != B->totalcols)throw "Matrices have incompatible sizes";
            SparseMatrix<T>* C = new SparseMatrix(totalrows,totalcols);
            for (int i = 0; i < totalrows; i++)
            {
                if (M.getEleAtPos(i)->right != NULL && B->M.getEleAtPos(i)->right == NULL)
                {
                    OrthNode * tempA = M.getEleAtPos(i)->right;
                    while (tempA != NULL)
                    {
                        C->setEntry(tempA->rol,tempA->col,tempA->value);
                        tempA = tempA->right;
                    }
                    delete tempA;
                    continue;
                }
                else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right != NULL)
                {
                    OrthNode * tempB = B->M.getEleAtPos(i)->right;
                    while (tempB != NULL)
                    {
                        C->setEntry(tempB->rol,tempB->col,tempB->value);
                        tempB = tempB->right;
                    }
                    delete tempB;
                    continue;
                }
                else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right == NULL)
                {
                    continue;
                }
                else{
                    OrthNode * tempA = M.getEleAtPos(i)->right;
                    OrthNode * tempB = B->M.getEleAtPos(i)->right;
                    while (tempA != NULL)
                    {
                        C->setEntry(tempA->rol,tempA->col,tempA->value);
                        tempA = tempA->right;
                    }
                    delete tempA;
                    while (tempB != NULL)
                    {
                        int oldEntry = C->getEntry(tempB->rol,tempB->col);
                        C->setEntry(tempB->rol,tempB->col,oldEntry + tempB->value);
                        tempB = tempB->right;
                    }
                    delete tempB;
                    continue;
                }
            }
        return C;
        }
        SparseMatrix<T> * subtract(SparseMatrix<T> * B){
            if(totalrows != B->totalrows || totalcols != B->totalcols)throw "Matrices have incompatible sizes";
            SparseMatrix<T>* C = new SparseMatrix(totalrows,totalcols);
            for (int i = 0; i < totalrows; i++)
            {
                if (M.getEleAtPos(i)->right != NULL && B->M.getEleAtPos(i)->right == NULL)
                {
                    OrthNode * tempA = M.getEleAtPos(i)->right;
                    while (tempA != NULL)
                    {
                        C->setEntry(tempA->rol,tempA->col,tempA->value);
                        tempA = tempA->right;
                    }
                    delete tempA;
                    continue;
                }
                else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right != NULL)
                {
                    OrthNode * tempB = B->M.getEleAtPos(i)->right;
                    while (tempB != NULL)
                    {
                        C->setEntry(tempB->rol,tempB->col, -(tempB->value));
                        tempB = tempB->right;
                    }
                    delete tempB;
                    continue;
                }
                else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right == NULL)
                {
                    continue;
                }
                else{
                    OrthNode * tempA = M.getEleAtPos(i)->right;
                    OrthNode * tempB = B->M.getEleAtPos(i)->right;
                    while (tempA != NULL)
                    {
                        C->setEntry(tempA->rol,tempA->col,tempA->value);
                        tempA = tempA->right;
                    }
                    delete tempA;
                    while (tempB != NULL)
                    {
                        int oldEntry = C->getEntry(tempB->rol,tempB->col);
                        C->setEntry(tempB->rol,tempB->col,oldEntry - tempB->value);
                        tempB = tempB->right;
                    }
                    delete tempB;
                    continue;
                }
            }
        return C;
        }

        SparseMatrix<T> * multiply(SparseMatrix<T> * B){
            //perform multiplication if the sizes of the matrices are compatible.
            if (totalrows != B->totalcols || totalcols != B->totalrows)throw "Matrices have incompatible sizes";
            SparseMatrix<T>* C = new SparseMatrix(totalrows,B->totalcols);
            for (int i = 0; i < totalrows; i++)
            {
                if (M.getEleAtPos(i)->right == NULL)
                {
                    continue;
                }
                else{
                    OrthNode* tempA = M.getEleAtPos(i)->right;
                    while (tempA != NULL)
                    {
                        if (B->M.getEleAtPos(tempA->col) != NULL)
                        {
                        OrthNode* tempB = B->M.getEleAtPos(tempA->col)->right;
                        while (tempB != NULL)
                            {
                                int oldEntry = C->getEntry(tempA->rol,tempB->col);
                                C->setEntry(tempA->rol,tempB->col,oldEntry + tempA->value * tempB->value);
                                tempB = tempB->right;
                            }
                        }
                        tempA = tempA->right;
                    }  
                }
            }
            return C;
        }

        // Only call this function if you know the size of matrix is reasonable.
        void printMatrix(){
            // for (int i = 0; i < totalrows; i++)
            // {
            //     for (int j = 0; j < totalcols; j++)
            //     {
            //         cout << getEntry(i,j) << " ";
            //     }
            //     cout << endl;
            // }
            cout << "Be careful, when the matrix is too big, do not use Print!" << endl;
        }
  • 测试代码如下:
#include <iostream>
#include <algorithm>
#include <chrono>
using namespace std;
int main(){
    SparseMatrix<int> X(10,10);
    SparseMatrix<int> Y(10,10);
//It is a Sparse Matrix so do not give so much elements.
    for (int i = 0; i < 10; i++)
    {
        X.setEntry(i,i,2);
        Y.setEntry(3,i,i+2);
    }
//If you want to test the time cost, do not use printMatrix();
    auto start = std::chrono::high_resolution_clock::now();
    X.printMatrix();
    cout << endl;
    Y.printMatrix();
    cout << endl;
    X.add(&Y)->printMatrix();
    cout << endl;
    X.subtract(&Y)->printMatrix();
    cout << endl;
    X.multiply(&Y)->printMatrix();
    cout << endl;
    cout << "Done" << endl;
    auto stop = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
    cout << "Running Time:" << duration << "ms\n";
    return 0;
}

3

  • 我不确定这是否仍可称为稀疏矩阵。在我看来,1Bx1B 稀疏矩阵(空或其他)不应占用 8GB 的​​ RAM。


    – 


  • 运行内存是个大问题,不管一行里有没有元素,我都创建了一个rHead指针,更优化的算法是添加一个*next指向下一个有元素struct rHead的指针,当一行里没有元素时,就表示这一行没有元素,另外创建一个指向一整列的指针,这样存储和定位都比较方便,创建后,就必须添加一个指向其下一级元素的指针,不过,可能是我没有正确删除分配的内存。rHead->nextrHeadrHeadcHeadcHeadOrthNode*down


    – 


  • ..这将使你的 getPos 操作符再次变为 O(n)。你需要某种树结构或至少二分搜索。


    –