#!/usr/bin/env python3
#
# EFI Boot Guard, unified kernel image generator
#
# Copyright (c) Siemens AG, 2022
#
# Authors:
#  Jan Kiszka <jan.kiszka@siemens.com>
#
# This work is licensed under the terms of the GNU GPL, version 2.  See
# the COPYING file in the top-level directory.
#
# SPDX-License-Identifier:	GPL-2.0

import argparse
import struct
import sys


def align(val, alignment):
    return (val + alignment - 1) & ~(alignment - 1)


class Section:
    IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040
    IMAGE_SCN_MEM_READ = 0x40000000

    def __init__(self, name, virt_size, virt_addr, data_size, data_offs,
                 chars):
        self.name = name
        self.virt_size = virt_size
        self.virt_addr = virt_addr
        self.data_size = data_size
        self.data_offs = data_offs
        self.chars = chars

    @staticmethod
    def from_struct(blob):
        (name, virt_size, virt_addr, data_size, data_offs, chars) = \
            struct.unpack_from('<8sIIII12xI', blob)
        return Section(name, virt_size, virt_addr, data_size, data_offs,
                       chars)

    def get_struct(self):
        return struct.pack('<8sIIII12xI', self.name, self.virt_size,
                           self.virt_addr, self.data_size, self.data_offs,
                           self.chars)


