#!/usr/bin/perl -w

use Getopt::Long;
use vars qw($opt_cffile $opt_count $opt_lambda $opt_threshold);

GetOptions("cffile=s", "count", "lambda=f", "threshold=f");
my $argcffile = $opt_cffile;

my $justcount = 0;
if ($opt_count) { $justcount = 1; }

my $threshold = 5;
if ($opt_threshold) { $threshold = $opt_threshold; }

my $nybias = 10;

# lambda value for TCR equation, indicating the "cost" of recovering
# from an FP.  The values are: 1 = tagged only, 9 = mailed back to
# sender asking for token (TMDA style), 999 = deleted outright.
# We (SpamAssassin) use a default of 5, representing "moved to
# infrequently-read folder".

my $lambda = 5;
if ($opt_lambda) { $lambda = $opt_lambda; }

my %is_spam = ();
my %tests_hit = ();
my %mutable_tests = ();

use vars qw(%rules %allrules);

readscores();

print "Reading per-message hit stat logs and scores...\n";
my ($num_tests, $num_spam, $num_nonspam);
my ($ga_yy, $ga_ny, $ga_yn, $ga_nn, $yyscore, $ynscore, $nyscore, $nnscore);

readlogs();
read_ranges();

if ($justcount) {
  $nybias = $nybias*($num_spam / $num_nonspam);
  evaluate();
} else {
  print "Writing logs and current scores as C code...\n";
  writescores_c();
  writetests_c();
}
exit 0;


sub readlogs {
  my $count = 0;
  $num_spam = $num_nonspam = 0;

  if ($justcount) {
    $ga_yy = $ga_ny = $ga_yn = $ga_nn = 0;
    $yyscore = $ynscore = $nyscore = $nnscore = 0.0;
  }

  foreach my $file ("spam.log", "nonspam.log") {
    open (IN, "<$file");

    while (<IN>) {
      next if /^#/;
      next if /^$/;
      if($_ !~ /^.\s+([-\d]+)\s+\S+\s*/) { warn "bad line: $_"; next; }
      my $hits = $1;

      $_ = $'; s/,,+/,/g; s/^\s+//; s/\s+$//;

      my $score = 0;
      my @tests = ();
      foreach my $tst (split (/,/, $_)) {
	next if ($tst eq '');
	if (!defined $scores{$tst}) {
          #warn "unknown test in $file, ignored: $tst\n";
	  next;
	}

        if ($justcount) {
          $score += $scores{$tst};
        } else {
          push (@tests, $tst);
        }
      }

      if (!$justcount) { 
        $tests_hit{$count} = \@tests;
      }

      if ($file eq "spam.log") {
	$num_spam++;
        if ($justcount) {
          if ($score >= $threshold) {
            $ga_yy++; $yyscore += $score;
          } else {
            $ga_yn++; $ynscore += $score;
          }
        } else {
          $is_spam{$count} = 1;
        }
      } else {
	$num_nonspam++;
        if ($justcount) {
          if ($score >= $threshold) {
            $ga_ny++; $nyscore += $score;
          } else {
            $ga_nn++; $nnscore += $score;
          }
        } else {
          $is_spam{$count} = 0;
        }
      }
      $count++;
    }
    close IN;
  }
  $num_tests = $count;
}


sub readscores {
  if (!defined $argcffile) { $argcffile = "../rules"; }
  print "Reading scores from \"$argcffile\"...\n";
  system ("./parse-rules-for-masses -d \"$argcffile\"") and die;
  require "./tmp/rules.pl";
  %allrules = %rules;           # ensure it stays global
}

