0
点赞
收藏
分享

微信扫一扫

Logistic回归与牛顿迭代法


在上一篇文章中,我讲述了Logistic回归的原理以及它的梯度上升法实现。现在来研究Logistic回归的另一种

实现,即牛顿迭代法

 

在上篇文章中,我们求出Logistic回归的似然函数的偏导数为

 

              

Logistic回归与牛顿迭代法_i++

 

由于

Logistic回归与牛顿迭代法_Data_02

是一个多元函数,变元是

Logistic回归与牛顿迭代法_牛顿迭代法_03

,多元函数求极值问题以前已经讲过,参考如下文章

 

我们知道,极值点的导数一定均为零,所以一共需要列出

Logistic回归与牛顿迭代法_Data_04

个方程,联立解出所有的参数

Logistic回归与牛顿迭代法_Data_05


当然,这里首先需要用Hessian矩阵来判断极值的存在性。方程组如下

 

              

Logistic回归与牛顿迭代法_i++_06

 

这一共是

Logistic回归与牛顿迭代法_Data_07

个方程,现在的问题变为如何解这个方程组。求Hessian矩阵就得先求二阶偏导,即

 

               

Logistic回归与牛顿迭代法_i++_08

 

Hessian矩阵表示出来,那么有

 

Logistic回归与牛顿迭代法_Data_09

Logistic回归与牛顿迭代法_Data_10

 

所以得到Hessian矩阵

Logistic回归与牛顿迭代法_牛顿迭代法_11

,可以看出矩阵

Logistic回归与牛顿迭代法_i++_12

是负定的,那么现在我来证明如果

Logistic回归与牛顿迭代法_i++_13

是负定的,那

么Hessian矩阵

Logistic回归与牛顿迭代法_i++_14

也是负定的

 

证明:设任意的

Logistic回归与牛顿迭代法_i++_15

n维列向量,因为

Logistic回归与牛顿迭代法_Data_16

是负定的,那么

Logistic回归与牛顿迭代法_牛顿迭代法_17

为二次型,也是负定的,又因为

 

     

Logistic回归与牛顿迭代法_i++_18

 

     所以

Logistic回归与牛顿迭代法_i++_19

也是负定的。

 

Hessian矩阵

Logistic回归与牛顿迭代法_牛顿迭代法_20

是负定的,也就是说多元函数存在局部极大值,这符合开始需求的最大似然估计。Hessian

描述了多元函数的局部曲率。有了这个Hessian矩阵,我们就可以用牛顿迭代法继续进行计算啦!

 

回想一下,对于一元函数

Logistic回归与牛顿迭代法_牛顿迭代法_21

我们是怎样通过牛顿迭代法求解零点的? 假设现在要求方程

Logistic回归与牛顿迭代法_Data_22

的解,那么首先选取一个点

Logistic回归与牛顿迭代法_牛顿迭代法_23

作为迭代起始点,然后通过下面式子进行迭代,直到达到指定的精度为止。

 

                

Logistic回归与牛顿迭代法_Data_24

 

原理详见:http://zh.m.wikipedia.org/wiki/%E7%89%9B%E9%A1%BF%E6%B3%95

 

有时候这个起始点的选取很关键,因为牛顿迭代法得到的是局部最优解,如果函数只存在一个零点,那么这个

Logistic回归与牛顿迭代法_牛顿迭代法_25

点选取无关重要,但是如果存在多个局部最优解,一般是求指定在某个点

Logistic回归与牛顿迭代法_Data_26

附近的零点。对于Logistic

回归问题,Hessian矩阵对于任意数据都是负定的,所以说极值点只有一个,初始点选取无关紧要。

 

对于多元函数求解零点,同样可以用牛顿迭代法,对于上面的Logistic回归,可以得到如下迭代式子

 

               

Logistic回归与牛顿迭代法_Data_27

 

其中

Logistic回归与牛顿迭代法_i++_28

为Hessian矩阵,而

Logistic回归与牛顿迭代法_牛顿迭代法_29

的表示如下

 

              

Logistic回归与牛顿迭代法_i++_30

 

由于Hessian矩阵

Logistic回归与牛顿迭代法_i++_31

是对称负定的,将矩阵

Logistic回归与牛顿迭代法_i++_32

提取一个负号出来,得到

 

              

Logistic回归与牛顿迭代法_i++_33

 

然后Hessian矩阵

Logistic回归与牛顿迭代法_Data_34

变为

Logistic回归与牛顿迭代法_Data_35

,这样

Logistic回归与牛顿迭代法_牛顿迭代法_36

就是对称正定的了。那么现在牛顿迭代式变为

 

              

Logistic回归与牛顿迭代法_Data_37

 

现在的重点是如何快速并有效计算

