#!/usr/bin/env python
# -*- coding: ISO-8859-1 -*-
"""
Simple tracd replacement based on the asyncore / asynchat framework

Code borrowed from tracd and the "Simple HTTP server based on asyncore/asynchat"
recipe by Pierre Quentel http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/259148

version 0.1 - November 2, 2005
Ludovico Magnocavallo <ludo@qix.it>

released under the same GNU General Public License (version 2 or later) as tracd

tracd Copyright (C) 2003, 2004 Edgewall Software
      Copyright (C) 2003, 2004 Jonas Borgström <jonas@edgewall.com>

"""

import asynchat, asyncore, socket, SimpleHTTPServer, select, urllib
import posixpath, sys, cgi, cStringIO, os, traceback, shutil

import re, locale, getopt, time, mimetypes, urllib2, md5
import trac.core
from trac import auth, siteconfig
from trac.Href import Href
from trac.Environment import Environment

class CI_dict(dict):
    """Dictionary with case-insensitive keys
    Replacement for the deprecated mimetools.Message class
    """

    def __init__(self, infile, *args):
        self._ci_dict = {}
        lines = infile.readlines()
        for line in lines:
            k,v=line.split(":",1)
            self._ci_dict[k.lower()] = self[k] = v.strip()
        self.headers = self.keys()
    
    def getheader(self,key,default=""):
        return self._ci_dict.get(key.lower(),default)
    
    def get(self,key,default=""):
        return self._ci_dict.get(key.lower(),default)
    
    def __getitem__(self,key):
        return self._ci_dict[key.lower()]
    
    def __contains__(self,key):
        return key.lower() in self._ci_dict
        
class socketStream:

    def __init__(self,sock):
        """Initiate a socket (non-blocking) and a buffer"""
        self.sock = sock
        self.buffer = cStringIO.StringIO()
        self.closed = 1   # compatibility with SocketServer
    
    def write(self, data):
        """Buffer the input, then send as many bytes as possible"""
        self.buffer.write(data)
        if self.writable():
            buff = self.buffer.getvalue()
            # next try/except clause suggested by Robert Brown
            try:
                    sent = self.sock.send(buff)
            except:
                    # Catch socket exceptions and abort
                    # writing the buffer
                    sent = len(data)

            # reset the buffer to the data that has not yet be sent
            self.buffer=cStringIO.StringIO()
            self.buffer.write(buff[sent:])
            
    def finish(self):
        """When all data has been received, send what remains
        in the buffer"""
        data = self.buffer.getvalue()
        # send data
        while len(data):
            while not self.writable():
                pass
            sent = self.sock.send(data)
            data = data[sent:]

    def getvalue(self):
        return self.buffer.getvalue()
            
    def writable(self):
        """Used as a flag to know if something can be sent to the socket"""
        return select.select([],[self.sock],[])[1]