sub writescores_c {
  my $output = '';
  my $size = 0;
  foreach my $name (@index_to_rule) {
    # jm: now, score-ranges-from-freqs has tflags to work from, so
    # it will always list all mutable tests.
    if (!defined $mutable_tests{$name}) {
      $mutable_tests{$name} = 0;
    }

    if ($mutable_tests{$name} == 0) {
      $range_lo{$name} = $range_hi{$name} = $scores{$name};
    } else {
      #$range_lo{$name} ||= 0.1;
      #$range_hi{$name} ||= 1.5;
    }

    if ($ignored_rule{$name}) { next; }

    if ($size != $rule_to_index{$name}) {
      warn "oops: size != rule_to_index{$name}: $size $rule_to_index{$name}";
    }

    $output .= ".".$size."\n".
                "n".$name."\n".
                "b".$scores{$name}."\n".
                "m".$mutable_tests{$name}."\n".
                "l".$range_lo{$name}."\n".
                "h".$range_hi{$name}."\n";
    $size++;
  }

  open (DAT, ">tmp/scores.data");
  print DAT "N$size\n", $output;
  close DAT;

  open (OUT, ">tmp/scores.h");
  print OUT "

int num_scores;
unsigned char is_mutatable[$size]; 	/* er, is_mutable I think ;) */
double range_lo[$size];
double range_hi[$size];
double bestscores[$size];
double scores[$size];
char *score_names[$size];

/* readscores() is defined in tests.h */

";
  close OUT;
}

sub writetests_c {
  my $file;

  # figure out max hits per message
  my $max_hits_per_msg = 0;
  for ($file = 0; $file < $num_tests; $file++) {
    my $hits = scalar @{$tests_hit{$file}} + 1;
    if ($hits > $max_hits_per_msg) { $max_hits_per_msg = $hits; }
  }

  open (TOP, ">tmp/tests.h");
  print TOP "

int num_tests = $num_tests;
int num_spam = $num_spam;
int num_nonspam = $num_nonspam;
int max_hits_per_msg = $max_hits_per_msg;
unsigned char num_tests_hit[$num_tests];
unsigned char is_spam[$num_tests];
unsigned short tests_hit[$num_tests][$max_hits_per_msg];

";
  $_ = join ('', <DATA>);
  print TOP $_;
  close TOP;

  open (DAT, ">tmp/tests.data");

  for ($file = 0; $file < $num_tests; $file++)
  {
    print DAT ".".$file."\n";

    my $out = '';
    $out .= "s".$is_spam{$file}."\n";

    my $num_tests_hit = 0;
    foreach my $test (@{$tests_hit{$file}}) {
      if ($test eq '') { next; }

      if ($ignored_rule{$test}) {
        warn "ignored rule $test got a hit in $file!\n";
        next;
      }

      if (!defined $rule_to_index{$test}) {
	warn "test with no C index: $test\n";
      }

      $num_tests_hit++;
      $out .= "t".$rule_to_index{$test}."\n";

      if ($num_tests_hit >= $max_hits_per_msg) {
	die "Need to increase \$max_hits_per_msg";
      }
    }

    print DAT "n".$num_tests_hit."\n".$out;
  }
  close DAT;
}

sub read_ranges {
  if (!-f 'tmp/ranges.data') {
    system ("make tmp/ranges.data");
  }

  # read ranges, and mutatableness, from ranges.data.
  open (IN, "<tmp/ranges.data")
  	or die "need to run score-ranges-from-freqs first!";

  my $count = 0;
  while (<IN>) {
    /^(\S+) (\S+) (\d+) (\S+)$/ or next;
    my $t = $4;
    $range_lo{$t} = $1+0;
    $range_hi{$t} = $2+0;
    my $mut = $3+0;

    if ($t =~ /^__/) { $ignored_rule{$t} = 1; next; }
    if ($range_lo{$t} == 0.0 && $range_hi{$t} == 0.0) {
      #warn "ignored rule: score and range == 0: $t\n";
      $ignored_rule{$t} = 1; next;
    }

    $ignored_rule{$t} = 0;
    $index_to_rule[$count] = $t; $rule_to_index{$t} = $count++;

    if (!$mut) { $mutable_tests{$t} = 0; } else { $mutable_tests{$t} = 1; }
  }
  close IN;

  # catch up on the ones missed; seems to be userconf or 0-hitters mostly.
  foreach my $t (sort keys %allrules) {
    next if (exists ($rule_to_index{$t}) || $ignored_rule{$t});
    if ($t =~ /^__/) { $ignored_rule{$t} = 1; next; }
    $index_to_rule[$count] = $t; $rule_to_index{$t} = $count++;
  }
}