Logistic回归与牛顿迭代法_Data_38

,即解方程组

Logistic回归与牛顿迭代法_牛顿迭代法_39

,通常的做法是直接用高斯消元法求解,

但是这样做有弊端,弊端有两个:(1)效率低;(2)数值稳定性差。

 

由于

Logistic回归与牛顿迭代法_i++_40

是对称正定的,可以用Cholesky矩阵分解法来解。Cholesky分解原理如下

 


至此,牛顿迭代法求解Logistic回归的精髓基本讲完。现在开始用C++代码来实现它。

 

代码:

#include <string.h>
#include <fstream>
#include <stdio.h>
#include <math.h>
 
#include "matrix.h"
#define Type double
#define Vector vector
 
using namespace std;
 
/** 定义数据集结构体 */
struct Data
{
    Vector<Type> x;
    Type y;
};
 
/** 预处理数据给data */
void PreProcessData(Vector<Data>& data, string path)
{
    string filename = path;
    ifstream file(filename.c_str());
    char s[1024];
    if(file.is_open())
    {
        while(file.getline(s, 1024))
        {
            Data tmp;
            Type x1, x2, x3, x4, x5, x6, x7;
            sscanf(s,"%lf %lf %lf %lf %lf %lf %lf", &x1, &x2, &x3, &x4, &x5, &x6, &x7);
            tmp.x.push_back(1);
            tmp.x.push_back(x2);
            tmp.x.push_back(x3);
            tmp.x.push_back(x4);
            tmp.x.push_back(x5);
            tmp.x.push_back(x6);
            tmp.y = x7;
            data.push_back(tmp);
        }
    }
}
 
void Init(Vector<Data> &data, Vector<Type> &w)
{
    w.clear();
    data.clear();
    PreProcessData(data, "TrainData.txt");
    for(int i = 0; i < data[0].x.size(); i++)
        w.push_back(0);
}
 
Type WX(const Vector<Type>& w, const Data& data)
{
    Type ans = 0;
    for(int i = 0; i < w.size(); i++)
        ans += w[i] * data.x[i];
    return ans;
}
 
Type Sigmoid(const Vector<Type>& w, const Data& data)
{
    Type x = WX(w, data);
    Type ans = exp(x) / (1 + exp(x));
    return ans;
}
 
void PreMatrix(Matrix<Type> &H, Matrix<Type> &U, const Vector<Data> &data, Vector<Type> &w)
{
    int ROWS = data[0].x.size();
    int COLS = data.size();
    Matrix<Type> A(COLS, COLS), P(ROWS, COLS), Q(COLS, 1), X(COLS, ROWS);
    for(int i = 0; i < COLS; i++)
    {
        Type t = Sigmoid(w, data[i]);
        A.put(i, i, t *(1 - t));
        Q.put(i, 0, data[i].y - t);
    }
    for(int i = 0; i < ROWS; i++)
    {
        for(int j = 0; j < COLS; j++)
            P.put(i, j, data[j].x[i]);
    }
    X = P.getTranspose();
 
    /** 计算矩阵U和矩阵H的值 */
    U = P * Q;
    H = X.getTranspose() * A * X;
}
 
Vector<Type> Matrix2Vector(Matrix<Type> &M)
{
    Vector<Type> X;
    X.clear();
    int ROWS = M.getRows();
    for(int i = 0; i < ROWS; i++)
        X.push_back(M.get(i, 0));
    return X;
}
 
Matrix<Type> Vector2Matrix(Vector<Type> &X)
{
    int ROWS = X.size();
    Matrix<Type> matrix(ROWS, 1);
    for(int i = 0; i < ROWS; i++)
        matrix.put(i, 0, X[i]);
    return matrix;
}
 
/** Cholesky分解得到矩阵L和矩阵D */
void Cholesky(Matrix<Type> &H, Matrix<Type> &L, Matrix<Type> &D)
{
    Type t = 0;
    int n = H.getRows();
    for(int k = 0; k < n; k++)
    {
        for(int i = 0; i < k; i++)
        {
            t = H.get(i, i) * H.get(k, i) * H.get(k, i);
            H.put(k, k, H.get(k, k) - t);
        }
        for(int j = k + 1; j < n; j++)
        {
            for(int i = 0; i < k; i++)
            {
                t = H.get(j, i) * H.get(i, i) * H.get(k, i);
                H.put(j, k, H.get(j, k) - t);
            }
            t = H.get(j, k) / H.get(k, k);
            H.put(j, k, t);
        }
    }
    for(int i = 0; i < n; i++)
    {
        D.put(i, i, H.get(i, i));
        L.put(i, i, 1);
        for(int j = 0; j < i; j++)
            L.put(i, j, H.get(i, j));
    }
}
 
