#!/usr/bin/python3
#
# The Qubes OS Project, http://www.qubes-os.org
#
# Copyright (C) 2022  Marek Marczykowski-Górecki
#                               <marmarek@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

from __future__ import annotations

from itertools import cycle
import subprocess
import sys

import dbus
import qubesdb
from ipaddress import IPv4Address
import os

def get_dns_resolv_conf():
    nameservers = []
    try:
        resolv = open("/etc/resolv.conf", "r", encoding="UTF-8")
    except IOError:
        return nameservers
    with resolv:
        for line in resolv:
            tokens = line.split(None, 2)
            if len(tokens) < 2 or tokens[0] != "nameserver":
                continue
            try:
                nameservers.append(IPv4Address(tokens[1]))
            except ValueError:
                pass
    return nameservers

def get_dns_resolved():
    try:
        bus = dbus.SystemBus()
    except dbus.exceptions.DBusException as s:
        if s.get_dbus_name() == 'org.freedesktop.DBus.Error.NoReply':
            return get_dns_resolv_conf()
        raise
    try:
        resolve1 = bus.get_object('org.freedesktop.resolve1',
                                  '/org/freedesktop/resolve1')
        dns = resolve1.Get('org.freedesktop.resolve1.Manager',
                           'DNS',
                           dbus_interface='org.freedesktop.DBus.Properties')
        if dns is None:
            return get_dns_resolv_conf()

    except dbus.exceptions.DBusException as s:
        error = s.get_dbus_name()
        if error in (
            'org.freedesktop.DBus.Error.ServiceUnknown',
            'org.freedesktop.DBus.Error.NameHasNoOwner',
            'org.freedesktop.DBus.Error.NoSuchUnit',
        ) or error.startswith('org.freedesktop.systemd1.'):
            return get_dns_resolv_conf()
        raise
    # Use global entries first
    dns.sort(key=lambda x: x[0] != 0)
    # Only keep IPv4 entries. systemd-resolved is trusted to return valid
    # addresses.
    # ToDo: We only need abridged IPv4 DNS entries for ifindex == 0.
    # to ensure static DNS of disconnected network interfaces are not added.
    return [IPv4Address(bytes(addr)) for ifindex, family, addr in dns
            if family == 2]

def install_firewall_rules(dns):
    qdb = qubesdb.QubesDB()
    qubesdb_dns = []
    for i in ('/qubes-netvm-primary-dns', '/qubes-netvm-secondary-dns'):
        ns_maybe = qdb.read(i)
        if ns_maybe is None:
            continue
        try:
            qubesdb_dns.append(IPv4Address(ns_maybe.decode("ascii", "strict")))
        except (UnicodeDecodeError, ValueError):
            pass
    preamble = [
        'add table ip qubes',
        # Add the chain so that the subsequent delete will work. If the chain already
        # exists this is a harmless no-op.
        'add chain ip qubes dnat-dns',
        # Delete the chain so that if the chain already exists, it will be removed.
        # The removal of the old chain and addition of the new one happen as a single
        # atomic operation, so there is no period where neither chain is present or
        # where both are present.
        'delete chain ip qubes dnat-dns',
    ]
    rules = [
        'table ip qubes {',
        'chain dnat-dns {',
        'type nat hook prerouting priority dstnat; policy accept;',
    ]
    dns_resolved = get_dns_resolved()
    if not dns_resolved:
        # User has no IPv4 DNS set in sys-net. Maybe IPv6 only environment.
        # Or maybe user wants to enforce DNS-Over-HTTPS.
        # Drop IPv4 DNS requests to qubesdb_dns addresses.
        for vm_nameserver in qubesdb_dns:
            vm_ns_ = str(vm_nameserver)
            rules += [
                f"ip daddr {vm_ns_} udp dport 53 drop",
                f"ip daddr {vm_ns_} tcp dport 53 drop",
            ]
    else:
        for vm_nameserver, dest in zip(qubesdb_dns, cycle(dns_resolved)):
            vm_ns_ = str(vm_nameserver)
            dns_ = str(dest)
            rules += [
                f"ip daddr {vm_ns_} udp dport 53 dnat to {dns_}",
                f"ip daddr {vm_ns_} tcp dport 53 dnat to {dns_}",
            ]
    rules += ["}", "}"]

    # check if new rules are the same as the old ones - if so, don't reload
    # and return that info via exit code
    try:
        old_rules = subprocess.check_output(
            ["nft", "list", "chain", "ip", "qubes", "dnat-dns"]).decode().splitlines()
    except subprocess.CalledProcessError:
        old_rules = []
    old_rules = [line.strip() for line in old_rules]

    if old_rules == rules:
        sys.exit(100)

    if os.path.exists('/run/qubes-service/qubes-firewall'):
        subprocess.call(["systemctl",
                         "--no-block",
                         "try-reload-or-restart",
                         "qubes-firewall.service"
        ])

    os.execvp("nft", ("nft", "--", "\n".join(preamble + rules)))

if __name__ == '__main__':
    install_firewall_rules(get_dns_resolved())
