#!/usr/bin/python3
#
# This script is run by a VM host (eg "west") to prepare itself for testing
# It should be passed a testname as its only argument

import os
import sys
import socket
import shutil
import distutils.dir_util
import subprocess
import pexpect
import glob


try:
    import argparse
except ImportError as e:
    module = str(e)[16:]
    sys.exit("we requires the python argparse module")


def shell(command, out=False):
    """Run command as a sub-process and report failures"""
    if args.verbose:
        print(command)
    status, output = subprocess.getstatusoutput(command)
    if status:
        print(("command '%s' failed with status %d: %s" %
               (command, status, output)))
    elif (args.verbose or out) and output:
        print(output)
    return output

# prevent double installs from going unnoticed
if os.path.isfile("/usr/sbin/ipsec") or os.path.isfile("/sbin/ipsec"):
    if os.path.isfile("/usr/local/sbin/ipsec"):
        sys.exit("\n\n---------------------------------------------------------------------\n"
                 "ABORT: found a swan userland in the base system as well as /usr/local\n"
                 "---------------------------------------------------------------------\n")

# just in case a new copy was installed
if os.path.isfile("/usr/local/sbin/ipsec") and os.path.isfile("/usr/sbin/restorecon"):
    subprocess.getoutput(
        "restorecon -Rv /usr/local/libexec/ipsec /usr/local/sbin/ipsec")

ipv = 0
ipv6 = 0
parser = argparse.ArgumentParser(description='swan-prep arguments')
parser.add_argument('--testname', '-t', action='store',
                    default='', help='The name of the test to prepapre')
parser.add_argument('--hostname', '-H', action='store',
                    default='', help='The name of the host to prepare as')
# we should get this from the testparams.sh file?
parser.add_argument('--userland', '-u', action='store',
                    default='libreswan', help='which userland to prepapre')
parser.add_argument('--x509', '-x', action='store_true',
                    help='create X509 NSS file by importing test certs')
parser.add_argument('--dnssec', '-d', action='store_true',
                    help='start nsd and unbound for DNSSEC - meant only for nic')
parser.add_argument('--x509name', '-X', action='store', default="",
                    help='create X509 NSS file by importing test certs')
parser.add_argument('--fips', '-f', action='store_true',
                    help='prepare /etc/ipsec.d for running in FIPS mode')
parser.add_argument('--revoked', '-r', action='store_true',
                    help='load a revoked certificate')
parser.add_argument('--signedbyother', '-o', action='store_true',
                    help='load the signedbyother certificate')
parser.add_argument('--eccert', '-e', action='store_true',
                    help='enable an EC cert for this host')
parser.add_argument('--nsspw', action='store_true',
                    help='set the security password (on the NSS DB)')
parser.add_argument('--certchain', '-c', action='store_true',
                    help='import the ca-chain test certs')
parser.add_argument('--46', '--64', action='store_true',
                    help='Do not disable IPv6. Default is disable IPv6 ', dest='ipv', default=False)
parser.add_argument('--6', action='store_true',
                    help='Enable IPv6 and run - /etc/init.d/network restart', dest='ipv6', default=False)
parser.add_argument('--verbose', '-v', action='store_true',
                    help='more verbose')
parser.add_argument('--nokeys', action='store_true',
                    help='do not provide any keys')
args = parser.parse_args()

if args.hostname:
    hostname = args.hostname
    if hostname == "nic":
        # nothing to do, just stop
        sys.exit()

else:
    hostname = socket.gethostname()
if "." in hostname:
    hostname = hostname.split(".")[0]

if args.testname:
    # Should this instead try to guess /testing/pluto prefix, or
    # allow testname to specify a directory path?
    testname = args.testname
    testpath = "/testing/pluto/" + testname
    if not os.path.isdir(testpath):
        sys.exit("Unknown or bad testname '%s'" % args.testname)
else:
    # Validate this is sane?
    testpath = os.getcwd()
    testname = os.path.basename(testpath)

