#!/usr/bin/env python3
# vim: et ts=4 sw=4 sts=4
#
# import from obsolete JSON KASP to LMDB-beckended KASP database.
#

from __future__ import print_function

import datetime
import time
import json
import sys
import re
import glob
import argparse
import time
import traceback
import os
import hashlib
import importlib
import codecs

opt_force = False
lmdb = None

def lmdb_requirement():
	global lmdb

	try:
		lmdb = importlib.import_module('lmdb')
	except ImportError:
		print("Error: unable to import module LMDB.")
		print("Probably you need to 'apt install python3-lmdb'.")
		sys.exit(10)

# workarounding that python 2 doesn't have int.to_bytes()
def to_bytes(n, length, endianness='big'):
	h = '%x' % n
	assert len(h) <= length * 2
	s = ('0'*(len(h) % 2) + h).zfill(length * 2)
	if sys.version_info >= (3,0):
		sb = codecs.decode(s, 'hex')
	else:
		sb = s.decode('hex')
	return bytearray(sb) if endianness == 'big' else bytearray(sb[::-1])

def from_bytes(ba, endianness='big'):
	x = ba if endianness == 'big' else bytearray(s[::-1])
	if sys.version_info >= (3,0):
		hx = codecs.encode(x, 'hex')
	else:
		hx = str(x).encode('hex')
	return int(hx, 16)

# aka knot_dname_from_str_alloc()
def str2dname(s):
	if s.endswith('.') is False:
		s += '.'
	res = bytearray(b"")
	nodes = s.lower().split('.')
	if nodes[-1] != "":
		nodes.append("")

	for node in nodes:
		res.append(len(node))
		res.extend(bytearray(node.lower(), 'ascii'))

	return res

def dname2str(dn):
	res = ""
	beg = 0
	end = ord(dn[0]) + 1
	while ord(dn[beg]) > 0:
		res += str(dn[beg+1:end]) + "."
		beg = end
		end = beg + ord(dn[beg]) + 1

	return res

# this is just helper for shuffling time
def shuffle_unixtime(base_time, shuffle_years, shuffle_months):
	rsm = shuffle_months + 12 * shuffle_years
	dt = datetime.datetime.fromtimestamp(base_time)
	newmonth = (dt.month - 1 + rsm) % 12 + 1 # in python, % always returns [0, 11]
	sameyear = dt.month + rsm % 12
	newyear = dt.year + rsm // 12 + (0 if sameyear in range(1, 13) else 1) # in python, (-1)//12 = -1
	dt2 = dt.replace(month=newmonth, year=newyear)
	print(dt2.month, "/", dt2.year)
	ttuple = dt2.timetuple()
	return int(time.mktime(ttuple))

def timespec2unix(spec):
	if re.match(r"^\d+$", spec):
		return int(spec)

	now = int(time.time())
	s = re.sub(r"^now", "t", spec)
	if s == "t":
		return now

	unitmap = { "" : 1, "mi" : 60, "h" : 3600, "d" : 86400 }
	unitmap_mo = { "mo" : 1, "y" : 12 }

	if re.match(r"^t[-+]\d+", s):
		unit = re.sub(r"^t[-+]\d+", "", s)
		cutend = len(s) if unit == "" else -len(unit)
		if unit in list(unitmap.keys()):
			return now + int(s[1:cutend]) * unitmap[unit]
		elif unit in list(unitmap_mo.keys()):
			return shuffle_unixtime(now, 0, int(s[1:cutend]) * unitmap_mo[unit])
		else:
			print("Error in time unit specification")

	print("Error in time specification")	

