#!/usr/bin/perl -w
#
# Copyright (C) 2016 Julian Andres Klode <jak@jak-linux.org>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.


=head1 NAME

triehash - Generate a perfect hash function derived from a trie.

=cut

use strict;
use warnings;
use Getopt::Long;

=head1 SYNOPSIS

B<triehash> [S<I<option>>] [S<I<input file>>]

=head1 DESCRIPTION

triehash takes a list of words in input file and generates a function and
an enumeration to describe the word

=head1 INPUT FILE FORMAT

The file consists of multiple lines of the form:

    [label ~ ] word [= value]

This maps word to value, and generates an enumeration with entries of the form:

    label = value

If I<label> is undefined, the word will be used, the minus character will be
replaced by an underscore. If value is undefined it is counted upwards from
the last value.

There may also be one line of the format

    [ label ~] = value

Which defines the value to be used for non-existing keys. Note that this also
changes default value for other keys, as for normal entries. So if you place

    = 0

at the beginning of the file, unknown strings map to 0, and the other strings
map to values starting with 1. If label is not specified, the default is
I<Unknown>.

=head1 OPTIONS

=over 4

=item B<-C>I<.c file> B<--code>=I<.c file>

Generate code in the given file.

=item B<-H>I<header file> B<--header>=I<header file>

Generate a header in the given file, containing a declaration of the hash
function and an enumeration.

=item B<--enum-name=>I<word>

The name of the enumeration.

=item B<--function-name=>I<word>

The name of the function.

=item B<--namespace=>I<name>

Put the function and enum into a namespace (C++)

=item B<--class=>I<name>

Put the function and enum into a class (C++)

=item B<--enum-class>

Generate an enum class instead of an enum (C++)

=item B<--counter-name=>I<name>

Use I<name> for a counter that is set to the latest entry in the enumeration
+ 1. This can be useful for defining array sizes.

=item B<--extern-c>

Wrap everything into an extern "C" block. Not compatible with the C++
options, as a header with namespaces, classes, or enum classes is not
valid C.

=item B<--multi-byte>=I<value>

Generate code reading multiple bytes at once. The value is a string of power
of twos to enable. The default value is 320 meaning that 8, 4, and single byte
reads are enabled. Specify 0 to disable multi-byte completely, or add 2 if you
also want to allow 2-byte reads. 2-byte reads are disabled by default because
they negatively affect performance on older Intel architectures.

This generates code for both multiple bytes and single byte reads, but only
enables the multiple byte reads of GNU C compatible compilers, as the following
extensions are used:

=over 8

=item Byte-aligned integers

We must be able to generate integers that are aligned to a single byte using:

    typedef uint64_t __attribute__((aligned (1))) triehash_uu64;

=item Byte-order

The macros __BYTE_ORDER__ and __ORDER_LITTLE_ENDIAN__ must be defined.

=back

We forcefully disable multi-byte reads on platforms where the variable
I<__ARM_ARCH> is defined and I<__ARM_FEATURE_UNALIGNED> is not defined,
as there is a measurable overhead from emulating the unaligned reads on
ARM.

=item B<--language=>I<language>

Generate a file in the specified language. Currently known are 'C' and 'tree',
the latter generating a tree.

=item B<--include=>I<header>

Add the header to the include statements of the header file. The value must
be surrounded by quotes or angle brackets for C code. May be specified multiple
times.

=back

=cut

my $unknown = -1;
my $unknown_label = "Unknown";
my $counter_start = 0;
my $enum_name = "PerfectKey";
my $function_name = "PerfectHash";
my $enum_class = 0;

my $code_name = "-";
my $header_name = "-";
my $code;
my $header;
my $ignore_case = 0;
my $multi_byte = "320";
my $language = 'C';
my $counter_name = undef;
my @includes = ();


Getopt::Long::config('default',
                     'bundling',
                     'no_getopt_compat',
                     'no_auto_abbrev',
                     'permute',
                     'auto_help');

GetOptions ("code|C=s" => \$code_name,
            "header|H=s"   => \$header_name,
            "function-name=s" => \$function_name,
            "ignore-case" => \$ignore_case,
            "enum-name=s" => \$enum_name,
            "language|l=s" => \$language,
            "multi-byte=s" => \$multi_byte,
            "enum-class" => \$enum_class,
            "include=s" => \@includes,
            "counter-name=s" => \$counter_name)
    or die("Could not parse options!");


