#!/usr/bin/python
# Copyright 2008 Toby Dickenson toby@tarind.com
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License,
# version 2.

import os,sys,time,signal,errno,stat,re,traceback,struct,errno,math
import optparse
from array import array

try:
    import subprocess
except ImportError:
    if sys.version_info()<(2,4):
        sys.exit('brutefirwrapper needs python 2.4')
    else:
        raise

REVISION = 12


TMPDIR = '/tmp/.BrutefirDrc-%d'%(os.getuid(),)
CHANNELS = 2

# What value should we set for TAIL_SECONDS?
#
# Our largest typical window is a 64k filter at 44.1kHz.
# To capture the whole reverberation tail from that window
# we need TAIL_SECONDS = 1.5 and this will ensure that
# absolutely nothing is lost on a track transition.
# This is the one true audiophile option.
#
# If you have a low latency filter (in DRC this is the default,
# and corresponds to PSFilterType = T in the config file)
# then in practice most of the energy is in the first 100ms.
# TAIL_SECONDS = 0.1 will minimise latency which may help
# eliminate any start-of-track stuttering (should that be a
# problem)
#
# Some users (including the Author) have PSFilter = L, which
# gives a filter with 750ms latency and most of the energy is
# then over by 850ms. We therefore set TAIL_SECONDS = 0.85
# as a default here as a compromise between fidelity and
# latency.
TAIL_SECONDS = 0.85

READAHEAD_SECONDS = 0
BLOCK_SECONDS = 0.2

start_timestamp = time.time()

def main():
    init_stderr()
    init_tmp()
    #
    now_text = time.strftime('%Y-%m-%d %H:%M:%S',time.gmtime(start_timestamp))
    sys.stderr.write('\n======== brutefirwrapper rev %s starting at %d %s\n'%(REVISION,start_timestamp,now_text,))
    #
    parser = optparse.OptionParser(usage='usage: %prog [options excluding -f and -c] client_id filter_name\n'
                                         '   or: %prog [options including -f and -c]')
    parser.add_option('-o','--output',dest='stream_format',default='wav',help='Output stream is wav or pcm (wav default)')
    parser.add_option('-c','--client-id',dest='client_id',help='Client id (usually MAC address)')
    parser.add_option('-f','--filter',dest='filter_name',help='Brutefir filter configuration file name')
    parser.add_option('--copy-sample-format',dest='copy_sample_format',action='store_true',default=False,
                      help='Patch the brutefir output sample format (16 or 24 bit) to match '
                           'the input wav stream. This matches the behaviour expected by '
                           'squeeezeboxserver for pcm output streams.')
    parser.add_option('--filter-as-is',dest='allow_patching',action='store_false',default=True,
                      help='Always use the specified brutefir filter configuration unchanged. '
                           'If not specified, then any mismatches in sample rate or sample format '
                           'are automatically fixed in a temporary configuration file.')
    parser.add_option('--dump-output',dest='dump_output',help='Capture output stream for debugging',action='store_true',default=False)
    opts,args = parser.parse_args()
    if len(args)==0:
        pass
    elif len(args)==2:
        # Version 1.2.0 used two arguments (client id and filter).
        opts.client_id = args[0]
        opts.filter_name = args[1]
    else:
        parser.error('unexpected arguments')
    #
    if not opts.filter_name:
        parser.error('must specify a filter')
    if not opts.client_id:
        parser.error('must specify a client_id')
    if not opts.stream_format in ['wav','pcm']:
        parser.error('stream format must be wav or pcm')

    # Process wav header from input stream to determine most of our operating parameters.
    # Historically we supported setting these on the command line, but using the wav header
    # is better all round.
    header = sys.stdin.read(44)
    if len(header)!=44:
        sys.exit('Input to brutefirwapper too short (%d bytes)'%(len(header),))
    fields = struct.unpack('<IIIIIHHIIHHII',header)
    for field_name,field_index,expected_value in [
                        ('chunkid',0,0x46464952),
                        ('wavformat',2,0x45564157),
                        ('subchunk1id',3,0x20746d66),
                        ('audioformat',5,1),
                        ('subchunk1size',4,16),
                        ('numchannels',6,2),
                        ('subchunk2id',11,0x61746164)]:
        if fields[field_index]!=expected_value:
            sys.exit('Unexpected wav header value. %s expected 0x%x got 0x%x'%(field_name,expected_value,fields[field_index]))
    samplerate = fields[7]
    bitspersample=fields[10]
    if bitspersample==16:
        in_format = format_16le()
    elif bitspersample==24:
        in_format = format_24le()
    else:
        sys.exit('%d bit samples are not supported'%(bitspersample,))
    
    # make client id filename/filesystem friendly
    client_id = opts.client_id.replace(':','_')
    
    # find a better filter for sample rate that is provided, i.e. support automatic filter selection based on filter filenames
    filter_name = find_filter(opts.filter_name,samplerate)
    # patch and/or resample the filter if necessary and alllowed
    if opts.allow_patching:
        filter_name = patch_filter(filter_name,client_id,samplerate,in_format,opts.copy_sample_format)

    brutefir = create_brutefir_process(filter_name)
    filter(brutefir,in_format,detect_out_format(filter_name),client_id,samplerate,opts.stream_format,opts.dump_output)

