Edinburgh Speech Tools  2.1-release
wagon_aux.cc
Go to the documentation of this file.
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1996,1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : May 1996 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* Various method functions */
38 /*=======================================================================*/
39 
40 #include <cstdlib>
41 #include <iostream>
42 #include <cstring>
43 #include "EST_unix.h"
44 #include "EST_cutils.h"
45 #include "EST_Token.h"
46 #include "EST_Wagon.h"
47 #include "EST_math.h"
48 
49 using namespace std;
50 
52 {
53  if (leaf())
54  return impurity.value();
55  else if (question.ask(d))
56  return left->predict(d);
57  else
58  return right->predict(d);
59 }
60 
62 {
63  if (leaf())
64  return this;
65  else if (question.ask(d))
66  return left->predict_node(d);
67  else
68  return right->predict_node(d);
69 }
70 
71 int WNode::pure(void)
72 {
73  // A node is pure if it has no sub-nodes or its not of type class
74 
75  if ((left == 0) && (right == 0))
76  return TRUE;
77  else if (get_impurity().type() != wnim_class)
78  return TRUE;
79  else
80  return FALSE;
81 }
82 
83 void WNode::prune(void)
84 {
85  // Check all sub-nodes and if they are all of the same class
86  // delete their sub nodes. Returns pureness of this node
87 
88  if (pure() == FALSE)
89  {
90  // Ok lets try and make it pure
91  if (left != 0) left->prune();
92  if (right != 0) right->prune();
93 
94  // Have to check purity as well as values to ensure left and right
95  // don't further split
96  if ((left != 0) && (left->pure() == TRUE) &&
97  (right != 0) && (right->pure() == TRUE) &&
98  (left->get_impurity().value() == right->get_impurity().value()))
99  {
100  delete left; left = 0;
101  delete right; right = 0;
102  }
103  }
104 
105 }
106 
108 {
109  // prune tree with held out data
110  // Check if node's questions differentiates for the held out data
111  // if not, prune all sub_nodes
112 
113  // Rescore with prune data
114  set_impurity(WImpurity(get_data())); // for this new data
115 
116  if (left != 0)
117  {
118  wgn_score_question(question,get_data());
119  if (question.get_score() < get_impurity().measure())
120  { // its worth goint ot the next level
121  wgn_find_split(question,get_data(),
122  left->get_data(),
123  right->get_data());
124  left->held_out_prune();
125  right->held_out_prune();
126  }
127  else
128  { // not worth the split so prune both sub_nodes
129  delete left; left = 0;
130  delete right; right = 0;
131  }
132  }
133 }
134 
135 void WNode::print_out(ostream &s, int margin)
136 {
137  int i;
138 
139  s << endl;
140  for (i=0;i<margin;i++) s << " ";
141  s << "(";
142  if (left==0) // base case
143  s << impurity;
144  else
145  {
146  s << question;
147  left->print_out(s,margin+1);
148  right->print_out(s,margin+1);
149  }
150  s << ")";
151 }
152 
153 ostream & operator <<(ostream &s, WNode &n)
154 {
155  // Output this node and its sub-node
156 
157  n.print_out(s,0);
158  s << endl;
159  return s;
160 }
161 
163 {
164  /* For ols we want to ignore anything that is categorial */
165  int i;
166 
167  for (i=0; i<dlength; i++)
168  {
169  if ((p_type[i] == wndt_binary) ||
170  (p_type[i] == wndt_float))
171  continue;
172  else
173  {
174  p_ignore[i] = TRUE;
175  }
176  }
177 
178  return;
179 }
180 
181 void WDataSet::load_description(const EST_String &fname, LISP ignores)
182 {
183  // Initialise a dataset with sizes and types
184  EST_String tname;
185  int i;
186  LISP description,d;
187 
188  description = car(vload(fname,1));
189  dlength = siod_llength(description);
190 
191  p_type.resize(dlength);
192  p_ignore.resize(dlength);
193  p_name.resize(dlength);
194 
195  if (wgn_predictee_name == "")
196  wgn_predictee = 0; // default predictee is first field
197  else
198  wgn_predictee = -1;
199 
200  for (i=0,d=description; d != NIL; d=cdr(d),i++)
201  {
202  p_name[i] = get_c_string(car(car(d)));
203  tname = get_c_string(car(cdr(car(d))));
204  p_ignore[i] = FALSE;
205  if ((wgn_predictee_name != "") && (wgn_predictee_name == p_name[i]))
206  wgn_predictee = i;
207  if ((wgn_count_field_name != "") &&
208  (wgn_count_field_name == p_name[i]))
209  wgn_count_field = i;
210  if ((tname == "count") || (i == wgn_count_field))
211  {
212  // The count must be ignored, repeat it if you want it too
213  p_type[i] = wndt_ignore; // the count must be ignored
214  p_ignore[i] = TRUE;
215  wgn_count_field = i;
216  }
217  else if ((tname == "ignore") || (siod_member_str(p_name[i],ignores)))
218  {
219  p_type[i] = wndt_ignore; // user specified ignore
220  p_ignore[i] = TRUE;
221  if (i == wgn_predictee)
222  wagon_error(EST_String("predictee \"")+p_name[i]+
223  "\" can't be ignored \n");
224  }
225  else if (siod_llength(car(d)) > 2)
226  {
227  LISP rest = cdr(car(d));
228  EST_StrList sl;
229  siod_list_to_strlist(rest,sl);
230  p_type[i] = wgn_discretes.def(sl);
231  if (streq(get_c_string(car(rest)),"_other_"))
232  wgn_discretes[p_type[i]].def_val("_other_");
233  }
234  else if (tname == "binary")
235  p_type[i] = wndt_binary;
236  else if (tname == "cluster")
237  p_type[i] = wndt_cluster;
238  else if (tname == "vector")
239  p_type[i] = wndt_vector;
240  else if (tname == "trajectory")
241  p_type[i] = wndt_trajectory;
242  else if (tname == "ols")
243  p_type[i] = wndt_ols;
244  else if (tname == "matrix")
245  p_type[i] = wndt_matrix;
246  else if (tname == "float")
247  p_type[i] = wndt_float;
248  else
249  {
250  wagon_error(EST_String("Unknown type \"")+tname+
251  "\" for field number "+itoString(i)+
252  "/"+p_name[i]+" in description file \""+fname+"\"");
253  }
254  }
255 
256  if (wgn_predictee == -1)
257  {
258  wagon_error(EST_String("predictee field \"")+wgn_predictee_name+
259  "\" not found in description ");
260  }
261 }
262 
263 int WQuestion::ask(const WVector &w) const
264 {
265  // Ask this question of the given vector
266  switch (op)
267  {
268  case wnop_equal: // for numbers
269  if (w.get_flt_val(feature_pos) == operand1.Float())
270  return TRUE;
271  else
272  return FALSE;
273  case wnop_binary: // for numbers
274  if (w.get_int_val(feature_pos) == 1)
275  return TRUE;
276  else
277  return FALSE;
278  case wnop_greaterthan:
279  if (w.get_flt_val(feature_pos) > operand1.Float())
280  return TRUE;
281  else
282  return FALSE;
283  case wnop_lessthan:
284  if (w.get_flt_val(feature_pos) < operand1.Float())
285  return TRUE;
286  else
287  return FALSE;
288  case wnop_is: // for classes
289  if (w.get_int_val(feature_pos) == operand1.Int())
290  return TRUE;
291  else
292  return FALSE;
293  case wnop_in: // for subsets -- note operand is list of ints
294  if (ilist_member(operandl,w.get_int_val(feature_pos)))
295  return TRUE;
296  else
297  return FALSE;
298  default:
299  wagon_error("Unknown test operator");
300  }
301 
302  return FALSE;
303 }
304 
305 ostream& operator<<(ostream& s, const WQuestion &q)
306 {
307  EST_String name;
308  static EST_Regex needquotes(".*[()'\";., \t\n\r].*");
309 
310  s << "(" << wgn_dataset.feat_name(q.get_fp());
311  switch (q.get_op())
312  {
313  case wnop_equal:
314  s << " = " << q.get_operand1().string();
315  break;
316  case wnop_binary:
317  break;
318  case wnop_greaterthan:
319  s << " > " << q.get_operand1().Float();
320  break;
321  case wnop_lessthan:
322  s << " < " << q.get_operand1().Float();
323  break;
324  case wnop_is:
325  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
326  name(q.get_operand1().Int());
327  s << " is ";
328  if (name.matches(needquotes))
329  s << quote_string(name,"\"","\\",1);
330  else
331  s << name;
332  break;
333  case wnop_matches:
334  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
335  name(q.get_operand1().Int());
336  s << " matches " << quote_string(name,"\"","\\",1);
337  break;
338  case wnop_in:
339  s << " in (";
340  for (int l=0; l < q.get_operandl().length(); l++)
341  {
342  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
343  name(q.get_operandl().nth(l));
344  if (name.matches(needquotes))
345  s << quote_string(name,"\"","\\",1);
346  else
347  s << name;
348  s << " ";
349  }
350  s << ")";
351  break;
352  // SunCC wont let me add this
353 // default:
354 // s << " unknown operation ";
355  }
356  s << ")";
357 
358  return s;
359 }
360 
362 {
363  // Returns the recommended value for this
364  EST_String s;
365  double prob;
366 
367  if (t==wnim_unset)
368  {
369  cerr << "WImpurity: no value currently set\n";
370  return EST_Val(0.0);
371  }
372  else if (t==wnim_class)
373  return EST_Val(p.most_probable(&prob));
374  else if (t==wnim_cluster)
375  return EST_Val(a.mean());
376  else if (t==wnim_ols) /* OLS TBA */
377  return EST_Val(a.mean());
378  else if (t==wnim_vector)
379  return EST_Val(a.mean()); /* wnim_vector */
380  else if (t==wnim_trajectory)
381  return EST_Val(a.mean()); /* NOT YET WRITTEN */
382  else
383  return EST_Val(a.mean());
384 }
385 
386 double WImpurity::samples(void)
387 {
388  if (t==wnim_float)
389  return a.samples();
390  else if (t==wnim_class)
391  return (int)p.samples();
392  else if (t==wnim_cluster)
393  return members.length();
394  else if (t==wnim_ols)
395  return members.length();
396  else if (t==wnim_vector)
397  return members.length();
398  else if (t==wnim_trajectory)
399  return members.length();
400  else
401  return 0;
402 }
403 
405 {
406  int i;
407  score = NAN;
408 
409  t=wnim_unset;
410  a.reset(); trajectory=0; l=0; width=0;
411  data = &ds; // for ols, model calculation
412  for (i=0; i < ds.n(); i++)
413  {
414  if (t == wnim_ols)
415  cumulate(i,1);
416  else if (wgn_count_field == -1)
417  cumulate((*(ds(i)))[wgn_predictee],1);
418  else
419  cumulate((*(ds(i)))[wgn_predictee],
420  (*(ds(i)))[wgn_count_field]);
421  }
422 }
423 
425 {
426  if (t == wnim_float)
427  return a.variance()*a.samples();
428  else if (t == wnim_vector)
429  return vector_impurity();
430  else if (t == wnim_trajectory)
431  return trajectory_impurity();
432  else if (t == wnim_matrix)
433  return a.variance()*a.samples();
434  else if (t == wnim_class)
435  return p.entropy()*p.samples();
436  else if (t == wnim_cluster)
437  return cluster_impurity();
438  else if (t == wnim_ols)
439  return ols_impurity(); /* RMSE for OLS model */
440  else
441  {
442  cerr << "WImpurity: can't measure unset object" << endl;
443  return 0.0;
444  }
445 }
446 
447 float WImpurity::vector_impurity()
448 {
449  // Find the mean/stddev for all values in all vectors
450  // sum the variances and multiply them by the number of members
451  EST_Litem *pp;
452  EST_Litem *countpp;
453  ssize_t i,j;
454  EST_SuffStats b;
455  double count = 1;
456 
457  a.reset();
458 
459 #if 1
460  /* simple distance */
461  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
462  {
463  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
464  {
465  b.reset();
466  for (pp=members.head(), countpp=member_counts.head(); pp != 0; pp=pp->next(), countpp=countpp->next())
467  {
468  i = members.item(pp);
469 
470  // Accumulate the value with count
471  b.cumulate(wgn_VertexTrack.a(i,j), member_counts.item(countpp)) ;
472  }
473  a += b.stddev();
474  count = b.samples();
475  }
476  }
477 #endif
478 
479 #if 0
480  /* full covariance */
481  /* worse in listening experiments */
482  EST_SuffStats **cs;
483  int mmm;
485  for (j=0; j<=wgn_VertexTrack.num_channels(); j++)
486  cs[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
487  /* Find means for diagonal */
488  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
489  {
490  if (wgn_VertexFeats.a(0,j) > 0.0)
491  {
492  for (pp=members.head(); pp != 0; pp=pp->next())
493  cs[j][j] += wgn_VertexTrack.a(members.item(pp),j);
494  }
495  }
496  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
497  {
498  for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
499  if (wgn_VertexFeats.a(0,j) > 0.0)
500  {
501  for (pp=members.head(); pp != 0; pp=pp->next())
502  {
503  mmm = members.item(pp);
504  cs[i][j] += (wgn_VertexTrack.a(mmm,i)-cs[j][j].mean())*
505  (wgn_VertexTrack.a(mmm,j)-cs[j][j].mean());
506  }
507  }
508  }
509  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
510  {
511  for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
512  if (wgn_VertexFeats.a(0,j) > 0.0)
513  a += cs[i][j].stddev();
514  }
515  count = cs[0][0].samples();
516 #endif
517 
518 #if 0
519  // look at mean euclidean distance between vectors
520  EST_Litem *qq;
521  int x,y;
522  double d,q;
523  count = 0;
524  for (pp=members.head(); pp != 0; pp=pp->next())
525  {
526  x = members.item(pp);
527  count++;
528  for (qq=pp->next(); qq != 0; qq=qq->next())
529  {
530  y = members.item(qq);
531  for (q=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
532  if (wgn_VertexFeats.a(0,j) > 0.0)
533  {
534  d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
535  q += d*d;
536  }
537  a += sqrt(q);
538  }
539 
540  }
541 #endif
542 
543  // This is sum of stddev*samples
544  return a.mean() * count;
545 }
546 
548 {
549  int j;
550 
551  if (trajectory != 0)
552  {
553  for (j=0; j<l; j++)
554  delete [] trajectory[j];
555  delete [] trajectory;
556  trajectory = 0;
557  l = 0;
558  }
559 }
560 
561 
562 float WImpurity::trajectory_impurity()
563 {
564  // Find the mean length of all the units in the cluster
565  // Create that number of points
566  // Interpolate each unit to that number of points
567  // collect means and standard deviations for each point
568  // impurity is sum of the variance for each point and each coef
569  // multiplied by the number of units.
570  EST_Litem *pp;
571  ssize_t i, j;
572  int ti;
573  ssize_t s, q, ni;
574  int s1l, s2l;
575  double n, m, m1, m2, w;
576  EST_SuffStats lss, stdss;
577  EST_SuffStats l1ss, l2ss;
578  int l1, l2;
579  int ola=0;
580 
581  if (trajectory != 0)
582  { /* already done this */
583  return score;
584  }
585 
586  lss.reset();
587  l = 0;
588  for (pp=members.head(); pp != 0; pp=pp->next())
589  {
590  i = members.item(pp);
591  for (q=0; q<wgn_UnitTrack.a(i,1); q++)
592  {
593  ni = wgn_UnitTrack.a(i,0)+q;
594  if (wgn_VertexTrack.a(ni,0) == -1.0)
595  {
596  l1ss += q;
597  ola = 1;
598  break;
599  }
600  }
601  if (q==wgn_UnitTrack.a(i,1))
602  { /* can't find -1 center point, so put all in l2 */
603  l1ss += 0;
604  l2ss += q;
605  }
606  else
607  l2ss += wgn_UnitTrack.a(i,1) - (q+1) - 1;
608  lss += wgn_UnitTrack.a(i,1); /* length of each unit in the cluster */
609  if (wgn_UnitTrack.a(i,1) > l)
610  l = (int)wgn_UnitTrack.a(i,1);
611  }
612 
613  if (ola==0) /* no -1's so its not an ola type cluster */
614  {
615  l = ((int)lss.mean() < 7) ? 7 : (int)lss.mean();
616 
617  /* a list of SuffStats on for each point in the trajectory */
618  trajectory = new EST_SuffStats *[l];
619  width = wgn_VertexTrack.num_channels()+1;
620  for (j=0; j<l; j++)
621  trajectory[j] = new EST_SuffStats[width];
622 
623  for (pp=members.head(); pp != 0; pp=pp->next())
624  { /* for each unit */
625  i = members.item(pp);
626  m = (float)wgn_UnitTrack.a(i,1L)/(float)l; /* find interpolation */
627  s = wgn_UnitTrack.a(i,0); /* start point */
628  for (ti=0,n=0.0; ti<l; ti++,n+=m)
629  {
630  ni = (int)n; // hmm floor or nint ??
631  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
632  {
633  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
634  trajectory[ti][j] += wgn_VertexTrack.a(s+ni,j);
635  }
636  }
637  }
638 
639  /* find sum of sum of stddev for all coefs of all traj points */
640  stdss.reset();
641  for (ti=0; ti<l; ti++)
642  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
643  {
644  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
645  stdss += trajectory[ti][j].stddev();
646  }
647 
648  // This is sum of all stddev * samples
649  score = stdss.mean() * members.length();
650  }
651  else
652  { /* OLA model */
653  l1 = (l1ss.mean() < 10.0) ? 10 : (int)l1ss.mean();
654  l2 = (l2ss.mean() < 10.0) ? 10 : (int)l2ss.mean();
655  l = l1 + l2 + 1 + 1;
656 
657  /* a list of SuffStats on for each point in the trajectory */
658  trajectory = new EST_SuffStats *[l];
659  for (j=0; j<l; j++)
660  trajectory[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
661 
662  for (pp=members.head(); pp != 0; pp=pp->next())
663  { /* for each unit */
664  i = members.item(pp);
665  s1l = 0;
666  s = wgn_UnitTrack.a(i,0); /* start point */
667  for (q=0; q<wgn_UnitTrack.a(i,1); q++)
668  if (wgn_VertexTrack.a(s+q,0) == -1.0)
669  {
670  s1l = q; /* printf("awb q is -1 at %d\n",q); */
671  break;
672  }
673  s2l = (int)wgn_UnitTrack.a(i,1) - (s1l + 2);
674  m1 = (float)(s1l)/(float)l1; /* find interpolation step */
675  m2 = (float)(s2l)/(float)l2; /* find interpolation step */
676  /* First half */
677  for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
678  {
679  ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
680  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
681  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
682  trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
683  }
684  ti = l1; /* do it explicitly in case s1l < 1 */
685  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
686  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
687  trajectory[ti][j] += -1;
688  /* Second half */
689  s += s1l+1;
690  for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
691  {
692  ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
693  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
694  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
695  trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
696  }
697  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
698  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
699  trajectory[ti][j] += -2;
700  }
701 
702  /* find sum of sum of stddev for all coefs of all traj points */
703  /* windowing the sums with a triangular weight window */
704  stdss.reset();
705  m = 1.0/(float)l1;
706  for (w=0.0,ti=0; ti<l1; ti++,w+=m)
707  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
708  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
709  stdss += trajectory[ti][j].stddev() * w;
710  m = 1.0/(float)l2;
711  for (w=1.0,ti++; ti<l-1; ti++,w-=m)
712  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
713  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
714  stdss += trajectory[ti][j].stddev() * w;
715 
716  // This is sum of all stddev * samples
717  score = stdss.mean() * members.length();
718  }
719  return score;
720 }
721 
722 static void part_to_ols_data(EST_FMatrix &X, EST_FMatrix &Y,
723  EST_IVector &included,
724  EST_StrList &feat_names,
725  const EST_IList &members,
726  const WVectorVector &d)
727 {
728  int m,n,p;
729  int w, xm=0;
730  EST_Litem *pp;
731  WVector *wv;
732 
733  w = wgn_dataset.width();
734  included.resize(w);
735  X.resize(members.length(),w);
736  Y.resize(members.length(),1);
737  feat_names.append("Intercept");
738  included[0] = TRUE;
739 
740  for (p=0,pp=members.head(); pp; p++,pp=pp->next())
741  {
742  n = members.item(pp);
743  if (n < 0)
744  {
745  p--;
746  continue;
747  }
748  wv = d(n);
749  Y.a_no_check(p,0) = (*wv)[0];
750  X.a_no_check(p,0) = 1;
751  for (m=1,xm=1; m < w; m++)
752  {
753  if (wgn_dataset.ftype(m) == wndt_float)
754  {
755  if (p == 0) // only do this once
756  {
757  feat_names.append(wgn_dataset.feat_name(m));
758  }
759  X.a_no_check(p,xm) = (*wv)[m];
760  included.a_no_check(xm) = FALSE;
761  included.a_no_check(xm) = TRUE;
762  xm++;
763  }
764  }
765  }
766 
767  included.resize(xm);
768  X.resize(p,xm);
769  Y.resize(p,1);
770 }
771 
772 float WImpurity::ols_impurity()
773 {
774  // Build an OLS model for the current data and measure it against
775  // the data itself and give a RMSE
776  EST_FMatrix X,Y;
777  EST_IVector included;
778  EST_FMatrix coeffs;
779  EST_StrList feat_names;
780  float best_score;
781  EST_FMatrix coeffsl;
782  EST_FMatrix pred;
783  float cor,rmse;
784 
785  // Load the sample members into matrices for ols
786  part_to_ols_data(X,Y,included,feat_names,members,*data);
787 
788  // Find the best ols model.
789  // Far too computationally expensive
790  // if (!stepwise_ols(X,Y,feat_names,0.0,coeffs,
791  // X,Y,included,best_score))
792  // return WGN_HUGE_VAL; // couldn't find a model
793 
794  // Non stepwise model
795  if (!robust_ols(X,Y,included,coeffsl))
796  {
797  // printf("no robust ols\n");
798  return WGN_HUGE_VAL;
799  }
800  ols_apply(X,coeffsl,pred);
801  ols_test(Y,pred,cor,rmse);
802  best_score = cor;
803 
804  printf("Impurity OLS X(%zd,%zd) Y(%zd,%zd) %f, %f, %f\n",
805  X.num_rows(),X.num_columns(),Y.num_rows(),Y.num_columns(),
806  rmse,cor,
807  1-best_score);
808  if (fabs(coeffsl[0]) > 10000)
809  {
810  // printf("weird sized Intercept %f\n",coeffsl[0]);
811  return WGN_HUGE_VAL;
812  }
813 
814  return (1-best_score) *members.length();
815 }
816 
817 float WImpurity::cluster_impurity()
818 {
819  // Find the mean distance between all members of the dataset
820  // Uses the global DistMatrix for distances between members of
821  // the cluster set. Distances are assumed to be symmetric thus only
822  // the bottom half of the distance matrix is filled
823  EST_Litem *pp, *q;
824  int i,j;
825  double dist;
826 
827  a.reset();
828  for (pp=members.head(); pp != 0; pp=pp->next())
829  {
830  i = members.item(pp);
831  for (q=pp->next(); q != 0; q=q->next())
832  {
833  j = members.item(q);
834  dist = (j < i ? wgn_DistMatrix.a_no_check(i,j) :
836  a+=dist; // cumulate for whole cluster
837  }
838  }
839 
840  // This is sum distance between cross product of members
841 // return a.sum();
842  if (a.samples() > 1)
843  return a.stddev() * a.samples();
844  else
845  return 0.0;
846 }
847 
849 {
850  // Distance this unit is from all others in this cluster
851  // in absolute standard deviations from the the mean.
852  float dist = cluster_member_mean(i);
853  float mdist = dist-a.mean();
854 
855  if (mdist == 0.0)
856  return 0.0;
857  else
858  return fabs((dist-a.mean())/a.stddev());
859 
860 }
861 
863 {
864  // Would this be a member of this cluster?. Returns 1 if
865  // its distance is less than at least one other
866  float dist = cluster_member_mean(i);
867  EST_Litem *pp;
868 
869  for (pp=members.head(); pp != 0; pp=pp->next())
870  {
871  if (dist < cluster_member_mean(members.item(pp)))
872  return 1;
873  }
874  return 0;
875 }
876 
878 {
879  // Position in ranking closest to centre
880  float dist = cluster_distance(i);
881  EST_Litem *pp;
882  int ranking = 1;
883 
884  for (pp=members.head(); pp != 0; pp=pp->next())
885  {
886  if (dist >= cluster_distance(members.item(pp)))
887  ranking++;
888  }
889 
890  return ranking;
891 }
892 
893 float WImpurity::cluster_member_mean(int i)
894 {
895  // Returns the mean difference between this member and all others
896  // in cluster
897  EST_Litem *q;
898  int j,n;
899  double dist,sum;
900 
901  for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
902  {
903  j = members.item(q);
904  if (i != j)
905  {
906  dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
907  sum += dist;
908  n++;
909  }
910  }
911 
912  return ( n == 0 ? 0.0 : sum/n );
913 }
914 
915 void WImpurity::cumulate(const float pv,double count)
916 {
917  // Cumulate data for impurity calculation
918 
920  {
921  t = wnim_cluster;
922  members.append((int)pv);
923  }
925  {
926  t = wnim_ols;
927  members.append((int)pv);
928  }
930  {
931  t = wnim_vector;
932 
933  // AUP: Implement counts in vectors
934  members.append((int)pv);
935  member_counts.append((float)count);
936  }
938  {
939  t = wnim_trajectory;
940  members.append((int)pv);
941  }
943  {
944  if (t == wnim_unset)
946  t = wnim_class;
947  p.cumulate((int)pv,count);
948  }
950  {
951  t = wnim_float;
952  a.cumulate((int)pv,count);
953  }
955  {
956  t = wnim_float;
957  a.cumulate(pv,count);
958  }
959  else
960  {
961  wagon_error("WImpurity: cannot cumulate EST_Val type");
962  }
963 }
964 
965 ostream & operator <<(ostream &s, WImpurity &imp)
966 {
967  ssize_t j,i;
968  EST_SuffStats b;
969 
970  if (imp.t == wnim_float)
971  s << "(" << imp.a.stddev() << " " << imp.a.mean() << ")";
972  else if (imp.t == wnim_vector)
973  {
974  EST_Litem *p, *countp;
975  s << "((";
976  imp.vector_impurity();
977  if (wgn_vertex_output == "mean") //output means
978  {
979  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
980  {
981  b.reset();
982  for (p=imp.members.head(), countp=imp.member_counts.head(); p != 0; p=p->next(), countp=countp->next())
983  {
984  // Accumulate the members with their counts
985  b.cumulate(wgn_VertexTrack.a((ssize_t)imp.members.item(p),j), imp.member_counts.item(countp));
986  //b += wgn_VertexTrack.a(imp.members.item(p),j);
987  }
988  s << "(" << b.mean() << " ";
989  if (isfinite(b.stddev()))
990  s << b.stddev() << ")";
991  else
992  s << "0.001" << ")";
993  if (j+1<wgn_VertexTrack.num_channels())
994  s << " ";
995  }
996  }
997  else /* output best in the cluster */
998  {
999  /* print out vector closest to center, rather than average */
1000  double best = WGN_HUGE_VAL;
1001  double x,d;
1002  ssize_t bestp = 0;
1003  EST_SuffStats *cs;
1004 
1006 
1007  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
1008  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
1009  {
1010  cs[j].reset();
1011  for (p=imp.members.head(); p != 0; p=p->next())
1012  {
1013  cs[j] += wgn_VertexTrack.a((ssize_t)imp.members.item(p),j);
1014  }
1015  }
1016 
1017  for (p=imp.members.head(); p != 0; p=p->next())
1018  {
1019  for (x=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
1020  if (wgn_VertexFeats.a(static_cast<ssize_t>(0),j) > 0.0)
1021  {
1022  d = (wgn_VertexTrack.a((ssize_t)imp.members.item(p),j)-cs[j].mean())
1023  /* / cs[j].stddev() */ ; /* seems worse 061218 */
1024  x += d*d;
1025  }
1026  if (x < best)
1027  {
1028  bestp = imp.members.item(p);
1029  best = x;
1030  }
1031  }
1032  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1033  {
1034  s << "( ";
1035  s << wgn_VertexTrack.a(bestp,j);
1036  // s << " 0 "; // fake stddev
1037  s << " ";
1038  if (isfinite(cs[j].stddev()))
1039  s << cs[j].stddev();
1040  else
1041  s << "0";
1042  s << " ) ";
1043  if (j+1<wgn_VertexTrack.num_channels())
1044  s << " ";
1045  }
1046 
1047  delete [] cs;
1048  }
1049  s << ") ";
1050  s << imp.a.mean() << ")";
1051  }
1052  else if (imp.t == wnim_trajectory)
1053  {
1054  s << "((";
1055  imp.trajectory_impurity();
1056  for (i=0; i<imp.l; i++)
1057  {
1058  s << "(";
1059  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1060  {
1061  s << "(" << imp.trajectory[i][j].mean() << " "
1062  << imp.trajectory[i][j].stddev() << " " << ")";
1063  }
1064  s << ")\n";
1065  }
1066  s << ") ";
1067  // Mean of cross product of distances (cluster score)
1068  s << imp.a.mean() << ")";
1069  }
1070  else if (imp.t == wnim_cluster)
1071  {
1072  EST_Litem *p;
1073  s << "((";
1074  for (p=imp.members.head(); p != 0; p=p->next())
1075  {
1076  // Ouput cluster member and its mean distance to others
1077  s << "(" << imp.members.item(p) << " " <<
1078  imp.cluster_member_mean(imp.members.item(p)) << ")";
1079  if (p->next() != 0)
1080  s << " ";
1081  }
1082  s << ") ";
1083  // Mean of cross product of distances (cluster score)
1084  s << imp.a.mean() << ")";
1085  }
1086  else if (imp.t == wnim_ols)
1087  {
1088  /* Output intercept, feature names and coefficients for ols model */
1089  EST_FMatrix X,Y;
1090  EST_IVector included;
1091  EST_FMatrix coeffs;
1092  EST_StrList feat_names;
1093  EST_FMatrix coeffsl;
1094  EST_FMatrix pred;
1095  float cor=0.0,rmse;
1096 
1097  s << "((";
1098  // Load the sample members into matrices for ols
1099  part_to_ols_data(X,Y,included,feat_names,imp.members,*(imp.data));
1100  if (!robust_ols(X,Y,included,coeffsl))
1101  {
1102  printf("no robust ols\n");
1103  // shouldn't happen
1104  }
1105  else
1106  {
1107  ols_apply(X,coeffsl,pred);
1108  ols_test(Y,pred,cor,rmse);
1109  for (i=0; i<coeffsl.num_rows(); i++)
1110  {
1111  s << "(";
1112  s << feat_names.nth(i);
1113  s << " ";
1114  s << coeffsl[i];
1115  s << ") ";
1116  }
1117  }
1118 
1119  // Mean of cross product of distances (cluster score)
1120  s << ") " << cor << ")";
1121  }
1122  else if (imp.t == wnim_class)
1123  {
1124  EST_Litem *i;
1125  EST_String name;
1126  double prob;
1127 
1128  s << "(";
1129  for (i=imp.p.item_start(); !imp.p.item_end(i); i=imp.p.item_next(i))
1130  {
1131  imp.p.item_prob(i,name,prob);
1132  s << "(" << name << " " << prob << ") ";
1133  }
1134  s << imp.p.most_probable(&prob) << ")";
1135  }
1136  else
1137  s << "([WImpurity unset])";
1138 
1139  return s;
1140 }
1141 
1142 
1143 
1144 
int ols_test(const EST_FMatrix &real, const EST_FMatrix &predicted, float &correlation, float &rmse)
Definition: EST_ols.cc:288
int width(void) const
Definition: EST_Wagon.h:94
int Int(void) const
Definition: EST_Val.h:141
int wgn_count_field
Definition: wagon.cc:71
const WVectorVector * data
Definition: EST_Wagon.h:159
EST_String wgn_vertex_output
Definition: wagon.cc:78
float wgn_score_question(WQuestion &q, WVectorVector &ds)
Definition: wagon.cc:1091
EST_Litem * item_next(EST_Litem *idx) const
Used for iterating through members of the distribution.
int robust_ols(const EST_FMatrix &X, const EST_FMatrix &Y, EST_FMatrix &coeffs)
Definition: EST_ols.cc:73
const EST_String & most_probable(double *prob=NULL) const
Return the most probable member of the distribution.
WDataSet wgn_dataset
Definition: wagon.cc:59
EST_Val predict(const WVector &w)
Definition: wagon_aux.cc:51
double stddev(void) const
standard deviation of currently cummulated values
ssize_t num_columns() const
return number of columns
Definition: EST_TMatrix.h:179
float cluster_distance(int i)
Definition: wagon_aux.cc:848
void cumulate(double a, double count=1.0)
int wgn_predictee
Definition: wagon.cc:73
A Regular expression class to go with the CSTR EST_String class.
Definition: EST_Regex.h:56
STATIC void left(STATUS Change)
Definition: editline.c:523
void wgn_find_split(WQuestion &q, WVectorVector &ds, WVectorVector &y, WVectorVector &n)
Definition: wagon.cc:775
int ask(const WVector &w) const
Definition: wagon_aux.cc:263
float measure(void)
Definition: wagon_aux.cc:424
Discretes wgn_discretes
Definition: wagon.cc:57
int num_channels() const
return number of channels in track
Definition: EST_Track.h:657
double mean(void) const
mean of currently cummulated values
#define NIL
Definition: siod_defs.h:92
STATIC void right(STATUS Change)
Definition: editline.c:538
int siod_llength(LISP list)
Definition: siod.cc:202
wn_oper get_op(void) const
Definition: EST_Wagon.h:130
INLINE const T & a_no_check(ssize_t n) const
read-only const access operator: without bounds checking
Definition: EST_TVector.h:254
EST_Track wgn_UnitTrack
Definition: wagon.cc:64
EST_String itoString(int n)
Make a EST_String object from an integer.
Definition: util_io.cc:141
ssize_t num_rows() const
return number of rows
Definition: EST_TMatrix.h:177
T & nth(int n)
return the Nth value
Definition: EST_TList.h:145
EST_SuffStats ** trajectory
Definition: EST_Wagon.h:158
ostream & operator<<(ostream &s, WNode &n)
Definition: wagon_aux.cc:153
int ssize_t
#define streq(X, Y)
Definition: EST_cutils.h:57
EST_Litem * item_start() const
Used for iterating through members of the distribution.
EST_IList members
Definition: EST_Wagon.h:156
int item_end(EST_Litem *idx) const
Used for iterating through members of the distribution.
int def(const EST_StrList &members)
EST_FList member_counts
Definition: EST_Wagon.h:157
void held_out_prune(void)
Definition: wagon_aux.cc:107
EST_UItem * next()
Definition: EST_UList.h:55
int in_cluster(int i)
Definition: wagon_aux.cc:862
int get_fp(void) const
Definition: EST_Wagon.h:129
WNode * predict_node(const WVector &d)
Definition: wagon_aux.cc:61
LISP vload(const char *fname, long cflag)
Definition: slib_file.cc:632
void siod_list_to_strlist(LISP l, EST_StrList &a)
Definition: siod.cc:520
const char * get_c_string(LISP x)
Definition: slib.cc:638
void cumulate(const float pv, double count=1.0)
Definition: wagon_aux.cc:915
float & a(ssize_t i, int c=0)
Definition: EST_Track.cc:1025
#define l2
int ols_apply(const EST_FMatrix &samples, const EST_FMatrix &coeffs, EST_FMatrix &res)
Definition: EST_ols.cc:185
const EST_String & feat_name(const int &i) const
Definition: EST_Wagon.h:92
#define FALSE
Definition: EST_bool.h:119
void prune(void)
Definition: wagon_aux.cc:83
EST_Val value(void)
Definition: wagon_aux.cc:361
void item_prob(EST_Litem *idx, EST_String &s, double &prob) const
During iteration returns name and probability given index.
EST_String wgn_count_field_name
Definition: wagon.cc:72
float mean(EST_FVector &m)
EST_FMatrix wgn_DistMatrix
Definition: wagon.cc:61
#define wagon_error(WMESS)
Definition: EST_Wagon.h:50
const EST_String & string(void) const
Definition: EST_Val.h:161
int matches(const char *e, ssize_t pos=0) const
Exactly match this string?
Definition: EST_String.cc:651
getString int
Definition: EST_item_aux.cc:50
EST_String wgn_predictee_name
Definition: wagon.cc:74
void append(const T &item)
add item onto end of list
Definition: EST_TList.h:196
void reset(void)
reset internal values
int length() const
Definition: EST_UList.cc:57
void ignore_non_numbers()
Definition: wagon_aux.cc:162
EST_Track wgn_VertexTrack
Definition: wagon.cc:62
const EST_IList & get_operandl(void) const
Definition: EST_Wagon.h:132
double samples(void)
number of samples in set
const EST_Val get_operand1(void) const
Definition: EST_Wagon.h:131
#define X
T & item(const EST_Litem *p)
Definition: EST_TList.h:139
int ilist_member(const EST_IList &l, int i)
int get_int_val(int n) const
Definition: EST_Wagon.h:60
EST_UItem * head() const
Definition: EST_UList.h:97
float get_flt_val(int n) const
Definition: EST_Wagon.h:61
INLINE const T & a_no_check(ssize_t row, ssize_t col) const
const access with no bounds check, care recommend
Definition: EST_TMatrix.h:182
#define l1
void resize(int rows, int cols, int set=1)
resize matrix
LISP car(LISP x)
Definition: slib_list.cc:115
EST_Track wgn_VertexFeats
Definition: wagon.cc:63
#define WGN_HUGE_VAL
Definition: EST_Wagon.h:54
EST_String
int ftype(const int &i) const
Definition: EST_Wagon.h:89
EST_String quote_string(const EST_String &s, const EST_String &quote, const EST_String &escape, int force)
Definition: EST_Token.cc:844
float sum(const EST_FMatrix &a)
sum of elements
Definition: vec_mat_aux.cc:147
#define TRUE
Definition: EST_bool.h:118
INLINE ssize_t n() const
number of items in vector.
Definition: EST_TVector.h:251
LISP siod_member_str(const char *key, LISP list)
Definition: siod.cc:167
void load_description(const EST_String &descfname, LISP ignores)
Definition: wagon_aux.cc:181
float cluster_ranking(int i)
Definition: wagon_aux.cc:877
void resize(int n, int set=1)
resize vector
LISP cdr(LISP x)
Definition: slib_list.cc:124
double samples(void)
Definition: wagon_aux.cc:386
float Float(void) const
Definition: EST_Val.h:149