# This implements a simple trie. Each node has three attributes:
#
# children - A hash of keys to other nodes
# value    - The value to be stored here
# label    - A named representation of the value.
#
# The key at each level of the trie can consist of one or more bytes, and the
# trie can be normalized to a form where all keys at a level have the same
# length using rebuild_tree().
package Trie {

    sub new {
        my $class = shift;
        my $self = {};
        bless $self, $class;

        $self->{children} = {};
        $self->{value} = undef;
        $self->{label} = undef;

        return $self;
    }

    # Return the largest power of 2 smaller or equal to the argument
    sub alignpower2 {
        my ($self, $length) = @_;

        return 8 if ($length >= 8 && $multi_byte =~ /3/);
        return 4 if ($length >= 4 && $multi_byte =~ /2/);
        return 2 if ($length >= 2 && $multi_byte =~ /1/);

        return 1;
    }

    # Split the key into a head block and a tail
    sub split_key {
        my ($self, $key) = @_;
        my $length = length $key;
        my $split = $self->alignpower2($length);

        return (substr($key, 0, $split), substr($key, $split));
    }

    # Given a key, a label, and a value, insert that into the tree, possibly
    # replacing an existing node.
    sub insert {
        my ($self, $key, $label, $value) = @_;

        if (length($key) == 0) {
            $self->{label} = $label;
            $self->{value} = $value;
            return;
        }

        my ($child, $tail) = $self->split_key($key);

        $self->{children}{$child} = Trie->new if (!defined($self->{children}{$child}));

        $self->{children}{$child}->insert($tail, $label, $value);
    }

    # Construct a new trie that only contains words of a given length. This
    # is used to split up the common trie after knowing all words, so we can
    # switch on the expected word length first, and have the per-trie function
    # implement simple longest prefix matching.
    sub filter_depth {
        my ($self, $togo) = @_;

        my $new = Trie->new;

        if ($togo != 0) {
            my $found = 0;
            foreach my $key (sort keys %{$self->{children}}) {
                if ($togo > length($key) || defined $self->{children}{$key}->{value}) {
                    my $child = $self->{children}{$key}->filter_depth($togo - length($key));

                    $new->{children}{$key}= $child if defined $child;
                    $found = 1 if defined $child;
                }
            }
            return undef if (!$found);
        } else {
            $new->{value} = $self->{value};
            $new->{label} = $self->{label};
        }

        return $new;
    }

    # (helper for rebuild_tree)
    # Reinsert all value nodes into the specified $trie, prepending $prefix
    # to their $paths.
    sub reinsert_value_nodes_into {
        my ($self, $trie, $prefix) = @_;

        $trie->insert($prefix, $self->{label}, $self->{value}) if (defined $self->{value});

        foreach my $key (sort keys %{$self->{children}}) {
            $self->{children}{$key}->reinsert_value_nodes_into($trie, $prefix . $key);
        }
    }

    # (helper for rebuild_tree)
    # Find the earliest point to split a key. Normally, we split at the maximum
    # power of 2 that is greater or equal than the length of the key. When we
    # are building an ASCII-optimised case-insensitive trie that simply ORs
    # each byte with 0x20, we need to split at the first ambiguous character:
    #
    # For example, the words a-bc and a\rbc are identical in such a situation:
    #       '-' | 0x20 == '-' == '\r' | 0x20
    # We cannot simply switch on all 4 bytes at once, but need to split before
    # the ambiguous character so we can process the ambiguous character on its
    # own.
    sub find_ealier_split {
        my ($self, $key) = @_;

        if ($ignore_case) {
            for my $i (0..length($key)-1) {
                # If the key starts with an ambiguous character, we need to
                # take only it. Otherwise, we need to take everything
                # before the character.
                return $self->alignpower2($i || 1) if (main::ambiguous(substr($key, $i, 1)));
            }
        }
        return $self->alignpower2(length $key);
    }

    # This rebuilds the trie, splitting each key before ambiguous characters
    # as explained in find_earlier_split(), and then chooses the smallest
    # such split at each level, so that all keys at all levels have the same
    # length (so we can use a multi-byte switch).
    sub rebuild_tree {
        my $self = shift;
        # Determine if/where we need to split before an ambiguous character
        my $new_split = 99999999999999999;
        foreach my $key (sort keys %{$self->{children}}) {
            my $special_length = $self->find_ealier_split($key);
            $new_split = $special_length if ($special_length < $new_split);
        }

        # Start building a new uniform trie
        my $newself = Trie->new;
        $newself->{label} = $self->{label};
        $newself->{value} = $self->{value};
        $newself->{children} = {};

        foreach my $key (sort keys %{$self->{children}}) {
            my $head = substr($key, 0, $new_split);
            my $tail = substr($key, $new_split);
            # Rebuild the child node at $head, pushing $tail downwards
            $newself->{children}{$head} //= Trie->new;
            $self->{children}{$key}->reinsert_value_nodes_into($newself->{children}{$head}, $tail);
            # We took up to one special character of each key label. There might
            # be more, so we need to rebuild recursively.
            $newself->{children}{$head} = $newself->{children}{$head}->rebuild_tree();
        }

        return $newself;
    }
}