/** 回带求出线性方程组的解 */
void Solve(Matrix<Type> &H, Vector<Type> &X)
{
    int ROWS = H.getRows();
    int COLS = H.getColumns();
    Matrix<Type> L(ROWS, COLS), D(ROWS, COLS);
    Cholesky(H, L, D);
 
    int n = ROWS;
    for(int k = 0; k < n; k++)
    {
        for(int i = 0; i < k; i++)
            X[k] -= X[i] * L.get(k, i);
        X[k] /= L.get(k, k);
    }
    L = D * L.getTranspose();
    for(int k = n - 1; k >= 0; k--)
    {
        for(int i = k + 1; i < n; i++)
            X[k] -= X[i] * L.get(k, i);
        X[k] /= L.get(k, k);
    }
}
 
/** 打印迭代步骤 */
void Display(int cnt, Type error, Vector<Type> w)
{
    cout<<"第"<<cnt<<"次迭代前后的目标差为: "<<error<<endl;
    cout<<"参数w为: ";
    for(int i = 0; i < w.size(); i++)
        cout<<w[i]<<" ";
    cout<<endl;
    cout<<endl;
}
 
Type StopFlag(Vector<Type> w1, Vector<Type> w2)
{
    Type ans = 0;
    int size = w1.size();
    for(int i = 0; i < size; i++)
        ans += 0.5 * (w1[i] - w2[i]) * (w1[i] - w2[i]);
    return ans;
}
 
/** 牛顿迭代步骤 */
void NewtonIter(Vector<Data> &data, Vector<Type> &w)
{
    int cnt = 0;
    Type delta = 0.0001;
    int ROWS = data[0].x.size();
    int COLS = data.size();
 
    while(1)
    {
        Matrix<Type> H(ROWS, ROWS), U(ROWS, 1), W(ROWS, 1);
        PreMatrix(H, U, data, w);
        Vector<Type> X = Matrix2Vector(U);
        Solve(H, X);
        Matrix<Type> x = Vector2Matrix(X);
        W = Vector2Matrix(w);
        W += x;
        Vector<Type> _w = Matrix2Vector(W);
        Type error = StopFlag(_w, w);
        w = _w;
        cnt++;
        Display(cnt, error, w);
        if(error < delta) break;
    }
}
 
/** 训练数据得到w数组,构造分类器 */
void TrainData(Vector<Data> &data, Vector<Type> &w)
{
    Init(data, w);
    NewtonIter(data, w);
}
 
/** 根据构造好的分类器对数据进行分类 */
void Separator(Vector<Type> w)
{
    vector<Data> data;
    PreProcessData(data, "TestData.txt");
    cout<<"预测分类结果:"<<endl;
    for(int i = 0; i < data.size(); i++)
    {
        Type p0 = 0;
        Type p1 = 0;
        Type x = WX(w, data[i]);
        p1 = exp(x) / (1 + exp(x));
        p0 = 1 - p1;
        cout<<"实例: ";
        for(int j = 0; j < data[i].x.size(); j++)
            cout<<data[i].x[j]<<" ";
        cout<<"所属类别为:";
        if(p1 >= p0) cout<<1<<endl;
        else cout<<0<<endl;
    }
}
 
int main()
{
    Vector<Type> w;
    Vector<Data> data;
    TrainData(data, w);
    Separator(w);
    return 0;
}

上面的牛顿迭代法代码中用到了矩阵操作,使用头文件matrix.h,这是一个很优秀的矩阵第三方库,代码如下

代码:matrix.h


/*****************************************************************************/
/* Name: matrix.h                                                            */
/* Uses: Class for matrix math functions.                                    */
/* Date: 4/19/2011                                                           */
/* Author: Andrew Que <http://www.DrQue.net/>                                */
/* Revisions:                                                                */
/*   0.1 - 2011/04/19 - QUE - Creation.                                      */
/*   0.5 - 2011/04/24 - QUE - Most functions are complete.                   */
/*   0.8 - 2011/05/01 - QUE -                                                */
/*     = Bug fixes.                                                          */
/*     + Dot product.                                                        */
/*   1.0 - 2011/11/26 - QUE - Release.                                       */
/*                                                                           */
/* Notes:                                                                    */
/*   This unit implements some very basic matrix functions, which include:   */
/*    + Addition/subtraction                                                 */
/*    + Transpose                                                            */
/*    + Row echelon reduction                                                */
/*    + Determinant                                                          */
/*    + Dot product                                                          */
/*    + Matrix product                                                       */
/*    + Scalar product                                                       */
/*    + Inversion                                                            */
/*    + LU factorization/decomposition                                       */
/*     There isn't much for optimization in this unit as it was designed as  */
/*   more of a learning experience.                                          */
/*                                                                           */
/* License:                                                                  */
/*   This program is free software: you can redistribute it and/or modify    */
/*   it under the terms of the GNU General Public License as published by    */
/*   the Free Software Foundation, either version 3 of the License, or       */
/*   (at your option) any later version.                                     */
/*                                                                           */
/*   This program is distributed in the hope that it will be useful,         */
/*   but WITHOUT ANY WARRANTY; without even the implied warranty of          */
/*   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the           */
/*   GNU General Public License for more details.                            */
/*                                                                           */
/*   You should have received a copy of the GNU General Public License       */
/*   along with this program.  If not, see <http://www.gnu.org/licenses/>.   */
/*                                                                           */
/*                     (C) Copyright 2011 by Andrew Que                      */
/*                           http://www.DrQue.net/                           */
/*****************************************************************************/
#ifndef _MATRIX_H_
#define _MATRIX_H_
 
