#!/usr/bin/env python
#
#   Copyright 2016-2019 Blaise Frederick
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
#
#       $Author: frederic $
#       $Date: 2016/06/14 12:04:51 $
#       $Id: showstxcorr,v 1.11 2016/06/14 12:04:51 frederic Exp $
#
import os

import numpy as np
import rapidtide.io as tide_io

'''% This code is released under the terms of the GNU GPL v2. This code
% is not FDA approved for clinical use; it is provided
% freely for research purposes. If using this in a publication
% please reference this properly as: 

% Finn ES, Shen X, Scheinost D, Rosenberg MD, Huang, Chun MM,
% Papademetris X & Constable RT. (2015). Functional connectome
% fingerprinting: Identifying individuals using patterns of brain
% connectivity. Nature Neuroscience 18, 1664-1671.

% This code provides a framework for implementing functional
% connectivity-based behavioral prediction in a leave-one-subject-out
% cross-validation scheme, as described in Finn, Shen et al 2015 (see above
% for full reference). The first input ('all_mats') is a pre-calculated
% MxMxN matrix containing all individual-subject connectivity matrices,
% where M = number of nodes in the chosen brain atlas and N = number of
% subjects. Each element (i,j,k) in these matrices represents the
% correlation between the BOLD timecourses of nodes i and j in subject k
% during a single fMRI session. The second input ('all_behav') is the
% Nx1 vector of scores for the behavior of interest for all subjects.

% As in the reference paper, the predictive power of the model is assessed
% via correlation between predicted and observed scores across all
% subjects. Note that this assumes normal or near-normal distributions for
% both vectors, and does not assess absolute accuracy of predictions (only
% relative accuracy within the sample). It is recommended to explore
% additional/alternative metrics for assessing predictive power, such as
% prediction error sum of squares or prediction r^2.


clear;
clc;

% ------------ INPUTS -------------------

all_mats  = load('mid_all_VP.mat');  % 268 x 268 x N subjects
all_mats=all_mats.mid_all_vP;'''
allmats_img, all_mats, allmats_hdr, allmatsdims, allmatssizes = tide_io.readfromnifti('mid_all_VP.nii.gz')
all_behav = tide_io.readvec('any_drug_free_urine.txt')

'''tmp=load('any_drug_free_urine.txt');
all_behav=tmp(:);
%all_behav = load('ab_one_week.mat');; % single vector w/ onbe value per subject
%all_behav=all_behav.x';
% threshold for feature selection'''
thresh = 0.05

#% ---------------------------------------

#no_sub = size(all_mats,3);
#no_node = size(all_mats,1);
no_sub = all_mats.size[3]
no_node = all_mats.size[0]

'''behav_pred_pos = zeros(no_sub,1);
behav_pred_neg = zeros(no_sub,1);
behav_pred_both = zeros(no_sub,1);

pos_mask_all = ones(no_node,no_node);
neg_mask_all = ones(no_node,no_node);'''

behav_pred_pos = np.zeroes((no_sub), dtype=np.float)
behav_pred_neg = np.zeroes((no_sub), dtype=np.float)
behav_pred_both = np.zeroes((no_sub), dtype=np.float)

pos_mask_all = np.ones((no_node, no_node))
neg_mask_all = np.ones((no_node, no_node))

