#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
"""


__author__ = "James Reynolds"
__email__ = "reynolds@biology.utah.edu"
__copyright__ = "Copyright (c) 2020 University of Utah, School of Biological Sciences"
__license__ = "MIT"
__version__ = "0.8"
min_jamf_version = "0.7.3"


from pprint import pprint
import argparse
import ast
import jamf
import json
import logging
import re
import sys
import time


class Parser:
    def __init__(self):
        valid_jamf_records = [x.lower() for x in jamf.records.valid_records()]
        self.parser = argparse.ArgumentParser()
        # https://docs.python.org/3/library/argparse.html
        self.parser.add_argument(
            "-c", "--cleanup", action="store_true", help="Show packages sorted by usage"
        )
        self.parser.add_argument(
            "-p",
            "--patch-definitions",
            action="store_true",
            help="Set patch definitions",
        )
        self.parser.add_argument(
            "-g",
            "--groups",
            action="store_true",
            help="Display packages as groups and exit",
        )
        self.parser.add_argument(
            "-u", "--usage", action="store_true", help="Display package usage and exit"
        )
        self.parser.add_argument(
            "-a",
            "--auto-promote",
            nargs="*",
            help="Automatically update policies, groups, and patch policies that match "
            "supplied regular expression.",
        )
        self.parser.add_argument("-i", "--id", nargs="*", help="Search for id matches")
        self.parser.add_argument(
            "-n", "--name", nargs="*", help="Search for exact name matches"
        )
        self.parser.add_argument(
            "-r", "--regex", nargs="*", help="Search for regular expression matches"
        )
        self.parser.add_argument("-C", "--config", help="path to config file")
        self.parser.add_argument(
            "-v", "--version", action="store_true", help="print version and exit"
        )

    def parse(self, argv):
        """
        :param argv:    list of arguments to parse
        :returns:       argparse.NameSpace object
        """
        args = self.parser.parse_args(argv)
        return args


def check_version():
    python_jamf_version = jamf.version.jamf_version_up_to(min_jamf_version)
    if python_jamf_version != min_jamf_version:
        print(
            f"pkgctl requires python-jamf {min_jamf_version} or newer. "
            f"You have {python_jamf_version}."
        )
        exit()


def confirm(_message):
    """
    Ask user to enter Y or N (case-insensitive).
    :return: True if the answer is Y.
    :rtype: bool
    """
    answer = ""
    while answer not in ["y", "n"]:
        answer = input(_message).lower()
    return answer == "y"


def printWaitText():
    if not hasattr(printWaitText, "printed"):
        printWaitText.printed = True
        print("Analyzing Jamf data")


def loadRelatedData(packages):
    printWaitText()
    for record in packages:
        record.related


def process_package_promotion(item, packages):
    package = item[2]
    while True:
        answer = ""
        choices = list(map(str, range(1, len(packages) + 1)))
        while answer not in choices + ["b", "x", "q"]:
            class_name = eval("jamf." + item[0]).singular_class.__name__
            print(f"Set the {class_name} named {item[1]} to which package?")
            index = 1
            for val in packages:
                print(f"  [{index}] {val}")
                index += 1
            answer = input("Enter a number, [b]ack, or e[x]it/[q]uit: ").lower()
            print()
        if answer == "x" or answer == "q":
            exit()
        elif answer == "b":
            return False
        else:
            new_package = packages[int(answer) - 1]
            the_class = jamf.records.class_name(item[0])
            rec = the_class().find(item[1])
            if item[0] == "PatchPolicies":
                if "version" in new_package.metadata:
                    pkg_version = new_package.metadata["version"]
                else:
                    print(f"Couldn't determine version for {new_package.name}")
                    return False
                rec.data["general"]["target_version"] = pkg_version
                rec.save()
                rec.refresh()
                new_package.refresh_related()
                print(f"Saved patch policiy named {item[1]}")
                return True
            elif item[0] == "Policies":
                temp_packages = rec.data["package_configuration"]["packages"]
                if int(temp_packages["size"]) > 1:
                    rec_packages = temp_packages["package"]
                else:
                    rec_packages = [temp_packages["package"]]
                for rec_package in rec_packages:
                    if rec_package["name"] == package.name:
                        del rec_package["id"]
                        rec_package["name"] = new_package.name
                        rec.save()
                        rec.refresh()
                        package.refresh_related()
                        print(f"Saved policy named {item[1]}")
                        return True
            elif item[0] == "ComputerGroups":
                criteria = rec.data["criteria"]["criterion"]
                if not type(criteria) is list:
                    criteria = [criteria]
                found = False
                print(new_package.name)
                for crit in criteria:
                    if crit["value"] == package.name:
                        crit["value"] = new_package.name
                        rec.save()
                        rec.refresh()
                        package.refresh_related()
                        print(f"Saved computer group named {item[1]}")
                        return True
            return False


def process_package_group(group, packages):
    while True:
        items = []
        text = ""
        for package in packages:
            if "PatchSoftwareTitles" in package.related:
                text += f"  {package}\n"
            else:
                text += f"  {package} [no patch defined]\n"
            if "Policies" in package.related:
                for rec in package.related["Policies"]:
                    items.append(["Policies", rec, package])
                    text += f"      Policies\n        [{len(items)}] {rec}\n"
            if "ComputerGroups" in package.related:
                for rec in package.related["ComputerGroups"]:
                    items.append(["ComputerGroups", rec, package])
                    text += f"      ComputerGroups\n        [{len(items)}] {rec}\n"
            if "PatchPolicies" in package.related:
                for rec in package.related["PatchPolicies"]:
                    items.append(["PatchPolicies", rec, package])
                    text += f"      PatchPolicies\n        [{len(items)}] {rec}\n"
        loop = True
        while loop:
            choices = list(map(str, range(1, len(items) + 1))) + ["b", "x", "q"]
            answer = ""
            while answer not in choices:
                print(group)
                print(text)
                answer = input("Enter a number, [b]ack, or e[x]it/[q]uit: ").lower()
            if answer == "x" or answer == "q":
                exit()
            elif answer == "b":
                return False
            else:
                item = items[int(answer) - 1]
                result = process_package_promotion(item, packages)
                loop = False


def print_groups(packages, related=False, group_index=False):
    if related:
        loadRelatedData(packages)
    group_choices = []
    pprint(packages.groups)
    for group, children in packages.groups.items():
        print_me = False
        text = ""
        if group_index:
            if len(children) > 1:
                group_choices.append(group)
                text += f"[{len(group_choices)}] {group}\n"
                print_me = True
        else:
            text += f"{group}\n"
            print_me = True
        if print_me:
            for child in children:
                text += f"  {child}"
                if related:
                    text += " - "
                    if "PatchSoftwareTitles" in child.related:
                        text += (
                            f"patch {len(child.related['PatchSoftwareTitles'])}; "
                        )
                    if "Policies" in child.related:
                        text += f"Policies {len(child.related['Policies'])}; "
                    if "ComputerGroups" in child.related:
                        text += (
                            f"ComputerGroups {len(child.related['ComputerGroups'])}; "
                        )
                    if "PatchPolicies" in child.related:
                        text += (
                            f"PatchPolicies {len(child.related['PatchPolicies'])}; "
                        )
                text += "\n"
            print(f"{text}")
    return group_choices


def main(argv):
    logger = logging.getLogger(__name__)
    timmy = Parser()
    args = timmy.parse(argv)
    logger.debug(f"args: {args!r}")
    if args.version:
        print("jctl " + __version__)
        print(f"python_jamf {jamf.__version__} ({min_jamf_version} required)")
        exit(1)
    check_version()
    if args.config:
        api = jamf.API(config_path=args.config)
    else:
        api = jamf.API()
    all_packages = jamf.records.Packages()
    # Quick filter records
    if all_packages and (args.regex or args.name or args.id):
        found = []
        if args.regex:
            for regex in args.regex:
                found = found + all_packages.recordsWithRegex(regex)
        if args.name:
            for name in args.name:
                found = found + [all_packages.recordWithName(name)]
        if args.id:
            for id in args.id:
                try:
                    id = int(id)
                except ValueError:
                    print(f"ID must be a number: {id}")
                    exit(1)
                found = found + [all_packages.recordWithId(id)]
        quick = []
        for temp in found:
            if temp:
                quick = quick + [temp]
    else:
        quick = all_packages
    if quick:
        sorted_results = sorted(quick)
    else:
        sorted_results = []
    # Print package usage
    if args.usage:
        printWaitText()
        for record in sorted_results:
            record.usage_print_during()
        exit()
    # Generate groups
    for record in sorted_results:
        record.metadata
    # Print package cleanup
    if args.cleanup:
        print_groups(all_packages, True)
        exit()
    # Print package groups
    if args.groups:
        print_groups(all_packages)
        exit()
    # Set package definitions
    if args.patch_definitions:
        print("Updating patch definitions...")
        all_psts = jamf.records.PatchSoftwareTitles()
        change_made = False
        for pst in all_psts:
            result = pst.set_all_packages_update_during()
            if result:
                pst.save()
                change_made = True
        if not change_made:
            print("No packages match patch software titles")
        if args.patch_definitions:
            exit()
    # Interactive package manager
    while True:
        printWaitText()
        loadRelatedData(all_packages)
        group_choices = print_groups(all_packages, True, True)
        choices = list(map(str, range(1, len(group_choices) + 1))) + ["q", "x"]
        answer = ""
        while answer not in choices:
            answer = input("Enter a number, or e[x]it/[q]uit: ").lower()
            print("")
        if answer == "x" or answer == "q":
            exit()
        else:
            group = group_choices[int(answer) - 1]
            children = all_packages.groups[group]
            process_package_group(group, children)


if __name__ == "__main__":
    fmt = "%(asctime)s: %(levelname)8s: %(name)s - %(funcName)s(): %(message)s"
    logging.basicConfig(level=logging.INFO, format=fmt)
    try:
        main(sys.argv[1:])
    except KeyboardInterrupt:
        exit(1)
