Source code for libuplift.datasets.IST

"""The International Stroke Trial dataset.

This is a randomized clinical trial of heparin and aspirin treatment
for stroke patients.

This dataset is derived from the corrected dataset available here:
https://datashare.ed.ac.uk/handle/10283/128
The webpage contains detailed descriptions.

This version only includes pre-randomization variables, two targets,
and several additional targets related to side effects.

"""

import numpy as np

from .base import _fetch_remote_csv
from .base import RemoteFileMetadata


ARCHIVE = RemoteFileMetadata(
    filename="IST.csv.gz",
    url=('https://github.com/jszymon/uplift_sklearn_data/'
         'releases/download/IST/IST.csv.gz'),
    checksum=('24401e85937748cb0488994c1d8aaf6e'
              'be34a5d12d5446840b4f038f8a8e4de7'))


[docs] def fetch_IST(include_pilot=True, include_location_vars=True, include_prediction_model_vars=True, data_home=None, download_if_missing=True, random_state=None, shuffle=False, categ_as_strings=False, return_X_y=False, as_frame=False): """Load the International Stroke Trial (IST) dataset. Download it if necessary. This is a randomized clinical trial of heparin and aspirin treatment for stroke patients. This dataset is derived from the corrected dataset available here: https://datashare.ed.ac.uk/handle/10283/128 The webpage contains detailed descriptions. This version only includes pre-randomization variables, two main targets, and several additional targets related to side effects. The two main targets are: target_ID14 - death after 14 days target_OCCODE - outcome after 6 month. Original study used ("dead" or "dependent") as outcome of interest Additionally there are 9 targets describing side effects at 14 days: target_H14, target_ISC14, target_NK14, target_STRK14, target_HTI14, target_PE14, target_DVT14, target_TRAN14, target_NCB14 **Variables** See https://datashare.ed.ac.uk/handle/10283/128 **Changes to the original dataset** - Only pretreatment variables, variables describing outcomes at 14 days and 6 month outcome code are included - Change all N/Y variables to 0/1 - Level H of RXHEP recoded as M for pilot study cases - Add var IS_PILOT indicating pilot study obtained by testing if RHEP24 is NaN. The variable is only added if include_pilot is True. - RDATE variable has been split into RYEAR and RMONTH, month names have been translated to English - Recoded OCCODE to descriptive values, merge two "missing status" categories to "NA" Parameters ---------- include_pilot : boolean, default=True Whether to include records from a pilot study with 984 patients. Some values (RATRIAL and RASP3) are missing in the pilot. include_location_vars : boolean, default=True Should variables describing hospitals and their locations be included. These are categorical variables with large number of levels. The variables are: HOSPNUM, COUNTRY data_home : string, optional Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. download_if_missing : boolean, default=True If False, raise a IOError if the data is not locally available instead of trying to download the data from the source site. random_state : int, RandomState instance or None (default) Determines random number generation for dataset shuffling. Pass an int for reproducible output across multiple function calls. shuffle : bool, default=False Whether to shuffle dataset. categ_as_strings : bool, default=False Whether to return categorical variables as strings. return_X_y : boolean, default=False. If True, returns ``(data.data, data.target)`` instead of a Bunch object. as_frame : boolean, default=False If True features are returned as pandas DataFrame. If False features are returned as object or float array. Float array is returned if all features are floats. Returns ------- dataset : dict-like object with the following attributes: dataset.data : numpy array Each row corresponds to the features in the dataset. dataset.target : numpy array Each value is 1 if a purchase was made 0 otherwise. dataset.DESCR : string Description of the dataset. (data, target) : tuple if ``return_X_y`` is True """ # dictionaries treatment_heparin_values = ["N", "L", "M"] hospnum_values = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "24", "25", "26", "27", "28", "29", "30", "31", "32", "33", "34", "35", "36", "37", "38", "39", "40", "41", "46", "47", "48", "49", "50", "51", "52", "53", "54", "55", "56", "57", "58", "59", "60", "61", "62", "63", "64", "65", "66", "67", "68", "69", "70", "71", "72", "73", "74", "75", "76", "77", "78", "79", "80", "81", "82", "83", "84", "85", "86", "87", "88", "89", "90", "93", "95", "96", "97", "98", "101", "102", "105", "106", "107", "108", "109", "110", "111", "113", "114", "115", "116", "117", "118", "119", "121", "122", "123", "124", "125", "126", "127", "129", "130", "131", "132", "133", "134", "135", "136", "137", "139", "141", "142", "143", "144", "147", "149", "152", "153", "154", "155", "156", "158", "159", "161", "162", "163", "164", "165", "166", "169", "170", "171", "172", "173", "174", "175", "176", "177", "178", "179", "180", "181", "183", "184", "186", "187", "188", "189", "190", "191", "193", "194", "195", "196", "197", "198", "200", "201", "202", "203", "204", "206", "207", "208", "209", "210", "211", "212", "213", "214", "215", "217", "218", "219", "220", "222", "223", "224", "225", "226", "227", "228", "229", "230", "231", "232", "233", "234", "236", "237", "238", "239", "240", "241", "242", "243", "244", "245", "246", "247", "248", "249", "250", "251", "252", "253", "255", "256", "257", "258", "259", "260", "262", "264", "265", "267", "268", "271", "274", "278", "279", "281", "283", "285", "286", "289", "290", "291", "292", "293", "294", "295", "296", "297", "299", "300", "301", "302", "303", "304", "305", "306", "307", "308", "309", "310", "311", "312", "313", "314", "317", "319", "320", "322", "323", "324", "326", "327", "328", "330", "331", "332", "333", "334", "336", "337", "339", "341", "342", "343", "344", "345", "346", "348", "349", "350", "351", "352", "353", "354", "355", "359", "360", "361", "362", "363", "364", "365", "366", "367", "368", "369", "371", "372", "373", "374", "375", "376", "377", "378", "380", "381", "382", "383", "384", "387", "388", "390", "391", "392", "394", "395", "396", "399", "400", "402", "403", "404", "405", "406", "407", "408", "409", "410", "411", "412", "413", "414", "415", "416", "417", "418", "419", "420", "421", "422", "423", "424", "425", "428", "429", "430", "431", "433", "434", "435", "436", "437", "438", "439", "440", "441", "443", "445", "447", "449", "452", "453", "454", "455", "456", "457", "458", "461", "462", "463", "464", "465", "467", "468", "469", "470", "471", "472", "473", "474", "476", "477", "478", "479", "480", "481", "482", "483", "484", "485", "486", "487", "488", "491", "492", "495", "496", "497", "498", "499", "500", "501", "502", "503", "504", "505", "506", "507", "508", "510", "511", "512", "513", "514", "515", "516", "518", "519", "520", "521", "522", "523", "524", "527", "528", "529", "531", "532", "533", "534", "535", "536", "538", "539", "540", "541", "542", "543", "545", "546", "547", "548", "549", "550", "551", "553", "554", "557", "558", "559", "560", "561", "562", "563", "564", "565", "567", "568"] country_values = ['UK', 'ITAL', 'PORT', 'EIRE', 'BELG', 'FINL', 'AUSL', 'CZEC', 'USA', 'HUNG', 'NETH', 'NEW', 'SWIT', 'AUST', 'SLOV', 'SPAI', 'NORW', 'SWED', 'CHIL', 'GREE', 'POLA', 'TURK', 'SOUT', 'ISRA', 'SLOK', 'CANA', 'HONG', 'BRAS', 'INDI', 'ARGE', 'DENM', 'FRAN', 'SRI', 'ROMA', 'JAPA', 'SING'] rconsc_values = ["F", "D", "U"] sex_values = ["F", "M"] ync_values = ["N", "Y", "C"] # C=Can't assess stype_values = ["PACS", "LACS", "TACS", "POCS", "OTH"] day_values = ["Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"] occode_values = ["dead", "dependent", "not_recovered", "recovered", "NA"] # attribute descriptions treatment_descr = [("treatment_asp", np.int32, "RXASP"), ("treatment_hep", treatment_heparin_values, "RXHEP"), ] target_descr = [("target_ID14", np.int32, "ID14"), ("target_OCCODE", occode_values, "OCCODE"), ("target_H14", np.int32, "H14"), ("target_ISC14", np.int32, "ISC14"), ("target_NK14", np.int32, "NK14"), ("target_STRK14", np.int32, "STRK14"), ("target_HTI14", np.int32, "HTI14"), ("target_PE14", np.int32, "PE14"), ("target_DVT14", np.int32, "DVT14"), ("target_TRAN14", np.int32, "TRAN14"), ("target_NCB14", np.int32, "NCB14"), ] feature_descr = [("IS_PILOT", np.int32), ("HOSPNUM", hospnum_values), ("COUNTRY", country_values), ("RDELAY", np.int32), ("RCONSC", rconsc_values), ("SEX", sex_values), ("AGE", np.int32), ("RSLEEP", np.int32), ("RATRIAL", float), ("RCT", np.int32), ("RVISINF", np.int32), ("RHEP24", float), ("RASP3", float), ("RSBP", np.int32), ("RDEF1", ync_values), ("RDEF2", ync_values), ("RDEF3", ync_values), ("RDEF4", ync_values), ("RDEF5", ync_values), ("RDEF6", ync_values), ("RDEF7", ync_values), ("RDEF8", ync_values), ("STYPE", stype_values), ("RYEAR", np.int32), ("RMONTH", np.int32), ("HOURLOCAL", np.int32), ("MINLOCAL", np.int32), ("DAYLOCAL", day_values), ("EXPDD", float), ("EXPD6", float), ("EXPD14", float), ] arch = ARCHIVE dataset_name = "IST" remove_vars = [] if not include_location_vars: remove_vars += ["HOSPNUM", "COUNTRY"] if not include_prediction_model_vars: remove_vars += ["EXPDD", "EXPD6", "EXPD14"] if not include_pilot: remove_vars.append("IS_PILOT") record_mask = np.ones(19435, np.bool) record_mask[:984] = False else: record_mask = None if len(remove_vars) == 0: remove_vars = None ret = _fetch_remote_csv(arch, dataset_name, feature_attrs=feature_descr, treatment_attrs=treatment_descr, target_attrs=target_descr, categ_as_strings=categ_as_strings, return_X_y=return_X_y, as_frame=as_frame, download_if_missing=download_if_missing, random_state=random_state, shuffle=shuffle, total_attrs=44, all_num=False, remove_vars=remove_vars, record_mask=record_mask ) if not return_X_y: ret.descr = __doc__ return ret