2 matrix.cc -- implement Matrix
4 source file of the Flower Library
6 (c) 1997 Han-Wen Nienhuys <hanwen@stack.nl>
10 #include "full-storage.hh"
11 #include "diagonal-storage.hh"
16 return dat->is_type_b (Diagonal_storage::static_name());
20 Matrix::set_full() const
22 if ( dat->name() != Full_storage::static_name ()) {
23 Matrix_storage::set_full (((Matrix*)this)->dat);
28 Matrix::try_set_band()const
36 // it only looks constant
37 Matrix*self = (Matrix*)this;
38 Matrix_storage::set_band (self->dat,b);
45 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
46 r += sqr (dat->elem (i,j));
53 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
58 Matrix::set_diag (Real r)
60 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
61 dat->elem (i,j)=(i==j) ? r: 0.0;
65 Matrix::set_diag (Vector d)
67 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
68 dat->elem (i,j)=(i==j) ? d (i): 0.0;
72 Matrix::operator+=(Matrix const &m)
74 Matrix_storage::set_addition_result (dat, m.dat);
75 assert (m.cols() == cols ());
76 assert (m.rows() == rows ());
77 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
78 dat->elem (i,j) += m (i,j);
82 Matrix::operator-=(Matrix const &m)
84 Matrix_storage::set_addition_result (dat, m.dat);
85 assert (m.cols() == cols ());
86 assert (m.rows() == rows ());
87 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
88 dat->elem (i,j) -= m (i,j);
93 Matrix::operator*=(Real a)
95 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
100 Matrix::operator=(Matrix const &m)
105 dat = m.dat->clone();
109 Matrix::band_i()const
112 Diagonal_storage const * diag = (Diagonal_storage*) dat;
113 return diag->band_size_i();
116 while (starty >= 0) {
117 for ( int i = starty, j = 0; i < dim(); i++, j++)
120 for ( int i=0, j = starty; j < dim(); i++,j++)
129 Matrix::Matrix (Matrix const &m)
133 dat = m.dat->clone();
137 Matrix::Matrix (int n, int m)
139 dat = Matrix_storage::get_full (n,m);
143 Matrix::Matrix (Matrix_storage*stor_p)
148 Matrix::Matrix (int n)
150 dat = Matrix_storage::get_full (n,n);
154 Matrix::Matrix (Vector v, Vector w)
156 dat = Matrix_storage::get_full (v.dim(), w.dim ());
157 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
158 dat->elem (i,j)=v (i)*w (j);
163 Matrix::row (int k) const
169 for (int i=0; i < n; i++)
170 v (i)=dat->elem (k,i);
176 Matrix::col (int k) const
180 for (int i=0; i < n; i++)
181 v (i)=dat->elem (i,k);
186 Matrix::left_multiply (Vector const & v) const
188 Vector dest (v.dim());
189 assert (dat->cols()==v.dim ());
190 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
191 dest (i)+= dat->elem (j,i)*v (j);
196 Matrix::operator *(Vector const & v) const
198 Vector dest (rows());
199 assert (dat->cols()==v.dim ());
200 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
201 dest (i)+= dat->elem (i,j)*v (j);
206 operator /(Matrix const& m1,Real a)
214 ugh. Only works for square matrices.
217 Matrix::transpose() // delegate to storage?
220 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j)) {
223 Real r=dat->elem (i,j);
224 dat->elem (i,j) = dat->elem (j,i);
231 Matrix::operator-() const
240 Matrix::transposed() const
248 operator *(Matrix const &m1, Matrix const &m2)
250 Matrix result (Matrix_storage::get_product_result (m1.dat, m2.dat));
252 result.set_product (m1,m2);
257 Matrix::set_product (Matrix const &m1, Matrix const &m2)
259 assert (m1.cols()==m2.rows ());
260 assert (cols()==m2.cols () && rows ()==m1.rows ());
262 if (m1.dat->try_right_multiply (dat, m2.dat))
265 for (int i=0, j=0; dat->mult_ok (i,j);
266 dat->mult_next (i,j)) {
268 for (int k = 0; k < m1.cols(); k++)
269 r += m1(i,k)*m2(k,j);
275 Matrix::insert_row (Vector v, int k)
278 assert (v.dim()==cols ());
280 for (int j=0; j < c; j++)
281 dat->elem (k,j)=v (j);
286 Matrix::swap_columns (int c1, int c2)
288 assert (c1>=0&& c1 < cols()&&c2 < cols () && c2 >=0);
290 for (int i=0; i< r; i++) {
291 Real r=dat->elem (i,c1);
292 dat->elem (i,c1) = dat->elem (i,c2);
298 Matrix::swap_rows (int c1, int c2)
300 assert (c1>=0&& c1 < rows()&&c2 < rows () && c2 >=0);
302 for (int i=0; i< c; i++) {
303 Real r=dat->elem (c1,i);
304 dat->elem (c1,i) = dat->elem (c2,i);
313 assert (cols() == rows ());