#!/usr/bin/env python
import sys
import os
import argparse
from pkg_resources import require
from pathlib import Path
import time

import numpy as np
import nibabel as nib
import torch

from totalsegmentator.libs import setup_nnunet, download_pretrained_weights
from totalsegmentator.preview import generate_preview
from totalsegmentator.statistics import get_basic_statistics_for_entire_dir, get_radiomics_features_for_entire_dir


def main():
    parser = argparse.ArgumentParser(description="Segment 104 anatomical structures in CT images.",
                                     epilog="Written by Jakob Wasserthal. If you use this tool please cite https://arxiv.org/abs/2208.05868")

    parser.add_argument("-i", metavar="filepath", dest="input",
                        help="CT nifti image", 
                        type=lambda p: Path(p).absolute(), required=True)

    parser.add_argument("-o", metavar="directory", dest="output",
                        help="Output directory for segmentation masks", 
                        type=lambda p: Path(p).absolute(), required=True)

    parser.add_argument("--ml", action="store_true", help="Save one multilabel image for all classes",
                        default=False)

    parser.add_argument("-nr", "--nr_thr_resamp", type=int, help="Nr of threads for resampling", default=1)

    parser.add_argument("-ns", "--nr_thr_saving", type=int, help="Nr of threads for saving segmentations", 
                        default=6)

    parser.add_argument("-f", "--fast", action="store_true", help="Run faster lower resolution model",
                        default=False)

    parser.add_argument("-t", "--nora_tag", type=str, 
                        help="tag in nora as mask. Pass nora project id as argument.",
                        default="None")

    parser.add_argument("-p", "--preview", action="store_true", 
                        help="Generate a png preview of segmentation",
                        default=False)

    # todo: implement this (prio2)
    # for 15mm model only run the models which are needed for these rois
    # parser.add_argument("-rs", "--roi_subset", type=str, nargs="+",
    #                     help="Manually define only a subset of classes to predict")

    parser.add_argument("-s", "--statistics", action="store_true", 
                        help="Calc volume and mean intensity. Results will be in statistics.json",
                        default=False)

    parser.add_argument("-r", "--radiomics", action="store_true", 
                        help="Calc radiomics features. Requires pyradiomics. Results will be in statistics_radiomics.json",
                        default=False)

    parser.add_argument("-q", "--quiet", action="store_true", help="Print no intermediate outputs",
                        default=False)

    parser.add_argument("-v", "--verbose", action="store_true", help="Show more intermediate output",
                        default=False)

    parser.add_argument("--test", metavar="0|1|2|3", choices=[0, 1, 2, 3], type=int,
                        help="Only needed for unittesting.",
                        default=0)

    parser.add_argument('--version', action='version', version=require("TotalSegmentator")[0].version)

    args = parser.parse_args()

    quiet, verbose = args.quiet, args.verbose

    if args.test == 0 and not torch.cuda.is_available():
        raise ValueError("TotalSegmentator only works with a NVidia CUDA GPU. CUDA not found. " + 
                         "If you do not have a GPU check out our online tool: www.totalsegmentator.com")

    setup_nnunet()

    from totalsegmentator.nnunet import nnUNet_predict_image  # this has to be after setting new env vars

    if args.fast:
        task_id = 256
        resample = 3.0
        trainer = "nnUNetTrainerV2_ep8000_nomirror"
    else:
        task_id = [251, 252, 253, 254, 255]
        resample = 1.5
        trainer = "nnUNetTrainerV2_ep4000_nomirror"

    if type(task_id) is list:
        for tid in task_id:
            download_pretrained_weights(tid)
    else:
        download_pretrained_weights(task_id)

    folds = [0]  # None
    seg = nnUNet_predict_image(args.input, args.output, task_id, model="3d_fullres", folds=folds,
                         trainer=trainer, tta=False, multilabel_image=args.ml, resample=resample,
                         nora_tag=args.nora_tag, preview=args.preview, nr_threads_resampling=args.nr_thr_resamp, 
                         nr_threads_saving=args.nr_thr_saving, quiet=quiet, verbose=verbose, test=args.test)

    if args.statistics:
        if not quiet: print("Calculating statistics...")  
        st = time.time()
        get_basic_statistics_for_entire_dir(seg, args.input, args.output / "statistics.json", quiet)
        # get_radiomics_features_for_entire_dir(args.input, args.output, args.output / "statistics_radiomics.json")
        if not quiet: print(f"  calculated in {time.time()-st:.2f}s")

    if args.radiomics:
        if not quiet: print("Calculating radiomics...")  
        st = time.time()
        get_radiomics_features_for_entire_dir(args.input, args.output, args.output / "statistics_radiomics.json")
        if not quiet: print(f"  calculated in {time.time()-st:.2f}s")


if __name__ == "__main__":
    main()