class Keykey:
	'''Kasp DB key serialized (type, zone_name, key_id)'''

	def __init__(self, raw_bytearray):
		self.raw = bytearray(raw_bytearray)

	@classmethod
	def from_params(self, valtype, zone_name, key_id):
		selfraw = to_bytes(valtype, 1)
		if zone_name is not None:
			selfraw.extend(zone_name)
		if key_id is not None:
			selfraw.extend(bytearray(key_id.encode("ascii")))
			selfraw.append(0)
		return Keykey(selfraw)

	def getRaw(self):
		return bytearray(self.raw)

	def getType(self):
		return self.raw[0]

	def __getSplit(self):
		x = self.raw.find(to_bytes(0, 1))
		assert x > 0
		return x + 1

	def getZone(self):
		if self.getType() == 2:
			return None
		return str(self.raw[1:self.__getSplit()])

	def getKeyid(self):
		if self.getType() != 1:
			return None
		return str(self.raw[self.__getSplit():])

class Keyparams:
	'''Serialized key parameters for kasp-db.'''

	def __init__(self, raw_bytearray):
		self.raw = bytearray(raw_bytearray)
		self.timers_dict = { "created" : [ 0, 20, 28 ],
		                     "publish" : [ 1, 28, 36 ],
		                     "ready"   : [ 2, 36, 44 ],
		                     "active"  : [ 3, 44, 52 ],
		                     "retire"  : [ 4, 52, 60 ],
		                     "remove"  : [ 5, 60, 68 ] }

	@classmethod
	def from_params(self, pubkey, keytag, algorithm, isksk, timers):
		assert len(timers) == 6
		if sys.version_info >= (3,0):
			pk = codecs.decode(bytearray(pubkey, 'ascii'), "base64")
		else:
			pk = pubkey.decode("base64")
		selfraw = to_bytes(len(pk), 8)
		selfraw.extend(to_bytes(0, 8)) # zero length of unused-future
		selfraw.extend(to_bytes(int(keytag), 2))
		selfraw.extend(to_bytes(int(algorithm), 1))
		selfraw.extend(to_bytes((1 if isksk else 0), 1))
		for t in timers:
			if t < 0:
				print("keytag=%i timers=(%i, %i, %i, %i, %i, %i)" % (keytag,
				timers[0], timers[1], timers[2], timers[3], timers[4], timers[5]))
				assert False
			selfraw.extend(to_bytes(t, 8))
		selfraw.extend(pk)
		return Keyparams(selfraw)

	def _check(self):
		assert len(self.raw) >= 16
		pkl = from_bytes(self.raw[0:8])
		ufl = from_bytes(self.raw[8:16])
		assert len(self.raw) == 68 + pkl + ufl
		assert self.raw[19] < 2

	def getRaw(self):
		self._check()
		return bytearray(self.raw)

	def getAlgorithm(self):
		self._check()
		return int(self.raw[18])

	def setAlgorithm(self, algorithm):
		self._check()
		self.raw[18] = to_bytes(algorithm, 1)[0]

	def isKSK(self):
		self._check()
		return 1 if self.raw[19] != 0 else 0

	def setKSK(self, isksk):
		self._check()
		self.raw[11] = (b"\01" if isksk else b"\00")[0]

	def getKeytag(self):
		self._check()
		return from_bytes(self.raw[16:18])

	def setKeytag(self, keytag):
		self._check()
		self.raw[16:18] = to_bytes(keytag, 2)

	def getTimers(self):
		self._check()
		res = [ 0, 0, 0, 0, 0, 0 ]
		for i, x, y in list(self.timers_dict.values()):
			res[i] = from_bytes(self.raw[x:y])
		return res

	def getTimersString(self):
		self._check()
		res = "["
		for ti in list(self.timers_dict.keys()):
			_, x, y = self.timers_dict[ti];
			res += (" " if res == "[" else ", ") + ti + ": " + str(from_bytes(self.raw[x:y]))
		return res + " ]"

	def setTimers(self, timers):
		self._check()
		assert len(timers) == 5
		for i, x, y in list(self.timers_dict.values()):
			self.raw[x:y] = to_bytes(timers[i], 8)

	def getPubKey(self):
		self._check()
		pkl = from_bytes(self.raw[0:8])
		return self.raw[68:68+pkl].encode("base64")

	def getParams(self):
		return [ self.getPubKey(), self.getKeytag(), self.getAlgorithm(),
		         self.isKSK(), self.getTimers() ];

	def setByParamName(self, param_name, new_val):
		if param_name == "algorithm":
			self.setAlgorithm(int(new_val))
		elif param_name == "isksk":
			if new_val in ("1", "True", "true", "on", "yes", "Yes"):
				self.setKSK(True)
			elif new_val in ("0", "False", "false", "off", "no", "No"):
				self.setKSK(False)
			else:
				print("Error: bad true/false value", new_val)
		elif param_name == "keytag":
			self.setKeytag(int(new_val))
		elif param_name in list(self.timers_dict.keys()):
			_, x, y = self.timers_dict[param_name]
			self.raw[x:y] = to_bytes(timespec2unix(new_val), 8)
		else:
			print("Error: bad parameter", param_name)

	def computeDS(self, zone_str, digestalg):
		ds_raw = bytearray(str2dname(zone_str))
		ds_raw.extend(to_bytes(257 if self.isKSK() else 256, 2))
		ds_raw.extend(b"\x03") # protocol is always == 3
		ds_raw.extend(self.raw[18:19]) # algorithm
		pkl = from_bytes(self.raw[0:8])
		ds_raw.extend(self.raw[68:68+pkl]) # pubkey
		if digestalg == "sha1":
			ds_hash = hashlib.sha1(ds_raw).hexdigest()
			algno = " 1 "
		elif digestalg == "sha256":
			ds_hash = hashlib.sha256(ds_raw).hexdigest()
			algno = " 2 "
		elif digestalg == "sha384":
			ds_hash = hashlib.sha384(ds_raw).hexdigest()
			algno = " 4 "
		else:
			print("Error: bad DS digest algorith", ds_hash)
			return
		return zone_str + ' DS ' + str(self.getKeytag()) + ' ' + str(self.getAlgorithm()) + algno + ds_hash

	def isPublished(self, moment):
		tmrs = self.getTimers()
		if tmrs[self.timers_dict["publish"][0]] <= moment:
			if moment < tmrs[self.timers_dict["remove"][0]]:
				return True
		return False

	def isReady(self, moment):
		tmrs = self.getTimers()
		if tmrs[self.timers_dict["ready"][0]] <= moment:
			if moment < tmrs[self.timers_dict["ready"][0]]:
				return True
		return False

	def isActive(self, moment):
		tmrs = self.getTimers()
		if tmrs[self.timers_dict["active"][0]] <= moment:
			if moment < tmrs[self.timers_dict["retire"][0]]:
				return True
		return False

	def isRetired(self, moment):
		tmrs = self.getTimers()
		if tmrs[self.timers_dict["retire"][0]] <= moment:
			return True
		return False

	def isRemoved(self, moment):
		tmrs = self.getTimers()
		if tmrs[self.timers_dict["remove"][0]] <= moment:
			return True
		return False