# Setup pluto.log softlink
if hostname != "nic":
    if os.path.isfile("/tmp/pluto.log") or os.path.islink("/tmp/pluto.log"):
        os.unlink("/tmp/pluto.log")

    outputdir = "%s/OUTPUT/" % testpath
    if not os.path.isdir(outputdir):
        os.mkdir(outputdir)
        os.chmod(outputdir, 0o777)

    if args.userland in ("libreswan", "openswan"):
        dname = "pluto"
    elif args.userland == "strongswan":
        dname = "charon"
    elif args.userland == "racoon":
        dname = "racoon"
    else:
        dname = "iked"

    daemonlogfile = "%s/%s.%s.log" % (outputdir, hostname, dname)
    tmplink = "/tmp/%s.log" % dname
    if os.path.islink(tmplink) or os.path.isfile(tmplink):
        os.unlink(tmplink)
    os.symlink(daemonlogfile, tmplink)
    f = open(daemonlogfile, 'w')
    f.close()
    os.chmod(daemonlogfile, 0o777)

if args.userland:
    if not args.userland in ("libreswan", "strongswan", "racoon", "shrew", "openswan"):
        sys.exit("swan-prep: unknown userland type '%s'" % args.userland)
    userland = args.userland
else:
    userland = "libreswan"

# print "swan-prep running on %s for test %s with userland
# %s"%(hostname,testname,userland)

# wipe any old configs in /etc/ipsec.*
if os.path.isfile("/etc/ipsec.conf"):
    os.unlink("/etc/ipsec.conf")
if os.path.isfile("/etc/ipsec.secrets"):
    os.unlink("/etc/ipsec.secrets")
if os.path.isdir("/etc/ipsec.d"):
    shutil.rmtree("/etc/ipsec.d")
    os.mkdir("/etc/ipsec.d")

# if using systemd, ensure we don't restart pluto on crash
if os.path.isfile("/lib/systemd/system/ipsec.service"):
    service = "".join(open("/lib/systemd/system/ipsec.service").readlines())
    if "Restart=always" in service:
        fp = open("/lib/systemd/system/ipsec.service", "w")
        fp.write("".join(service).replace("Restart=always", "Restart=no"))
    # always reload to avoid "service is masked" errors
    subprocess.getoutput("/usr/bin/systemctl daemon-reload")

# we have to cleanup the audit log or we could get entries from previous test
if os.path.isfile("/var/log/audit/audit.log"):
    fp = open("/var/log/audit/audit.log", "w")
    fp.close()
    if os.path.isfile("/lib/systemd/system/auditd.service"):
        subprocess.getoutput("/usr/bin/systemctl restart auditd.service")

# ensure cores are just dropped, not sent to aobrt-hook-ccpp or systemd.
# setting /proc/sys/kernel/core_pattern to "core" or a pattern does not
# work. And you cannot use shell redirection ">". So we hack it with "tee"
pattern = "|/usr/bin/tee /tmp/core.%h.%e.%p"
fp = open("/proc/sys/kernel/core_pattern", "w")
fp.write(pattern)
fp.close()