class DigestAuth:
    """A simple HTTP DigestAuth implementation (rfc2617)"""
    MAX_NONCES = 100
    def __init__(self, htdigest, realm):
        self.active_nonces = []
        self.hash = {}
        self.realm = realm
        self.load_htdigest(htdigest, realm)

    def load_htdigest(self, filename, realm):
        """
        Load account information from apache style htdigest files,
        only users from the specified realm are used
        """
        fd = open(filename, 'r')
        for line in fd.readlines():
            line = line.strip()
            if not line:
                continue
            u, r, a1 = line.split(':')
            if r == realm:
                self.hash[u] = a1
        if self.hash == {}:
            print >> sys.stderr, "Warning: found no users in realm:", realm
        
    def parse_auth_header(self, authorization):
        values = {}
        for value in urllib2.parse_http_list(authorization):
            n, v = value.split('=', 1)
            if v[0] == '"' and v[-1] == '"':
                values[n] = v[1:-1]
            else:
                values[n] = v
        return values

    def send_auth_request(self, req, stale='false'):
        """
        Send a digest challange to the browser. Record used nonces
        to avoid replay attacks.
        """
        nonce = trac.util.hex_entropy()
        self.active_nonces.append(nonce)
        if len(self.active_nonces) > DigestAuth.MAX_NONCES:
            self.active_nonces = self.active_nonces[-DigestAuth.MAX_NONCES:]
        req.send_response(401)
        req.send_header('WWW-Authenticate',
                        'Digest realm="%s", nonce="%s", qop="auth", stale="%s"'
                        % (self.realm, nonce, stale))
        req.end_headers()

    def do_auth(self, req):
        if not 'Authorization' in req.headers or \
               req.headers['Authorization'][:6] != 'Digest':
            self.send_auth_request(req)
            return None
        auth = self.parse_auth_header(req.headers['Authorization'][7:])
        required_keys = ['username', 'realm', 'nonce', 'uri', 'response',
                           'nc', 'cnonce']
        # Invalid response?
        for key in required_keys:
            if not auth.has_key(key):
                self.send_auth_request(req)
                return None
        # Unknown user?
        if not self.hash.has_key(auth['username']):
            self.send_auth_request(req)
            return None

        kd = lambda x: md5.md5(':'.join(x)).hexdigest()
        a1 = self.hash[auth['username']]
        a2 = kd([req.command, auth['uri']])
        # Is the response correct?
        correct = kd([a1, auth['nonce'], auth['nc'],
                      auth['cnonce'], auth['qop'], a2])
        if auth['response'] != correct:
            self.send_auth_request(req)
            return None
        # Is the nonce active, if not ask the client to use a new one
        if not auth['nonce'] in self.active_nonces:
            self.send_auth_request(req, stale='true')
            return None
        self.active_nonces.remove(auth['nonce'])
        return auth['username']