# static: just for use in following method
def arr_ind2unix(arr, ind, defaul):
	try:
		ttuple = datetime.datetime.strptime(arr[ind], "%Y-%m-%dT%H:%M:%S+0000").timetuple()
		res = int(time.mktime(ttuple))
		return res if res >= 0 else 0
	except KeyError:
		return defaul

def import_nsec3salt(keys, env, db_keys, zname):
	try:
		with lmdb.Transaction(env, db_keys, write=True) as txn_keys:
			dbk1 = Keykey.from_params(3, zname, None).getRaw()
			dbv1 = keys["nsec3_salt"]
			if dbv1 is None:
				return
			if sys.version_info >= (3,0):
				dbv1d = codecs.decode(bytearray(dbv1, 'ascii'), "base64")
			else:
				dbv1d = dbv1.decode("base64")
			txn_keys.put(dbk1, dbv1d, dupdata=False, overwrite=True)

			dbk2 = Keykey.from_params(4, zname, None).getRaw()
			dbv2 = to_bytes(arr_ind2unix(keys, "nsec3_salt_created", 0), 8)
			txn_keys.put(dbk2, dbv2, dupdata=False, overwrite=True)
	except (KeyError, AttributeError):
		pass # nsec3salt not configured or set to null, no problem

# import single JSON zone config into open LMDB env
def import_file(fname, env, db_keys):
	try:
		with open(fname) as f:
			keys = json.load(f)

	except ValueError:
		print("Warning: not imported ", fname)
		return False

	try:
		zname_str = re.sub(r'^zone_', '', re.sub(r'\.json$', '', re.sub(r'.*/', '', fname)))
		print("Importing zone", zname_str)
		zname = str2dname(zname_str)
		import_nsec3salt(keys, env, db_keys, zname)

		import_now = int(time.time())

		for key in keys["keys"]:
			dbk3 = Keykey.from_params(1, zname, key["id"]).getRaw()

			infty = 0x00ffffffffffff00 # time infinity, this is year 142'715'360

			dbv3 = Keyparams.from_params(key["public_key"], key["keytag"],
			                             key["algorithm"], key["ksk"], [
			                               arr_ind2unix(key, "created", 0),
			                               arr_ind2unix(key, "publish", 0),
			                               arr_ind2unix(key, "active", 0), # taking active for ready
			                               arr_ind2unix(key, "active", 0),
			                               arr_ind2unix(key, "retire", infty),
			                               arr_ind2unix(key, "remove", infty)
			                             ])

			if dbv3.isRemoved(import_now):
				continue

			with lmdb.Transaction(env, db_keys, write=True) as txn_keys:
				txn_keys.put(dbk3, dbv3.getRaw(), dupdata=False, overwrite=True)

	except (KeyError, KeyboardInterrupt, TypeError):
		print("Warning: not imported ", fname)
		return False

	return True

