#!/usr/bin/env python3

import sys
import argparse
import os
from io import BytesIO
from textwrap import dedent
from types import SimpleNamespace
from functools import partial

import requests
import yaml

from requests_toolbelt import MultipartEncoder, MultipartEncoderMonitor
from tqdm import tqdm
from neotermcolor import colored

_d = SimpleNamespace()


def gen_random_str(length=32):
    """
    pyaltt2 copy/paste to avoid unnecesseary imports
    """
    import string
    import random
    symbols = string.ascii_letters + '0123456789'
    return ''.join(random.choice(symbols) for i in range(length))


def encrypt_data(data, key=None):
    if key is None:
        key = gen_random_str(12)
    import subprocess

    p = subprocess.run(
        'openssl aes-256-cbc -a -salt -pbkdf2 -pass stdin'.split(),
        input=key.encode() + '\n'.encode() + data,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE)

    if p.returncode:
        raise RuntimeError(f'openssl error {p.stderr}')
    return key, p.stdout


def safe_request(rfn, on_except=None):
    try:
        r = rfn()
        if not r.ok:
            raise RuntimeError(
                f'Server response code: {r.status_code} {r.reason}')
        return r
    except Exception as e:
        if on_except:
            on_except()
        print(colored(e, color='red', attrs='bold'))
        exit(1)


def ok():
    print(colored('OK', color='green', attrs='bold'))


try:
    with open(f'{os.path.expanduser("~")}/.sshare.yml') as fh:
        config = yaml.safe_load(fh)['sshare']
except FileNotFoundError:
    config = {}

ap = argparse.ArgumentParser()

ap.add_argument('FILE',
                help='File to share ("-" for stdin [text]) or command (c:)',
                nargs='?')

ap.add_argument('-u',
                '--url',
                help='Service URL',
                default=config.get('url', 'http://localhost:8008'))

ap.add_argument('-k',
                '--key',
                help='upload key',
                default=config.get('upload-key'))

ap.add_argument('-T',
                '--timeout',
                help='timeout',
                default=config.get('timeout', 5),
                type=int)

ap.add_argument('-s',
                '--one-shot',
                help='One-shot sharing',
                action='store_true')

ap.add_argument('-x', '--expires', help='Expiration (in seconds)', type=int)

ap.add_argument('-c',
                '--encrypt',
                help='Encrypt file on the client-side',
                action='store_true')

ap.add_argument('-w',
                '--prompt-password',
                help='Ask for encryption password, instead of auto-generation',
                action='store_true')

a = ap.parse_args()

LOCAL_KEY = None

if a.prompt_password:
    from getpass import getpass
    pass1 = getpass('Password: ')
    pass2 = getpass('Repeat password: ')
    if pass1 != pass2:
        print(colored('Passwords don\'t match', color='red'))
        exit(2)
    else:
        LOCAL_KEY = pass1

headers = {}
if a.key:
    headers['X-Auth-Key'] = a.key

data = {}

if a.one_shot:
    data['oneshot'] = '1'

if a.expires:
    data['expires'] = str(a.expires)

if a.FILE is not None and a.FILE.startswith('c:'):
    if a.FILE == 'c:token':
        r = safe_request(
            partial(requests.post,
                    f'{a.url}/api/v1/token',
                    data=data,
                    headers=headers,
                    timeout=a.timeout))
        print(f'One-time token created, expires: {r.headers["Expires"]}')
        print()
        print(colored(r.json()['token'], color='white', attrs='bold'))
    elif a.FILE.startswith('c:delete:token:'):
        safe_request(
            partial(requests.delete,
                    f'{a.url}/api/v1/token/{a.FILE[9:]}',
                    headers=headers,
                    timeout=a.timeout))
        ok()
    elif a.FILE.startswith('c:delete'):
        safe_request(
            partial(requests.delete,
                    f'{a.FILE[9:]}',
                    headers=headers,
                    timeout=a.timeout))
        ok()
    else:
        print(colored('Command unsupported', color='red', attrs='bold'))
        exit(2)
    exit(0)

if a.FILE == '-':
    a.FILE = None

if not a.FILE:
    b = BytesIO()
    s = sys.stdin.buffer.read()
    try:
        s.decode()
        ctype = 'text/plain'
        ext = 'txt'
    except:
        ctype = 'application/octet-stream'
        ext = 'dat'
    if a.encrypt:
        LOCAL_KEY, s = encrypt_data(s, LOCAL_KEY)
    b.write(s)
    b.seek(0)
    size = len(s)
    if size == 0:
        exit(0)
    files = {
        'file': (
            f'paste.{ext}',
            b,
        )
    }
else:
    if a.encrypt:
        with open(a.FILE, 'rb') as fh:
            contents = fh.read()
        LOCAL_KEY, s = encrypt_data(contents, LOCAL_KEY)
        b = BytesIO()
        b.write(s)
        b.seek(0)
        size = len(s)
        stream = b
    else:
        size = os.path.getsize(a.FILE)
        stream = open(a.FILE, 'rb')
    files = {'file': (a.FILE, stream)}

fname = 'STDIN' if not a.FILE else os.path.basename(a.FILE)

print(f'>> {colored(a.url, color="grey")}')

_d.prev = 0


def progress(mon):
    b = int(mon.bytes_read / 1000)
    pbar.update(b - _d.prev)
    _d.prev = b


data.update(files)

if LOCAL_KEY:
    data['raw'] = '1'

e = MultipartEncoder(fields=data)
m = MultipartEncoderMonitor(e, progress)

headers['Content-Type'] = m.content_type

pbar = tqdm(total=int(size / 1000), unit="KB", desc=fname)

r = safe_request(partial(requests.post,
                         f'{a.url}/u',
                         data=m,
                         headers=headers,
                         timeout=a.timeout),
                 on_except=pbar.close)

pbar.clear()
pbar.close()
url = r.headers["Location"]
print(
    dedent(f"""
        {colored(fname + " uploaded", color="green")}

        Access URL: {colored(url, color="cyan", attrs="bold")}
        Expires: {r.headers["Expires"]}
        """))

if a.one_shot:
    print(colored('(Single access attempt allowed)', color='yellow'))
    print()

if a.encrypt:
    print(colored('=' * (len(url) + 12), color='grey'))
    if not a.prompt_password:
        print(f'Decrypt password: {LOCAL_KEY}')
        print()
    print(
        colored(
            f'curl -s {url} |\n    openssl aes-256-cbc -d -a -pbkdf2'
            f' -out {url.rsplit("/",1)[-1]}',
            color='green',
            attrs='bold'))
    print(colored('=' * (len(url) + 12), color='grey'))
    print()
