#! /usr/bin/env python

"""
A simple Hashcash implementation

Visit U{http://www.hashcash.org} for more info about
the theory and usage of hashcash.

Run this module through epydoc to get pretty doco.

Overview:
    - implements a class L{HashCash}, with very configurable parameters
    - offers two convenience wrapper functions, L{generate} and L{verify},
      for those who can't be bothered instantiating a class
    - given a string s, genToken produces a hashcash token
      string t, as binary or base64
    - generating t consumes a lot of cpu time
    - verifying t against s is almost instantaneous
    - this implementation produces clusters of tokens, to even out
      the token generation time

Performance:
    - this implementation is vulnerable to:
        - people with lots of computers, especially big ones
        - people writing bruteforcers in C (python is way slow)
    - even with the smoothing effect of creating token clusters,
      the time taken to create a token can vary by a factor of 7

Theory of this implementation:

    - a hashcash token is created by a brute-force algorithm
      of finding an n-bit partial hash collision

    - given a string s, and a quality level q,
      generate a 20-byte string h, such that:
        
        1. h != s
        2. len(h) == 20
        3. ((sha(s) xor sha(h)) and (2 ^ q - 1)) == 0
    
    - in other words, hash(h) and hash(s) have q least
      significant bits in common

If you come up with a faster, but PURE PYTHON implementation,
using only modules included in standard python distribution,
please let me know so I can upgrade mine or link to yours.

Written by David McNab, August 2004
Released to the public domain.
"""
import sha, array, random, base64, math
from random import randint

shanew = sha.new

# your own config settings - set these to get a good trade-off between
# token size and uniformity of time taken to generate tokens
#
# the final token size will be tokenSize * chunksPerToken for binary
# tokens, or ceil(4/3 * tokenSize * chunksPerToken) for base64 tokens
#
# the reason for building a token out of multiple token chunks is to
# try to even out the time taken for token generation
# 
# without this, token generation time is very random, with some tokens
# generating almost instantaneously, and other tokens taking ages

defaultChunkSize = 3        # size of each chunk in a token
defaultNumChunks = 12       # number of chunks in each token
defaultQuality = 12         # number of partial hash collision bits required
defaultFormat = 'base64'    # by default, return tokens in base64 format
defaultVerbosity = 0        # increase this to get more verbose output