if userland == "libreswan" or userland == "openswan" or userland == "strongswan":
    # copy in base configs


    if hostname != "nic":
        # this brings in the nss *.db files that are path-specific -
        # they have pathnames hardcoded inside the file
        # shutil.copytree("/testing/baseconfigs/%s/etc/ipsec.d"%hostname,
        # "/etc/").  This default database contains a default key-pair
        # for basic testing.
        distutils.dir_util.copy_tree("/testing/baseconfigs/%s/etc/ipsec.d" % hostname, "/etc/ipsec.d/")
        # fill in any missing dirs
        if userland == "strongswan":
            prefix = "strongswan/"
        else:
            prefix = ""
        for dir in ("/etc/%sipsec.d" % prefix, "/etc/%sipsec.d/policies" % prefix, "/etc/%sipsec.d/cacerts" % prefix, "/etc/%sipsec.d/crls" % prefix, "/etc/%sipsec.d/certs" % prefix, "/etc/%sipsec.d/private" % prefix):
            if not os.path.isdir(dir):
                os.mkdir(dir)
                if "private" in dir:
                    os.chmod(dir, 0o700)

        # test specific files
        ipsecconf = "%s/ipsec.conf" % (testpath)
        ipsecsecrets = "%s/ipsec.secrets" % (testpath)
        h_ipsecconf = "%s/%s.conf" % (testpath, hostname)
        h_ipsecsecrets = "%s/%s.secrets" % (testpath, hostname)
        xl2tpdconf = "%s/%s.xl2tpd.conf" % (testpath, hostname)
        pppoptions = "%s/%s.ppp-options.xl2tpd" % (testpath, hostname)
        chapfile = "%s/chap-secrets" % testpath

        if os.path.isfile(h_ipsecconf):
            if os.path.isfile(ipsecconf):
                print("conflicting files %s %s" % (ipsecconf, h_ipsecconf))
            else:
                ipsecconf = h_ipsecconf
        elif not os.path.isfile(ipsecconf):
            ipsecconf = "/testing/baseconfigs/%s/etc/ipsec.conf" % hostname

        if os.path.isfile(h_ipsecsecrets):
            if os.path.isfile(ipsecsecrets):
                print("conflicting files %s %s" % (ipsecsecrets, h_ipsecsecrets))
            else:
                ipsecsecrets = h_ipsecsecrets
        elif not os.path.isfile(ipsecsecrets):
            ipsecsecrets = "/testing/baseconfigs/%s/etc/ipsec.secrets" % hostname

        if args.userland == "strongswan":
            # check version and fail early
            output = subprocess.getoutput("strongswan version")
            if not "U5.6.0" in output:
                sys.exit("strongswan 5.6.0 must be installed")
            # required to write log file in /tmp
            subprocess.getoutput("setenforce 0")
            strongswanconf = "%s/%sstrongswan.conf" % (testpath, hostname)
            shutil.copy(strongswanconf, "/etc/strongswan/strongswan.conf")
            for dir in ("/etc/strongswan/ipsec.d/aacerts", "/etc/strongswan/ipsec.d/ocspcerts"):
                if not os.path.isdir(dir):
                    os.mkdir(dir)

        dstconf = "/etc/%sipsec.conf" % (prefix)
        dstsecrets = "/etc/%sipsec.secrets" % (prefix)
        shutil.copy(ipsecconf, dstconf)
        shutil.copy(ipsecsecrets, dstsecrets)
        os.chmod(dstsecrets, 0o600)

        if os.path.isfile(xl2tpdconf):
            shutil.copyfile(xl2tpdconf, "/etc/xl2tpd/xl2tpd.conf")
        if os.path.isfile(pppoptions):
            shutil.copyfile(pppoptions, "/etc/ppp/options.xl2tpd")
        if os.path.isfile(chapfile):
            shutil.copyfile(chapfile, "/etc/ppp/chap-secrets")

        dbfiles = glob.glob("/etc/ipsec.d/*db")
        for dbfile in dbfiles:
            os.chown(dbfile, 0, 0)

    # restore /etc/hosts to original - some tests make changes
    shutil.copyfile("/testing/baseconfigs/all/etc/hosts", "/etc/hosts")
    resolv = "/testing/baseconfigs/all/etc/resolv.conf"
    if os.path.isfile("/testing/baseconfigs/%s/etc/resolv.conf" % hostname):
        resolv = "/testing/baseconfigs/%s/etc/resolv.conf" % hostname
    else:
        resolv = "/testing/baseconfigs/all/etc/resolv.conf"
    dst = "/etc/resolv.conf"
    if os.path.islink(dst):  # on fedora 22 it is link frist remove the link
        os.unlink(dst)
    shutil.copyfile(resolv, "/etc/resolv.conf")

if args.fips:
    fp = open("/etc/system-fips", "w")
    fp.close()
    shell("/testing/guestbin/fipson")
    # the test also requires using a modutil cmd which we cannot run here
    shutil.copyfile(
        "/testing/baseconfigs/all/etc/sysconfig/pluto.fips", "/etc/sysconfig/pluto")
else:
    shutil.copyfile(
        "/testing/baseconfigs/all/etc/sysconfig/pluto", "/etc/sysconfig/pluto")
    if os.path.isfile("/etc/system-fips"):
        os.unlink("/etc/system-fips")

