import sys
from pathlib import Path
import argparse
import subprocess

import nibabel as nib
import numpy as np

from totalsegmentator.map_to_binary import class_map_5_parts


if __name__ == "__main__":
    """
    Combine binary labels into a binary file.

    Works with any number of label files.

    Usage:
    totalseg_combine_masks -i totalsegmentator_output_dir -o combined_mask.nii.gz -m lung
    """
    parser = argparse.ArgumentParser(description="Combine masks.",
                                     epilog="Written by Jakob Wasserthal. If you use this tool please cite https://arxiv.org/abs/2208.05868")

    parser.add_argument("-i", metavar="directory", dest="mask_dir",
                        help="TotalSegmentator output directory containing all the masks", 
                        type=lambda p: Path(p).absolute(), required=True)

    parser.add_argument("-o", metavar="filepath", dest="output",
                        help="Output directory for combined mask", 
                        type=lambda p: Path(p).absolute(), required=True)

    parser.add_argument("-m", "--masks", type=str, choices=["lung", "vertebrae", "ribs", "vertebrae_ribs", "heart"],
                        help="The type of masks you want to combine", required=True)

    parser.add_argument("-t", "--nora_tag", type=str, help="tag in nora as mask. Pass nora project id as argument.",
                        default="None")

    args = parser.parse_args()

    if args.masks == "ribs":
        masks = list(class_map_5_parts["class_map_part_ribs"].values())
    elif args.masks == "vertebrae":
        masks = list(class_map_5_parts["class_map_part_vertebrae"].values())
    elif args.masks == "vertebrae_ribs":
        masks = list(class_map_5_parts["class_map_part_vertebrae"].values()) + list(class_map_5_parts["class_map_part_ribs"].values())
    elif args.masks == "lung":
        masks = ["lung_upper_lobe_left", "lung_lower_lobe_left", "lung_upper_lobe_right",
                 "lung_middle_lobe_right", "lung_lower_lobe_right"]
    elif args.masks == "heart":
        masks = ["heart_myocardium", "heart_atrium_left", "heart_ventricle_left",
                 "heart_atrium_right", "heart_ventricle_right"]

    ref_img = None
    for mask in masks:
        if (args.mask_dir / f"{mask}.nii.gz").exists():
            ref_img = nib.load(args.mask_dir / f"{masks[0]}.nii.gz")
            break
    if ref_img is None:
        raise ValueError("All masks are not present in the directory")

    combined = np.zeros(ref_img.shape, dtype=np.uint8)
    for idx, mask in enumerate(masks):
        if (args.mask_dir / f"{mask}.nii.gz").exists():
            img = nib.load(args.mask_dir / f"{mask}.nii.gz").get_fdata()
            combined[img > 0.5] = 1

    nib.save(nib.Nifti1Image(combined, ref_img.affine), args.output)

    if args.nora_tag != "None":
        subprocess.call(f"/opt/nora/src/node/nora -p {args.nora_tag} --add {args.output} --addtag mask", shell=True)
