Source code for libuplift.utils.multi_array
"""Multiple simultaneuosly indexed arrays.
Needed to work around lack of sample properties in scikit-learn."""
import numpy as np
from sklearn.utils import check_consistent_length
[docs]
class MultiArray:
__array_ufunc__ = None # don't allow numpy operations by default
def __init__(self, main_array, array_dict=None, scalar_dict=None):
"""behaves like main_array w.r.t. indexing, but arrays in
array_dict are indexed simulteneously, scalar_dict is passed to
indexing results."""
self.main_array = main_array
if array_dict is None:
array_dict = dict()
check_consistent_length([main_array] + list(array_dict.values()))
self.array_dict = array_dict
if scalar_dict is None:
scalar_dict = dict()
self.scalar_dict = scalar_dict
self.shape = self.main_array.shape
self.ndim = self.main_array.ndim
self.dtype = self.main_array.dtype
[docs]
def __getitem__(self, idx):
new_dict = {k:self.array_dict[k][idx] for k in self.array_dict}
return MultiArray(self.main_array[idx], new_dict, self.scalar_dict)