2 matrix.cc -- implement Matrix
4 source file of the Flower Library
6 (c) 1997 Han-Wen Nienhuys <hanwen@stack.nl>
10 #include "full-storage.hh"
11 #include "diagonal-storage.hh"
14 Matrix::band_b() const
16 return dat->is_type_b (Diagonal_storage::static_name());
20 Matrix::set_full() const
22 if ( dat->name() != Full_storage::static_name ())
24 Matrix_storage::set_full (((Matrix*)this)->dat);
29 Matrix::try_set_band() const
37 // it only looks constant
38 Matrix*self = (Matrix*)this;
39 Matrix_storage::set_band (self->dat,b);
46 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
47 r += sqr (dat->elem (i,j));
54 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
59 Matrix::set_diag (Real r)
61 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
62 dat->elem (i,j)=(i==j) ? r: 0.0;
66 Matrix::set_diag (Vector d)
68 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
69 dat->elem (i,j)=(i==j) ? d (i): 0.0;
73 Matrix::operator+=(Matrix const &m)
75 Matrix_storage::set_addition_result (dat, m.dat);
76 assert (m.cols() == cols ());
77 assert (m.rows() == rows ());
78 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
79 dat->elem (i,j) += m (i,j);
83 Matrix::operator-=(Matrix const &m)
85 Matrix_storage::set_addition_result (dat, m.dat);
86 assert (m.cols() == cols ());
87 assert (m.rows() == rows ());
88 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
89 dat->elem (i,j) -= m (i,j);
94 Matrix::operator*=(Real a)
96 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
101 Matrix::operator=(Matrix const &m)
106 dat = m.dat->clone();
110 Matrix::band_i() const
114 Diagonal_storage const * diag = (Diagonal_storage*) dat;
115 return diag->band_size_i();
120 for ( int i = starty, j = 0; i < dim(); i++, j++)
123 for ( int i=0, j = starty; j < dim(); i++,j++)
132 Matrix::Matrix (Matrix const &m)
136 dat = m.dat->clone();
140 Matrix::Matrix (int n, int m)
142 dat = Matrix_storage::get_full (n,m);
146 Matrix::Matrix (Matrix_storage*stor_p)
151 Matrix::Matrix (int n)
153 dat = Matrix_storage::get_full (n,n);
157 Matrix::Matrix (Vector v, Vector w)
159 dat = Matrix_storage::get_full (v.dim(), w.dim ());
160 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
161 dat->elem (i,j)=v (i)*w (j);
166 Matrix::row (int k) const
172 for (int i=0; i < n; i++)
173 v (i)=dat->elem (k,i);
179 Matrix::col (int k) const
183 for (int i=0; i < n; i++)
184 v (i)=dat->elem (i,k);
189 Matrix::left_multiply (Vector const & v) const
191 Vector dest (v.dim());
192 assert (dat->cols()==v.dim ());
193 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
194 dest (i)+= dat->elem (j,i)*v (j);
199 Matrix::operator *(Vector const & v) const
201 Vector dest (rows());
202 assert (dat->cols()==v.dim ());
203 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
204 dest (i)+= dat->elem (i,j)*v (j);
209 operator /(Matrix const& m1,Real a)
217 ugh. Only works for square matrices.
220 Matrix::transpose() // delegate to storage?
223 for (int i=0, j=0; dat->mult_ok (i,j); dat->mult_next (i,j))
227 Real r=dat->elem (i,j);
228 dat->elem (i,j) = dat->elem (j,i);
235 Matrix::operator-() const
244 Matrix::transposed() const
252 operator *(Matrix const &m1, Matrix const &m2)
254 Matrix result (Matrix_storage::get_product_result (m1.dat, m2.dat));
256 result.set_product (m1,m2);
261 Matrix::set_product (Matrix const &m1, Matrix const &m2)
263 assert (m1.cols()==m2.rows ());
264 assert (cols()==m2.cols () && rows ()==m1.rows ());
266 if (m1.dat->try_right_multiply (dat, m2.dat))
269 for (int i=0, j=0; dat->mult_ok (i,j);
270 dat->mult_next (i,j))
273 for (int k = 0; k < m1.cols(); k++)
274 r += m1(i,k)*m2(k,j);
280 Matrix::insert_row (Vector v, int k)
283 assert (v.dim()==cols ());
285 for (int j=0; j < c; j++)
286 dat->elem (k,j)=v (j);
291 Matrix::swap_columns (int c1, int c2)
293 assert (c1>=0&& c1 < cols()&&c2 < cols () && c2 >=0);
295 for (int i=0; i< r; i++)
297 Real r=dat->elem (i,c1);
298 dat->elem (i,c1) = dat->elem (i,c2);
304 Matrix::swap_rows (int c1, int c2)
306 assert (c1>=0&& c1 < rows()&&c2 < rows () && c2 >=0);
308 for (int i=0; i< c; i++)
310 Real r=dat->elem (c1,i);
311 dat->elem (c1,i) = dat->elem (c2,i);
320 assert (cols() == rows ());