61 #if defined(INSTANTIATE_TEMPLATES) 62 #include "../base_class/EST_TVector.cc" 76 for (i=0,e=examples; e !=
NIL; e=
cdr(e),i++)
80 void EST_bracketed_string::init()
98 set_bracketed_string(
string);
107 for (i=0; i < p_length; i++)
108 delete [] valid_spans[i];
109 delete [] valid_spans;
118 p_length = find_num_nodes(
string);
119 symbols =
new LISP[p_length];
121 set_leaf_indices(
string,0,symbols);
126 valid_spans =
new int*[length()];
127 for (i=0; i < length(); i++)
129 valid_spans[i] =
new int[length()+1];
130 for (j=i+1; j <= length(); j++)
131 valid_spans[i][j] = 0;
140 int EST_bracketed_string::find_num_nodes(LISP
string)
145 else if (
CONSP(
string))
146 return find_num_nodes(
car(
string))+
147 find_num_nodes(
cdr(
string));
152 int EST_bracketed_string::set_leaf_indices(LISP
string,
int i,LISP *syms)
159 return set_leaf_indices(
cdr(
string),i+1,syms);
163 return set_leaf_indices(
cdr(
string),
164 set_leaf_indices(
car(
string),i,syms),
169 void EST_bracketed_string::find_valid(
int s,LISP t)
const 176 for (c=s,l=t; l !=
NIL; l=
cdr(l))
178 c += num_leafs(
car(l));
179 valid_spans[s][c] = 1;
181 find_valid(s,
car(t));
182 find_valid(s+num_leafs(
car(t)),
cdr(t));
186 int EST_bracketed_string::num_leafs(LISP t)
const 193 return num_leafs(
car(t)) + num_leafs(
cdr(t));
215 double EST_SCFG_traintest::f_I_cal(
int c,
int p,
int i,
int k)
239 double pBpqr =
prob_B(p,q,r);
241 for (j=i+1; j < k; j++)
243 double in = f_I(c,q,i,j);
245 s += pBpqr * in * f_I(c,r,j,k);
253 inside[p][i][k] = res;
261 double EST_SCFG_traintest::f_O_cal(
int c,
int p,
int i,
int k)
290 double out = f_O(c,q,j,k);
292 s2 += out * f_I(c,r,j,i);
301 double out = f_O(c,q,i,j);
303 s3 += out * f_I(c,r,k,j);
314 outside[p][i][k] = res;
319 void EST_SCFG_traintest::reestimate_rule_prob_B(
int c,
int ri,
int p,
int q,
int r)
325 double pBpqr =
prob_B(p,q,r);
332 double d1 = f_I(c,q,i,j);
333 if (d1 == 0)
continue;
336 double d2 = f_I(c,r,j,k);
337 if (d2 == 0)
continue;
338 double d3 = f_O(c,p,i,k);
339 if (d3 == 0)
continue;
359 void EST_SCFG_traintest::reestimate_rule_prob_U(
int c,
int ri,
int p,
int m)
373 n2 +=
prob_U(p,m) * f_O(c,p,i-1,i);
379 d[ri] += f_P(c,p) / fP;
383 double EST_SCFG_traintest::f_P(
int c)
388 double EST_SCFG_traintest::f_P(
int c,
int p)
396 double d1 = f_O(c,p,i,j);
397 if (d1 == 0)
continue;
398 db += f_I(c,p,i,j)*d1;
404 void EST_SCFG_traintest::reestimate_grammar_probs(
int passes,
422 for (pass = startpass; pass < passes; pass++)
431 for (mC=0.0,lPc=0.0,c=0; c < corpus.
length(); c++)
434 if ((spread > 0) && (((c+(pass*spread))%100) >= spread))
436 printf(
" %d",c); fflush(stdout);
442 reestimate_rule_prob_B(c,ri,
444 rules(r).daughter1(),
445 rules(r).daughter2());
447 reestimate_rule_prob_U(c,
460 double n_prob = n[ri]/d[ri];
463 se += (n_prob-
rules(r).prob())*(n_prob-
rules(r).prob());
464 rules(r).set_prob(n_prob);
466 printf(
"pass %d cross entropy %g RMSE %f %f %d\n",
470 if (checkpoint != -1 && checkpoint != 0)
472 if ((pass % checkpoint) == checkpoint-1)
475 sprintf(cp,
".%03d",pass);
492 reestimate_grammar_probs(passes, startpass, checkpoint,
496 void EST_SCFG_traintest::init_io_cache(
int c,
int nt)
502 inside =
new double**[nt];
503 outside =
new double**[nt];
504 for (i=0; i < nt; i++)
506 inside[i] =
new double*[mc];
507 outside[i] =
new double*[mc];
508 for (j=0; j < mc; j++)
510 inside[i][j] =
new double[mc];
511 outside[i][j] =
new double[mc];
512 for (k=0; k < mc; k++)
514 inside[i][j][k] = -1;
515 outside[i][j][k] = -1;
521 void EST_SCFG_traintest::clear_io_cache(
int c)
531 for (j=0; j < mc; j++)
533 delete [] inside[i][j];
534 delete [] outside[i][j];
537 delete [] outside[i];
547 double EST_SCFG_traintest::cross_entropy()
552 for (c=0; c < corpus.
length(); c++)
575 for (mC=0.0,lPc=0.0,c=0; c < corpus.
length(); c++)
596 cout <<
"cross entropy " << -(lPc/mC) <<
" (" << failed <<
" failed out of " <<
597 corpus.
length() <<
" sentences )" << endl;
int valid(int i, int k) const
If a bracketing from i to k is valid in string.
void set_bracketed_string(LISP string)
void set_rule_prob_cache()
(re-)set rule probability caches
double safe_log(const double x)
double prob_B(int p, int q, int r) const
The rule probability of given binary rule.
void fill(const T &v)
Fill entire array will value v.
int siod_llength(LISP list)
INLINE const T & a_no_check(ssize_t n) const
read-only const access operator: without bounds checking
void gc_unprotect(LISP *location)
A class representing a stochastic context free grammar (SCFG).
This class represents a bracketed string used in training of SCFGs.
int distinguished_symbol() const
int num_nonterminals() const
Number of nonterminals.
EST_write_status save(const EST_String &filename)
Save current grammar to named file.
LISP vload(const char *fname, long cflag)
EST_String terminal(int m) const
Convert terminal index to string form.
void resize(ssize_t n, int set=1)
INLINE ssize_t length() const
number of items in vector.
SCFGRuleList rules
The rules themselves.
void load_corpus(const EST_String &filename)
void set_corpus(EST_Bcorpus &b, LISP examples)
const EST_String symbol_at(int i) const
The nth symbol in the string.
void gc_protect(LISP *location)
void train_inout(int passes, int startpass, int checkpoint, int spread, const EST_String &outfile)
double prob_U(int p, int m) const
The rule probability of given unary rule.
EST_Item * daughter1(const EST_Item *n)
return first daughter of n
void resize(int n, int set=1)
resize vector