import sys
from random import randint
from pyspark import SparkConf
from pyspark.sql import SQLContext
from typedecorator import params, Nullable, setup_typecheck
from varspark import java
from varspark.etc import find_jar
[docs]class VarsparkContext(object):
"""The main entry point for VariantSpark functionality.
"""
[docs] @classmethod
def spark_conf(cls, conf=SparkConf()):
""" Adds the necessary option to the spark configuration.
Note: In client mode these need to be setup up using --jars or --driver-class-path
"""
return conf.set("spark.jars", find_jar())
def __init__(self, ss, silent=False):
"""The main entry point for VariantSpark functionality.
:param ss: SparkSession
:type ss: :class:`.pyspark.SparkSession`
:param bool silent: Do not produce welcome info.
"""
self.sc = ss.sparkContext
self.silent = silent
self.sql = SQLContext.getOrCreate(self.sc)
self._jsql = self.sql._jsqlContext
self._jvm = self.sc._jvm
self._vs_api = getattr(self._jvm, 'au.csiro.variantspark.api')
jss = ss._jsparkSession
self._jvsc = self._vs_api.VSContext.apply(jss)
setup_typecheck()
if not self.silent:
sys.stderr.write('Running on Apache Spark version {}\n'.format(self.sc.version))
if self.sc._jsc.sc().uiWebUrl().isDefined():
sys.stderr.write('SparkUI available at {}\n'.format(
self.sc._jsc.sc().uiWebUrl().get()))
sys.stderr.write(
'Welcome to\n'
' _ __ _ __ _____ __ \n'
'| | / /___ ______(_)___ _____ / /_/ ___/____ ____ ______/ /__ \n'
'| | / / __ `/ ___/ / __ `/ __ \/ __/\__ \/ __ \/ __ `/ ___/ //_/ \n'
'| |/ / /_/ / / / / /_/ / / / / /_ ___/ / /_/ / /_/ / / / ,< \n'
'|___/\__,_/_/ /_/\__,_/_/ /_/\__//____/ .___/\__,_/_/ /_/|_| \n'
' /_/ \n')
[docs] @params(self=object, vcf_file_path=str, min_partitions=int)
def import_vcf(self, vcf_file_path, min_partitions=0):
""" Import features from a VCF file.
"""
return FeatureSource(self._jvm, self._vs_api,
self._jsql, self.sql, self._jvsc.importVCF(vcf_file_path,
min_partitions))
[docs] @params(self=object, label_file_path=str, col_name=str)
def load_label(self, label_file_path, col_name):
""" Loads the label source file
:param label_file_path: The file path for the label source file
:param col_name: the name of the column containing labels
"""
return self._jvsc.loadLabel(label_file_path, col_name)
[docs] def stop(self):
""" Shut down the VariantsContext.
"""
self.sc.stop()
self.sc = None
# Deprecated
VariantsContext = VarsparkContext
[docs]class FeatureSource(object):
def __init__(self, _jvm, _vs_api, _jsql, sql, _jfs):
self._jfs = _jfs
self._jvm = _jvm
self._vs_api = _vs_api
self._jsql = _jsql
self.sql = sql
[docs] @params(self=object, label_source=object, n_trees=Nullable(int), mtry_fraction=Nullable(float),
oob=Nullable(bool), seed=Nullable(int), batch_size=Nullable(int),
var_ordinal_levels=Nullable(int), max_depth=int, min_node_size=int)
def importance_analysis(self, label_source, n_trees=1000, mtry_fraction=None,
oob=True, seed=None, batch_size=100, var_ordinal_levels=3,
max_depth=java.MAX_INT, min_node_size=1):
"""Builds random forest classifier.
:param label_source: The ingested label source
:param int n_trees: The number of trees to build in the forest.
:param float mtry_fraction: The fraction of variables to try at each split.
:param bool oob: Should OOB error be calculated.
:param int seed: Random seed to use.
:param int batch_size: The number of trees to build in one batch.
:param int var_ordinal_levels:
:return: Importance analysis model.
:rtype: :py:class:`ImportanceAnalysis`
"""
vs_algo = self._jvm.au.csiro.variantspark.algo
jrf_params = vs_algo.RandomForestParams(bool(oob),
java.jfloat_or(
mtry_fraction),
True, java.NAN, True,
java.jlong_or(seed,
randint(
java.MIN_LONG,
java.MAX_LONG)),
max_depth,
min_node_size, False,
0)
jia = self._vs_api.ImportanceAnalysis(self._jsql, self._jfs, label_source,
jrf_params, n_trees, batch_size, var_ordinal_levels)
return ImportanceAnalysis(jia, self.sql)
[docs]class ImportanceAnalysis(object):
""" Model for random forest based importance analysis
"""
def __init__(self, _jia, sql):
self._jia = _jia
self.sql = sql
[docs] @params(self=object, limit=Nullable(int))
def important_variables(self, limit=10):
""" Gets the top limit important variables as a list of tuples (name, importance) where:
- name: string - variable name
- importance: double - gini importance
"""
jimpvarmap = self._jia.importantVariablesJavaMap(limit)
return sorted(jimpvarmap.items(), key=lambda x: x[1], reverse=True)
[docs] def oob_error(self):
""" OOB (Out of Bag) error estimate for the model
:rtype: float
"""
return self._jia.oobError()
[docs] def variable_importance(self):
""" Returns a DataFrame with the gini importance of variables.
The DataFrame has two columns:
- variable: string - variable name
- importance: double - gini importance
"""
jdf = self._jia.variableImportance()
jdf.count()
jdf.createTempView("df")
return self.sql.table("df")