class RequestHandler(asynchat.async_chat,
    SimpleHTTPServer.SimpleHTTPRequestHandler, trac.core.Request):

    url_re = re.compile('/(?P<project>[^/\?]+)'
                        '(?P<path_info>/?[^\?]*)?'
                        '(?:\?(?P<query_string>.*))?')
    server_version = 'tracd/' + trac.__version__

    env = None
    
    protocol_version = "HTTP/1.1"
    MessageClass = CI_dict

    def __init__(self,conn,addr,server):
        asynchat.async_chat.__init__(self,conn)
        self.client_address = addr
        self.connection = conn
        self.server = server
        # set the terminator : when it is received, this means that the
        # http request is complete ; control will be passed to
        # self.found_terminator
        self.set_terminator ('\r\n\r\n')
        self.rfile = cStringIO.StringIO()
        self.found_terminator = self.handle_request_line
        self.request_version = "HTTP/1.1"
        # buffer the response and headers to avoid several calls to select()
        self.wfile = cStringIO.StringIO()

    def read(self, len):
        return self.rfile.read(len)
    
    def write(self, data):
        self.wfile.write(data)
        
    def init_request(self):
        trac.core.Request.init_request(self)
        self.remote_addr = str(self.client_address[0])
        self.hdf.setValue('HTTP.Host', self.server.http_host)
        self.hdf.setValue('HTTP.Protocol', 'http')
        self.hdf.setValue('HTTP.PathInfo', self.path_info)
        if self.headers.has_key('Cookie'):
            self.incookie.load(self.headers['Cookie'])
        
    def collect_incoming_data(self,data):
        """Collect the data arriving on the connexion"""
        self.rfile.write(data)

    def get_header(self, name):
        return self.headers.get(name)

    def parse_path(self, path):
        m = self.url_re.findall(self.path)
        if not m:
            raise trac.util.TracError('Unknown URI')
        self.project_name, self.path_info, self.query_string = m[0]
        if not self.server.projects.has_key(self.project_name):
            raise trac.util.TracError('Unknown Project')
        self.path_info = urllib.unquote(self.path_info)
        self.env = self.server.projects[self.project_name]
        self.log = self.env.log
        self.cgi_location = '/' + self.project_name
        self.base_url = 'http://%s/%s' % (self.server.http_host, self.project_name)

    def log_message(self, format, *args):
        if self.env:
            self.log.debug(format%args)

    def do_project_index(self):
        self.send_response(200)
        self.send_header('Content-Type', 'text/html')
        self.end_headers()
        self.write('<html><head><title>Available Projects</title></head>')
        self.write('<body><h1>Available Projects</h1><ul>')
        for project in self.server.projects.keys():
            self.write('<li><a href="%s">%s</a></li>' % (project, project))
        self.write('</ul></body><html>')

    def do_htdocs_req(self, path):
        """This function serves request for static img/css files"""
        path = urllib.unquote(path)
        # Make sure the path doesn't contain any dangerous ".."-parts.
        path = '/'.join(filter(lambda x: x not in ['..', ''],
                               path.split('/')))
        filename = os.path.join(siteconfig.__default_htdocs_dir__,
                                os.path.normcase(path))
        try:
            f = open(filename, 'rb')
        except IOError:
            self.send_error(404, path)
            return
        self.send_response(200)
        mtype, enc = mimetypes.guess_type(filename)
        stat = os.fstat(f.fileno())
        content_length = stat[6]
        last_modified = time.strftime("%a, %d %b %Y %H:%M:%S GMT",
                                      time.gmtime(stat[8]))
        self.send_header('Content-Type', mtype)
        self.send_header('Conten-Length', str(content_length))
        self.send_header('Last-Modified', last_modified)
        self.end_headers()
        shutil.copyfileobj(f, self.wfile)
        
    def prepare_POST(self):
        """Prepare to read the request body"""
        bytesToRead = int(self.headers.getheader('content-length'))
        # set terminator to length (will read bytesToRead bytes)
        self.set_terminator(bytesToRead)
        self.rfile = cStringIO.StringIO()
        # control will be passed to a new found_terminator
        self.found_terminator = self.handle_post_data
    
    def handle_post_data(self):
        """Called when a POST request body has been read"""
        self.rfile.seek(0)
        self.do_POST()
        self.finish()
            
    def do_POST(self):
        self.do_trac_req()

    def do_HEAD(self):
        self.do_GET()
        
    def do_GET(self):
        if self.path[0:13] == '/':
            self.do_project_index()
        elif self.path[0:13] == '/trac_common/':
            self.do_htdocs_req(self.path[13:])
        else:
            self.do_trac_req()
        
    def do_trac_req(self):
        try:
            self.parse_path(self.path)
        except:
            self.send_error(404, 'Unknown request')
            return
        try:
            self.do_real_trac_req()
        except Exception, e:
            trac.core.send_pretty_error(e, self.env, self)

    def do_real_trac_req(self):
        start = time.time()
        self.init_request()

        self.remote_user = None
        if self.path_info == '/login':
            if not self.env.auth:
                raise trac.util.TracError('Authentication not enabled. '
                                          'Please use the tracd --auth option.\n')

            self.remote_user = self.env.auth.do_auth(self)
            if not self.remote_user:
                return

        # Parse arguments
        args = trac.core.parse_args(self.command, self.path_info,
                                    self.query_string,
                                    self.rfile, None, self.headers)

        trac.core.dispatch_request(self.path_info, args, self, self.env)

        self.log.debug('Total request time: %f s', time.time() - start)


    def handle_request_line(self):
        """Called when the http request line and headers have been received"""
        # prepare attributes needed in parse_request()
        self.rfile.seek(0)
        self.raw_requestline = self.rfile.readline()
        self.parse_request()

        if self.command in ['GET','HEAD']:
            # if method is GET or HEAD, call do_GET or do_HEAD and finish
            method = "do_"+self.command
            if hasattr(self,method):
                getattr(self,method)()
                self.finish()
        elif self.command=="POST":
            # if method is POST, call prepare_POST, don't finish yet
            self.prepare_POST()
        else:
            self.send_error(501, "Unsupported method (%s)" %self.command)

    def end_headers(self):
        """Send the blank line ending the MIME headers, send the buffered
        response and headers on the connection, then set self.wfile to
        this connection
        This is faster than sending the response line and each header
        separately because of the calls to select() in socketStream"""
        if self.request_version != 'HTTP/0.9':
            self.wfile.write("\r\n")
        self.start_resp = cStringIO.StringIO(self.wfile.getvalue())
        self.wfile = socketStream(self.connection)
        self.copyfile(self.start_resp, self.wfile)

    def handle_error(self):
        traceback.print_exc(sys.stderr)
        self.close()

    def copyfile(self, source, outputfile):
        """Copy all data between two file objects
        Set a big buffer size"""
        shutil.copyfileobj(source, outputfile, length = 128*1024)

    def finish(self):
        """Send data, then close"""
        try:
            self.wfile.finish()
        except AttributeError: 
            # if end_headers() wasn't called, wfile is a StringIO
            # this happens for error 404 in self.send_head() for instance
            self.wfile.seek(0)
            self.copyfile(self.wfile, socketStream(self.connection))
        self.close()

