#!/usr/bin/env python

"""
This is a small utility script I originally whipped up for dealing with
re-writing the copyright headers in GPIO-Zero which had so many myriad
contributors it was becoming a pain to keep the headers accurate.

Usage is simple: just run ./copyrights in a clone of the repo after all PRs
have been merged and just before you're about to do a release. It should (if
it's working properly!) rewrite the copyright headers in all the files.

If you want to change its operation, the crucial bits are get_license() which
dictates what legalese gets written at the top, and update_copyright() which
does the actual re-writing (and which can be adapted to insert, e.g. a
corporate copyright holder before/after all individual contributors).
Everything else gets derived from the git history, so as long as this is
reasonable the output should be too.

Obviously, you should check the "git diff" is vaguely sane before committing
it!
"""

import io
import re
import sys
from collections import namedtuple
from operator import attrgetter
from itertools import groupby, tee
from datetime import datetime
from subprocess import Popen, PIPE, DEVNULL
from pathlib import Path
from fnmatch import fnmatch
from functools import lru_cache


Contribution = namedtuple('Contribution', ('author', 'email', 'year', 'filename'))

class Copyright(namedtuple('Copyright', ('author', 'email', 'years'))):
    def __str__(self):
        if self.email:
            name = '{self.author} <{self.email}>'.format(self=self)
        else:
            name = self.author
        return 'Copyright (c) {years} {name}'.format(
            years=int_ranges(self.years), name=name)


def main():
    includes = {
        '**/*.py',
        '**/*.rst',
    }
    excludes = {
        'docs/examples/*.py',
        'docs/license.rst',
    }
    prefixes = {
        '.py': '#',
        '.rst': '..',
    }
    if len(sys.argv) > 1:
        includes = set(sys.argv[1:])
    contributions = get_contributions(includes, excludes)
    for filename, copyrights in contributions.items():
        filename = Path(filename)
        update_copyright(filename, copyrights, prefixes[filename.suffix])


def get_contributions(include, exclude):
    sorted_blame = sorted(
        get_blame(include, exclude),
        key=lambda c: (c.filename, c.author, c.email)
    )
    blame_by_file = {
        filename: list(file_contributions)
        for filename, file_contributions in groupby(
            sorted_blame, key=attrgetter('filename')
        )
    }
    return {
        filename: {
            Copyright(author, email, frozenset(y.year for y in years))
            for (author, email), years in groupby(
                file_contributors, key=lambda c: (c.author, c.email)
            )
        }
        for filename, file_contributors in blame_by_file.items()
    }


def get_blame(include, exclude):
    for filename in get_source_files(include, exclude):
        blame = Popen(
            ['git', 'blame', '--line-porcelain', 'HEAD', '--', filename],
            stdout=PIPE,
            stderr=PIPE,
            universal_newlines=True
        )
        author = email = year = None
        for line in blame.stdout:
            if line.startswith('author '):
                author = line.split(' ', 1)[1].rstrip()
            elif line.startswith('author-mail '):
                email = line.split(' ', 1)[1].rstrip().lstrip('<').rstrip('>')
            elif line.startswith('author-time '):
                # Forget the timezone; we only want the year anyway
                year = datetime.fromtimestamp(int(line.split(' ', 1)[1].rstrip())).year
            elif line.startswith('filename '):
                yield Contribution(
                    author=author, email=email, year=year, filename=filename)
                author = email = year = None
        blame.wait()
        assert blame.returncode == 0


def get_source_files(include, exclude):
    ls_tree = Popen(
        ['git', 'ls-tree', '-r', '--name-only', 'HEAD'],
        stdout=PIPE,
        stderr=DEVNULL,
        universal_newlines=True
    )
    if not include:
        include = {'*'}
    for filename in ls_tree.stdout:
        filename = filename.strip()
        if any(fnmatch(filename, pattern) for pattern in exclude):
            continue
        if any(fnmatch(filename, pattern) for pattern in include):
            yield filename
    ls_tree.wait()
    assert ls_tree.returncode == 0


insertion_point = object()
def parse_source_file(filename, prefix):
    license = get_license()
    license_start = license[0]
    license_end = license[-1]
    with filename.open('r') as source:
        state = 'preamble'
        for linenum, line in enumerate(source, start=1):
            if state == 'preamble':
                if linenum == 1 and line.startswith('#!'):
                    yield line
                elif linenum < 10 and 'set fileencoding' in line:
                    yield line
                elif line.rstrip() == prefix:
                    pass # skip blank comment lines
                elif line.startswith(prefix + ' Copyright (c)'):
                    pass # skip existing copyright lines
                elif line.startswith(prefix + ' ' + license_start):
                    state = 'license' # skip existing license lines
                else:
                    yield insertion_point
                    state = 'blank'
            elif state == 'license':
                if line.startswith(prefix + ' ' + license_end):
                    yield insertion_point
                    state = 'blank'
                    continue
            if state == 'blank':
                # Ensure there's a blank line between license and start of the
                # source body
                if line.strip():
                    yield '\n'
                yield line
                state = 'body'
            elif state == 'body':
                yield line


def update_copyright(filename, copyrights, prefix):
    print('Re-writing {filename}...'.format(filename=filename))
    license = get_license()
    copyrights = [
        Copyright('Canonical Ltd.', '', {datetime.now().year}),
    ] + sorted(
        copyrights, reverse=True, key=lambda c:
            (c.years[::-1] if isinstance(c.years, tuple) else (c.years,), c.author)
    )
    content = []
    for line in parse_source_file(filename, prefix):
        if line is insertion_point:
            if len(content) > 0:
                content.append(prefix + '\n')
            for copyright in copyrights:
                content.append(prefix + ' ' + str(copyright) + '\n')
            content.append(prefix + '\n')
            content.extend(
                (prefix + ' ' + l).strip() + '\n'
                for l in license
            )
        else:
            content.append(line)
    # Yes, if I was doing this "properly" I'd write to a temporary file and
    # rename it over the target. However, I'm assuming you're running this
    # under a git clone ... after all, you are ... aren't you?
    with filename.open('w') as target:
        for line in content:
            target.write(line)


@lru_cache()
def get_license():
    return """\
This file is part of pibootctl.

pibootctl is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

pibootctl is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with pibootctl.  If not, see <https://www.gnu.org/licenses/>.
""".splitlines()


def pairwise(iterable):
    """
    Taken from the recipe in the documentation for :mod:`itertools`.
    """
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)


def int_ranges(values):
    """
    Given a set of integer *values*, returns a compressed string representation
    of all values in the set. For example:

        >>> int_ranges({1, 2})
        '1, 2'
        >>> int_ranges({1, 2, 3})
        '1-3'
        >>> int_ranges({1, 2, 3, 4, 8})
        '1-4, 8'
        >>> int_ranges({1, 2, 3, 4, 8, 9})
        '1-4, 8-9'
    """
    if len(values) == 0:
        return ''
    elif len(values) == 1:
        return '{0}'.format(*values)
    elif len(values) == 2:
        return '{0}, {1}'.format(*values)
    else:
        ranges = []
        start = None
        for i, j in pairwise(sorted(values)):
            if start is None:
                start = i
            if j > i + 1:
                ranges.append((start, i))
                start = j
        if j == i + 1:
            ranges.append((start, j))
        else:
            ranges.append((j, j))
        return ', '.join(
            ('{start}-{finish}' if finish > start else '{start}').format(
                start=start, finish=finish)
            for start, finish in ranges
        )


if __name__ == '__main__':
    main()