# Set up NSS DB
if userland in ("libreswan", "openswan"):

    # Set password options.
    if args.nsspw or args.fips:
        dbpassword = "s3cret"
        util_pw = " -f /tmp/nsspw"
        p12cmd_pw = " -k /tmp/nsspw"
        with open("/tmp/nsspw", "w") as f:
            f.write(dbpassword)
            f.write("\n")
        with open("/etc/ipsec.d/nsspassword", "w") as f:
            if args.nsspw:
                f.write("NSS Certificate DB:" + dbpassword + "\n")
            if args.fips:
                f.write("NSS FIPS 140-2 Certificate DB:" + dbpassword + "\n")
    else:
        util_pw = ""
        p12cmd_pw = " -K ''"

    if args.x509 or args.nokeys:
        # Delete any existing db files, and start fresh.
        if args.x509:
            print("Preparing X.509 files")
        else:
            print("Creating empty NSS database")
        oldfiles = glob.glob("/etc/ipsec.d/*db")
        for oldfile in oldfiles:
            os.unlink(oldfile)
        shell("/usr/bin/certutil -N --empty-password -d sql:/etc/ipsec.d")

    # If needed set a password (this will upgrade any existing
    # database database)
    if args.nsspw or args.fips:
        with open("/tmp/pw", "w") as f:
            f.write("\n")
        shell("/usr/bin/certutil -W -f /tmp/pw -@ /tmp/nsspw -d sql:/etc/ipsec.d", out=True)

    # Switch on fips in the NSS db
    if args.fips:
        shell("/usr/bin/modutil -dbdir sql:/etc/ipsec.d -fips true -force", out=True)

    # this section is getting rough. could use a nice refactoring
    if args.x509:
        shutil.rmtree("/etc/strongswan/ipsec.d/cacerts/")
        os.mkdir("/etc/strongswan/ipsec.d/cacerts/")

        if not os.path.isfile("/testing/x509/keys/mainca.key"):
            print("\n\n---------------------------------------------------------------------\n"
                  "WARNING: no mainca.key file, did you run testing/x509/dist_certs.py?\n"
                  "---------------------------------------------------------------------\n")

        if args.eccert:
            p12 = hostname + "-ec"
            ca = "curveca"
            pw = "-W \"\""
        else:
            if args.x509name:
                p12 = args.x509name
            else:
                p12 = hostname
            ca = "mainca"
            pw = "-w /testing/x509/nss-pw"

        if args.certchain:
            icanum = 2
            pw = "-w /testing/x509/nss-pw"
            root = "mainca"
            ica_p = hostname + "_chain_int_"
            ee = hostname + "_chain_endcert"

            # a note about DB trusts
            # 'CT,,' is our root's trust. T is important!! it is for verifying "SSLClient" x509 KU
            # ',,' is an intermediate's trust
            # 'P,,' (trusted peer) is nic's for OCSP
            # 'u,u,u' will be end cert trusts that are p12 imported (with privkey)

            # mainca and nic import
            shell("/usr/bin/certutil -A -n %s -t 'CT,,' -d sql:/etc/ipsec.d/ -a -i /testing/x509/cacerts/%s.crt%s" %
                  (root, root, util_pw))
            shell("/usr/bin/certutil -A -n nic -t 'P,,' -d sql:/etc/ipsec.d/ -a -i /testing/x509/certs/nic.crt%s" % util_pw)

            # install ee
            shell("/usr/bin/pk12util -i /testing/x509/pkcs12/%s.p12 -d sql:/etc/ipsec.d %s%s" %
                  (ee, pw, p12cmd_pw))

            # install intermediates
            for i in range(1, icanum + 1):
                acrt = ica_p + str(i)
                shell("/usr/bin/certutil -A -n %s -t ',,' -d sql:/etc/ipsec.d/ -a -i /testing/x509/certs/%s.crt%s" %
                      (acrt, acrt, util_pw))
            if args.revoked:
                shell("/usr/bin/pk12util -i /testing/x509/pkcs12/%s_revoked.p12 -d /etc/ipsec.d %s%s" %
                      (hostname + "_chain", pw, p12cmd_pw))

        else:
            shell("/usr/bin/pk12util -i /testing/x509/pkcs12/%s/%s.p12 -d sql:/etc/ipsec.d %s%s" %
                  (ca, p12, pw, p12cmd_pw))
            # install all other public certs
            # libreswanhost = os.getenv("LIBRESWANHOSTS") #kvmsetu.sh is not
            # sourced

            if args.revoked:
                shell(
                    "/usr/bin/pk12util -i /testing/x509/pkcs12/mainca/revoked.p12 -d sql:/etc/ipsec.d %s%s" % (pw, p12cmd_pw))

            if args.signedbyother:
                shell(
                    "/usr/bin/pk12util -i /testing/x509/pkcs12/otherca/signedbyother.p12 -d sql:/etc/ipsec.d %s%s" % (pw, p12cmd_pw))

            # fix trust import from p12
            # is pw needed?
            shell("/usr/bin/certutil -M -n 'Libreswan test CA for mainca - Libreswan' -d sql:/etc/ipsec.d/ -t 'CT,,'%s" % (util_pw))

            for certname in ("west", "east", "road", "north", "hashsha2", "west-ec", "east-ec", "nic"):
                if not hostname in certname:
                    shell("/usr/bin/certutil -A -n %s -t 'P,,' -d sql:/etc/ipsec.d -a -i /testing/x509/certs/%s.crt%s" %
                          (certname, certname, util_pw))