class PEHeaders:
    OPT_OFFS_SIZE_OF_INIT_DATA = 0x8
    OPT_OFFS_ADDRESS_OF_ENTRY_POINT = 0x10
    OPT_OFFS_BASE_OF_CODE = 0x14
    OPT_OFFS_SECTION_ALIGNMENT = 0x20
    OPT_OFFS_FILE_ALIGNMENT = 0x24
    OPT_OFFS_SIZE_OF_IMAGE = 0x38
    OPT_OFFS_SIZE_OF_HEADERS = 0x3C

    def __init__(self, name, blob):
        # Parse headers: DOS, COFF, optional header
        if len(blob) < 0x40:
            print("Invalid %s, image too small" % name, file=sys.stderr)
            exit(1)

        (magic, pe_offs) = struct.unpack_from('<H58xI', blob)

        if magic != 0x5a4d:
            print("Invalid %s, bad DOS header magic" % name, file=sys.stderr)
            exit(1)

        self.dos_header = blob[:pe_offs]

        self.header_size = pe_offs + 0x18
        if self.header_size > len(blob):
            print("Invalid %s, incomplete COFF header" % name, file=sys.stderr)
            exit(1)

        self.coff_header = blob[pe_offs:self.header_size]

        (magic, self.machine, num_sections, opt_header_size) = \
            struct.unpack_from('<IHH12xH2x', self.coff_header)
        if magic != 0x4550:
            print("Invalid %s, bad PE header magic" % name, file=sys.stderr)
            exit(1)

        coff_offs = self.header_size

        self.header_size += opt_header_size
        if self.header_size > len(blob):
            print("Invalid %s, incomplete optional header" % name,
                  file=sys.stderr)
            exit(1)

        self.opt_header = blob[coff_offs:self.header_size]

        section_offs = self.header_size

        self.header_size += num_sections * 0x28
        if self.header_size > len(blob):
            print("Invalid %s, incomplete section headers" % name,
                  file=sys.stderr)
            exit(1)

        self.first_data = len(blob)
        self.end_of_sections = 0

        self.sections = []
        for n in range(num_sections):
            section = Section.from_struct(
                blob[section_offs:section_offs+0x28])
            if section.data_offs + section.data_size > len(blob):
                print("Invalid %s, section data missing" % name,
                      file=sys.stderr)
                exit(1)

            if section.data_size and section.data_offs < self.first_data:
                self.first_data = section.data_offs

            end_of_section = section.data_offs + section.data_size
            if end_of_section > self.end_of_sections:
                self.end_of_sections = end_of_section

            self.sections.append(section)

            section_offs += 0x28

    def get_opt_header_field(self, offset):
        format = '<%dxI' % offset
        return struct.unpack_from(format, self.opt_header)[0]

    def set_opt_header_field(self, offset, val):
        format = '<%dsI%ds' % (offset, len(self.opt_header) - offset - 4)
        self.opt_header = struct.pack(format, self.opt_header[:offset], val,
                                      self.opt_header[offset+4:])

    def get_size_of_init_data(self):
        return self.get_opt_header_field(PEHeaders.OPT_OFFS_SIZE_OF_INIT_DATA)

    def set_size_of_init_data(self, size):
        self.set_opt_header_field(PEHeaders.OPT_OFFS_SIZE_OF_INIT_DATA, size)

    def get_address_of_entry_point(self):
        return self.get_opt_header_field(
            PEHeaders.OPT_OFFS_ADDRESS_OF_ENTRY_POINT)

    def set_address_of_entry_point(self, addr):
        self.set_opt_header_field(PEHeaders.OPT_OFFS_ADDRESS_OF_ENTRY_POINT,
                                  addr)

    def get_base_of_code(self):
        return self.get_opt_header_field(PEHeaders.OPT_OFFS_BASE_OF_CODE)

    def set_base_of_code(self, base):
        self.set_opt_header_field(PEHeaders.OPT_OFFS_BASE_OF_CODE, base)

    def get_section_alignment(self):
        return self.get_opt_header_field(PEHeaders.OPT_OFFS_SECTION_ALIGNMENT)

    def set_section_alignment(self, alignment):
        self.set_opt_header_field(PEHeaders.OPT_OFFS_SECTION_ALIGNMENT,
                                  alignment)
        self.set_size_of_image(align(self.get_size_of_image(), alignment))

    def get_file_alignment(self):
        return self.get_opt_header_field(PEHeaders.OPT_OFFS_FILE_ALIGNMENT)

    def get_size_of_image(self):
        return self.get_opt_header_field(PEHeaders.OPT_OFFS_SIZE_OF_IMAGE)

    def set_size_of_image(self, size):
        self.set_opt_header_field(PEHeaders.OPT_OFFS_SIZE_OF_IMAGE, size)

    def get_size_of_headers(self):
        return self.get_opt_header_field(PEHeaders.OPT_OFFS_SIZE_OF_HEADERS)

    def set_size_of_headers(self, size):
        self.set_opt_header_field(PEHeaders.OPT_OFFS_SIZE_OF_HEADERS, size)

    def add_section(self, section):
        self.header_size += 0x28
        size_of_headers = self.get_size_of_headers()
        if self.header_size > size_of_headers:
            size_of_headers = align(self.header_size,
                                    self.get_file_alignment())
            self.set_size_of_headers(size_of_headers)

        self.sections.append(section)
        self.coff_header = struct.pack('<6sH16s', self.coff_header[:6],
                                       len(self.sections),
                                       self.coff_header[8:])

        new_size = align(section.virt_addr + section.virt_size,
                         self.get_section_alignment())
        if new_size > self.get_size_of_image():
            self.set_size_of_image(new_size)

        if section.chars & Section.IMAGE_SCN_CNT_INITIALIZED_DATA:
            new_size = self.get_size_of_init_data() + section.data_size
            self.set_size_of_init_data(new_size)

        if size_of_headers > self.first_data:
            file_relocation = size_of_headers - self.first_data
            self.first_data += file_relocation
            for sect in self.sections:
                if sect.data_size > 0:
                    sect.data_offs += file_relocation

        end_of_section = section.data_offs + section.data_size
        if end_of_section > self.end_of_sections:
            self.end_of_sections = end_of_section