def import_dir(dirname):
	print("Importing json key config in", dirname)
	if os.path.isfile(dirname + "/data.mdb"):
		print("Warning: LMDB key configuration in", dirname, "already exists.")
		if opt_force:
			print("...deleting it.")
			os.remove(dirname + "/data.mdb")
			os.remove(dirname + "/lock.mdb")
		else:
			print("If you want to delete it and import again, use 'force' option.")
			return False

	env = lmdb.open(dirname, max_dbs=2, map_size=500*1024*1024)
	db_keys = env.open_db(b"keys_db")
	something_imported = False
	for json_file in glob.glob(dirname + "/*.json"):
		something_imported = import_file(json_file, env, db_keys) or something_imported

	if not something_imported:
		print("Warning: nothing imported in", dirname)

class VersionAction(argparse.Action):
	def __init__(self, option_strings, version=None, dest=argparse.SUPPRESS,
	             default=argparse.SUPPRESS, help="show program's version number and exit"):
		super(VersionAction, self).__init__(option_strings=option_strings, dest=dest,
		      default=default, nargs=0, help=help)
		self.version = version

	def __call__(self, parser, namespace, values, option_string=None):
		version = self.version
		if version is None:
			version = parser.version
		formatter = parser._get_formatter()
		formatter.add_text(version)
		sys.stdout.write(formatter.format_help())
		sys.exit(0)

def main():
	global opt_force
	parser = argparse.ArgumentParser(description="Knot DNSSEC KASP converter (JSON to LMDB)",
	                                 formatter_class=argparse.RawTextHelpFormatter)
	parser.add_argument("-i", "--import", action="append", nargs="?", dest="importdir",
	                    help='''Import zone-key configuration from JSON.
Syntax: -i <key_dir>
(You can import multiple key_dirs at once by repeating this option.)''')
	parser.add_argument("-f", "--force", action="store_true", dest="force", help="Do stuff even if dangerous.")
	parser.add_argument("-V", "--version", action=VersionAction, version="knot KASP legacy JSON importer (debian support for Knot DNS), version 2.7.1")
	args = parser.parse_args()
	opt_force = args.force

	if args.importdir is not None:
		lmdb_requirement()
		if isinstance(args.importdir, (list, tuple)):
			importdir = args.importdir
		else:
			importdir = [args.importdir]

		for dirn in importdir:
			import_dir(dirn)

if __name__ == "__main__":
	main()
