#!/opt/anaconda/bin/python

from microtc.utils import tweet_iterator, TEXT, KLASS, VALUE
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, r2_score
from sklearn.metrics import cohen_kappa_score
from scipy.stats import pearsonr, spearmanr
import numpy as np
import json
import os
import sys


def show_mismatches(gold, predicted, output=None, avgf1=None):
    i = 0
    klass = os.environ.get(KLASS)
    for g, p in zip(tweet_iterator(gold), tweet_iterator(predicted)):
        g[KLASS] = str(g[KLASS])
        p[KLASS] = str(p[KLASS])
        i += 1
        if g[KLASS] != p[KLASS]:
            p['klass_from_gold'] = g[KLASS]
            print(json.dumps(p, sort_keys=True), i)


def performance(gold, predicted, output=None, avgf1=None):
    y = [str(tweet[KLASS]) for tweet in tweet_iterator(gold)]
    le = LabelEncoder().fit(y)
    y = le.transform(y)

    hy = [str(tweet[KLASS]) for tweet in tweet_iterator(predicted)]
    hy = le.transform(hy)
    # with open(fname.replace('_output.json', '.params')) as fpt:
    #     a = json.load(fpt)
    f1 = {le.inverse_transform([klass])[0]: f1 for klass, f1 in enumerate(
        f1_score(y, hy, average=None))}
    precision = {le.inverse_transform([klass])[0]: p for klass, p in enumerate(
        precision_score(y, hy, average=None))}

    report = {
        'quadratic_weighted_kappa': cohen_kappa_score(y, hy, weights='quadratic'),
        'macrof1': f1_score(y, hy, average='macro'),
        'macrorecall': recall_score(y, hy, average='macro'),
        'microf1': f1_score(y, hy, average='micro'),
        'accuracy': accuracy_score(y, hy),
        'filename': predicted,
    }
    report['macrof1accuracy'] = report['accuracy'] * report['macrof1']

    for key, value in f1.items():
        report["f1_{0}".format(key)] = value
        # report["precision_{0}".format(key)] = precision[key]

    if avgf1:
        L = [f1.get(k) for k in avgf1.split(':')]
        report['avg_f1_{0}'.format(avgf1)] = sum(L) / len(L)

    if output is None:
        print(json.dumps(report, sort_keys=1, indent=2) + "\n")
    else:
        with open(output, 'w') as f:
            f.write(json.dumps(report, sort_keys=1) + "\n")

def regression_performance(gold, predicted, output=None):
    y = np.array([str(tweet[VALUE]) for tweet in tweet_iterator(gold)], dtype=np.float)
    hy = np.array([str(tweet[VALUE]) for tweet in tweet_iterator(predicted)], dtype=np.float)
  
    report = {
        'r2': r2_score(y, hy),
        'pearsonr': pearsonr(y, hy),
        'spearmanr': spearmanr(y, hy),
        'filename': predicted,
    }

    if output is None:
        print(json.dumps(report, sort_keys=1, indent=2) + "\n")
    else:
        with open(output, 'w') as f:
            f.write(json.dumps(report, sort_keys=1) + "\n")


if __name__ == '__main__':
    from argparse import ArgumentParser
    parser = ArgumentParser(description='microtc')
    parser.add_argument('gold', nargs=1, default=None, help='Gold standard set')
    parser.add_argument('predicted', nargs=1, default=None, help='Predicted set')
    parser.add_argument('-a', '--avgf1', dest='avgf1', default=None, type=str, help='Indicates if microtc-perf must show the average of some selected classes (using `:` as separator')

    parser.add_argument('-o', '--output', dest='output', default=None, type=str, help='Indicates the name of the output file, defaults to stdout')
    parser.add_argument('--mismatches', dest='mismatches', default=False, action='store_true', help='Prints the mismatching classifications to stdout')
    parser.add_argument('-r', '--regression', dest='regression', default=False, action='store_true', help='Score regression data')
    args = parser.parse_args()

    if args.mismatches:
        show_mismatches(args.gold[0], args.predicted[0])
    elif args.regression:
        regression_performance(args.gold[0], args.predicted[0], args.output)
    else:
        performance(args.gold[0], args.predicted[0], args.output, args.avgf1)