class HashCash:
    """
    Class for creating/verifying hashcash tokens

    Feel free to subclass this, overriding the default attributes:
        - chunksize
        - numchunks
        - quality
        - format
        - verbosity
    """
    # override these at your pleasure
    
    chunksize = defaultChunkSize
    numchunks = defaultNumChunks
    quality = defaultQuality
    format = defaultFormat
    verbosity = defaultVerbosity
    
    def __init__(self, **kw):
        """
        Create a HashCash object
        
        Keywords:
            - chunksize - size of each token chunk
            - numchunks - number of chunks per token
            - quality - strength of token, in bits:
                - legal values are 1 to 160
                - typical values are 10 to 30, larger values taking much
                  longer to generate
            - format - 'base64' to output tokens in base64 format; any other
              value causes tokens to be generated in binary string format
            - verbosity - verbosity of output messages:
                - 0 = silent
                - 1 = critical only
                - 2 = noisy
        """
        for key in ['chunksize', 'numchunks', 'quality', 'format', 'verbosity']:
            if kw.has_key(key):
                setattr(self, key, kw[key])
    
        self.b64ChunkLen = int(math.ceil(self.chunksize * 4.0 / 3))
    
    def generate(self, value):
        """
        Generate a hashcash token against string 'value'
        """
        quality = self.quality
        mask = 2 ** quality - 1
        hV = sha.new(value).digest()
        nHV = intify(hV)
        
        maxTokInt = 2 ** (self.chunksize * 8)
    
        tokenChunks = []
        chunksPerToken = self.numchunks
    
        # loop around generating random strings until we get one which,
        # when xor'ed with value, produces a hash with the first n bits
        # set to zero
        while 1:
            nTok = randint(0, maxTokInt)
            sNTok = binify(nTok)
            hSNTok = shanew(sNTok).digest()
            nHSNTok = intify(hSNTok)
            if (nHV ^ nHSNTok) & mask == 0:
                # got a good token
                if self.format == 'base64':
                    if not self._checkBase64(sNTok):
                        # chunk fails to encode/decode base64
                        if self.verbosity >= 2:
                            print "Ditching bad candidate token"
                        continue
                    bSNTok = self._enc64(sNTok)
                    if self.verbosity >= 2:
                        print "encoded %s to %s, expect chunklen %s" % (
                            repr(sNTok), repr(bSNTok), self.b64ChunkLen)
                    sNTok = bSNTok
                # got something that works, add it to chunks, return if we got enough chunks
                if sNTok in tokenChunks:
                    continue # already got this one
                tokenChunks.append(sNTok)
                if len(tokenChunks) == chunksPerToken:
                    return "".join(tokenChunks)
    
    def verify(self, value, token):
        """
        Verifies a hashcash token against string 'value'
        """
        if self.verbosity >= 2:
            print "Verify: checking token %s (len %s) against %s" % (token, len(token), value)
        # mask is an int with least-significant 'q' bits set to 1
        mask = 2 ** self.quality - 1
    
        # breaking up token into its constituent chunks
        chunks = []
    
        # verify token size
        if len(token) != self.chunksize * self.numchunks:
            # try base64
            decoded = False
            try:
                for i in range(0, self.numchunks):
                    b64chunk = token[(i * self.b64ChunkLen) : ((i + 1) * self.b64ChunkLen)]
                    chunk = self._dec64(b64chunk)
                    if len(chunk) != self.chunksize:
                        if self.verbosity >= 2:
                            print "Bad chunk length in decoded base64, wanted %s, got %s" % (
                                self.chunksize, len(chunk))
                        return False
                    chunks.append(chunk)
            except:
                if self.verbosity >= 2:
                    if decoded:
                        print "Bad token length"
                    else:
                        print "Base64 decode failed"
                return False
        else:
            # break up token into its chunks
            for i in range(0, self.numchunks):
                chunks.append(token[(i * self.chunksize) : ((i + 1) * self.chunksize)])
    
        # produce hash string and hash int for input string
        hV = sha.new(value).digest()
        nHv = intify(hV)
    
        # test each chunk
        if self.verbosity >= 2:
            print "chunks = %s" % repr(chunks)
    
        while chunks:
            chunk = chunks.pop()
    
            # defeat duplicate chunks
            if chunk in chunks:
                if self.verbosity >= 2:
                    print "Rejecting token chunk - duplicate exists"
                return False
    
            # hash the string and the token
            hTok = sha.new(chunk).digest()
        
            # defeat the obvious attack
            if hTok == hV:
                if self.verbosity >= 2:
                    print "Rejecting token chunk - equal to token"
                return False
        
            # test if these hashes have the least significant n bits in common
            nHTok = intify(hTok)
            if (nHTok ^ nHv) & mask != 0:
                # chunk failed
                if self.verbosity >= 2:
                    print "Rejecting token chunk %s - hash test failed" % repr(chunk)
                return False
        
        # pass
        return True
    
    def _checkBase64(self, item):
        """
        Ensures the item correctly encodes then decodes to/from base64
        """
        #if self.verbose:
        #    print "Checking candidate token"
        enc = self._enc64(item)
        if len(enc) != self.b64ChunkLen:
            if self.verbosity >= 1:
                print "Bad candidate token"
            return False
        return self._dec64(enc) == item
    
    def _enc64(self, item):
        """
        Base64-encode a string, remove padding
        """
        enc = base64.encodestring(item).strip()
        while enc[-1] == '=':
            enc = enc[:-1]
        return enc
    
    def _dec64(self, item):
        """
        Base64-decode a string
        """
        dec = base64.decodestring(item+"====")
        return dec
    
def generate(value, quality, b64=False):
    """
    Generates a hashcash token

    This is a convenience wrapper function which saves you from having to
    instantiate a HashCash object.
    
    Arguments:
        - value - a string against which to generate token
        - quality - an int from 1 to 160 - typically values are 16 to 30
        - b64 - if True, return the token as base64 (suitable for email,
          news, and other text-based contexts), otherwise return a binary string
    
    Quality values for desktop PC usage should typically be between 16 and 30.
    Too low, and it makes an attacker's life easy.
    Too high, and it makes life painful for the user.
    """
    if b64:
        format = 'base64'
    else:
        format = 'binary'

    h = HashCash(quality=quality, format=format)

    return h.generate(value)

