#!/usr/bin/python3 -I
import argparse
import os
import subprocess
import sys
import re
from typing import List, Tuple

class SectionAdder:
    def __init__(self, base_addr, round_to_next):
        self._address = base_addr
        self._args = []
        self.round_to_next = round_to_next
    def add_section(self, section_name, path):
        self._address = self.round_to_next(self._address)
        self._args.append(f"--add-section=.{section_name}={path}")
        self._args.append(f"--change-section-vma=.{section_name}={self._address}")
        self._address += os.stat(path).st_size
    def objcopy_args(self):
        return self._args

def main(args):
    section_re = re.compile(rb"\A *(?:0|[1-9][0-9]*) +([!-~]+) +([0-9a-f]{8}) +([0-9a-f]{16}) +[0-9a-f]{16} +[0-9a-f]{8} +2")
    alignment_mask = (1 << 21) - 1
    parser = argparse.ArgumentParser(
            prog='uki-generate',
            description='Create a unified kernel image for Qubes OS')
    parser.add_argument('hypervisor', help="Xen Kernel")
    parser.add_argument('config', help="Xen config file")
    parser.add_argument('kernel', help="Dom0 kernel")
    parser.add_argument('initramfs')
    parser.add_argument('output', help="UKI image save path")
    parser.add_argument('--custom-section', nargs=2, action='append', help='A custom section placed right before the kernel')
    args = parser.parse_args(args[1:])
    hyp, cfg, kern, initramfs, custom, out = args.hypervisor, args.config, args.kernel, args.initramfs, args.custom_section, args.output
    if hyp[0] != '/':
        hyp = './' + hyp
    if out[0] != '/':
        out = './' + out
    output = subprocess.check_output([
        "objdump",
        "-WE", # do not use debuginfod
        "--section-headers",
        "--",
        hyp,
    ])
    max_vma = 0
    for line in output.splitlines():
        m = section_re.match(line)
        if not m:
            continue
        section_name, size, start_vma = m.group(1), int(m.group(2), 16), int(m.group(3), 16)
        if section_name.startswith(b".annobin"):
            continue
        max_vma = max(max_vma, size, + start_vma)
    def round_to_next(f: int) -> int:
        max_address = (0xffffffffffffffff & ~alignment_mask)
        if f > max_address:
            print(f"Fatal error: Address overflow: {f} exceeds {max_address}", file=sys.stderr)
            sys.exit(1)
        return (f + alignment_mask) & ~alignment_mask

    adder = SectionAdder(base_addr = max_vma, round_to_next = round_to_next)
    adder.add_section("config", cfg)
    for (section, path) in custom or []:
        adder.add_section(section, path)
    adder.add_section("kernel", kern)
    adder.add_section("ramdisk", initramfs)
    cmdline = (
        [
            "objcopy",
            f"--section-alignment={alignment_mask + 1}",
            f"--file-alignment={1 << 5}",
            #"--remove-section=.buildid",
            "--remove-section=.annobin.*",
            #"--strip-debug",
        ]
        + adder.objcopy_args()
        + [
            "--long-section-names=disable",
            "--",
            hyp,
            out,
        ]
    )
    subprocess.check_call(cmdline, stdin=subprocess.DEVNULL, stdout=subprocess.DEVNULL)
if __name__ == '__main__':
    try:
        main(sys.argv)
    except subprocess.CalledProcessError as e:
        sys.exit(e.returncode)
