#!/usr/bin/env python3

import argparse
import sys, os
import traceback

import hyperchamber as hc
import hypergan as hg
import hypergan.cli as cli
from hypergan.viewer import GlobalViewer
import pkg_resources
import semantic_version
hg_version=semantic_version.Version(pkg_resources.require("hypergan")[0].version)

class CommandParser:
    def common(self, parser, directory=True):
        if directory:
            parser.add_argument('directory', action='store', type=str, help='The location of your data.  Subdirectories are treated as different classes.  You must have at least 1 subdirectory.')
        self.common_flags(parser)

    def common_flags(self, parser):
        parser.add_argument('--size', '-s', type=str, default='64x64x3', help='Size of your data.  For images it is widthxheightxchannels.')
        parser.add_argument('--batch_size', '-b', type=int, default=0, help='Number of samples to include in each batch.  If using batch norm, this needs to be preserved when in server mode')
        parser.add_argument('--config', '-c', action='store', default=None, type=str, help='The configuration file to load.')
        parser.add_argument('--input_config', '-i', action='store', default=None, type=str, help='The input configuration to use.')
        parser.add_argument('--toml', '-T', action='store_true', help='Use TOML as configration file format instead of JSON')
        parser.add_argument('--device', '-d', type=str, default='/gpu:0', help='In the form "/gpu:0", "/cpu:0", etc.  Always use a GPU (or TPU) to train')
        parser.add_argument('--crop', dest='crop', action='store_true', help='If your images are perfectly sized you can skip cropping.')
        parser.add_argument('--random_crop', dest='random_crop', action='store_true', help='Randomly crop images.')
        parser.add_argument('--resize', dest='resize', action='store_true', help='If your images are perfectly sized you can skip resize.')
        parser.add_argument('--align', '-a', nargs='?', action='append', dest='align', help='Adds an additional input folder.')
        parser.add_argument('--save_every', type=int, default=-1, help='Saves the model every n steps.')
        parser.add_argument('--sample_every', type=int, default=100, help='Saves a sample every X steps.')
        parser.add_argument('--save_samples', action='store_true', help='Saves samples to the local `samples` directory.')
        parser.add_argument('--sampler', type=str, default=None, help='Select a sampler.  Some choices: static_batch, batch, grid, progressive')
        parser.add_argument('--save_file', type=str, default=None, help='The save file to load and save from.  Use gs://bucket/mymodel/model.ckpt for TPU saving')
        parser.add_argument('--ipython', type=bool, default=False, help='Enables iPython embedded mode.')
        parser.add_argument('--steps', type=int, default=-1, help='Number of steps to train for.  -1 is unlimited (default)')
        parser.add_argument('--noviewer', dest='viewer', action='store_false', help='Disables the display of samples in a window.')
        parser.add_argument('--viewer_size', '-z', type=float, dest='viewer_size', default=1, help='Size of the viewer window as a multiplier. WARNING: values above 60 may cause crashes')
        parser.add_argument('--list-templates', '-l', dest='list_templates', action='store_true', help='List available templates.')
        parser.add_argument('--debug', dest='debug', action='store_true', help='Start debugger.')
        parser.add_argument('--version', action='version', version='%(prog)s 0.10.0 alpha')
        parser.add_argument('--nomenu', dest='menu', action='store_false', help='Disables the file menu.')
        parser.add_argument('--width', '-w', type=int, default=8, help='Number of samples per row')
        parser.add_argument('--loss_every', type=int, default=1, help='Saves a losses graph every X steps.')
        parser.add_argument('--save_losses', action='store_true', help='Saves losses graph to the local `losses` directory.')

    def get_parser(self):
        parser = argparse.ArgumentParser(description='Train, run, and deploy your GANs.', add_help=True)
        subparsers = parser.add_subparsers(dest='method')
        train_parser = subparsers.add_parser('train')
        test_parser = subparsers.add_parser('test')
        sample_parser = subparsers.add_parser('sample')
        build_parser = subparsers.add_parser('build')
        new_parser = subparsers.add_parser('new')
        subparsers.required = True
        self.common_flags(parser)
        self.common(sample_parser)
        self.common(train_parser)
        self.common(test_parser, directory=False)
        self.common(build_parser)
        self.common(new_parser)

        return parser

try:
    args = CommandParser().get_parser().parse_args()

    if args.method == 'train' and args.save_every == -1 and not args.viewer:
        print("Error, there is no way to save the model.  Refusing to run.  Please add --save_every N")
        exit(1)

    if args.size:
        size = [int(x) for x in args.size.split("x")] + [None, None, None]
    else:
        size = [None, None, None]
    width = size[0] or 64
    height = size[1] or 64
    channels = size[2] or 3
    if args.toml:
        use_toml = True
        config_format = '.toml'
    else:
        use_toml = False
        config_format = '.json'
    config_name = args.config or 'default'

    if args.method != 'new':
        config_filename = hg.Configuration.find(config_name, config_format=config_format)
        config = hc.Selector().load(config_filename, load_toml=use_toml)
        assert config.hypergan_version, "hypergan_version must be specified as a base configuration string.  Example \"hypergan_version\": \"~1\""
        config_version = semantic_version.SimpleSpec(config.hypergan_version)
        assert config_version.match(hg_version), "incompatible config version ("+config.hypergan_version+") for hypergan ("+str(hg_version)+")"


        if config.input is None and args.input_config is None:
            directories = [args.directory]
            if args.align:
                directories+=args.align
            input_config = hc.Config({
                "class": "class:hypergan.inputs.image_loader.ImageLoader",
                "batch_size": args.batch_size,
                "directories": directories,
                "channels": channels,
                "crop": args.crop,
                "height": height,
                "random_crop": args.random_crop,
                "resize": args.resize,
                "shuffle": True,
                "width": width
            })
        else:
            if config.input:
                if args.batch_size:
                    config.input["batch_size"] = args.batch_size
                input_config = hc.Config(config.input)
            else:
                input_config_filename = hg.Configuration.find(args.input_config, config_format=config_format)
                input_config = hc.Selector().load(input_config_filename, load_toml=use_toml)
        if 'inherit' in config:
            base_filename = hg.Configuration.find(config['inherit']+config_format, config_format=config_format)
            base_config = hc.Selector().load(base_filename, load_toml=use_toml)
            config = {**base_config, **config}
    else:
        input_config = None
        config = None

    if args.list_templates:
        configs = hg.Configuration.list(config_format=config_format, prepackaged=True)
        for config in configs:
            try:
                template = hg.Configuration.load(config+config_format, verbose=False, use_toml=use_toml, prepackaged=True)
            except Exception as e:
                template = hc.Config({"description": str(e), "publication": "error"})
            print("%-30s - %-60s" %  (config, str(template.description)+"("+str(template.publication)+")"))
            if template.runtime and "train" in template.runtime:
                print("%30s   Usage: %s" %  ("", template.runtime["train"]))
    else:
        gancli = cli.CLI(args=vars(args), input_config=input_config, gan_config=config)
        gancli.run()
except:
    traceback.print_exc()
    print("Error detected, HyperGAN exiting")
finally:
    GlobalViewer.close()
