#!/usr/bin/env python

import os
import time
import shlex
import atexit
import subprocess

from typing import List

import click
import psutil

import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter


sns.set_context('talk')


def get_clean_command(cmd: List[str]) -> str:
    name = os.path.basename(cmd[0])
    return ' '.join([name] + cmd[1:])


def gather_data(proc: psutil.Process) -> None:
    data = []
    with proc.oneshot():
        data.append(pd.Timestamp.now())
        data.append(get_clean_command(proc.cmdline()))
        data.append(proc.memory_info().rss)
        data.append(proc.status())

    # write data
    fname = f'procwatch_{proc.pid}_data.csv'

    if not os.path.exists(fname):
        with open(fname, 'w') as fd:
            fd.write('date,command,RSS,status\n')

    with open(fname, 'a') as fd:
        fd.write('"' + '","'.join(map(str, data)) + '"\n')


def shorten_msg(text: str, width: int, suffix: str = '[...]') -> str:
    if len(text) <= width:
        return text
    else:
        return text[:width-len(suffix)] + suffix


def sizeof_fmt(num: float, suffix: str = 'B') -> str:
    narrow_space = u'\u202F'
    size_types = ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']
    for unit in size_types:
        if abs(num) < 1024.0:
            return f'{num:3.1f}{narrow_space}{unit}{suffix}'
        num /= 1024.0
    return f'{num:.1f}{narrow_space}Yi{suffix}'


def plot_result(proc: psutil.Process) -> None:
    # read data
    fname = f'procwatch_{proc.pid}_data.csv'
    df = pd.read_csv(fname, parse_dates=['date'])

    # sometimes processes change their title (e.g. using setproctitle).
    # Use most commonly observed command as educated guess.
    df['command'] = df['command'].mode().iloc[0]

    # prevent overflow of legend
    df['command'] = df['command'].apply(shorten_msg, width=50)

    # plot
    fig_scale = 1.5
    fig, ax = plt.subplots(figsize=(8*fig_scale, 6*fig_scale))
    sns.lineplot(x='date', y='RSS', hue='command', data=df)

    plt.xticks(rotation=90)

    @FuncFormatter
    def formatter(x, pos):
        return sizeof_fmt(x)
    ax.yaxis.set_major_formatter(formatter)

    plt.tight_layout()
    plt.savefig(f'procwatch_{proc.pid}_memory.pdf')


def exit_handler(proc: psutil.Process):
    print('Stop PID tracking, plotting result...')
    plot_result(proc)


def handle_input(inp: str) -> int:
    try:
        pid = int(inp)
    except ValueError:
        p = subprocess.Popen(shlex.split(inp))
        pid = p.pid
    return pid


@click.command()
@click.argument('pid_cmd', metavar='PID/CMD')
@click.option(
    '--interval', default=5,
    help='Data retrieval interval in seconds.')
def main(pid_cmd: str, interval: int):
    """Retrieve information about a running process."""
    pid = handle_input(pid_cmd)
    proc = psutil.Process(pid)
    atexit.register(exit_handler, proc)

    while True:
        if not proc.is_running() or proc.status() == psutil.STATUS_ZOMBIE:
            break

        gather_data(proc)
        time.sleep(interval)


if __name__ == '__main__':
    main()