sub evaluate {
   printf ("\n# SUMMARY for threshold %3.1f:\n", $threshold);
   printf "# Correctly non-spam: %6d  %4.2f%%  (%4.2f%% of non-spam corpus)\n", $ga_nn,
       ($ga_nn /  $num_tests) * 100.0,
       ($ga_nn /  $num_nonspam) * 100.0;
   printf "# Correctly spam:     %6d  %4.2f%%  (%4.2f%% of spam corpus)\n" , $ga_yy,
       ($ga_yy /  $num_tests) * 100.0,
       ($ga_yy /  $num_spam) * 100.0;
   printf "# False positives:    %6d  %4.2f%%  (%4.2f%% of nonspam, %6.0f weighted)\n", $ga_ny,
       ($ga_ny /  $num_tests) * 100.0,
       ($ga_ny /  $num_nonspam) * 100.0,
       $nyscore*$nybias;
   printf "# False negatives:    %6d  %4.2f%%  (%4.2f%% of spam, %6.0f weighted)\n", $ga_yn,
       ($ga_yn /  $num_tests) * 100.0,
       ($ga_yn /  $num_spam) * 100.0,
       $ynscore;

  # convert to the TCR metrics used in the published lit
  my $nspamspam = $ga_yy;
  my $nspamlegit = $ga_yn;
  my $nlegitspam = $ga_ny;
  my $nlegitlegit = $ga_yn;
  my $nlegit = $num_nonspam;
  my $nspam = $num_spam;

  my $werr = ($lambda * $nlegitspam + $nspamlegit)
                  / ($lambda * $nlegit + $nspam);

  my $werr_base = $nspam
                  / ($lambda * $nlegit + $nspam);

  $werr ||= 0.000001;     # avoid / by 0
  my $tcr = $werr_base / $werr;

  my $sr = ($nspamspam / $nspam) * 100.0;
  my $sp = ($nspamspam / ($nspamspam + $nlegitspam)) * 100.0;
  printf "# TCR: %3.6f   SpamRecall: %3.6f%%   SpamPrecision: %3.6f%%\n",
                  $tcr, $sr, $sp;
}

__DATA__

void loadtests (void) {
  FILE *fin = fopen ("tmp/tests.data", "r");
  char buf[256];
  int file = 0;
  int tnum = 0;

  while (fgets (buf, 255, fin) != NULL) {
    char cmd;
    long arg;

    cmd = (char) *buf;
    arg = strtol (buf+1, NULL, 10);

    if (cmd == '.') {
      file = arg;

    } else if (cmd == 'n') {
      tnum = 0;
      num_tests_hit[file] = arg;

    } else if (cmd == 's') {
      is_spam[file] = arg;

    } else if (cmd == 't') {
      tests_hit[file][tnum] = arg; tnum++;
    }
  }
  fclose(fin);

  printf ("Read test results for %d messages.\n", file+1);
}

void loadscores (void) {
  FILE *fin = fopen ("tmp/scores.data", "r");
  char buf[256];
  int snum = 0;

  while (fgets (buf, 255, fin) != NULL) {
    char cmd;
    long arg;
    float argf;
    char *str, *white;

    cmd = (char) *buf;
    arg = strtol (buf+1, NULL, 10);
    argf = strtod (buf+1, NULL);
    str = buf+1;

    while ((white = strchr (str, '\n')) != NULL) {
      *white = '\0';
    }

    if (cmd == '.') {
      snum = arg;

    } else if (cmd == 'N') {
      num_scores = arg;

    } else if (cmd == 'b') {
      bestscores[snum] = argf;

    } else if (cmd == 'l') {
      range_lo[snum] = argf;

    } else if (cmd == 'h') {
      range_hi[snum] = argf;

    } else if (cmd == 'n') {
      score_names[snum] = strdup (str);	/* leaky leak ;) */

    } else if (cmd == 'm') {
      is_mutatable[snum] = arg;
    }
  }
  fclose(fin);

  printf ("Read scores for %d tests.\n", num_scores);
}