#include <iostream>
#include <cassert>
#include <climits>
#include <vector>
 
// Class forward for identity matrix.
template< class TYPE > class IdentityMatrix;
 
//=============================================================================
// Matrix template class
//   Contains a set of matrix manipulation functions.  The template is designed
// so that the values of the matrix can be of any type that allows basic
// arithmetic.
//=============================================================================
template< class TYPE = int >
  class Matrix
  {
    protected:
      // Matrix data.
      unsigned rows;
      unsigned columns;
 
      // Storage for matrix data.
      std::vector< std::vector< TYPE > > matrix;
 
      // Order sub-index for rows.
      //   Use: matrix[ order[ row ] ][ column ].
      unsigned * order;
 
      //-------------------------------------------------------------
      // Return the number of leading zeros in the given row.
      //-------------------------------------------------------------
      unsigned getLeadingZeros
      (
        // Row to count
        unsigned row
      ) const
      {
        TYPE const ZERO = static_cast< TYPE >( 0 );
        unsigned column = 0;
        while ( ZERO == matrix[ row ][ column ] )
          ++column;
 
        return column;
      }
 
      //-------------------------------------------------------------
      // Reorder the matrix so the rows with the most zeros are at
      // the end, and those with the least at the beginning.
      //
      // NOTE: The matrix data itself is not manipulated, just the
      // 'order' sub-indexes.
      //-------------------------------------------------------------
      void reorder()
      {
        unsigned * zeros = new unsigned[ rows ];
 
        for ( unsigned row = 0; row < rows; ++row )
        {
          order[ row ] = row;
          zeros[ row ] = getLeadingZeros( row );
        }
 
        for ( unsigned row = 0; row < (rows-1); ++row )
        {
          unsigned swapRow = row;
          for ( unsigned subRow = row + 1; subRow < rows; ++subRow )
          {
            if ( zeros[ order[ subRow ] ] < zeros[ order[ swapRow ] ] )
              swapRow = subRow;
          }
 
          unsigned hold    = order[ row ];
          order[ row ]     = order[ swapRow ];
          order[ swapRow ] = hold;
        }
 
        delete zeros;
      }
 
      //-------------------------------------------------------------
      // Divide a row by given value.  An elementary row operation.
      //-------------------------------------------------------------
      void divideRow
      (
        // Row to divide.
        unsigned row,
 
        // Divisor.
        TYPE const & divisor
      )
      {
        for ( unsigned column = 0; column < columns; ++column )
          matrix[ row ][ column ] /= divisor;
      }
 
      //-------------------------------------------------------------
      // Modify a row by adding a scaled row. An elementary row
      // operation.
      //-------------------------------------------------------------
      void rowOperation
      (
        unsigned row,
        unsigned addRow,
        TYPE const & scale
      )
      {
        for ( unsigned column = 0; column < columns; ++column )
          matrix[ row ][ column ] += matrix[ addRow ][ column ] * scale;
      }
 
      //-------------------------------------------------------------
      // Allocate memory for matrix data.
      //-------------------------------------------------------------
      void allocate
      (
        unsigned rowNumber,
        unsigned columnNumber
      )
      {
        // Allocate order integers.
        order = new unsigned[ rowNumber ];
 
        // Setup matrix sizes.
        matrix.resize( rowNumber );
        for ( unsigned row = 0; row < rowNumber; ++row )
          matrix[ row ].resize( columnNumber );
      }
 
      //-------------------------------------------------------------
      // Free memory used for matrix data.
      //-------------------------------------------------------------
      void deallocate
      (
        unsigned rowNumber,
        unsigned columnNumber
      )
      {
        // Free memory used for storing order (if there is any).
        if ( 0 != rowNumber )
          delete[] order;
      }
 
    public:
      // Used for matrix concatenation.
      typedef enum
      {
        TO_RIGHT,
        TO_BOTTOM
      } Position;
 
      //-------------------------------------------------------------
      // Return the number of rows in this matrix.
      //-------------------------------------------------------------
      unsigned getRows() const
      {
        return rows;
      }
 
      //-------------------------------------------------------------
      // Return the number of columns in this matrix.
      //-------------------------------------------------------------
      unsigned getColumns() const
      {
        return columns;
      }
 
      //-------------------------------------------------------------
      // Get an element of the matrix.
      //-------------------------------------------------------------
      TYPE get
      (
        unsigned row,   // Which row.
        unsigned column // Which column.
      ) const
      {
        assert( row < rows );
        assert( column < columns );
 
        return matrix[ row ][ column ];
      }
 
      //-------------------------------------------------------------
      // Proform LU decomposition.
      // This will create matrices L and U such that A=LxU
      //-------------------------------------------------------------
      void LU_Decomposition
      (
        Matrix & upper,
        Matrix & lower
      ) const
      {
        assert( rows == columns );
 
        TYPE const ZERO = static_cast< TYPE >( 0 );
 
        upper = *this;
        lower = *this;
 
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
            lower.matrix[ row ][ column ] = ZERO;
 
        for ( unsigned row = 0; row < rows; ++row )
        {
          TYPE value = upper.matrix[ row ][ row ];
          if ( ZERO != value )
          {
            upper.divideRow( row, value );
            lower.matrix[ row ][ row ] = value;
          }
 
          for ( unsigned subRow = row + 1; subRow < rows; ++subRow )
          {
            TYPE value = upper.matrix[ subRow ][ row ];
            upper.rowOperation( subRow, row, -value );
            lower.matrix[ subRow ][ row ] = value;
          }
        }
      }
 
      //-------------------------------------------------------------
      // Set an element in the matrix.
      //-------------------------------------------------------------
      void put
      (
        unsigned row,
        unsigned column,
        TYPE const & value
      )
      {
        assert( row < rows );
        assert( column < columns );
 
        matrix[ row ][ column ] = value;
      }
 
      //-------------------------------------------------------------
      // Return part of the matrix.
      // NOTE: The end points are the last elements copied.  They can
      // be equal to the first element when wanting just a single row
      // or column.  However, the span of the total matrix is
      // ( 0, rows - 1, 0, columns - 1 ).
      //-------------------------------------------------------------
      Matrix getSubMatrix
      (
        unsigned startRow,
        unsigned endRow,
        unsigned startColumn,
        unsigned endColumn,
        unsigned const * newOrder = NULL
      )
      {
        Matrix subMatrix( endRow - startRow + 1, endColumn - startColumn + 1 );
 
        for ( unsigned row = startRow; row <= endRow; ++row )
        {
          unsigned subRow;
          if ( NULL == newOrder )
            subRow = row;
          else
            subRow = newOrder[ row ];
 
          for ( unsigned column = startColumn; column <= endColumn; ++column )
            subMatrix.matrix[ row - startRow ][ column - startColumn ] =
              matrix[ subRow ][ column ];
        }
 
        return subMatrix;
      }
 
      //-------------------------------------------------------------
      // Return a single column from the matrix.
      //-------------------------------------------------------------
      Matrix getColumn
      (
        unsigned column
      )
      {
        return getSubMatrix( 0, rows - 1, column, column );
      }
 
      //-------------------------------------------------------------
      // Return a single row from the matrix.
      //-------------------------------------------------------------
      Matrix getRow
      (
        unsigned row
      )
      {
        return getSubMatrix( row, row, 0, columns - 1 );
      }
 
      //-------------------------------------------------------------
      // Place matrix in reduced row echelon form.
      //-------------------------------------------------------------
      void reducedRowEcholon()
      {
        TYPE const ZERO = static_cast< TYPE >( 0 );
 
        // For each row...
        for ( unsigned rowIndex = 0; rowIndex < rows; ++rowIndex )
        {
          // Reorder the rows.
          reorder();
 
          unsigned row = order[ rowIndex ];
 
          // Divide row down so first term is 1.
          unsigned column = getLeadingZeros( row );
          TYPE divisor = matrix[ row ][ column ];
          if ( ZERO != divisor )
          {
            divideRow( row, divisor );
 
            // Subtract this row from all subsequent rows.
            for ( unsigned subRowIndex = ( rowIndex + 1 ); subRowIndex < rows; ++subRowIndex )
            {
              unsigned subRow = order[ subRowIndex ];
              if ( ZERO != matrix[ subRow ][ column ] )
                rowOperation
                (
                  subRow,
                  row,
                  -matrix[ subRow ][ column ]
                );
            }
          }
 
        }
 
        // Back substitute all lower rows.
        for ( unsigned rowIndex = ( rows - 1 ); rowIndex > 0; --rowIndex )
        {
          unsigned row = order[ rowIndex ];
          unsigned column = getLeadingZeros( row );
          for ( unsigned subRowIndex = 0; subRowIndex < rowIndex; ++subRowIndex )
          {
            unsigned subRow = order[ subRowIndex ];
            rowOperation
            (
              subRow,
              row,
              -matrix[ subRow ][ column ]
            );
          }
        }
 
      } // reducedRowEcholon
 
      //-------------------------------------------------------------
      // Return the determinant of the matrix.
      // Recursive function.
      //-------------------------------------------------------------
      TYPE determinant() const
      {
        TYPE result = static_cast< TYPE >( 0 );
 
        // Must have a square matrix to even bother.
        assert( rows == columns );
 
        if ( rows > 2 )
        {
          int sign = 1;
          for ( unsigned column = 0; column < columns; ++column )
          {
            TYPE subDeterminant;
 
            Matrix subMatrix = Matrix( *this, 0, column );
 
            subDeterminant  = subMatrix.determinant();
            subDeterminant *= matrix[ 0 ][ column ];
 
            if ( sign > 0 )
              result += subDeterminant;
            else
              result -= subDeterminant;
 
            sign = -sign;
          }
        }
        else
        {
          result = ( matrix[ 0 ][ 0 ] * matrix[ 1 ][ 1 ] )
                 - ( matrix[ 0 ][ 1 ] * matrix[ 1 ][ 0 ] );
        }
 
        return result;
 
      } // determinant
 
      //-------------------------------------------------------------
      // Calculate a dot product between this and an other matrix.
      //-------------------------------------------------------------
      TYPE dotProduct
      (
        Matrix const & otherMatrix
      ) const
      {
        // Dimentions of each matrix must be the same.
        assert( rows == otherMatrix.rows );
        assert( columns == otherMatrix.columns );
 
        TYPE result = static_cast< TYPE >( 0 );
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
          {
            result +=
              matrix[ row ][ column ]
              * otherMatrix.matrix[ row ][ column ];
          }
 
        return result;
 
      } // dotProduct
 
      //-------------------------------------------------------------
      // Return the transpose of the matrix.
      //-------------------------------------------------------------
      Matrix const getTranspose() const
      {
        Matrix result( columns, rows );
 
        // Transpose the matrix by filling the result's rows will
        // these columns, and vica versa.
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
            result.matrix[ column ][ row ] = matrix[ row ][ column ];
 
        return result;
 
      } // transpose
 
      //-------------------------------------------------------------
      // Transpose the matrix.
      //-------------------------------------------------------------
      void transpose()
      {
        *this = getTranspose();
      }
 
      //-------------------------------------------------------------
      // Return inverse matrix.
      //-------------------------------------------------------------
      Matrix const getInverse() const
      {
        // Concatenate the identity matrix onto this matrix.
        Matrix inverseMatrix
          (
            *this,
            IdentityMatrix< TYPE >( rows, columns ),
            TO_RIGHT
          );
 
        // Row reduce this matrix.  This will result in the identity
        // matrix on the left, and the inverse matrix on the right.
        inverseMatrix.reducedRowEcholon();
 
        // Copy the inverse matrix data back to this matrix.
        Matrix result
        (
          inverseMatrix.getSubMatrix
          (
            0,
            rows - 1,
            columns,
            columns + columns - 1,
            inverseMatrix.order
          )
        );
 
        return result;
 
      } // invert
 
 
      //-------------------------------------------------------------
      // Invert this matrix.
      //-------------------------------------------------------------
      void invert()
      {
        *this = getInverse();
 
      } // invert
 
      //=======================================================================
      // Operators.
      //=======================================================================
 
      //-------------------------------------------------------------
      // Add by an other matrix.
      //-------------------------------------------------------------
      Matrix const operator +
      (
        Matrix const & otherMatrix
      ) const
      {
        assert( otherMatrix.rows == rows );
        assert( otherMatrix.columns == columns );
 
        Matrix result( rows, columns );
 
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
            result.matrix[ row ][ column ] =
              matrix[ row ][ column ]
              + otherMatrix.matrix[ row ][ column ];
 
        return result;
      }
 
      //-------------------------------------------------------------
      // Add self by an other matrix.
      //-------------------------------------------------------------
      Matrix const & operator +=
      (
        Matrix const & otherMatrix
      )
      {
        *this = *this + otherMatrix;
        return *this;
      }
 
      //-------------------------------------------------------------
      // Subtract by an other matrix.
      //-------------------------------------------------------------
      Matrix const operator -
      (
        Matrix const & otherMatrix
      ) const
      {
        assert( otherMatrix.rows == rows );
        assert( otherMatrix.columns == columns );
 
        Matrix result( rows, columns );
 
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
            result.matrix[ row ][ column ] =
              matrix[ row ][ column ]
              - otherMatrix.matrix[ row ][ column ];
 
        return result;
      }
 
      //-------------------------------------------------------------
      // Subtract self by an other matrix.
      //-------------------------------------------------------------
      Matrix const & operator -=
      (
        Matrix const & otherMatrix
      )
      {
        *this = *this - otherMatrix;
        return *this;
      }
 
      //-------------------------------------------------------------
      // Matrix multiplication.
      //-------------------------------------------------------------
      Matrix const operator *
      (
        Matrix const & otherMatrix
      ) const
      {
        TYPE const ZERO = static_cast< TYPE >( 0 );
 
        assert( otherMatrix.rows == columns );
 
        Matrix result( rows, otherMatrix.columns );
 
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < otherMatrix.columns; ++column )
          {
            result.matrix[ row ][ column ] = ZERO;
 
            for ( unsigned index = 0; index < columns; ++index )
              result.matrix[ row ][ column ] +=
                matrix[ row ][ index ]
                * otherMatrix.matrix[ index ][ column ];
          }
 
        return result;
      }
 
      //-------------------------------------------------------------
      // Multiply self by matrix.
      //-------------------------------------------------------------
      Matrix const & operator *=
      (
        Matrix const & otherMatrix
      )
      {
        *this = *this * otherMatrix;
        return *this;
      }
 
      //-------------------------------------------------------------
      // Multiply by scalar constant.
      //-------------------------------------------------------------
      Matrix const operator *
      (
        TYPE const & scalar
      ) const
      {
        Matrix result( rows, columns );
 
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
            result.matrix[ row ][ column ] = matrix[ row ][ column ] * scalar;
 
        return result;
      }
 
      //-------------------------------------------------------------
      // Multiply self by scalar constant.
      //-------------------------------------------------------------
      Matrix const & operator *=
      (
        TYPE const & scalar
      )
      {
        *this = *this * scalar;
        return *this;
      }
 
      //-------------------------------------------------------------
      // Copy matrix.
      //-------------------------------------------------------------
      Matrix & operator =
      (
        Matrix const & otherMatrix
      )
      {
        if ( this == &otherMatrix )
          return *this;
 
        // Release memory currently in use.
        deallocate( rows, columns );
 
        rows    = otherMatrix.rows;
        columns = otherMatrix.columns;
        allocate( rows, columns );
 
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
            matrix[ row ][ column ] =
            otherMatrix.matrix[ row ][ column ];
 
        return *this;
      }
 
      //-------------------------------------------------------------
      // Copy matrix data from array.
      // Although matrix data is two dimensional, this copy function
      // assumes the previous row is immediately followed by the next
      // row's data.
      //
      // Example for 3x2 matrix:
      //     int const data[ 3 * 2 ] =
      //     {
      //       1, 2, 3,
      //       4, 5, 6
      //     };
      //    Matrix< int > matrix( 3, 2 );
      //    matrix = data;
      //-------------------------------------------------------------
      Matrix & operator =
      (
        TYPE const * data
      )
      {
        unsigned index = 0;
 
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
            matrix[ row ][ column ] = data[ index++ ];
 
        return *this;
      }
 
      //-----------------------------------------------------------------------
      // Return true if this matrix is the same of parameter.
      //-----------------------------------------------------------------------
      bool operator ==
      (
        Matrix const & value
      ) const
      {
        bool isEqual = true;
        for ( unsigned row = 0; row < rows; ++row )
          for ( unsigned column = 0; column < columns; ++column )
            if ( matrix[ row ][ column ] != value.matrix[ row ][ column ] )
              isEqual = false;
 
        return isEqual;
      }
 
      //-----------------------------------------------------------------------
      // Return true if this matrix is NOT the same of parameter.
      //-----------------------------------------------------------------------
      bool operator !=
      (
        Matrix const & value
      ) const
      {
        return !( *this == value );
      }
 
      //-------------------------------------------------------------
      // Constructor for empty matrix.
      // Only useful if matrix is being assigned (i.e. copied) from
      // somewhere else sometime after construction.
      //-------------------------------------------------------------
      Matrix()
      :
        rows( 0 ),
        columns( 0 )
      {
        allocate( 0, 0 );
      }
 
      //-------------------------------------------------------------
      // Constructor using rows and columns.
      //-------------------------------------------------------------
      Matrix
      (
        unsigned rowsParameter,
        unsigned columnsParameter
      )
      :
        rows( rowsParameter ),
        columns( columnsParameter )
      {
        TYPE const ZERO = static_cast< TYPE >( 0 );
 
        // Allocate memory for new matrix.
        allocate( rows, columns );
 
        // Fill matrix with zero.
        for ( unsigned row = 0; row < rows; ++row )
        {
          order[ row ] = row;
 
          for ( unsigned column = 0; column < columns; ++column )
            matrix[ row ][ column ] = ZERO;
        }
      }
 
      //-------------------------------------------------------------
      // This constructor will allow the creation of a matrix based off
      // an other matrix.  It can copy the matrix entirely, or omitted a
      // row/column.
      //-------------------------------------------------------------
      Matrix
      (
        Matrix const & copyMatrix,
        unsigned omittedRow    = INT_MAX,
        unsigned omittedColumn = INT_MAX
      )
      {
        // Start with the number of rows/columns from matrix to be copied.
        rows    = copyMatrix.getRows();
        columns = copyMatrix.getColumns();
 
        // If a row is omitted, then there is one less row.
        if ( INT_MAX != omittedRow  )
          rows--;
 
        // If a column is omitted, then there is one less column.
        if ( INT_MAX != omittedColumn )
          columns--;
 
        // Allocate memory for new matrix.
        allocate( rows, columns );
 
        unsigned rowIndex = 0;
        for ( unsigned row = 0; row < rows; ++row )
        {
          // If this row is to be skipped...
          if ( rowIndex == omittedRow )
            rowIndex++;
 
          // Set default order.
          order[ row ] = row;
 
          unsigned columnIndex = 0;
          for ( unsigned column = 0; column < columns; ++column )
          {
            // If this column is to be skipped...
            if ( columnIndex == omittedColumn )
              columnIndex++;
 
            matrix[ row ][ column ] = copyMatrix.matrix[ rowIndex ][ columnIndex ];
 
            columnIndex++;
          }
 
          ++rowIndex;
        }
 
      }
 
      //-------------------------------------------------------------
      // Constructor to concatenate two matrices.  Concatenation
      // can be done to the right, or to the bottom.
      //   A = [B | C]
      //-------------------------------------------------------------
      Matrix
      (
        Matrix const & copyMatrixA,
        Matrix const & copyMatrixB,
        Position position = TO_RIGHT
      )
      {
        unsigned rowOffset    = 0;
        unsigned columnOffset = 0;
 
        if ( TO_RIGHT == position )
          columnOffset = copyMatrixA.columns;
        else
          rowOffset = copyMatrixA.rows;
 
        rows    = copyMatrixA.rows    + rowOffset;
        columns = copyMatrixA.columns + columnOffset;
 
        // Allocate memory for new matrix.
        allocate( rows, columns );
 
        for ( unsigned row = 0; row < copyMatrixA.rows; ++row )
          for ( unsigned column = 0; column < copyMatrixA.columns; ++column )
            matrix[ row ][ column ] = copyMatrixA.matrix[ row ][ column ];
 
        for ( unsigned row = 0; row < copyMatrixB.rows; ++row )
          for ( unsigned column = 0; column < copyMatrixB.columns; ++column )
            matrix[ row + rowOffset ][ column + columnOffset ] =
              copyMatrixB.matrix[ row ][ column ];
      }
 
      //-------------------------------------------------------------
      // Destructor.
      //-------------------------------------------------------------
      ~Matrix()
      {
        // Release memory.
        deallocate( rows, columns );
      }
 
  };
 