def create_brutefir_process(brutefir_filter):
    cmdline = [ 'brutefir', '-nodefault', brutefir_filter ]
    
    try:
        brutefir = subprocess.Popen(cmdline,
                                    stdin=subprocess.PIPE,
                                    stdout=subprocess.PIPE)
    except OSError,e:
        if e.errno==errno.ENOENT:
            sys.exit('Can not find brutefir executable. Is it installed and on PATH?')
        else:
            raise
    return brutefir

def filter(brutefir,in_format,out_format,client_id,samplerate,stream_format,dump_output):
    pid = os.fork()
    if pid:
        # parent
        try:
            brutefir.stdin.close()
            try:
                output_process(brutefir.stdout,sys.stdout,out_format,client_id,samplerate,stream_format,dump_output)
            except IOError,e:
                if e.errno==errno.EPIPE:
                    # A common error when squeezeboxserver aborts playback.
                    # Dont log it.
                    pass
                else:
                    raise
        finally:
            # We end up here if playback completes, or the playback is aborted
            # and squeezeboxserver closes our output pipe. Terminate the input process,
            # because it will otherwise wait forever.
            os.kill(brutefir.pid,signal.SIGTERM)
            os.kill(pid,signal.SIGTERM)
    else:
        # child
        try:
            input_process(sys.stdin,brutefir.stdin,in_format,samplerate)
        finally:
            os._exit(0)

def patch_filter(filter_filename,client_id,samplerate,in_format,copy_sample_format):
    # Check whether the specified brutefir filter is running at the specified rate
    # If it is not, then the file needs to be patched and any filters resampled.
    # This function returns the name of the filter which should be used, whether
    # the original or patched file
    try:
        original_filter_body = filter_body = open(filter_filename).read()
    except EnvironmentError:
        # cant read the file
        return filter_filename
    
    filter_body = patch_samplerate(filter_body,samplerate)
    filter_body = patch_format(filter_body,'input',in_format)
    if copy_sample_format:
        filter_body = patch_format(filter_body,'output',in_format)
    
    if filter_body==original_filter_body:
        return filter_filename
    else:
        patched_filename = os.path.split(filter_filename)[1]
        patched_filename = '%s-%s-%d-%s' % (patched_filename,client_id,samplerate,in_format.brutefir_name)
        patched_filename = os.path.join(TMPDIR,patched_filename)
        open(patched_filename,'w').write(filter_body)
        return patched_filename