# Code generator for C and C++
package CCodeGen {
    my $static = ($code_name eq $header_name) ? "static" : "";
    my $enum_specifier = $enum_class ? "enum class" : "enum";

    sub new {
        my $class = shift;
        my $self = {};
        bless $self, $class;

        return $self;
    }

    sub open_output {
        my $self = shift;
        if ($code_name ne "-") {
            open($code, '>', $code_name) or die "Cannot open $code_name: $!" ;
        } else {
            $code = *STDOUT;
        }
        if($code_name eq $header_name) {
            $header = $code;
        } elsif ($header_name ne "-") {
            open($header, '>', $header_name) or die "Cannot open $header_name: $!" ;
        } else {
            $header = *STDOUT;
        }
    }

    sub word_to_label {
        my ($class, $word) = @_;

        $word =~ s/_/__/g;
        $word =~ s/-/_/g;
        return $word;
    }

    # Return a case label, by shifting and or-ing bytes in the word
    sub case_label {
        my ($self, $key) = @_;

        return sprintf("'%s'", substr($key, 0, 1)) if not $multi_byte;

        my $output = '0';

        for my $i (0..length($key)-1) {
            $output .= sprintf("| onechar('%s', %d, %d)", substr($key, $i, 1), 8 * $i, 8*length($key));
        }

        return $output;
    }

    # Return an appropriate read instruction for $length bytes from $offset
    sub switch_key {
        my ($self, $offset, $length) = @_;

        return "string[$offset]" if $length == 1;
        return sprintf("*((triehash_uu%s*) &string[$offset])", $length * 8);
    }

    # Render the trie so that it matches the longest prefix.
    sub print_table {
        my ($self, $trie, $fh, $indent, $index) = @_;
        $indent //= 0;
        $index //= 0;

        # If we have children, try to match them.
        if (%{$trie->{children}}) {
            # The difference between lowercase and uppercase alphabetical characters
            # is that they have one bit flipped. If we have alphabetical characters
            # in the search space, and the entire search space works fine if we
            # always turn on the flip, just OR the character we are switching over
            # with the bit.
            my $want_use_bit = 0;
            my $can_use_bit = 1;
            my $key_length = 0;
            foreach my $key (sort keys %{$trie->{children}}) {
                $can_use_bit &= not main::ambiguous($key);
                $want_use_bit |= ($key =~ /^[a-zA-Z]+$/);
                $key_length = length($key);
            }

            if ($ignore_case && $can_use_bit && $want_use_bit) {
                printf $fh (("    " x $indent) . "switch(%s | 0x%s) {\n", $self->switch_key($index, $key_length), "20" x $key_length);
            } else {
                printf $fh (("    " x $indent) . "switch(%s) {\n", $self->switch_key($index, $key_length));
            }

            my $notfirst = 0;
            foreach my $key (sort keys %{$trie->{children}}) {
                if ($notfirst) {
                    printf $fh ("    " x $indent . "    break;\n");
                }
                if ($ignore_case) {
                    printf $fh ("    " x $indent . "case %s:\n", $self->case_label(lc($key)));
                    printf $fh ("    " x $indent . "case %s:\n", $self->case_label(uc($key))) if lc($key) ne uc($key) && !($can_use_bit && $want_use_bit);
                } else {
                    printf $fh ("    " x $indent . "case %s:\n", $self->case_label($key));
                }

                $self->print_table($trie->{children}{$key}, $fh, $indent + 1, $index + length($key));

                $notfirst=1;
            }

            printf $fh ("    " x $indent . "}\n");
        }


        # This node has a value, so it is a possible end point. If no children
        # matched, we have found our longest prefix.
        if (defined $trie->{value}) {
            printf $fh ("    " x $indent . "return %s;\n", ($enum_class ? "${enum_name}::" : "").$trie->{label});
        }

    }

    sub print_words {
        my ($self, $trie, $fh, $indent, $sofar) = @_;

        $indent //= 0;
        $sofar //= "";


        printf $fh ("    " x $indent."%s = %s,\n", $trie->{label}, $trie->{value}) if defined $trie->{value};

        foreach my $key (sort keys %{$trie->{children}}) {
            $self->print_words($trie->{children}{$key}, $fh, $indent, $sofar . $key);
        }
    }

    sub print_functions {
        my ($self, $trie, %lengths) = @_;
        foreach my $local_length (sort { $a <=> $b } (keys %lengths)) {
            print $code ("static enum ${enum_name} ${function_name}${local_length}(const char *string)\n");
            print $code ("{\n");
            $self->print_table($trie->filter_depth($local_length)->rebuild_tree(), $code, 1);
            printf $code ("    return %s$unknown_label;\n", ($enum_class ? "${enum_name}::" : ""));
            print $code ("}\n");
        }
    }

    sub main {
        my ($self, $trie, $num_values, %lengths) = @_;
        print $header ("#ifndef TRIE_HASH_${function_name}\n");
        print $header ("#define TRIE_HASH_${function_name}\n");
        print $header ("#include <stddef.h>\n");
        print $header ("#include <stdint.h>\n");
        foreach my $include (@includes) {
            print $header ("#include $include\n");
        }
        printf $header ("enum { $counter_name = $num_values };\n") if (defined($counter_name));
        print $header ("${enum_specifier} ${enum_name} {\n");
        $self->print_words($trie, $header, 1);
        printf $header ("    $unknown_label = $unknown,\n");
        print $header ("};\n");
        print $header ("$static enum ${enum_name} ${function_name}(const char *string, size_t length);\n");

        print $code ("#include \"$header_name\"\n") if ($header_name ne $code_name);

        if ($multi_byte) {
            print $code ("#ifdef __GNUC__\n");
            for (my $i=16; $i <= 64; $i *= 2) {
                print $code ("typedef uint${i}_t __attribute__((aligned (1))) triehash_uu${i};\n");
                print $code ("typedef char static_assert${i}[__alignof__(triehash_uu${i}) == 1 ? 1 : -1];\n");
            }

            print $code ("#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__\n");
            print $code ("#define onechar(c, s, l) (((uint64_t)(c)) << (s))\n");
            print $code ("#else\n");
            print $code ("#define onechar(c, s, l) (((uint64_t)(c)) << (l-8-s))\n");
            print $code ("#endif\n");
            print $code ("#if (!defined(__ARM_ARCH) || defined(__ARM_FEATURE_UNALIGNED)) && !defined(TRIE_HASH_NO_MULTI_BYTE)\n");
            print $code ("#define TRIE_HASH_MULTI_BYTE\n");
            print $code ("#endif\n");
            print $code ("#endif /*GNUC */\n");

            print $code ("#ifdef TRIE_HASH_MULTI_BYTE\n");
            $self->print_functions($trie, %lengths);
            $multi_byte = 0;
            print $code ("#else\n");
            $self->print_functions($trie, %lengths);
            print $code ("#endif /* TRIE_HASH_MULTI_BYTE */\n");
        } else {
            $self->print_functions($trie, %lengths);
        }

        print $code ("$static enum ${enum_name} ${function_name}(const char *string, size_t length)\n");
        print $code ("{\n");
        print $code ("    switch (length) {\n");
        foreach my $local_length (sort { $a <=> $b } (keys %lengths)) {
            print $code ("    case $local_length:\n");
            print $code ("        return ${function_name}${local_length}(string);\n");
        }
        print $code ("    default:\n");
        printf $code ("        return %s$unknown_label;\n", ($enum_class ? "${enum_name}::" : ""));
        print $code ("    }\n");
        print $code ("}\n");

        # Print end of header here, in case header and code point to the same file
        print $header ("#endif                       /* TRIE_HASH_${function_name} */\n");
    }
}