class Server(asyncore.dispatcher):
    """Copied from http_server in medusa"""
    def __init__ (self, ip, port,handler):
        self.ip = ip
        self.port = port
        self.handler = handler
        asyncore.dispatcher.__init__ (self)
        self.create_socket (socket.AF_INET, socket.SOCK_STREAM)

        self.set_reuse_addr()
        self.bind ((ip, port))

        # lower this to 5 if your OS complains
        self.listen (1024)

    def handle_accept (self):
        try:
            conn, addr = self.accept()
        except socket.error:
            self.log_info ('warning: server accept() threw an exception', 'warning')
            return
        except TypeError:
            self.log_info ('warning: server accept() threw EWOULDBLOCK', 'warning')
            return
        # creates an instance of the handler class to handle the request/response
        # on the incoming connexion
        self.handler(conn,addr,self)

def usage():
    print 'usage: %s [options] <projenv> [projenv] ...' % sys.argv[0]
    print '\nOptions:\n'
    print '-a --auth [project],[htdigest_file],[realm]'
    print '-p --port [port]\t\tPort number to use (default: 80)'
    print '-b --hostname [hostname]\tIP to bind to (default: \'\')'
    print
    sys.exit(1)
    
def main():
    locale.setlocale(locale.LC_ALL, '')
    port = 80
    hostname = ''
    auths = {}
    projects = {}
    try:
        opts, args = getopt.getopt(sys.argv[1:], "a:p:b:",
                                   ["auth=", "port=", "hostname="])
    except getopt.GetoptError:
        usage()
    for o, a in opts:
        if o in ("-a", "--auth"):
            info = a.split(',', 3)
            if len(info) != 3:
                usage()
            p, h, r = info
            auths[p] = DigestAuth(h, r)
        if o in ("-p", "--port"):
            port = int(a)
        elif o in ("-b", "--hostname"):
            hostname = a

    if not args:
        usage()
        
    httpd = Server(hostname,port,RequestHandler)
    httpd.server_port = port
    if port == 80:
        httpd.http_host = hostname
    else:
        httpd.http_host = '%s:%d' % (hostname, port)
    
    for path in args:
        # Remove trailing slashes
        while path and not os.path.split(path)[1]:
            path = os.path.split(path)[0]
        project = os.path.split(path)[1]
        # We assume the projenv filenames follow the following
        # naming convention: /some/path/project
        auth = auths.get(project, None)
        env = Environment(path)
        # Upgrade the database schema if needed
        env.upgrade(backup=1)
        env.href = Href('/' + project)
        env.abs_href = Href('http://%s/%s' % (httpd.http_host, project))
        projects[project] = env
        projects[project].set_config('trac', 'htdocs_location',
                                     '/trac_common/')
        projects[project].auth = auth
        
    httpd.projects = projects
    try:
        asyncore.loop(timeout=2)
    except KeyboardInterrupt:
        print "Crtl+C pressed. Shutting down."

if __name__ == '__main__':
    main()
