#!/usr/bin/python3

import libertine.utils
import os
import select
import signal
import sys

from socket import *


def accept_new_connection():
    newconn = container_dbus_session_sock.accept()[0]
    descriptors.append(newconn)

    host_dbus_session_sock = socket(AF_UNIX, SOCK_STREAM)
    host_dbus_session_sock.connect(host_dbus_session_addr)
    descriptors.append(host_dbus_session_sock)

    socket_pairs.append([newconn, host_dbus_session_sock])


def get_socket_pair(socket):
    for i in range(len(socket_pairs)):
        if socket in socket_pairs[i]:
            return socket_pairs[i]


def get_socket_partner(socket):
    socket_pair = get_socket_pair(socket)

    for i in range(len(socket_pair)):
        if socket != socket_pair[i]:
            return socket_pair[i]


def close_connections(remove_socket):
    partner_socket = get_socket_partner(remove_socket)

    socket_pair = get_socket_pair(remove_socket)
    socket_pairs.remove(socket_pair)

    descriptors.remove(remove_socket)
    remove_socket.shutdown(SHUT_RDWR)
    remove_socket.close()

    descriptors.remove(partner_socket)
    partner_socket.shutdown(SHUT_RDWR)
    partner_socket.close()


def close_all_connections():
    for i, j in socket_pairs:
        i.shutdown(SHUT_RDWR)
        i.close()
        j.shutdown(SHUT_RDWR)
        j.close()


def get_host_dbus_socket():
    socket_key = "DBUS_SESSION_BUS_ADDRESS=unix:abstract="

    with open(os.path.join(libertine.utils.get_user_runtime_dir(), 'dbus-session'), 'r') as fd:
        dbus_session_str = fd.read()

    fd.close()

    host_dbus_socket = dbus_session_str.partition(socket_key)[2]
    host_dbus_socket = host_dbus_socket.rstrip('\n')
    host_dbus_socket = "\0%s" % host_dbus_socket

    return host_dbus_socket


def socket_cleanup(signum, frame):
    container_dbus_session_sock.close()
    close_all_connections()
    os.remove(dbus_session_socket_path)


def main_loop():
    signal.signal(signal.SIGTERM, socket_cleanup)
    signal.signal(signal.SIGINT, socket_cleanup)

    while 1:
        try:
            rlist, wlist, elist = select.select(descriptors, [], [])
        except InterruptedError:
            continue
        except:
            break

        for sock in rlist:
            if sock.fileno() == -1:
                continue

            if sock == container_dbus_session_sock:
                accept_new_connection()

            else:
                data = sock.recv(4096)
                if len(data) == 0:
                    close_connections(sock)
                    continue

                send_sock = get_socket_partner(sock)

                if send_sock.fileno() < 0:
                    continue

                totalsent = 0
                while totalsent < len(data):
                    sent = send_sock.send(data)

                    if sent == 0:
                        close_connections(sock)
                        break
                    totalsent = totalsent + sent


dbus_session_socket_path = sys.argv[1]

container_dbus_session_sock = socket(AF_UNIX, SOCK_STREAM)
container_dbus_session_sock.bind(dbus_session_socket_path)
container_dbus_session_sock.listen(5)

host_dbus_session_addr = get_host_dbus_socket()
descriptors = [container_dbus_session_sock]
socket_pairs = []

main_loop()
