#! /usr/bin/python3
# Copyright(c) 2017, Intel Corporation
#
# Redistribution  and  use  in source  and  binary  forms,  with  or  without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of  source code  must retain the  above copyright notice,
#   this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
# * Neither the name  of Intel Corporation  nor the names of its contributors
#   may be used to  endorse or promote  products derived  from this  software
#   without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING,  BUT NOT LIMITED TO,  THE
# IMPLIED WARRANTIES OF  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT  SHALL THE COPYRIGHT OWNER  OR CONTRIBUTORS BE
# LIABLE  FOR  ANY  DIRECT,  INDIRECT,  INCIDENTAL,  SPECIAL,  EXEMPLARY,  OR
# CONSEQUENTIAL  DAMAGES  (INCLUDING,  BUT  NOT LIMITED  TO,  PROCUREMENT  OF
# SUBSTITUTE GOODS OR SERVICES;  LOSS OF USE,  DATA, OR PROFITS;  OR BUSINESS
# INTERRUPTION)  HOWEVER CAUSED  AND ON ANY THEORY  OF LIABILITY,  WHETHER IN
# CONTRACT,  STRICT LIABILITY,  OR TORT  (INCLUDING NEGLIGENCE  OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,  EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import argparse
import struct
import fnmatch
import glob
import sys
import tempfile
import os
import subprocess
import fcntl
import filecmp
import stat
import re
from array import *


def check_rpd(ifile):
    data = ifile.read(0x20)
    pof_hdr = struct.unpack('IIIIIIII', data)
    for i in range(0, 3):
        if pof_hdr[i] != 0xffffffff:
            print "invalid rpd file"
            raise Exception

    if pof_hdr[3] != 0x6a6a6a6a:
        print "invalid rpd file"
        raise Exception

    return pof_hdr[4]


def reverse_bits(x, n):
    result = 0
    for i in xrange(n):
        if (x >> i) & 1:
            result |= 1 << (n - 1 - i)
    return result


def reverse_bits_in_file(ifile, ofile):
    bit_rev = array('B')
    for i in range(0, 256):
        bit_rev.append(reverse_bits(i, 8))

    while True:
        ichunk = ifile.read(4096)
        if not ichunk:
            break

        ochunk = ''
        for b in ichunk:
            ochunk += chr(bit_rev[ord(b)])
        ofile.write(ochunk)


def get_flash_size(dev):

    MEMGETINFO = 0x80204d01

    try:
        fd = os.open(dev, os.O_SYNC | os.O_RDONLY)
    except Exception as e:
        print "failed to open " + dev + ": ", e
        raise

    ioctl_data = struct.pack('BIIIIIQ', 0, 0, 0, 0, 0, 0, 0)

    try:
        ret = fcntl.ioctl(fd, MEMGETINFO, ioctl_data)
    except Exception as e:
        print "ioctl failed: ", e
        raise

    ioctl_odata = struct.unpack_from('BIIIIIQ', ret)

    os.close(fd)

    return ioctl_odata[2]


def flash_erase(dev, start, nbytes):
    MEMERASE = 0x40084d02

    try:
        fd = os.open(dev, os.O_SYNC | os.O_RDWR)
    except Exception as e:
        print "failed to open " + dev + ": ", e
        raise

    ioctl_data = struct.pack('II', start, nbytes)

    try:
        ret = fcntl.ioctl(fd, MEMERASE, ioctl_data)
    except Exception as e:
        print "ioctl failed: ", e
        raise

    os.close(fd)


def flash_write(dev, start, nbytes, ifile):

    try:
        fd = os.open(dev, os.O_SYNC | os.O_RDWR)
    except Exception as e:
        print "failed to open " + dev + ": ", e
        raise

    os.lseek(fd, start, os.SEEK_SET)

    while nbytes > 0:
        if (nbytes > 4096):
            rbytes = 4096
        else:
            rbytes = nbytes

        ichunk = ifile.read(rbytes)

        if not ichunk:
            print "read of flash failed"
            raise Exception

        os.write(fd, ichunk)
        nbytes -= rbytes

    os.close(fd)


def flash_read(dev, start, nbytes, ofile):
    try:
        fd = os.open(dev, os.O_SYNC | os.O_RDONLY)
    except Exception as e:
        print "failed to open " + dev + ": ", e
        raise

    os.lseek(fd, start, os.SEEK_SET)

    while nbytes > 0:
        if (nbytes > 4096):
            rbytes = 4096
        else:
            rbytes = nbytes

        ichunk = os.read(fd, rbytes)

        if not ichunk:
            print "read of flash failed"
            raise Exception

        ofile.write(ichunk)
        nbytes -= rbytes

    os.close(fd)


