#!/usr/local/bin/python2.7
# encoding: utf-8
'''
hmf (script) provides command-line access to much of the functionality of hmf.

This script basically takes any input arguments and runs all combinations of
them through :func:`hmf.tools.get_hmf`, writing the results to a filename
which is specified. It only writes out the attributes requested, making it
quite optimal.

.. note :: at this time, this script is in alpha. It works for the obvious arguments!
'''

import sys
import os
import traceback

from src import hmf
import numpy as np
from argparse import ArgumentParser
from argparse import RawDescriptionHelpFormatter

__all__ = ["main"]
__version__ = "0.2"
__date__ = "2014 - 01 - 23"
__updated__ = "2014 - 01 - 28"

DEBUG = 0
TESTRUN = 0
PROFILE = 0

class CLIError(Exception):
    '''Generic exception to raise and log different fatal errors.'''
    def __init__(self, msg):
        super(CLIError).__init__(type(self))
        self.msg = "E: %s" % msg
    def __str__(self):
        return self.msg
    def __unicode__(self):
        return self.msg

def main(argv=None):
    '''Generate halo mass functions and write them to file (BETA).'''

    if argv is None:
        argv = sys.argv
    else:
        sys.argv.extend(argv)

    program_name = os.path.basename(sys.argv[0])
    program_version = "v%s" % __version__
    program_build_date = str(__updated__)
    program_version_message = '%%(prog)s %s (%s)\n\nHISTORY\n------%s' % (program_version, program_build_date, HISTORY)
    program_shortdesc = __import__('__main__').__doc__.split("\n")[1]
    program_license = '''%s

    Copyright (c) 2014 Steven Murray

    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:

    The above copyright notice and this permission notice shall be included in
    all copies or substantial portions of the Software.

    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    THE SOFTWARE.

USAGE
''' % (program_shortdesc)

    try:
        h = hmf.MassFunction()
        m_attrs = ["dndlog10m", "lnsigma", "n_eff", "sigma",
                   "dndm", "ngtm", "fsigma", "mgtm", "nltm", "dndlnm",
                   "how_big", "mltm", "_sigma_0", "_dlnsdlnm"]
        k_attrs = ["power", "delta_k", "transfer", "nonlinear_power",
                   "_lnP_0", "_lnP_cdm_0", "_lnT_cdm", "_unnormalised_lnP",
                   "_unnormalised_lnT"]
        # Setup argument parser
        parser = ArgumentParser(description=program_license, formatter_class=RawDescriptionHelpFormatter)
        parser.add_argument("-v", "--verbose", dest="verbose", action="count", help="set verbosity level [default: %(default)s]")
        parser.add_argument('-V', '--version', action='version', version=program_version_message)

        # HMF specific arguments
        config = parser.add_argument_group("Config", "Variables of Configuration")
        config.add_argument("filename", help="filename to write to")
        config.add_argument("--get", nargs="*", default=["dndm"],
                            choices=m_attrs + k_attrs)

        hmfargs = parser.add_argument_group("HMF", "HMF-specific arguments")
        hmfargs.add_argument("--M", nargs=3, type=float,
                            help="the mass range and intervals, min max step [default: %s %s %s]" %
                            (np.log10(h.M[0]), np.log10(h.M[-1]), np.log10(h.M[1]) - np.log10(h.M[0])))
        hmfargs.add_argument("--mf-fit", nargs="*", choices=hmf.Fits.mf_fits + ["all"],
                             help="fitting function(s) to use. 'all' uses all of them [default: %s]" % h.mf_fit)
        hmfargs.add_argument("--delta-h", nargs="*", type=float,
                            help="overdensity of halo w.r.t delta_wrt [default %s]" % h.delta_wrt)
        hmfargs.add_argument("--delta-wrt", choices=["mean", "crit"],
                            help="what delta_h is with respect to [default: %s]" % h.delta_h)
        hmfargs.add_argument("--user-fit", help="a custom fitting function defined as a string in terms of x for sigma [default: %s]" % "'" + h.user_fit + "'")
        hmfargs.add_argument("--no-cut-fit", action="store_true", help="whether to cut the fitting function at tested boundaries")
        hmfargs.add_argument("--z2", nargs="*", type=float, help="upper redshift for volume weighting")
        hmfargs.add_argument("--nz", nargs="*", type=float, help="number of redshift bins for volume weighting")
        hmfargs.add_argument("--delta-c", nargs="*", type=float, help="critical overdensity for collapse [default: %s]" % h.delta_c)

        # # Transfer-specific arguments
        transferargs = parser.add_argument_group("Transfer", "Transfer-specific arguments")
        transferargs.add_argument("--z", nargs="*", type=float, help="redshift of analysis [default: %s]" % h.transfer.z)
        transferargs.add_argument("--lnk", nargs=3, type=float, help="the wavenumber range and intervals, min max step [default: %s %s %s]" %
                                  (h.transfer.lnk[0], h.transfer.lnk[-1], h.transfer.lnk[1] - h.transfer.lnk[0]))
        transferargs.add_argument("--wdm-mass", nargs="*", type=float, help="warm dark matter mass (0 is CDM)")
        transferargs.add_argument("--transfer-fit", nargs="*", choices=cosmology.transfer.Transfer.fits + ['all'],
                                  help="which fit for the transfer function to use ('all' uses all of them) [default: %s]" % h.transfer.transfer_fit)

        cambargs = parser.add_argument_group("CAMB", "CAMB-specific arguments")
        cambargs.add_argument("--Scalar-initial-condition", nargs="*", type=int, choices=[1, 2, 3, 4, 5],
                              help="[CAMB] initial scalar perturbation mode [default: %s]" % h.transfer._camb_options["Scalar_initial_condition"])
        cambargs.add_argument("--lAccuracyBoost", nargs="*", type=float,
                            help="[CAMB] optional accuracy boost [default: %s]" % h.transfer._camb_options["lAccuracyBoost"])
        cambargs.add_argument("--AccuracyBoost", nargs="*", type=float,
                            help="[CAMB] optional accuracy boost [default: %s]" % h.transfer._camb_options["AccuracyBoost"])
        cambargs.add_argument("--w-perturb", action="store_true", help="[CAMB] whether w should be perturbed or not")
        cambargs.add_argument("--transfer--k-per-logint", nargs="*", type=float,
                            help="[CAMB] number of estimated wavenumbers per interval [default: %s]" % h.transfer._camb_options["transfer__k_per_logint"])
        cambargs.add_argument("--transfer--kmax", nargs='*', type=float,
                            help="[CAMB] maximum wavenumber to estimate [default: %s]" % h.transfer._camb_options["transfer__kmax"])
        cambargs.add_argument("--ThreadNum", type=int,
                              help="number of threads to use (0 is automatic detection) [default: %s]" % h.transfer._camb_options["ThreadNum"])

        # # Cosmo-specific arguments
        cosmoargs = parser.add_argument_group("Cosmology", "Cosmology arguments")
        cosmoargs.add_argument("--default", nargs="*",
                            choices=['planck1_base'], help="base cosmology to use [default: %s]" % h.transfer.cosmo.default)
        cosmoargs.add_argument("--force-flat", action="store_true",
                            help="force cosmology to be flat (changes omega_lambda) [default: %s]" % h.transfer.cosmo.force_flat)
        cosmoargs.add_argument("--sigma-8", nargs="*", type=float, help="mass variance in top-hat spheres with r=8")
        cosmoargs.add_argument("--n", nargs="*", type=float, help="spectral index")
        cosmoargs.add_argument("--w", nargs="*", type=float, help="dark energy equation of state")
        cosmoargs.add_argument("--cs2-lam", nargs="*", type=float, help="constant comoving sound speed of dark energy")

        h_group = cosmoargs.add_mutually_exclusive_group()
        h_group.add_argument("--h", nargs="*", type=float, help="The hubble parameter")
        h_group.add_argument("--H0", nargs="*", type=float, help="The hubble constant")

        omegab_group = cosmoargs.add_mutually_exclusive_group()
        omegab_group.add_argument("--omegab", nargs="*", type=float, help="baryon density")
        omegab_group.add_argument("--omegab-h2", nargs="*", type=float, help="baryon density by h^2")

        omegac_group = cosmoargs.add_mutually_exclusive_group()
        omegac_group.add_argument("--omegac", nargs="*", type=float, help="cdm density")
        omegac_group.add_argument("--omegac-h2", nargs="*", type=float, help="cdm density by h^2")
        omegac_group.add_argument("--omegam", nargs="*", type=float, help="total matter density")

        cosmoargs.add_argument("--omegav", type=float, nargs="*", help="the dark energy density")

        # Process arguments
        args = parser.parse_args()

        # # Process the arguments
        kwargs = {}
        for arg in ["omegab", "omegab_h2", "omegac", "omegac_h2", "omegam", "h", "H0",
                    "sigma_8", "n", "w", "cs2_lam", "omegav", "ThreadNum", "transfer__kmax",
                    "transfer__k_per_logint", "AccuracyBoost", "lAccuracyBoost",
                    "Scalar_initial_condition", "z", "z2", "nz", "delta_c", "user_fit", "delta_h",
                    "delta_wrt"]:
            if getattr(args, arg) is not None:
                kwargs[arg] = make_scalar(getattr(args, arg))

        if args.M is not None:
            kwargs["M"] = np.arange(args.M[0], args.M[1], args.M[2])

        if args.mf_fit is not None:
            if "all" in args.mf_fit:
                kwargs['mf_fit'] = hmf.Fits.mf_fits
                kwargs["mf_fit"].remove("user_model")
            else:
                kwargs['mf_fit'] = make_scalar(args.mf_fit)

        if args.user_fit is not None:
            if "user_model" not in kwargs['mf_fit']:
                kwargs['mf_fit'].append("user_model")

        if args.no_cut_fit:
            kwargs['cut_fit'] = not args.no_cut_fit

        if args.w_perturb:
            kwargs["w_perturb"] = args.w_perturb

        if args.lnk is not None:
            kwargs["lnk"] = np.arange(args.lnk[0], args.lnk[1], args.lnk[2])

        if args.transfer_fit is not None:
            if 'all' in args.transfer_fit:
                kwargs["transfer_fit"] = hmf.Transfer.fits
            else:
                kwargs["transfer_fit"] = make_scalar(args.transfer_fit)


        m_att = [a for a in args.get if a in m_attrs]
        k_att = [a for a in args.get if a in k_attrs]
        # # run the hmf
        for res, label in hmf.tools.get_hmf(args.get, **kwargs):
            if m_att:
                marray = np.empty((len(res.M), len(m_att) + 1))
                marray[:, 0] = res.M
                for i, attr in enumerate(m_att):
                    marray[:, i + 1] = getattr(res, attr)
                    with open(args.filename + "_MDATA_" + label, 'w') as f:
                        f.write("# M\t" + "\t".join(m_att) + "\n")
                        np.savetxt(f, marray)
            if k_att:
                karray = np.empty((len(res.transfer.lnk), len(k_att) + 1))
                karray[:, 0] = res.transfer.lnk
                for i, attr in enumerate(k_att):
                    karray[:, i + 1] = getattr(res.transfer, attr)
                    with open(args.filename + "_KDATA_" + label, 'w') as f:
                        f.write("# lnk\t" + "\t".join(k_att) + "\n")
                        np.savetxt(f, karray)


        return 0
    except KeyboardInterrupt:
        ### handle keyboard interrupt ###
        return 0
    except Exception, e:
        if DEBUG or TESTRUN:
            raise(e)
        traceback.print_exc()
        indent = len(program_name) * " "
        sys.stderr.write(program_name + ": " + repr(e) + "\n")
        sys.stderr.write(indent + "  for help use --help\n")
        return 2

HISTORY = """
0.2 - 28/01/2014
    Made compatible with more versions of numpy by removing "header" arg.
    Fixed access to k-based data.
    Fixed some labeling issues with the M array and lnk array
    M and lnk are written out for any choice of M or k attributes as first column

0.1 - 24/01/2014
    First version
"""

def make_scalar(a):
    if isinstance(a, (list, tuple)):
        if len(a) == 1:
            a = a[0]
    return a

if __name__ == "__main__":
    if DEBUG:
        sys.argv.append("-h")
        sys.argv.append("-v")
    if TESTRUN:
        import doctest
        doctest.testmod()
    if PROFILE:
        import cProfile
        import pstats
        profile_filename = 'scripts.hmfrun_profile.txt'
        cProfile.run('main()', profile_filename)
        statsfile = open("profile_stats.txt", "wb")
        p = pstats.Stats(profile_filename, stream=statsfile)
        stats = p.strip_dirs().sort_stats('cumulative')
        stats.print_stats()
        statsfile.close()
        sys.exit(0)
    sys.exit(main())
