稀疏矩阵是绝大多数元素为 0 的矩阵,只有少数非零元素。现在我必须填充我的矩阵Sparsematrix class
,使得矩阵可以做到add
、subtract
和multiply
。我使用 COO 来存储我的矩阵。
template <class T>
class VecList{
private:
int capacity;
int length;
T* arr;
void doubleListSize(){
T * oldArr = arr;
arr = new T[2*capacity];
capacity = 2 * capacity;
for(int i=0;i<length;i++){
arr[i] = oldArr[i];
}
delete [] oldArr;
}
public:
VecList(){
length = 0;
capacity = 100;
arr = new T[capacity];
}
VecList(T* a, int n){
length = n;
capacity = 100 + 2*n;
arr = new T[capacity];
for(int i=0;i<n;i++){
arr[i] = a[i];
}
for (int i = 0; i < n; i++)
{
cout << arr[i] << " ";
}
cout << endl;
printList();
}
~VecList(){
delete [] arr;
}
int getLength(){
return length;
}
bool isEmpty(){
return length==0;
}
void insertEleAtPos(int i, T x){
if(length==capacity)
doubleListSize();
if(i > length || i < 0)
throw "Illegal position";
for(int j=length;j>i;j--)
arr[j] = arr[j-1];
arr[i] = x;
length++;
}
T deleteEleAtPos(int i){
if(i >= length || i < 0)
throw "Illegal position";
T tmp = arr[i];
for(int j=i;j<length-1;j++)
arr[j] = arr[j+1];
length--;
return tmp;
}
void setEleAtPos(int i, T x){
if(i >= length || i < 0)
throw "Illegal position";
arr[i] = x;
}
T getEleAtPos(int i){
if(i >= length || i < 0)
throw "Illegal position";
return arr[i];
}
int locateEle(T x){
for(int i=0;i<length;i++){
if(arr[i]==x)
return i;
}
return -1;
}
void printList(){
for(int i=0;i<length;i++)
cout << arr[i] << " ";
}
};
COO 使用三个VecList
来存储矩阵。
rowIndex
:表示行数。colIndex
:表示列数。values
:表示元素的值。以下是我的Sparsematrix class
:
template <class T>
class SparseMatrix{
private:
int rows;
int cols;
VecList<int>* rowIndex;
VecList<int>* colIndex;
VecList<T>* values;
public:
SparseMatrix(){ //Create a 10x10 Sparse matrix
rows = 10;
cols = 10;
rowIndex = new VecList<int>();
colIndex = new VecList<int>();
values = new VecList<T>();
}
SparseMatrix(int r, int c){ //Create a rxc Sparse matrix
rows = r;
cols = c;
rowIndex = new VecList<int>();
colIndex = new VecList<int>();
values = new VecList<T>();
}
~SparseMatrix(){
delete rowIndex;
delete colIndex;
delete values;
}
};
如你所见,我只需要关注非零元素。例如:一个稀疏矩阵
0 2 0
0 0 1
3 0 0
rows = Exact number of rows, 3
cols = Exact number of columns, 3
rowIndex = 0 1 2
colIndex = 1 2 0
values = 2 1 3
第一根竖线表示,row[0]col[1]
有一个非零元素“2”。第二根竖线表示,row[1]col[2]
有一个非零元素“1”。现在我写了几个函数来实现矩阵之间的运算。
int findPos(int a, int b){ //If there is a non-zero element at (a, b), then return its position in "rowIndex", else return -1.
for (int i = 0; i < rowIndex->getLength(); i++)
{
if(rowIndex->getEleAtPos(i) == a && colIndex->getEleAtPos(i) == b)return i;
else if(rowIndex->getEleAtPos(rowIndex->getLength() - 1 - i) == a && colIndex->getEleAtPos(colIndex->getLength()-1-i) == b)return rowIndex->getLength()-1 - i;
}
return -1;
}
void setEntry(int rPos, int cPos, T x){ // Set (rPos, cPos) = x
int pos = findPos(rPos,cPos);
//Find if there is a non-zero element at (rPos, cPos).
if(x != 0){
//If the origin matrix does not have an element at(rPos, cPos),insert x to the matrix.
if (pos == -1)
{
rowIndex->insertEleAtPos(rowIndex->getLength(),rPos);
colIndex->insertEleAtPos(colIndex->getLength(),cPos);
values->insertEleAtPos(values->getLength(),x);
}
else{
//If the origin matrix has an element at(rPos, cPos),replace it with x.
rowIndex->setEleAtPos(pos,rPos);
colIndex->setEleAtPos(pos,cPos);
values->setEleAtPos(pos,x);
}
}
else{
//If x == 0 and the origin matrix has an element at(rPos, cPos), delete the element.
if(pos != -1){
rowIndex->deleteEleAtPos(pos);
colIndex->deleteEleAtPos(pos);
values->deleteEleAtPos(pos);
}
}
//If x == 0, and the origin matrix does not have an element at(rPos, cPos), nothing changed.
}
T getEntry(int rPos, int cPos){
//Get the element at (rPos, cPos)
return findPos(rPos,cPos) == -1 ? 0 : values->getEleAtPos(findPos(rPos,cPos));
}
SparseMatrix<T> * add(SparseMatrix<T> * B){
if(rows != B->rows || cols != B->cols)throw "Matrices have incompatible sizes";
SparseMatrix<T> *C = new SparseMatrix<T>(rows,cols);//Create a new matrix C as result.
for (int i = 0; i < rowIndex->getLength(); i++)
{
//I call the two input matrices "A" and "B". I put every elements of A into C, and also put every elements of B into C. But I use "C->setEntry", which means when A[i][j] has an element and B[i][j] also has an element, "setEntry" will cover the prior one. So I use C->setEntry(i,j,C->getEntry(i,j) + A[i][j] or B[i][j]), in another word, setEntry with (oldvalue + newvalue).That's what I did.
C->setEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i),C->getEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i))+values->getEleAtPos(i));
C->setEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i),C->getEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i))+B->values->getEleAtPos(i));
}
return C;
}
SparseMatrix<T> * subtract(SparseMatrix<T> * B){
//The same method as add.
if(rows != B->rows || cols != B->cols)throw "Matrices have incompatible sizes";
SparseMatrix<T> *C = new SparseMatrix<T>(rows,cols);
for (int i = 0; i < rowIndex->getLength(); i++)
{
C->setEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i),C->getEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i))-values->getEleAtPos(i));
C->setEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i),C->getEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i))-B->values->getEleAtPos(i));
}
return C;
}
SparseMatrix<T> * multiply(SparseMatrix<T> * B){
//perform multiplication if the sizes of the matrices are compatible.
if(rows != B->cols || cols != B->rows)throw "Matrices have incompatible sizes";
SparseMatrix<T> *C = new SparseMatrix<T>(rows,B->cols);
//I call the two input matrices as "A" and "B".
//My method is take a row of A first, let this row do the arithmetic with each column of B,then I finish a row in C. Then continue to the next row.
for (int i = 0; i < rowIndex->getLength();i++)
{
for (int j = 0; j < B->colIndex->getLength(); j++)
{
if (B->findPos(colIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j)) != -1)
{
C->setEntry(rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j),C->getEntry(rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j))+(values->getEleAtPos(i)*B->values->getEleAtPos(j)));
}
}
}
return C;
}
void printMatrix(){
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < cols; j++)
{
cout << getEntry(i,j) << " ";
}
cout << endl;
}
}
我已经测试了几种情况,所有情况都表明add
,subtract
和multiply
运行良好。但是有一个 10000×10000 矩阵(称为“X”和“Y”)测试我无法通过,X 和 Y 没有很多非零元素。而且它们只是加、减和乘。
时间限制是 1 秒。(不包括 printMatrix(),但包括 setEntry())我超过了它。我该如何减少程序的运行时间?(我还想知道 COO 存储是否错误,以及findPos()
函数是否备用。)谢谢。我的工具是 VSCode2022,使用 C++11,Windows 11。这是测试代码的示例。
#include <iostream>
#include <algorithm>
#include <chrono>
using namespace std;
int main(){
auto start = std::chrono::high_resolution_clock::now();
SparseMatrix<int> X,Y;
X.setEntry(1,3,4);
X.setEntry(7,8,2);
Y.setEntry(1,6,4);
Y.setEntry(1,3,4);
Y.setEntry(7,7,2);
X.printMatrix();
cout << endl;
Y.printMatrix();
cout << endl;
X.add(&Y)->printMatrix();
cout << endl;
X.subtract(&Y)->printMatrix();
cout << endl;
Y.multiply(&X)->printMatrix();
cout << "Done" << endl;
auto stop = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
cout << "Running Time:" << duration << "ms\n";
return 0;
}
3
最佳答案
2
首先进行一些快速的粗略计算:矩阵中的项数findPos
为 O(n) 。乘法调用有两个嵌套的 for 循环。内循环至少调用一次,对于and调用可能调用两次。n
findPos
getEntry
setEntry
作为第一个改进,请考虑先按行rowIndex
排序colIndex
,然后按列排序。这样可以使用一种称为二分搜索的晦涩算法,其复杂度为 O(log n)。此外,它还允许您快速选择属于给定行的所有元素。
这也使您的add
和subtract
函数变得更加简单:您只需同时遍历两个矩阵,并将元素复制到结果矩阵,如果两个矩阵都填充了该位置,则偶尔进行加法/减法。
第二个优化是另一种常见的优化:在 A * B 乘法中,A 矩阵逐行读取,B 矩阵逐列读取。因此,如果您先转置 B,则可以逐行读取,如上所述,这样现在速度更快。
4
-
“一种名为二分查找的晦涩算法”晦涩难懂?
– -
这是个笑话;)这似乎是 C++ 或算法/数据结构课程的家庭作业,所以 OP 应该了解二分搜索。
–
-
谢谢你的建议。二分查找很容易理解,但这意味着我必须在、和
rowIndex
之后对进行排序?顺便说一句,我实际上不明白 是什么。add
subtract
multiply
sorted by row first and then by column
– -
我的意思是,同一行中的单元格应该按递增顺序出现在
colIndex
数组中。您可以构造加法、乘法、减法运算,使它们只按正确的顺序生成索引。
–
|
-
经过两天的努力,考试顺利通过了ACCEPTED,两个10000×10000的稀疏矩阵运算在0.01s~0.02s内完成。
-
而且我必须说,使用三个
VecList
或任意一个two-dimensional array
来存储稀疏矩阵绝对不是一个好主意。因为矩阵中有成千上万个 0。这次我只使用一个VecList
来存储矩阵,但在 VecList 中有struct rHead
成员。
struct OrthNode{
/*I apologize for misspelling ‘row’ as ‘rol’.It was only when the program
was mostly written that I noticed this.*/
int rol, col, value;
struct OrthNode *right;
};
struct rHead{
int nums;
//This variable 'nums' is useless.I didn't use it in the later program.
struct OrthNode * right;
};
- 包含
rHead
在 中VecList
,一个矩阵有多少行,那么rHead
中就有多少行VecList
。每个行rHead
占一整行,当我想访问一行时,我应该rHead
先访问它。 - 就像
OrthNode
是一张身份证,记录了所有非零元素的信息,它的col
,它的rol
,它的value
。值得注意的是,它有一个*right
指向同一行下一个元素的指针。 - 现在,我以
rHead
作为每一行的开头。 和rHead->right
指向这一行中第一个非零元素OrthNode
。 而这个OrthNode
的指针OrthNode->right
指向这一行的下一个元素。无论rHead
或OrthNode
,如果右侧没有元素,则它指向NULL
。 - 这样,当我要访问的时候
A[3][4]
,就先访问rHead
第三行的,也就是现在VecList[3]
我在A[3][?]
,然后找rHead->right
,假设是rHead->right = A[3][1]
,再找A[3][1]->right
,直到得到A[3][4]
。 - 因此,我使用一个
VecList
来逐行存储稀疏矩阵。我按列的顺序将元素放在每行中(例如,在row[3]
ArHead->A[3][1]->A[3][2]->A[3][4]->A[3][6]
[3][k] 中,k 是按顺序排列的。) - 这是我的代码:
template <class T>
class VecList{
//Only this member funcion has changed, so that it can return a correct thing.
public:
T* getEleAtPos(int i){
if(i >= length || i < 0)
throw "Illegal position";
return &arr[i];
}
//Others things in VecList are the same as in question.
};
template <class T>
class SparseMatrix{
private:
VecList<rHead> M;//The VecList is the same as in the question.
int totalrows;
int totalcols;
public:
SparseMatrix(){
totalrows = 10;
totalcols = 10;
for(int i=0; i<10; i++){
M.insertEleAtPos(M.getLength(), {0, nullptr});
}
}
SparseMatrix(int r, int c){
totalrows = r;
totalcols = c;
for(int i=0; i<r; i++){
M.insertEleAtPos(M.getLength(), {0, nullptr});
}
}
~SparseMatrix(){
for (int i = 0; i < totalrows; i++)
{
if(M.getEleAtPos(i)->right != NULL)
{
OrthNode * temp = M.getEleAtPos(i)->right;
while (temp != NULL)
{
M.getEleAtPos(i)->right = temp->right;
OrthNode*delNode = temp;
temp = temp->right;
delete delNode;
}
delete temp;
}
}
}
void setEntry(int rPos, int cPos, T x){
OrthNode* newNode = new OrthNode;
newNode->rol = rPos;
newNode->col = cPos;
newNode->value = x;
if (x == 0)
{
if (M.getEleAtPos(rPos)->right == NULL)
{
delete newNode;
return;
}
OrthNode*temp = M.getEleAtPos(rPos)->right;
if (temp->col == cPos)
{
M.getEleAtPos(rPos)->right = temp->right;
delete newNode;
delete temp;
M.getEleAtPos(rPos)->nums --;
return;
}
while(temp->col < cPos && temp != NULL){
if (temp->right->col == cPos)
{
OrthNode*delNode = temp->right;
temp = temp->right->right;
delete delNode;
M.getEleAtPos(rPos)->nums --;
return;
}
temp = temp->right;
}
if (temp->right == NULL)
{
delete temp;
delete newNode;
return;
}
}
else{
if (M.getEleAtPos(rPos)->right == NULL)
{
newNode->right = NULL;
M.getEleAtPos(rPos)->right = newNode;
M.getEleAtPos(rPos)->nums ++;
return;
}
else{
OrthNode* temp = M.getEleAtPos(rPos)->right;
if (cPos < temp->col)
{
newNode->right = temp;
M.getEleAtPos(rPos)->right = newNode;
M.getEleAtPos(rPos)->nums ++;
return;
}
while (temp->col <= cPos && temp != NULL)
{
if (temp->col == cPos)
{
temp->value = x;
M.getEleAtPos(rPos)->nums ++;
delete newNode;
return;
}
if (temp->right == NULL)
{
newNode->right =NULL;
temp->right = newNode;
M.getEleAtPos(rPos)->nums ++;
return;
}
if (temp->col < cPos && temp->right->col > cPos)
{
newNode->right = temp->right;
temp->right = newNode;
M.getEleAtPos(rPos)->nums ++;
return;
}
temp = temp->right;
}
}
}
}
T getEntry(int rPos, int cPos){
OrthNode* read = new OrthNode;
read = M.getEleAtPos(rPos)->right;
while (read != NULL)
{
if (read->col == cPos)
{
return read->value;
}
read = read->right;
}
delete read;
return 0;
}
SparseMatrix<T> * add(SparseMatrix<T> * B){
if(totalrows != B->totalrows || totalcols != B->totalcols)throw "Matrices have incompatible sizes";
SparseMatrix<T>* C = new SparseMatrix(totalrows,totalcols);
for (int i = 0; i < totalrows; i++)
{
if (M.getEleAtPos(i)->right != NULL && B->M.getEleAtPos(i)->right == NULL)
{
OrthNode * tempA = M.getEleAtPos(i)->right;
while (tempA != NULL)
{
C->setEntry(tempA->rol,tempA->col,tempA->value);
tempA = tempA->right;
}
delete tempA;
continue;
}
else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right != NULL)
{
OrthNode * tempB = B->M.getEleAtPos(i)->right;
while (tempB != NULL)
{
C->setEntry(tempB->rol,tempB->col,tempB->value);
tempB = tempB->right;
}
delete tempB;
continue;
}
else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right == NULL)
{
continue;
}
else{
OrthNode * tempA = M.getEleAtPos(i)->right;
OrthNode * tempB = B->M.getEleAtPos(i)->right;
while (tempA != NULL)
{
C->setEntry(tempA->rol,tempA->col,tempA->value);
tempA = tempA->right;
}
delete tempA;
while (tempB != NULL)
{
int oldEntry = C->getEntry(tempB->rol,tempB->col);
C->setEntry(tempB->rol,tempB->col,oldEntry + tempB->value);
tempB = tempB->right;
}
delete tempB;
continue;
}
}
return C;
}
SparseMatrix<T> * subtract(SparseMatrix<T> * B){
if(totalrows != B->totalrows || totalcols != B->totalcols)throw "Matrices have incompatible sizes";
SparseMatrix<T>* C = new SparseMatrix(totalrows,totalcols);
for (int i = 0; i < totalrows; i++)
{
if (M.getEleAtPos(i)->right != NULL && B->M.getEleAtPos(i)->right == NULL)
{
OrthNode * tempA = M.getEleAtPos(i)->right;
while (tempA != NULL)
{
C->setEntry(tempA->rol,tempA->col,tempA->value);
tempA = tempA->right;
}
delete tempA;
continue;
}
else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right != NULL)
{
OrthNode * tempB = B->M.getEleAtPos(i)->right;
while (tempB != NULL)
{
C->setEntry(tempB->rol,tempB->col, -(tempB->value));
tempB = tempB->right;
}
delete tempB;
continue;
}
else if (M.getEleAtPos(i)->right == NULL && B->M.getEleAtPos(i)->right == NULL)
{
continue;
}
else{
OrthNode * tempA = M.getEleAtPos(i)->right;
OrthNode * tempB = B->M.getEleAtPos(i)->right;
while (tempA != NULL)
{
C->setEntry(tempA->rol,tempA->col,tempA->value);
tempA = tempA->right;
}
delete tempA;
while (tempB != NULL)
{
int oldEntry = C->getEntry(tempB->rol,tempB->col);
C->setEntry(tempB->rol,tempB->col,oldEntry - tempB->value);
tempB = tempB->right;
}
delete tempB;
continue;
}
}
return C;
}
SparseMatrix<T> * multiply(SparseMatrix<T> * B){
//perform multiplication if the sizes of the matrices are compatible.
if (totalrows != B->totalcols || totalcols != B->totalrows)throw "Matrices have incompatible sizes";
SparseMatrix<T>* C = new SparseMatrix(totalrows,B->totalcols);
for (int i = 0; i < totalrows; i++)
{
if (M.getEleAtPos(i)->right == NULL)
{
continue;
}
else{
OrthNode* tempA = M.getEleAtPos(i)->right;
while (tempA != NULL)
{
if (B->M.getEleAtPos(tempA->col) != NULL)
{
OrthNode* tempB = B->M.getEleAtPos(tempA->col)->right;
while (tempB != NULL)
{
int oldEntry = C->getEntry(tempA->rol,tempB->col);
C->setEntry(tempA->rol,tempB->col,oldEntry + tempA->value * tempB->value);
tempB = tempB->right;
}
}
tempA = tempA->right;
}
}
}
return C;
}
// Only call this function if you know the size of matrix is reasonable.
void printMatrix(){
// for (int i = 0; i < totalrows; i++)
// {
// for (int j = 0; j < totalcols; j++)
// {
// cout << getEntry(i,j) << " ";
// }
// cout << endl;
// }
cout << "Be careful, when the matrix is too big, do not use Print!" << endl;
}
- 测试代码如下:
#include <iostream>
#include <algorithm>
#include <chrono>
using namespace std;
int main(){
SparseMatrix<int> X(10,10);
SparseMatrix<int> Y(10,10);
//It is a Sparse Matrix so do not give so much elements.
for (int i = 0; i < 10; i++)
{
X.setEntry(i,i,2);
Y.setEntry(3,i,i+2);
}
//If you want to test the time cost, do not use printMatrix();
auto start = std::chrono::high_resolution_clock::now();
X.printMatrix();
cout << endl;
Y.printMatrix();
cout << endl;
X.add(&Y)->printMatrix();
cout << endl;
X.subtract(&Y)->printMatrix();
cout << endl;
X.multiply(&Y)->printMatrix();
cout << endl;
cout << "Done" << endl;
auto stop = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
cout << "Running Time:" << duration << "ms\n";
return 0;
}
3
-
我不确定这是否仍可称为稀疏矩阵。在我看来,1Bx1B 稀疏矩阵(空或其他)不应占用 8GB 的 RAM。
–
-
运行内存是个大问题,不管一行里有没有元素,我都创建了一个
rHead
指针,更优化的算法是添加一个*next
指向下一个有元素struct rHead
的指针,当一行里没有元素时,就表示这一行没有元素,另外创建一个指向一整列的指针,这样存储和定位都比较方便,创建后,就必须添加一个指向其下一级元素的指针,不过,可能是我没有正确删除分配的内存。rHead->next
rHead
rHead
cHead
cHead
OrthNode
*down
–
-
..这将使你的 getPos 操作符再次变为 O(n)。你需要某种树结构或至少二分搜索。
–
|
getEleAtPos
几亿次。即使我们知道它们在范围内(因为循环条件)。–
acol = colIndex->getEleAtPos(i)
–
SparseMatrix
成为指针,因此您可以通过移除指针来立即加快速度。–
|