/** * The Matrix type defines a common * interface for matrix operations. */ template <typename T> classMatrix { protected: /** * TODO(P0): Add implementation * * Construct a new Matrix instance. * @param rows The number of rows * @param cols The number of columns * */ Matrix(int rows, int cols) { rows_ = rows; cols_ = cols; linear_ = new T[rows*cols]; memset(linear_, 0, sizeof(T)*(rows*cols)); }
/** The number of rows in the matrix */ int rows_; /** The number of columns in the matrix */ int cols_;
/** * TODO(P0): Allocate the array in the constructor. * TODO(P0): Deallocate the array in the destructor. * A flattened array containing the elements of the matrix. */ T *linear_;
public: /** @return The number of rows in the matrix */ virtualautoGetRowCount()const -> int= 0;
/** @return The number of columns in the matrix */ virtualautoGetColumnCount()const -> int= 0;
/** * Get the (i,j)th matrix element. * * Throw OUT_OF_RANGE if either index is out of range. * * @param i The row index * @param j The column index * @return The (i,j)th matrix element * @throws OUT_OF_RANGE if either index is out of range */ virtualautoGetElement(int i, int j)const -> T = 0;
/** * Set the (i,j)th matrix element. * * Throw OUT_OF_RANGE if either index is out of range. * * @param i The row index * @param j The column index * @param val The value to insert * @throws OUT_OF_RANGE if either index is out of range */ virtualvoidSetElement(int i, int j, T val)= 0;
/** * Fill the elements of the matrix from `source`. * * Throw OUT_OF_RANGE in the event that `source` * does not contain the required number of elements. * * @param source The source container * @throws OUT_OF_RANGE if `source` is incorrect size */ virtualvoidFillFrom(const std::vector<T> &source)= 0;
/** * The RowMatrix type is a concrete matrix implementation. * It implements the interface defined by the Matrix type. */ template <typename T> classRowMatrix : public Matrix<T> { public: /** * TODO(P0): Add implementation * * Construct a new RowMatrix instance. * @param rows The number of rows * @param cols The number of columns */ RowMatrix(int rows, int cols) : Matrix<T>(rows, cols) { data_ = new T* [rows]; for (int i = 0; i < rows; i++) { data_[i] = this->linear_ + i*cols; } }
/** * TODO(P0): Add implementation * @return The number of rows in the matrix */ autoGetRowCount()const -> intoverride{ returnthis->rows_; }
/** * TODO(P0): Add implementation * @return The number of columns in the matrix */ autoGetColumnCount()const -> intoverride{ returnthis->cols_; }
/** * TODO(P0): Add implementation * * Get the (i,j)th matrix element. * * Throw OUT_OF_RANGE if either index is out of range. * * @param i The row index * @param j The column index * @return The (i,j)th matrix element * @throws OUT_OF_RANGE if either index is out of range */ autoGetElement(int i, int j)const -> T override{ if (i < 0 || j < 0 || i >= GetRowCount() || j >= GetColumnCount()) throwException(ExceptionType::OUT_OF_RANGE, "source does not contain the required number of elements"); else return data_[i][j]; }
/** * Set the (i,j)th matrix element. * * Throw OUT_OF_RANGE if either index is out of range. * * @param i The row index * @param j The column index * @param val The value to insert * @throws OUT_OF_RANGE if either index is out of range */ voidSetElement(int i, int j, T val)override{ if (i < 0 || j < 0 || i >= GetRowCount() || j >= GetColumnCount()) throwException(ExceptionType::OUT_OF_RANGE, "source does not contain the required number of elements"); data_[i][j] = val; }
/** * TODO(P0): Add implementation * * Fill the elements of the matrix from `source`. * * Throw OUT_OF_RANGE in the event that `source` * does not contain the required number of elements. * * @param source The source container * @throws OUT_OF_RANGE if `source` is incorrect size */ voidFillFrom(const std::vector<T> &source)override{ if (static_cast<int>(source.size()) != GetColumnCount()*GetRowCount()) { throwException(ExceptionType::OUT_OF_RANGE, "size error"); return; } int cnt = 0; for (int i = 0; i < GetRowCount(); i++) for (int j = 0; j < GetColumnCount(); j++) { data_[i][j] = source[cnt++]; } }
private: /** * A 2D array containing the elements of the matrix in row-major format. * * TODO(P0): * - Allocate the array of row pointers in the constructor. * - Use these pointers to point to corresponding elements of the `linear` array. * - Don't forget to deallocate the array in the destructor. */ T **data_; };
/** * The RowMatrixOperations class defines operations * that may be performed on instances of `RowMatrix`. */ template <typename T> classRowMatrixOperations { public: /** * Compute (`matrixA` + `matrixB`) and return the result. * Return `nullptr` if dimensions mismatch for input matrices. * @param matrixA Input matrix * @param matrixB Input matrix * @return The result of matrix addition */ staticautoAdd(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) -> std::unique_ptr<RowMatrix<T>> { // TODO(P0): Add implementation if (matrixA->GetColumnCount() != matrixB->GetColumnCount() || matrixA->GetRowCount() != matrixB->GetRowCount()) return std::unique_ptr<RowMatrix<T>>(nullptr); int row = matrixA->GetRowCount(); int col = matrixA->GetColumnCount(); std::unique_ptr<RowMatrix<T>>ptr = std::make_unique<RowMatrix<T>>(row, col); for (int i = 0; i < row; i++) for (int j = 0; j < col; j++) { ptr->SetElement(i, j, matrixA->GetElement(i, j) + matrixB->GetElement(i, j)); } return ptr; }
/** * Compute the matrix multiplication (`matrixA` * `matrixB` and return the result. * Return `nullptr` if dimensions mismatch for input matrices. * @param matrixA Input matrix * @param matrixB Input matrix * @return The result of matrix multiplication */ staticautoMultiply(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) -> std::unique_ptr<RowMatrix<T>> { // TODO(P0): Add implementation if (matrixA->GetColumnCount() != matrixB->GetRowCount()) return std::unique_ptr<RowMatrix<T>>(nullptr); int row = matrixA->GetRowCount(); int col = matrixB->GetColumnCount(); std::unique_ptr<RowMatrix<T>>ptr = std::make_unique<RowMatrix<T>>(row, col); for (int i = 0; i < row; i++) for (int j = 0; j < col; j++) { int temp = 0; for (int z = 0; z < matrixA->GetColumnCount();z++) { temp += matrixA->GetElement(i, z) * matrixB->GetElement(z, j); } ptr->SetElement(i, j, temp); } return ptr; }
/** * Simplified General Matrix Multiply operation. Compute (`matrixA` * `matrixB` + `matrixC`). * Return `nullptr` if dimensions mismatch for input matrices. * @param matrixA Input matrix * @param matrixB Input matrix * @param matrixC Input matrix * @return The result of general matrix multiply */ staticautoGEMM(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB, const RowMatrix<T> *matrixC) -> std::unique_ptr<RowMatrix<T>> { // TODO(P0): Add implementation if (matrixA->GetColumnCount() != matrixB->GetRowCount()) return std::unique_ptr<RowMatrix<T>>(nullptr); if (matrixA->GetRowCount() != matrixC->GetRowCount() || matrixB->GetColumnCount() != matrixC->GetRowCount()) return std::unique_ptr<RowMatrix<T>>(nullptr); auto mul = Multiply(matrixA, matrixB); const RowMatrix<T>*p1 = mul.get(); returnAdd(p1, matrixC); } }; } // namespace bustub