# A character is ambiguous if the 1<<5 (0x20) bit does not correspond to the
# lower case bit. A word is ambiguous if any character is. This definition is
# used to check if we can perform the |0x20 optimization when building a case-
# insensitive trie.
sub ambiguous {
    my $word = shift;

    foreach my $char (split //, $word) {
        # If 0x20 does not solely indicate lowercase, it is ambiguous
        return 1 if ord(lc($char)) != (ord($char) | 0x20);
        return 1 if ord(uc($char)) != (ord($char) & ~0x20);
    }

    return 0;
}

sub build_trie {
    my $codegen = shift;
    my $trie = Trie->new;

    my $counter = $counter_start;
    my %lengths;

    open(my $input, '<', $ARGV[0]) or die "Cannot open ".$ARGV[0].": $!";
    while (my $line = <$input>) {
        my ($label, $word, $value) = $line =~/\s*(?:([^~\s]+)\s*~)?(?:\s*([^~=\s]+)\s*)?(?:=\s*([^\s]+)\s+)?\s*/;

        if (defined $word) {
            $counter = $value if defined($value);
            $label //= $codegen->word_to_label($word);

            $trie->insert($word, $label, $counter);
            $lengths{length($word)} = 1;
            $counter++;
        } elsif (defined $value) {
            $unknown = $value;
            $unknown_label = $label if defined($label);
            $counter = $value + 1;
        } else {
            die "Invalid line: $line";
        }
    }

    return ($trie, $counter, %lengths);
}