//=============================================================================
// Class for identity matrix.
//=============================================================================
template< class TYPE >
  class IdentityMatrix : public Matrix< TYPE >
  {
    public:
      IdentityMatrix
      (
        unsigned rowsParameter,
        unsigned columnsParameter
      )
      :
        Matrix< TYPE >( rowsParameter, columnsParameter )
      {
        TYPE const ZERO = static_cast< TYPE >( 0 );
        TYPE const ONE  = static_cast< TYPE >( 1 );
 
        for ( unsigned row = 0; row < Matrix< TYPE >::rows; ++row )
        {
          for ( unsigned column = 0; column < Matrix< TYPE >::columns; ++column )
            if ( row == column )
              Matrix< TYPE >::matrix[ row ][ column ] = ONE;
            else
              Matrix< TYPE >::matrix[ row ][ column ] = ZERO;
        }
      }
  };
 
//-----------------------------------------------------------------------------
// Stream operator used to convert matrix class to a string.
//-----------------------------------------------------------------------------
template< class TYPE >
  std::ostream & operator<<
  (
    // Stream data to place string.
    std::ostream & stream,
 
    // A matrix.
    Matrix< TYPE > const & matrix
  )
  {
    for ( unsigned row = 0; row < matrix.getRows(); ++row )
    {
      for ( unsigned column = 0; column < matrix.getColumns(); ++column )
        stream << "\t" << matrix.get( row , column );
 
      stream << std::endl;
    }
 
    return stream;
  }
 
#endif // _MATRIX_H_

 

举报

相关推荐

0 条评论