# Strong swan has simple files.
if userland == "strongswan" and args.x509:
    if args.eccert:
        ca = "curveca"
        key = hostname + "-ec"
    else:
        ca = "mainca"
        key = hostname

    shutil.copy("/testing/x509/cacerts/%s.crt" %
                ca, "/etc/strongswan/ipsec.d/cacerts/")
    shutil.copy("/testing/x509/keys/%s.key" %
                key, "/etc/strongswan/ipsec.d/private/")

    for certname in ("west", "east", "road", "north", "hashsha2", "west-ec", "east-ec"):
        shutil.copy("/testing/x509/certs/%s.crt" %
                    certname, "/etc/strongswan/ipsec.d/certs/")


if userland == "racoon2" or userland == "racoon":
    # Racoon uses x509 files straight from /testing/x509/*
    # shutil.copytree("%s/%s-racoon/"%(testpath,hostname), "/etc/racoon2/")
    distutils.dir_util.copy_tree("%s/%s-racoon" %
                                 (testpath, hostname), "/etc/racoon2/")
    os.chmod("/etc/racoon2/psk/test.psk", 0o600)
    os.chmod("/etc/racoon2/spmd.pwd", 0o600)
    if not os.path.isdir("/var/run/racoon2"):
        os.mkdir("/var/run/racoon2")
        os.chmod("/var/run/racoon2", 0o600)
    # craete stub ipsec.conf for stackmanager to load netkey
    ipsecconf = open("/etc/ipsec.conf", "w")
    ipsecconf.write("version 2\n")
    ipsecconf.write("config setup\n")
    ipsecconf.write("\tprotostack=netkey\n")
    ipsecconf.close()

if userland == "shrew":
    print("shrew not yet tested/integrated")

if hostname != "nic" and not args.ipv:
    subprocess.getoutput("sysctl net.ipv6.conf.all.disable_ipv6=1")
    subprocess.getoutput("sysctl net.ipv6.conf.default.disable_ipv6=1")

if args.ipv6:
    subprocess.getoutput("sysctl net.ipv6.conf.all.disable_ipv6=0")
    subprocess.getoutput("sysctl net.ipv6.conf.default.disable_ipv6=0")

    if os.path.isfile("/usr/bin/systemctl"):
        subprocess.getoutput("systemctl restart network.service")
    else:
        subprocess.getoutput("service network restart")

if args.dnssec:
        subprocess.getoutput("service nsd start")
        subprocess.getoutput("service unbound start")

if not os.path.isfile("/root/.gdbinit"):
    fp = open("/root/.gdbinit", "w")
    fp.write("set auto-load safe-path /")
    fp.close()

subprocess.getoutput("iptables -F");
subprocess.getoutput("iptables -X");

# Create LOGDROP (used to be done in swan-transmogrify but we want it here
# for docker)
for ipt in ("iptables", "ip6tables"):
    subprocess.getoutput("%s -N LOGDROP " % ipt)
    subprocess.getoutput("%s -A LOGDROP -j LOG" % ipt)
    subprocess.getoutput("%s -A LOGDROP -j DROP" % ipt)

# final prep - this kills any running userland
subprocess.call(["systemctl", "stop", "ipsec"])
# for some reason this fails to stop strongswan?
subprocess.call(["systemctl", "stop", "strongswan"])
# python has no pidof - just outsource to the shell, thanks python!
for dname in ( "pluto", "charon", "starter", "iked", "racoon2-spmd", "racoon2-iked"):
    try:
        subprocess.check_output(["killall", "-9", dname], stderr=subprocess.STDOUT)
    except:
        pass

# note shrew and racoon2 share a pid file name
for pidfile in ("/var/run/pluto/pluto.pid", "/var/run/charon.pid", "/var/run/iked.pid", "/var/run/spmd.pid", "/var/run/starter.charon.pid"):
    if os.path.isfile(pidfile):
        os.unlink(pidfile)

# remove stacks so test can start the stack it needs.
subprocess.getoutput("ipsec _stackmanager stop")