# Generates an ASCII art tree
package TreeCodeGen {

    sub new {
        my $class = shift;
        my $self = {};
        bless $self, $class;

        return $self;
    }

    sub word_to_label {
        my ($self, $word) = @_;
        return $word;
    }

    sub main {
        my ($self, $trie, $counter, %lengths) = @_;
        printf $code ("┌────────────────────────────────────────────────────┐\n");
        printf $code ("│                   Initial trie                     │\n");
        printf $code ("└────────────────────────────────────────────────────┘\n");
        $self->print($trie);
        printf $code ("┌────────────────────────────────────────────────────┐\n");
        printf $code ("│                   Rebuilt trie                     │\n");
        printf $code ("└────────────────────────────────────────────────────┘\n");
        $self->print($trie->rebuild_tree());

        foreach my $local_length (sort { $a <=> $b } (keys %lengths)) {
            printf $code ("┌────────────────────────────────────────────────────┐\n");
            printf $code ("│              Trie for words of length %-4d         │\n", $local_length);
            printf $code ("└────────────────────────────────────────────────────┘\n");
            $self->print($trie->filter_depth($local_length)->rebuild_tree());
        }
    }

    sub open_output {
        my $self = shift;
        if ($code_name ne "-") {
            open($code, '>', $code_name) or die "Cannot open ".$ARGV[0].": $!" ;
        } else {
            $code = *STDOUT;
        }
    }

    # Print a trie
    sub print {
        my ($self, $trie, $depth) = @_;
        $depth //= 0;

        print $code (" → ") if defined($trie->{label});
        print $code ($trie->{label} // "", "\n");
        foreach my $key (sort keys %{$trie->{children}}) {
            print $code ("│   " x ($depth), "├── $key");
            $self->print($trie->{children}{$key}, $depth + 1);
        }
    }
}

my %codegens = (
    C => "CCodeGen",
    tree => "TreeCodeGen",
);


defined($codegens{$language}) or die "Unknown language $language. Valid choices: ", join(", ", keys %codegens);
my $codegen = $codegens{$language}->new();
my ($trie, $counter, %lengths) = build_trie($codegen);

$codegen->open_output();
$codegen->main($trie, $counter, %lengths);


=head1 LICENSE

triehash is available under the MIT/Expat license, see the source code
for more information.

=head1 AUTHOR

Julian Andres Klode <jak@jak-linux.org>

=cut