def patch_samplerate(filter_body,samplerate):
    match = re.search(r'\bsampling_rate\s*:\s*([0-9]+)\s*;',filter_body)
    if not match:
        # unexpected
        return filter_body
    baseline_samplerate = int(match.group(1))
    if baseline_samplerate==samplerate:
        return filter_body
    #
    sys.stderr.write('Resampling brutefir filter from %d to %d\n' % (baseline_samplerate,samplerate))
    # Patch the sample rate
    filter_body = filter_body[:match.start(1)]+str(samplerate)+filter_body[match.end(1):]
    #
    # Next patch all filter files, and resample them with sox
    filter_lengths = [0]
    def resample_filter(original):
        source_pcm_filename = original.group(1)
        if source_pcm_filename == 'dirac pulse':
           return 'filename: "dirac pulse";'
        output_pcm_filename = os.path.split(source_pcm_filename)[1]
        output_pcm_filename = '%s-%d' % (output_pcm_filename,samplerate)
        output_pcm_filename = os.path.join(TMPDIR,output_pcm_filename)
        sox_process = subprocess.Popen(['/usr/share/squeezeboxserver/Bin/i386-linux/sox',
                                        '-t','raw',
                                        '-f', # floating point
                                        '-4', # 32 bit
                                        '-r',str(baseline_samplerate),
                                        source_pcm_filename,
                                        '-t','raw',
                                        output_pcm_filename,
                                        'rate','-v',str(samplerate),
                                        ],
                                        )
        os.waitpid(sox_process.pid,0)
        filter_lengths.append(os.stat(output_pcm_filename).st_size//4)
        return 'filename:"'+output_pcm_filename+'";'
    filter_body,junk = re.subn(r'\bfilename\s*:\s*"(.*)"\s*;',resample_filter,filter_body)
    
    #
    # add attenuation if the number of filter samples changes to ensure consistent output volume
    if True:
        def add_attenuation(original):
            attenuation = float(original.group(1))
            extra_attenuation = 20*math.log10((float(samplerate)/baseline_samplerate))
            attenuation += extra_attenuation
            return 'attenuation: %.9f;' % (attenuation,)
        filter_body,junk = re.subn(r'\battenuation\s*:\s*([^;]*);',add_attenuation,filter_body)
    #
    # Having resampled the filter, maybe we need more taps in the brutefir configuration
    match = re.search(r'\bfilter_length\s*:\s*([0-9]+\s*(?:,\s*[0-9]+\s*)?);',filter_body)
    if match:
        filter_length_text = match.group(1)
        if ',' in filter_length_text:
            samples_per_filter,number_filters = map(int,filter_length_text.split(','))
        else:
            samples_per_filter = int(filter_length_text)
            number_filters = 1
        max_filter_length = max(filter_lengths)
        if max_filter_length > samples_per_filter*number_filters:
            number_filters = int(math.ceil(float(max_filter_length)/samples_per_filter))
            filter_length_text = '%d,%d' % (samples_per_filter,number_filters,)
            filter_body = filter_body[:match.start(1)] + filter_length_text + filter_body[match.end(1):]
    #
    return filter_body

def patch_format(filter_body,section_name,in_format):
    match = re.search(r'\b'+section_name+r'.*?\bsample\s*:\s*"([^"]*)"\s*;',filter_body,re.DOTALL)
    if not match:
        # not so unexpected - the regular expression isnt a very tight match
        sys.stderr.write('Could not find section "%s" to patch to %s' % (section_name,in_format.brutefir_name))
        return filter_body
    baseline_format = match.group(1)
    if baseline_format==in_format.brutefir_name:
        return filter_body
    #
    sys.stderr.write('Patching brutefir %s format from %s to %s\n' % (section_name,baseline_format,in_format.brutefir_name))
    filter_body = filter_body[:match.start(1)]+in_format.brutefir_name+filter_body[match.end(1):]
    return filter_body

def find_filter(filter_name,samplerate):
    # don't know why toby did this replace but I keep it for compatibility
    path_to_filter = filter_name.replace('SAMPLERATE',str(samplerate))
    
    # make sure that symbolic links are followed before the filter is searched
    if os.path.islink(path_to_filter):
        path_to_filter=os.path.join(os.path.dirname(path_to_filter), os.readlink(path_to_filter))
    
    match = re.search(r'(.*)(44100|48000|88200|96000|176400|192000)\.(\w+)',path_to_filter)
    if not match:
        # user has not provided filter with different sample rates return selected filter
        return path_to_filter
    
    filter = '%s%d.%s' % (match.group(1),samplerate,match.group(3))
    if os.path.isfile(filter):
        sys.stderr.write('Found filter "%s" for sample_rate %d\n' % (filter,samplerate))
        return filter
    
    # no filter for this specific samplerate return the selected filter
    sys.stderr.write('Filter %s for sample_rate %d\n does not exist' % (filter,samplerate))
    return path_to_filter
        
def detect_out_format(filter_name):
    try:
        filter_body = open(filter_name).read()
    except EnvironmentError:
        pass
    else:
        match = re.search(r'\boutput.*?\bsample\s*:\s*"([^"]*)"\s*;',filter_body,re.DOTALL)
        if match:
            format_name = match.group(1)
            for format in [format_16le(),format_24le()]:
                if format.brutefir_name==format_name:
                    return format
    # default if no match
    sys.stderr.write('Unable to detect output sample format for file "%s" using standard format 24LE' % (filter_name))
    return format_24le()

def input_process(our_stdin,brutefir_stdin,format,samplerate):
    # The forked child process handles the brutefir input side.
    # It needs to copy our stdin to brutefir, followed by a little
    # silence.
    while True:
        chunk = our_stdin.read(int(BLOCK_SECONDS*samplerate)*CHANNELS*format.width)
        if not chunk:
            break
        brutefir_stdin.write(chunk)
        brutefir_stdin.flush()
    brutefir_stdin.write(format.silent_sample*CHANNELS*int(TAIL_SECONDS*samplerate))
    brutefir_stdin.close()

def output_process(brutefir_stdout,our_stdout,format,client_id,samplerate,stream_format,dump_output):
    # This process handles the brutefir output side.
    # First it loads the saved tail from disk. Next it reads
    # a little from the brutefir process and mixes it with
    # the tail. This gets written to our stdout. Next
    # it copies data from brutefir to our stout, but keeping
    # at least `readahead_samples` samples in RAM. When the brutefir
    # output ends we know that the last `tail_size` samples
    # need to be saved to disk as the tail for the next time
    # this process is run.
    #
    # Historically we kept `readahead_samples` quite high so that
    # we can write the tail file to disk as early as possible. This ensures it
    # will be ready for the next invocation of this program. In practice
    # this was unnecessarily defensive; squeezeboxserver 7 always
    # buffers plenty of our output, allowing one transcoding process
    # finish well in advance of the track playback end, and in
    # advance of the next transcoding process starting.
    #
    # Indeed such a high readahead could potentially cause problems when
    # the stream feeding our input is decoded in real time, and our output
    # stalls waiting for our readahead buffer to fill
    #
    # Today we set READAHEAD_SECONDS to zero.
    if dump_output:
        dump_file = open(os.path.join(TMPDIR,'dump-%.2f.%s' % (start_timestamp,stream_format)),'w')
    if stream_format=='wav':
        header = struct.pack('<IIIIIHHIIHHII',
                    0x46464952, 0xffffffff, 0x45564157, 0x20746d66, 16, # various magic numbers. report huge size because we dont know yet
                    1, 2, samplerate,                                   # pcm, stereo
                    samplerate * format.width * 2,
                    format.width * 2,
                    format.bitspersample,
                    0x61746164, 0xffffffff)
        our_stdout.write(header)
        if dump_output:
            dump_file.write(header)
    #
    tail_filename = os.path.join(TMPDIR,'tail-%s.pcm'%(client_id,))
    info_filename = os.path.join(TMPDIR,'tail-%s.txt'%(client_id,))
    saved_tail = read_tail(tail_filename,info_filename,client_id,format,samplerate)
    try_unlink(tail_filename)
    if saved_tail:
        # read the start of brutefir output to mix into this saved tail
        print >> sys.stderr, 'Starting to write tail-mixed output (%.3fs)' % (time.time()-start_timestamp,)
        for tail_sample in saved_tail:
            data = brutefir_stdout.read(format.width)
            if len(data)!=format.width:
                return
            to_write = format.encode(format.decode(data)+tail_sample)
            our_stdout.write(to_write)
            if dump_output:
                dump_file.write(to_write)
        our_stdout.flush()
        print >> sys.stderr, 'Finished writing tail-mixed output. Proper output follows (%.3fs)' % (time.time()-start_timestamp,)
    else:
        print >> sys.stderr, 'Output follows (%.3fs)' % (time.time()-start_timestamp,)
    #
    readahead_samples = int(samplerate*(READAHEAD_SECONDS+TAIL_SECONDS))*CHANNELS
    readahead_bytes = readahead_samples*format.width
    readahead = Readahead()
    readahead_pending = True
    while True:
        excess,readahead = readahead.split(-readahead_bytes)
        if excess:
            excess.write_to(our_stdout)
            if dump_output:
                excess.write_to(dump_file)
            our_stdout.flush()
            if readahead_pending:
                print >> sys.stderr, 'Readahead complete (%.3fs)' % (time.time()-start_timestamp,)
                readahead_pending = False
        #
        chunk = brutefir_stdout.read(int(BLOCK_SECONDS*samplerate)*CHANNELS*format.width)
        if not chunk:
            break
        readahead.append(chunk)
    #
    # At this point brutefir has completed its output, and we have in memory
    # several seconds of readahead followed by the tail. First we dump the tail
    # to disk, so it is ready for the next invocation of this process.
    tail_bytes = int(samplerate*TAIL_SECONDS)*CHANNELS*format.width
    if readahead.total_bytes>=tail_bytes:
        info_file = open(info_filename+'.new','w')
        info_file.write(str(samplerate)+'\n')
        info_file.write(format.brutefir_name+'\n')
        info_file.close()
        os.rename(info_filename+'.new',info_filename)
        #
        tail_file = open(tail_filename+'.new','wb')
        readahead,tail = readahead.split(-tail_bytes)
        tail.write_to(tail_file)
        tail_file.close()
        os.rename(tail_filename+'.new',tail_filename)
    #
    # Finally it is safe to write the readahead buffer to stdout
    readahead.write_to(our_stdout)
    if dump_output:
        readahead.write_to(dump_file)

    print >> sys.stderr, 'Output complete (%.3fs)' % (time.time()-start_timestamp,)



class Readahead(object):
    # A class for managing a readahead buffer which avoids moving
    # data around in memory. It can be appended to, it can be
    # split into two individual Readahead buffer objects, and
    # the whole content can be written to a file.
    def __init__(self):
        self.chunks = []
        self.total_bytes = 0
    def append(self,chunk):
        self.chunks.append(chunk)
        self.total_bytes += len(chunk)
    def __nonzero__(self):
        for chunk in self.chunks:
            if chunk:
                return True
        return False
    def split(self,position):
        if position<0:
            position = self.total_bytes+position
        l = Readahead()
        r = Readahead()
        for chunk in self.chunks:
            remaining_bytes = position-l.total_bytes
            if not remaining_bytes:
                r.append(chunk)
            elif remaining_bytes>=len(chunk):
                l.append(chunk)
            else:
                l.append(chunk[:remaining_bytes])
                r.append(chunk[remaining_bytes:])
        return l,r
    def write_to(self,file):
        for chunk in self.chunks:
            file.write(chunk)

def read_tail(tail_filename,info_filename,client_id,format,samplerate):
    try:
        age = time.time()-os.stat(tail_filename).st_mtime
        # I typically see tails that are 22s old on an SB3. This limit below
        # may need to be larger with future players (or a future squeezeboxserver)
        # with a larger buffer.
        if not -1<=age<=45:
            sys.stderr.write('ignoring tail; it is %ds old\n'%(age,))
            return None
        sys.stderr.write('tail is %ds old\n'%(age,))

        info_file = open(info_filename,'rb')

        tail_samplerate = int(info_file.readline())
        if tail_samplerate!=samplerate:
            # The tail was recorded from the output of brutefir running at a different
            # sample rate. Ideally we should resample it to the new sample rate,
            # however that is difficult right now. brutefir's output is 24 bit,
            # and sox - our resampling toolkit - does not support 24 bit.
            # The best we can do for now is drop the tail. This means you do not
            # get gapless playback between changes in sample rate.
            sys.stderr.write('ignoring tail; it has the wrong sample rate (%d not %d)\n'%(tail_samplerate,samplerate))
            return None
        
        tail_format_name = info_file.readline().strip()    
        if tail_format_name!=format.brutefir_name:
            sys.stderr.write('ignoring tail; it has the wrong sample format (%s not %s)\n'%(tail_format_name,format.brutefir_name))
            return None
        
        body = open(tail_filename,'rb')
        body = body.read()
        
        return [ format.decode(body[offset:offset+format.width]) for offset in range(0,len(body),format.width) ]
    except (EnvironmentError,ValueError):
        traceback.print_exc()
        pass

def try_unlink(filename):
    try:
        os.unlink(filename)
    except EnvironmentError:
        pass

class format_16le(object):
    silent_sample = '\x00\x00'
    width = 2
    bitspersample = 16
    brutefir_name = 'S16_LE'
    def decode(self,a):
        y = ord(a[0])+ord(a[1])*256
        if y>=2**15:
            y -= 2**16
        return y
    def encode(self,y):
        y = max(y,-(2**15))
        y = min(y,2**15-1)
        return chr(y&0xff) + chr((y>>8)&0xff)

class format_24le(object):
    silent_sample = '\x00\x00\x00'
    width = 3
    bitspersample = 24
    brutefir_name = 'S24_LE'
    def decode(self,a):
        y = ord(a[0])+ord(a[1])*256+ord(a[2])*65536
        if y>=2**23:
            y -= 2**24
        return y
    def encode(self,y):
        y = max(y,-(2**23))
        y = min(y,2**23-1)
        return chr(y&0xff) + chr((y>>8)&0xff) + chr((y>>16)&0xff)

def init_tmp():
    umask = os.umask(0077)
    try:
        # First try to create the directory. This will fail if it already exists.
        # Dont panic if it fails for any reason.
        try:
            os.mkdir(TMPDIR)
        except EnvironmentError:
            pass
        # Now check that our directory exists, and has the right ownership. 
        # Abort if it is wrong for any reason; this probably indicates a security problem
        tmpdir_stat = os.stat(TMPDIR)
        tmpdir_mode = tmpdir_stat.st_mode
        if not stat.S_ISDIR(tmpdir_mode):
            sys.stderr.write(TMPDIR+' is not a directory. This probably indicates a security issue\n')
            os.abort()
        if tmpdir_stat.st_uid!=os.getuid():
            sys.stderr.write(TMPDIR+' has the wrong owner. This probably indicates a security issue\n')
            os.abort()
    finally:
        os.umask(umask)

def init_stderr():
    # write stderr to a file (applies to this process and brutefir)
    reopen_stderr_as_log('/var/log/squeezeboxserver/brutefir.log')

def reopen_stderr_as_log(filename):
    try:
        fd = os.open(filename,os.O_APPEND|os.O_CREAT|os.O_WRONLY,0666)
    except OSError,e:
        # cant open the log file? leave stderr as is
        sys.stderr.write(str(e)+' - continuing anyway\n')
    else:
        # we can open the log file, so lets use it as stderr.
        os.close(2)
        os.dup(fd)
        os.close(fd)
        # check size, and maybe rename it so it wont be used again
        if os.fstat(2).st_size>100*1024:
            os.rename(filename,filename+'.old')

if __name__=='__main__':
    main()

