#!/usr/bin/env python
import os
import sys
from pathlib import Path
import argparse

from totalsegmentator.libs import download_pretrained_weights


if __name__ == "__main__":
    """
    Download totalsegmentator weights
    """
    parser = argparse.ArgumentParser(description="Import manually downloaded weights.",
                                     epilog="Written by Jakob Wasserthal.")

    parser.add_argument("-t", "--task", choices=["total", "total_fast", "lung_vessels", "cerebral_bleed",
                                                 "hip_implant", "coronary_arteries", "body", "pleural_pericard_effusion",
                                                 "liver_vessels", "bones_extremities"],
                        help="Task for which to download the weights", default="total")

    args = parser.parse_args()

    task_to_id = {
        "total": [251, 252, 253, 254, 255],
        "total_fast": [256],
        "lung_vessels": [258],
        "covid": [201],
        "cerebral_bleed": [150],
        "hip_implant": [260],
        "coronary_arteries": [503],
        "pleural_pericard_effusion": [315],
        "liver_vessels": [8]
    }

    for task_id in task_to_id[args.task]:
        print(f"Processing {task_id}...")
        download_pretrained_weights(task_id)