#for leftout = 1:no_sub;
for leftout in range(no_sub):
    #fprintf('\n Leaving out subj # %6.3f',leftout);
    print('\n Leaving out subj # %6.3f',leftout)

    #% leave out subject from matrices and behavior

    #train_mats = all_mats;
    train_mats = all_mats + 0.0
    #train_mats(:,:,leftout) = [];
    np.delete(train_mats, leftout, axis=3)
    #train_vcts = reshape(train_mats,[],size(train_mats,3));
    train_vcts = train_mats.reshape((no_sub, no_sub, 1, -1))

    #train_behav = all_behav;
    train_behav = all_behav + 0.0
    #train_behav(leftout) = [];
    np.delete(train_behav, leftout)

    #% correlate all edges with behavior %can change to a partial_corr (?)
    #% to control for something

    #[r_mat,p_mat] = corr(train_vcts,train_behav);
    r_mat, p_mat = np.correlate(train_vcts, train_behav[:,], axis=3)

    #r_mat = reshape(r_mat,no_node,no_node);
    #p_mat = reshape(p_mat,no_node,no_node);
    r_mat.reshape((no_node,no_node))
    p_mat.reshape((no_node,no_node))

    #% set threshold and define masks

    #pos_mask = zeros(no_node,no_node);
    #neg_mask = zeros(no_node,no_node);
    pos_mask = np.zeros((no_node,no_node), dtype=np.float)
    neg_mask = np.zeros((no_node,no_node), dtype=np.float)

    #pos_edges = find(r_mat > 0 & p_mat < thresh);
    #neg_edges = find(r_mat < 0 & p_mat < thresh);

    #pos_mask(pos_edges[]) = 1;
    #neg_mask(neg_edges) = 1;

    pos_mask = np.where(r_mat > 0, 1, 0) * np.where(p_mat < thresh, 1, 0)
    neg_mask = np.where(r_mat < 0, 1, 0) * np.where(p_mat < thresh, 1, 0)

    #pos_mask_all=pos_mask.*pos_mask_all;
    #neg_mask_all=neg_mask.*neg_mask_all;

    pos_mask_all *= pos_mask
    neg_mask_all *= neg_mask

    #% get sum of all edges in TRAIN subs (divide by 2 to control for the
    #% fact that matrices are symmetric)

    train_sumpos = zeros(no_sub-1,1);
    train_sumneg = zeros(no_sub-1,1);
    train_sumboth = zeros(no_sub-1,1);

    for ss = 1:size(train_sumpos);
        train_sumpos(ss) = sum(sum(train_mats(:,:,ss).*pos_mask))/2;
        train_sumneg(ss) = sum(sum(train_mats(:,:,ss).*neg_mask))/2;
        train_sumboth(ss) = train_sumpos(ss)-train_sumneg(ss);
    end

    #% build model on TRAIN subs %%could also add in behavior here to build
    #% to model (then add fit_pos3 below)

    fit_pos = polyfit(train_sumpos, train_behav,1);
    fit_neg = polyfit(train_sumneg, train_behav,1);
    fit_both = polyfit(train_sumboth, train_behav,1);

%    B_pos = mnrfit(train_sumpos,train_behav+1);
%    B_neg = mnrfit(train_sumneg,train_behav+1);
%    B_both = mnrfit(train_sumboth,train_behav+1);

    % run model on TEST sub

    test_mat = all_mats(:,:,leftout);
    test_sumpos = sum(sum(test_mat.*pos_mask))/2;
    test_sumneg = sum(sum(test_mat.*neg_mask))/2;
    test_sumboth = test_sumpos-test_sumneg;

%    pi_pos = mnrval(B_pos,test_sumpos);
%    pi_neg = mnrval(B_neg,test_sumneg);
%    pi_both = mnrval(B_both,test_sumboth);

%    behav_pred_pos(leftout)=double(pi_pos(2)>pi_pos(1));
%    behav_pred_neg(leftout)=double(pi_neg(2)>pi_neg(1));
%    behav_pred_both(leftout)=double(pi_both(2)>pi_both(1));

    behav_pred_pos(leftout) = fit_pos(1)*test_sumpos + fit_pos(2);
    behav_pred_neg(leftout) = fit_neg(1)*test_sumneg + fit_neg(2);
    behav_pred_both(leftout) = fit_both(1)*test_sumboth + fit_both(2);

end

% compare predicted and observed scores

[R_pos, P_pos] = corr(behav_pred_pos,all_behav,'type','Spearman')
[R_neg, P_neg] = corr(behav_pred_neg,all_behav,'type','Spearman')
[R_both, P_both] = corr(behav_pred_both,all_behav,'type','Spearman')

figure(1); plot(behav_pred_pos,all_behav,'r.'); lsline
figure(2); plot(behav_pred_neg,all_behav,'b.'); lsline
figure(3); plot(behav_pred_both,all_behav,'k.'); lsline

%save('pos_mask_all_no_motion_no_perm_005.txt','pos_mask_all','-ascii');
%save('neg_mask_all_no_motion_no_perm_005.txt','neg_mask_all','-ascii');

'''