def verify(value, quality, token):
    """
    Verifies a hashcash token.

    This is a convenience wrapper function which saves you from having to
    instantiate a HashCash object.

    Arguments:
        - value - the string against which to check the hashcash token
        - quality - the number of bits of token quality we require
        - token - a hashcash token string
    """
    h = HashCash(quality=quality)

    return h.verify(value, token)

def binify(L):
    """
    Convert a python long int into a binary string
    """
    res = []
    while L:
        res.append(chr(L & 0xFF))
        L >>= 8
    res.reverse()
    return "".join(res)

def intify(s):
    """
    Convert a binary string to a python long int
    """
    n = 0L
    for c in s:
        n = (n << 8) | ord(c)
    return n

def _randomString():
    """
    For our tests below.
    Generates a random-length human-readable random string,
    between 16 and 80 chars
    """
    chars = []
    slen = randint(16, 80)
    for i in range(slen):
        chars.append(chr(randint(32, 128)))
    value = "".join(chars)
    return value

# get a boost of speed if psyco is available on target machine
try:
    import psyco
    psyco.bind(genToken)
    psyco.bind(binify)
    psyco.bind(intify)
except:
    pass

def test(nbits=14):
    """
    Basic test function - perform encoding and decoding,
    in plain and base64 formats, using the wrapper functions
    """
    print "Test, using wrapper functions"

    value = _randomString()
    print "Generated random string\n%s" % value
    print

    print "Generating plain binary %s-bit token for:\n%s" % (nbits, value)
    tok = generate(value, nbits)

    print "Got token %s, now verifying" % repr(tok)
    result = verify(value, nbits, tok)

    print "Verify = %s" % repr(result)
    print

    print "Now generating base64 %s-bit token for:\n%s" % (nbits, value)
    tok = generate(value, nbits, True)

    print "Got base64 token %s, now verifying" % repr(tok)
    result = verify(value, nbits, tok)

    print "Verify = %s" % repr(result)

def ctest(quality=14):
    """
    Basic test function - perform token generation and verify, against
    a random string. Instantiate a HashCash class instead of just using the
    wrapper funcs.
    """
    print "Test using HashCash class"

    value = _randomString()
    print "Generated random string\n%s" % value
    print

    hc = HashCash(quality=quality, format='base64')

    print "Generating plain binary %s-bit token for:\n%s" % (quality, value)
    tok = hc.generate(value)

    print "Got token %s, now verifying" % repr(tok)
    result = hc.verify(value, tok)

    print "Verify = %s" % repr(result)
    print

def ntest():
    """
    This function does 256 key generations in a row, and dumps
    some statistical results
    """
    # adjust these as desired
    chunksize=3
    numchunks=32
    quality=6
    numIterations = 256

    import time
    try:
        import stats
    except:
        print "This test requires the stats module"
        print "Get it (and its dependencies) from:"
        print "http://www.nmr.mgh.harvard.edu/Neural_Systems_Group/gary/python.html"
        return

    print "Thrash test"

    times = []

    # create a hashcash token generator object
    hc = HashCash(
        chunksize=chunksize,
        numchunks=numchunks,
        quality=quality
        )

    # 256 times, create a random string and a matching hashcash token
    for i in range(numIterations):

        value = _randomString()

        # measure time for a single token generation
        then = time.time()    
        tok = hc.generate(value)
        now = time.time()
        times.append(now - then)

        # sanity check, make sure it's valid
        result = hc.verify(value, tok)
        if not result:
            print "Verify failed, token length=%s" % len(tok)
            return

        print "Generated %s of %s tokens" % (i, numIterations)

    print "---------------------------------"
    print "Thrash test performance results"
    print "Token quality: %s bits" % quality
    print "Min=%.3f max=%.3f max/min=%.3f mean=%.3f, median=%.3f, stdev=%.3f" % (
        min(times),
        max(times),
        max(times)/min(times),
        stats.lmean(times),
        stats.lmedian(times),
        stats.lstdev(times)
        )

if __name__ == '__main__':

    test()