def parse_args():
    descr = 'A tool to help update the flash used to configure an '
    descr += 'Intel FPGA at power up.'

    epi = 'example usage:\n\n'
    epi += '    fpgaflash user new_image.rpd 0000:04:00.0\n\n'

    fc = argparse.RawDescriptionHelpFormatter
    parser = argparse.ArgumentParser(description=descr, epilog=epi,
                                     formatter_class=fc)

    parser.add_argument('type', help='type of flash programming',
                        choices=['user', 'factory'])
    parser.add_argument('file', type=argparse.FileType('rb'),
                        help='rpd file to program into flash')

    bdf_help = "bdf of device to program (e.g. 04:00.0 or 0000:04:00.0)"
    bdf_help += " optional when one device in system"

    parser.add_argument('bdf', nargs='?', help=bdf_help)
    return parser.parse_args()


def get_bdf_mtd_mapping():
    bdf_map = dict()
    for fpga in glob.glob('/sys/class/fpga/*'):
        bdf = os.path.basename(os.readlink(os.path.join(fpga, "device")))
        if not bdf:
            continue

        mtds = glob.glob(os.path.join(fpga, 'intel-fpga-fme.*',
                                      'altr-asmip2*', 'mtd', 'mtd*'))
        for mtd in mtds:
            if not fnmatch.fnmatchcase(mtd, "*ro"):
                bdf_map[bdf] = os.path.join('/dev', os.path.basename(mtd))
                break

    return bdf_map


def print_bdf_mtd_mapping(bdf_map):
    print "\nFPGA cards available for flashing:"

    for key in bdf_map.keys():
        print "    {}".format(key)

    print


def normalize_bdf(bdf):

    pat = '[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.[0-9a-fA-F]$'
    if re.match(pat, bdf):
        return bdf

    if re.match('[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.[0-9a-fA-F]$', bdf):
        return "0000:{}".format(bdf)


def update_flash(update_type, ifile, bdf):

    bdf_map = get_bdf_mtd_mapping()

    if len(bdf_map) == 0:
        print "No FPGA devices found"
        raise Exception

    mtd_dev = None
    if not bdf:
        if len(bdf_map) > 1:
            print "Must specify a bdf. More than one device found."
        else:
            mtd_dev = bdf_map[bdf_map.keys()[0]]
    else:
        bdf = normalize_bdf(bdf)
        if not bdf:
            print "{} is an invalid bdf".format(bdf)
        elif bdf not in bdf_map.keys():
            print "Could not find flash device for {}".format(bdf)
        else:
            mtd_dev = bdf_map[bdf]

    if not mtd_dev:
        print_bdf_mtd_mapping(bdf_map)
        raise Exception

    try:
        mode = os.stat(mtd_dev).st_mode
    except Exception as e:
        print "Couldn't stat {}".format(mtd_dev)
        raise

    if not stat.S_ISCHR(mode):
        print "{} is not a device file.".format(mtd_dev)
        raise Exception

    flash_size = get_flash_size(mtd_dev)

    print "flash size is {}".format(flash_size)

    start_addr = check_rpd(ifile)

    if update_type == 'factory':
        start_addr = 0

    ofile = tempfile.NamedTemporaryFile(mode='wb', delete=False)

    ifile.seek(start_addr)

    print "reversing bits"
    reverse_bits_in_file(ifile, ofile)

    ifile.close()
    ofile.close()

    print "erasing flash"
    flash_erase(mtd_dev, start_addr, (flash_size - start_addr))

    nbytes = os.path.getsize(ofile.name)
    try:
        rfile = open(ofile.name, 'rb')
    except Exception as e:
        print "open({}) failed:".format(ofile.name)
        raise

    print "writing flash"
    flash_write(mtd_dev, start_addr, nbytes, rfile)
    rfile.close()

    vfile = tempfile.NamedTemporaryFile(mode='wb', delete=False)

    print "reading back flash"
    flash_read(mtd_dev, start_addr, nbytes, vfile)

    vfile.close()

    print "verifying flash"

    retval = filecmp.cmp(ofile.name, vfile.name)

    os.remove(ofile.name)
    os.remove(vfile.name)

    if retval:
        print "flash successfully verified"
    else:
        print "failed to verify flash"
        raise Exception


if __name__ == "__main__":

    args = parse_args()

    if args.type == 'factory':
        msg = "Are you sure you want to perform a factory update? [Yes/No]"
        sys.stdout.write(msg)
        line = sys.stdin.readline()
        if line != "Yes\n":
            sys.exit(1)

    update_flash(args.type, args.file, args.bdf)