def main():
    parser = argparse.ArgumentParser(
        description='Generate unified kernel image')
    parser.add_argument('-c', '--cmdline', metavar='"CMDLINE"', default='',
                        help='kernel command line')
    parser.add_argument('-d', '--dtb', metavar='DTB', action="append",
                        default=[], type=argparse.FileType('rb'),
                        help='device tree for the kernel '
                        '(can be specified multiple times)')
    parser.add_argument('-i', '--initrd', metavar='INITRD',
                        type=argparse.FileType('rb'),
                        help='initrd/initramfs for the kernel')
    parser.add_argument('stub', metavar='STUB',
                        type=argparse.FileType('rb'),
                        help='stub image to use')
    parser.add_argument('kernel', metavar='KERNEL',
                        type=argparse.FileType('rb'),
                        help='image of the kernel')
    parser.add_argument('output', metavar='UNIFIEDIMAGE',
                        type=argparse.FileType('wb'),
                        help='name of unified kernel image file')

    try:
        args = parser.parse_args()
    except IOError as e:
        print(e.strerror, file=sys.stderr)
        exit(1)

    cmdline = (args.cmdline + '\0').encode('utf-16-le')

    stub = args.stub.read()

    pe_headers = PEHeaders('stub image', stub)
    stub_first_data = pe_headers.first_data
    stub_end_of_sections = pe_headers.end_of_sections
    file_align = pe_headers.get_file_alignment()

    # Add extra section headers
    current_offs = align(stub_end_of_sections, file_align)
    sect_size = align(len(cmdline), file_align)
    cmdline_section = Section(b'.cmdline', sect_size, 0x30000,
                              sect_size, current_offs,
                              Section.IMAGE_SCN_CNT_INITIALIZED_DATA |
                              Section.IMAGE_SCN_MEM_READ)
    pe_headers.add_section(cmdline_section)

    kernel = args.kernel.read()

    # Just to perform an integrity test for the kernel image
    PEHeaders('kernel', kernel)

    current_offs = cmdline_section.data_offs + cmdline_section.data_size
    sect_size = align(len(kernel), file_align)
    kernel_section = Section(b'.kernel', sect_size, 0x2000000,
                             sect_size, current_offs,
                             Section.IMAGE_SCN_CNT_INITIALIZED_DATA |
                             Section.IMAGE_SCN_MEM_READ)
    pe_headers.add_section(kernel_section)

    current_offs = kernel_section.data_offs + kernel_section.data_size
    if args.initrd:
        initrd = args.initrd.read()
        sect_size = align(len(initrd), file_align)
        initrd_section = Section(b'.initrd', sect_size, 0x6000000,
                                 sect_size, current_offs,
                                 Section.IMAGE_SCN_CNT_INITIALIZED_DATA |
                                 Section.IMAGE_SCN_MEM_READ)
        pe_headers.add_section(initrd_section)
        current_offs = initrd_section.data_offs + initrd_section.data_size

    dtb_virt = 0x40000
    dtb = []
    dtb_section = []
    for n in range(len(args.dtb)):
        dtb.append(args.dtb[n].read())
        sect_size = align(len(dtb[n]), file_align)
        section = Section(bytes('.dtb-{}'.format(n + 1), 'ascii'),
                          sect_size, dtb_virt, sect_size, current_offs,
                          Section.IMAGE_SCN_CNT_INITIALIZED_DATA |
                          Section.IMAGE_SCN_MEM_READ)
        pe_headers.add_section(section)
        dtb_section.append(section)

        dtb_virt += section.data_size
        current_offs = section.data_offs + section.data_size

    #
    # Some ARM toolchains use a minimal alignment and put the text section at
    # a too low virtual address. This causes troubles when we relocated
    # sections because the PE header is also loaded into memory.
    #
    # Align the virtual address of the text section to 4K therefore.
    #
    for sect in pe_headers.sections:
        if sect.name == b'.text\0\0\0' and sect.virt_addr < 0x1000:
            if pe_headers.first_data > 0x1000:
                print("PE header too large - way too many DTBs?!",
                      file=sys.stderr)
                exit(1)
            virt_relocation = 0x1000 - sect.virt_addr
            sect.virt_addr = 0x1000
            pe_headers.set_address_of_entry_point(
                pe_headers.get_address_of_entry_point() + virt_relocation)
            pe_headers.set_base_of_code(
                pe_headers.get_base_of_code() + virt_relocation)
            break

    # Build unified image header
    image = pe_headers.dos_header + pe_headers.coff_header + \
        pe_headers.opt_header
    for section in pe_headers.sections:
        image += section.get_struct()

    # Pad till first section data
    image += bytearray(pe_headers.first_data - len(image))

    # Write remaining stub
    image += stub[stub_first_data:stub_end_of_sections]

    # Write data of extra sections
    image += bytearray(cmdline_section.data_offs - len(image))
    image += cmdline

    image += bytearray(kernel_section.data_offs - len(image))
    image += kernel

    if args.initrd:
        image += bytearray(initrd_section.data_offs - len(image))
        image += initrd

    for n in range(len(dtb)):
        image += bytearray(dtb_section[n].data_offs - len(image))
        image += dtb[n]

    # Align to promised size of last section
    image += bytearray(align(len(image), file_align) - len(image))

    args.output.write(image)


if __name__ == "__main__":
    main()
