{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports and setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "import re\n", "from six.moves.cPickle import load\n", "import argparse\n", "from time import time, strftime\n", "from inspect import getmembers, isfunction\n", "from imp import load_source\n", "import numpy as np\n", "from pandas import DataFrame, Series\n", "import json\n", "from subprocess import check_output\n", "import pickle\n", "import pandas as pd\n", "# Note: We import this *before* any import of TF to avoid weird issues\n", "# see- https://github.com/tensorflow/models/issues/523\n", "# Is a bit of a hack but seems harmless as long as SpaCy is installed\n", "import spacy\n", "\n", "### Note: No Snorkel imports until after the $SNORKELDB env var is set!\n", "\n", "expts = list(filter(lambda p : not re.match(r'.*\\.pyc?$', p),\n", " os.listdir('experiments/')))\n", "\n", "from utils_mderoche import ArgumentEmulator\n", "\n", "\n", "args = ArgumentEmulator(verbose=True, exp=\"fda\", disc_model_search_space=1, start_at=7,db_name = \"_snorkel_fda_temp\")\n", "DEV_SPLIT = 1\n", "TEST_SPLIT = 2 \n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "$SNORKELDB = sqlite:///_snorkel_fda_temp.db\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/matthieu/Documents/Snorkel/snorkel/.snorkelenv.gpu/local/lib/python2.7/site-packages/tensorflow/contrib/tensorboard/plugins/trace/trace.py:22: ImportWarning: Not importing directory '/home/matthieu/Documents/Snorkel/snorkel/parser': missing __init__.py\n", " import parser\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'gen-init-params': {'lf_propensity': True, 'seed': 123}, 'deps-thresh': 0.01, 'disc-params-range': {'dim': [50, 100], 'dropout': [0.1, 0.25, 0.5], 'lr': [0.0002, 1e-05, 0.0001, 0.0005, 5e-05], 'l2_penalty': [0.01, 0.001, 0.0001], 'rebalance': [0.1, 0.2, 0.05, 0.4], 'l1_penalty': [0.001, 0.0001, 1e-05]}, 'featurizer-init-params': {}, 'disc-eval-batch-size': 32, 'disc-params-default': {'dropout': 0.5, 'batch_size': 64, 'n_epochs': 750, 'beta': 0.9, 'max_sentence_length': 100, 'dim': 50, 'print_freq': 50, 'lr': 0.0002, 'l2_penalty': 0.0001, 'rebalance': 0.1, 'l1_penalty': 0.0002, 'allchecks': False}, 'gen-params-default': {'epochs': 50, 'reg_param': 0.1, 'decay': 0.95}, 'disc-model-class': , 'disc-init-params': {'seed': 123, 'n_threads': 4}, 'featurizer-class': , 'gen-params-range': {'step_size': [0.01, 0.001, 0.0001, 1e-05], 'reg_param': [0.0, 0.01, 0.1, 0.5], 'LF_acc_prior_weight_default': [0.5, 1.0, 1.5]}}\n" ] } ], "source": [ "\n", "if args.verbose:\n", " print(args)\n", "\n", "# Get the DB connection string and add to globals\n", "DB_NAME = \"snorkel_\" + args.exp if args.db_name is None else args.db_name\n", "if not args.postgres:\n", " DB_NAME += \".db\"\n", "DB_TYPE = \"postgres\" if args.postgres else \"sqlite\"\n", "DB_ADDR = \"localhost:{0}\".format(args.db_port) if args.db_port else \"\"\n", "os.environ['SNORKELDB'] = '{0}://{1}/{2}'.format(DB_TYPE, DB_ADDR, DB_NAME)\n", "print(\"$SNORKELDB = {0}\".format(os.environ['SNORKELDB']))\n", "\n", "# All Snorkel imports here, after SNORKELDB has been set\n", "from snorkel.annotations import (\n", " LabelAnnotator, load_label_matrix, load_gold_labels, load_marginals\n", ")\n", "from snorkel.learning.structure import DependencySelector\n", "from snorkel.learning import GenerativeModel, RandomSearch\n", "from snorkel.models.meta import SnorkelBase, snorkel_engine\n", "from snorkel.models import Document, Sentence\n", "from snorkel import SnorkelSession\n", "\n", "from utils import *\n", "\n", "############################################################################\n", "### [0] Start from clean slate: Clear the DB\n", "############################################################################\n", "if args.start_at == 0:\n", " if args.verbose > 0:\n", " print(\"Reseting DB...\")\n", " if args.postgres:\n", " raise NotImplementedError(\"TODO: DB clearing for Postgres.\")\n", " else:\n", " try:\n", " _ = check_output(['rm', DB_NAME])\n", " except:\n", " pass\n", " SnorkelBase.metadata.create_all(snorkel_engine)\n", "# Start Snorkel sess\n", "sess = SnorkelSession()\n", "\n", "# Only use parallelism > 1 with UDFs if using Postgres\n", "UDF_THREADS = args.n_threads if args.postgres else 1\n", "\n", "# Import from parsers.py\n", "preprocess = load_source('preprocess',\n", " os.path.join('experiments', args.exp, 'preprocess.py'))\n", "\n", "# Get candidate subclass from loaded module\n", "C = preprocess.C\n", "\n", "# Load the config dictionary\n", "# Load global first, then override any entries with local config file\n", "from config import config\n", "local_config_path = os.path.join('experiments', args.exp, 'config.py')\n", "if os.path.exists(local_config_path):\n", " local_config = load_source('local_config', local_config_path)\n", " config = recursive_merge_dicts(config, local_config.config)\n", "if args.verbose > 0:\n", " print(config)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "GEN_MODEL_NAME = 'G_final_{0}'.format(args.exp)\n", "L_train, L_dev, L_test = None, None, None\n", "Y_dev, Y_test = None, None\n", "gen_model = None\n", "if args.start_at <= 5:\n", "\n", " # Load L_train if starting here\n", " if L_train is None or L_dev is None or L_test is None:\n", " with PrintTimer(\"[5.0] Loading label matrices...\"):\n", " # Optionally subsample the training set here\n", " if args.training_docs > 0:\n", " cids_query = get_training_cids_query(\n", " sess,\n", " preprocess.CONTEXT_HIERARCHY,\n", " C,\n", " preprocess.CANDIDATE_CONTEXT,\n", " args.training_docs,\n", " training_docs_shuffle=args.training_docs_shuffle,\n", " verbose=args.verbose,\n", " training_splits=args.training_splits\n", " )\n", " L_train = load_label_matrix(sess, cids_query=cids_query)\n", " else:\n", " L_train = load_label_matrix(sess, split=0)\n", " L_train = load_label_matrix(sess, split=0)\n", " assert L_train.nnz > 0\n", " L_dev = load_label_matrix(sess, split=DEV_SPLIT)\n", " assert L_dev.nnz > 0\n", " L_test = load_label_matrix(sess, split=TEST_SPLIT)\n", " assert L_test.nnz > 0\n", " if args.verbose > 0:\n", " print(\"Using L_train: {0}\".format(L_train.__repr__()))\n", " print(\"Using L_dev: {0}\".format(L_dev.__repr__()))\n", " print(\"Using L_test: {0}\".format(L_test.__repr__()))\n", "\n", " # Select dependencies to model\n", " with PrintTimer(\"[5.1] Selecting dependencies...\"):\n", " if args.deps:\n", " ds = DependencySelector()\n", " np.random.seed(args.rand_seed)\n", " deps = ds.select(L_train, threshold=config['deps-thresh'])\n", " if args.verbose > 0:\n", " print(\"Selected {0} dependencies.\".format(len(deps)))\n", " else:\n", " deps = ()\n", " if args.verbose > 0:\n", " print(\"Skipping.\")\n", "\n", "# Run grid search to select best generative model\n", " with PrintTimer(\"[5.2] Searching over & training generative models\"):\n", " # Load dev and test labels\n", " if Y_dev is None:\n", " Y_dev = load_gold_labels(sess, annotator_name='gold', split=DEV_SPLIT)\n", " assert Y_dev.nonzero()[0].shape[0] > 0\n", " if Y_test is None:\n", " Y_test = load_gold_labels(sess, annotator_name='gold', split=TEST_SPLIT)\n", " assert Y_test.nonzero()[0].shape[0] > 0\n", "\n", " # Pass in the dependencies via default params\n", " gen_params_default = config['gen-params-default']\n", " gen_params_default['deps'] = deps\n", "\n", " # Train generative model with grid search if applicable\n", " gen_model = train_model(\n", " GenerativeModel,\n", " L_train,\n", " X_dev=L_dev,\n", " Y_dev=Y_dev,\n", " search_size=2,#args.gen_model_search_space,\n", " search_params=config['gen-params-range'],\n", " rand_seed=args.rand_seed,\n", " n_threads=args.n_threads,\n", " verbose=(args.verbose > 0),\n", " params_default=gen_params_default,\n", " model_init_params=config['gen-init-params'],\n", " model_name=GEN_MODEL_NAME,\n", " save_dir=args.save_dir,\n", " beta=args.gen_f_beta\n", " )\n", "\n", " # Save training marginals\n", " gen_model.save_marginals(sess, L_train, training=True)\n", "\n", " # Score generative model on test set\n", " print(\"\\n### Gen. model (DP) score on test set:\")\n", " _ = gen_model.error_analysis(sess, L_test, Y_test, display=True)\n", "\n", " if args.one_only:\n", " sys.exit(0)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We select a model to load: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DISC_MODEL_NAME = \"SparseLogisticRegression_fullchk_pr3__epoch_439\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "############################################################################\n", "### [6] Fit discriminative model\n", "############################################################################\n", "\n", "# DISC_MODEL_NAME = 'D_final_{0}'.format(args.exp)\n", "\n", "X_train, X_dev, Y_train, disc_model = None, None, None, None\n", "F_train,F_dev,F_test,featurizer = None,None,None,None\n", "if args.start_at <= 6:\n", " # Load data: candidates, dev labels, training marginals\n", " with PrintTimer(\"[6.0] Loading data\"):\n", " # Optionally subsample the training set here\n", " if args.training_docs > 0:\n", " cids_query = get_training_cids_query(\n", " sess,\n", " preprocess.CONTEXT_HIERARCHY,\n", " C,\n", " preprocess.CANDIDATE_CONTEXT,\n", " args.training_docs,\n", " training_docs_shuffle=args.training_docs_shuffle,\n", " verbose=args.verbose,\n", " training_splits=args.training_splits\n", " )\n", " X_train = sess.query(C)\\\n", " .filter(C.id == cids_query.subquery().c.id)\\\n", " .order_by(C.id)\\\n", " .all()\n", " Y_train = load_marginals(sess, cids_query=cids_query)\n", " else:\n", " X_train = sess.query(C).filter(C.split==0).order_by(C.id).all()\n", " Y_train = load_marginals(sess, split=0)\n", " assert len(X_train) > 0\n", " assert Y_train.nonzero()[0].shape[0] > 0\n", "\n", " X_dev = sess.query(C).filter(C.split == DEV_SPLIT).order_by(C.id).all()\n", " assert len(X_dev) > 0\n", " if Y_dev is None:\n", " Y_dev = load_gold_labels(sess,annotator_name='gold',split=DEV_SPLIT)\n", " assert Y_dev.nonzero()[0].shape[0] > 0\n", " if args.verbose > 0:\n", " print(\"Loaded X_train: {0}\".format(len(X_train)))\n", " print(\"Loaded Y_train: {0}\".format(Y_train.shape))\n", " print(\"Loaded X_dev: {0}\".format(len(X_dev)))\n", " print(\"Loaded Y_dev: {0}\".format(Y_dev.shape))\n", "\n", " if not config['disc-model-class'].representation:\n", " featurizer = config['featurizer-class'](**config.get('featurizer-init-params',dict()))\n", " with PrintTimer(\"[6.0.1] Computing Features\"):\n", " if args.recompute_feats:\n", " F_train = featurizer.apply(split=0)\n", " F_dev = featurizer.apply_existing(split=DEV_SPLIT)\n", " else:\n", " F_train = featurizer.load_matrix(sess,split=0)\n", " F_dev = featurizer.load_matrix(sess,split=DEV_SPLIT)\n", "\n", " else:\n", " F_train = X_train\n", " F_dev = X_dev\n", " # Run grid search to select best generative model\n", " with PrintTimer(\"[6.1] Searching over & training end disc. models\"):\n", " disc_model = train_model(\n", " config['disc-model-class'],\n", " F_train,\n", " Y_train=Y_train,\n", " X_dev=F_dev,\n", " Y_dev=Y_dev,\n", " cardinality=C.cardinality,\n", " search_size=args.disc_model_search_space,\n", " search_params=config['disc-params-range'],\n", " rand_seed=args.rand_seed,\n", " n_threads=args.n_threads,\n", " verbose=(args.verbose > 0),\n", " params_default=config['disc-params-default'],\n", " model_init_params=config['disc-init-params'],\n", " model_name=DISC_MODEL_NAME,\n", " save_dir=args.save_dir,\n", " eval_batch_size=config['disc-eval-batch-size']\n", " )\n", "\n", " if args.one_only:\n", " sys.exit(0)\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "### [7.0] Loading all data for final evaluation\n", "INFO:tensorflow:Restoring parameters from checkpoints/SparseLogisticRegression_fullchk_pr3__epoch_439/SparseLogisticRegression_fullchk_pr3__epoch_439-439\n", "[SparseLogisticRegression] Loaded model \n", "[GenerativeModel] Model loaded.\n", "Loaded X_test: 1232\n", "Loaded Y_test: (1232, 1)\n", "Loaded F_test: representation is False\n", "### Done in 13.6s.\n", "\n", "### [7.2] Evaluating heuristic baselines\n", "### Done in 0.0s.\n", "\n", "### [7.3] Evaluating generative model\n", "### Done in 0.3s.\n", "\n", "### [7.4] Evaluate full DP pipeline (disc. model)\n", "### Done in 0.1s.\n", "\n", " Coverage F1 Score Precision Recall\n", "CS 1.000000 0.079501 0.041396 1.000000\n", "MV 0.727273 0.338462 0.785714 0.215686\n", "Gen 0.731331 0.338462 0.785714 0.215686\n", "DP 1.000000 0.552632 0.840000 0.411765\n", "fda & 4.1 & 100.0 & 8.0 & 78.6 & 21.6 & 33.8 & 78.6 & 21.6 & 33.8 & - & - & - & 84.0 & 41.2 & 55.3 & - & - & -\\\\\n" ] } ], "source": [ " if args.start_at <= 7:\n", " with PrintTimer(\"[7.0] Loading all data for final evaluation\"):\n", " # TODO: Reload models if needed\n", " if disc_model is None:\n", " disc_model = config['disc-model-class'](\n", " cardinality=C.cardinality,\n", " **config['disc-init-params'])\n", " disc_model.load(model_name=DISC_MODEL_NAME,\n", " save_dir=args.save_dir)\n", "\n", " if gen_model is None:\n", " gen_model = GenerativeModel(**config['gen-init-params'])\n", " gen_model.load(model_name=GEN_MODEL_NAME, save_dir=args.save_dir)\n", " if not disc_model.representation and (featurizer is None):\n", " featurizer = config['featurizer-class'](**config.get('featurizer-init-params', dict()))\n", "\n", "\n", " # TODO: Handle logistic regression as well!\n", " X_test = sess.query(C).filter(C.split == TEST_SPLIT).order_by(C.id).all()\n", " if Y_test is None:\n", " Y_test = load_gold_labels(sess, annotator_name='gold', split=TEST_SPLIT)\n", " assert Y_test.nonzero()[0].shape[0] > 0\n", " if L_test is None:\n", " L_test = load_label_matrix(sess, split=TEST_SPLIT)\n", " assert L_test.nnz > 0\n", "\n", " if F_test is None:\n", " if not disc_model.representation:\n", " if args.recompute_feats:\n", " F_test = featurizer.apply_existing(split=TEST_SPLIT)\n", " F_test = featurizer.load_matrix(sess,split=TEST_SPLIT)\n", " else:\n", " F_test = X_test\n", "\n", " if args.verbose > 0:\n", " print(\"Loaded X_test: {0}\".format(len(X_test)))\n", " print(\"Loaded Y_test: {0}\".format(Y_test.shape))\n", " print(\"Loaded F_test: representation is {}\".format(disc_model.representation))\n", "\n", " if args.custom_error_analysis:\n", " with PrintTimer(\"[7.1] Custom error analysis exportation\"):\n", " if X_dev is None:\n", " X_dev = sess.query(C).filter(C.split == DEV_SPLIT).order_by(C.id).all()\n", " if L_dev is None:\n", " L_dev = load_label_matrix(sess, split=DEV_SPLIT)\n", " assert L_test.nnz > 0\n", " if Y_dev is None:\n", " Y_dev = load_gold_labels(sess, annotator_name='gold', split=DEV_SPLIT)\n", " if F_dev is None:\n", " if not disc_model.representation:\n", " F_dev = featurizer.load_matrix(sess,split=DEV_SPLIT)\n", " else:\n", " F_dev = X_dev\n", "\n", "\n", " custom_report_dir = os.path.join(args.reports_dir, strftime(\"%Y_%m_%d\"))\n", " custom_report_name = '{0}_{1}_custom.pkl'.format(args.exp, strftime(\"%H_%M_%S\"))\n", " if not os.path.exists(custom_report_dir):\n", " os.makedirs(custom_report_dir)\n", " print \"Generative Model\"\n", " custom_report = dict()\n", " tp, fp, tn, fn = gen_model.error_analysis(sess, L_dev, Y_dev)\n", " for key,cand_list in zip(['tp', 'fp', 'tn', 'fn'],[tp, fp, tn, fn]):\n", " custom_report[key] = [cand.id for cand in cand_list]\n", "\n", " print \"Discriminative Model\"\n", " tpd, fpd, tnd, fnd = disc_model.error_analysis(sess, F_dev, Y_dev)\n", " for key,cand_list in zip(['tpd', 'fpd', 'tnd', 'fnd'],[tpd, fpd, tnd, fnd]):\n", " custom_report[key] = [cand.id for cand in cand_list]\n", "\n", " with open(os.path.join(custom_report_dir, custom_report_name), 'wb') as f:\n", " pickle.dump(custom_report, f)\n", "\n", "\n", " if args.export_pred:\n", " if X_train is None:\n", " X_train = sess.query(C).filter(C.split == 0).order_by(C.id).all()\n", "\n", " if F_train is None:\n", " if not disc_model.representation:\n", " if args.recompute_feats:\n", " F_train = featurizer.apply_existing(split=0)\n", " F_train = featurizer.load_matrix(sess,split=0)\n", " else:\n", " F_train = X_train\n", "\n", " if X_dev is None:\n", " X_dev = sess.query(C).filter(C.split == DEV_SPLIT).order_by(C.id).all()\n", "\n", " if F_dev is None:\n", " if not disc_model.representation:\n", " if args.recompute_feats:\n", " F_dev = featurizer.apply_existing(split=DEV_SPLIT)\n", " F_dev = featurizer.load_matrix(sess, split=DEV_SPLIT)\n", " else:\n", " F_dev = X_dev\n", "\n", " pred_train = disc_model.predictions(F_train,batch_size = int(len(X_train)/400)+1)\n", " pred_dev = disc_model.predictions(F_dev, batch_size=int(len(X_dev) / 400) + 1)\n", " pred_test = disc_model.predictions(F_test)\n", "\n", " list_dic_res= list()\n", " for cand,pr in zip(X_train+X_dev+X_test,list(pred_train)+list(pred_dev)+list(pred_test)):\n", " dico_info = dict()\n", " for i,c in enumerate(cand.get_contexts()):\n", " dico_info[\"cand_\"+str(i)+\"_start\"] = c.char_start\n", " dico_info[\"cand_\" + str(i) + \"_end\"] = c.char_end\n", "\n", " dico_info[\"sentence_pos\"] = cand.get_parent().position\n", " dico_info[\"doc_id\"] = cand.get_parent().document.name\n", " dico_info[\"prediction\"] = pr\n", " list_dic_res.append(dico_info)\n", " custom_report_dir = os.path.join(args.reports_dir, strftime(\"%Y_%m_%d\"))\n", " custom_report_name = '{0}_{1}_preds.csv'.format(args.exp, strftime(\"%H_%M_%S\"))\n", " pd.DataFrame(list_dic_res).to_csv(os.path.join(custom_report_dir, custom_report_name),sep = \";\")\n", "\n", "\n", " scores = {}\n", " with PrintTimer(\"[7.2] Evaluating heuristic baselines\"):\n", " # Test candidate set score - applicable for binary case only\n", " if C.cardinality == 2:\n", " cs_test = np.ones(Y_test.shape[0])\n", " scores['CS'] = score_marginals(cs_test, Y_test)\n", " else:\n", " if args.verbose > 0:\n", " print(\"Candidate-set not applicable for categorical tasks.\")\n", "\n", " # Test majority vote of LFs on test set\n", " mv_test = majority_vote_marginals(L_test, cardinality=C.cardinality)\n", " scores['MV'] = score_marginals(mv_test, Y_test)\n", "\n", " with PrintTimer(\"[7.3] Evaluating generative model\"):\n", " # Score generative model on test set\n", " # TODO: Make sure this is the same as scuba.utils score function!!!\n", " np.random.seed(args.rand_seed)\n", " scores['Gen'] = score_marginals(gen_model.marginals(L_test), Y_test)\n", "\n", " with PrintTimer(\"[7.4] Evaluate full DP pipeline (disc. model)\"):\n", " # Score discriminative model trained on generative model predictions\n", " # TODO: Make sure this is the same as scuba.utils score function!!!\n", " np.random.seed(args.rand_seed)\n", " scores['DP'] = score_marginals(disc_model.marginals(F_test,\n", " batch_size=config['disc-eval-batch-size']), Y_test)\n", "\n", " if args.ds_tests:\n", " with PrintTimer(\"[7.4] Evaluating distant supervision baseline\"):\n", " # Score discriminative model trained on LF majority vote (hard)\n", " # Load data\n", " if L_train is None:\n", " # Optionally subsample the training set here\n", " if args.training_docs > 0:\n", " cids_query = get_training_cids_query(\n", " sess,\n", " preprocess.CONTEXT_HIERARCHY,\n", " C,\n", " preprocess.CANDIDATE_CONTEXT,\n", " args.training_docs,\n", " training_docs_shuffle=args.training_docs_shuffle,\n", " verbose=args.verbose,\n", " training_splits=args.training_splits\n", " )\n", " L_train = load_label_matrix(sess, cids_query=cids_query)\n", " else:\n", " L_train = load_label_matrix(sess, split=0)\n", " assert L_train.nnz > 0\n", " if X_train is None:\n", " # Optionally subsample the training set here\n", " if args.training_docs > 0:\n", " cids_query = get_training_cids_query(\n", " sess,\n", " preprocess.CONTEXT_HIERARCHY,\n", " C,\n", " preprocess.CANDIDATE_CONTEXT,\n", " args.training_docs,\n", " training_docs_shuffle=args.training_docs_shuffle,\n", " verbose=args.verbose,\n", " training_splits=args.training_splits\n", " )\n", " X_train = sess.query(C)\\\n", " .filter(C.id == cids_query.subquery().c.id)\\\n", " .order_by(C.id)\\\n", " .all()\n", " else:\n", " X_train = sess.query(C).filter(C.split == 0)\\\n", " .order_by(C.id).all()\n", " assert len(X_train) > 0\n", " if X_dev is None:\n", " X_dev = sess.query(C).filter(C.split == DEV_SPLIT)\\\n", " .order_by(C.id).all()\n", " assert len(X_dev) > 0\n", " if Y_dev is None:\n", " Y_dev = load_gold_labels(sess,annotator_name='gold',split=DEV_SPLIT)\n", " assert Y_dev.nonzero()[0].shape[0] > 0\n", "\n", " # Compute soft ([0,1]) majority vote training marginals\n", " Y_train_mv = majority_vote_marginals(L_train,\n", " cardinality=C.cardinality)\n", "\n", " # Train discriminative model with MV training labels\n", " disc_model = train_model(\n", " config['disc-model-class'],\n", " X_train,\n", " Y_train=Y_train_mv,\n", " X_dev=X_dev,\n", " Y_dev=Y_dev,\n", " cardinality=C.cardinality,\n", " search_size=args.disc_model_search_space,\n", " search_params=config['disc-params-range'],\n", " rand_seed=args.rand_seed,\n", " n_threads=args.n_threads,\n", " verbose=(args.verbose > 0),\n", " params_default=config['disc-params-default'],\n", " model_init_params=config['disc-init-params'],\n", " model_name=DISC_MODEL_NAME + \"_ds\",\n", " save_dir=args.save_dir,\n", " eval_batch_size=config['disc-eval-batch-size']\n", " )\n", " np.random.seed(args.rand_seed)\n", " scores['DS-MV'] = score_marginals(disc_model.marginals(X_test,\n", " batch_size=config['disc-eval-batch-size']), Y_test)\n", "\n", " if args.supervised_tests:\n", " with PrintTimer(\"[7.5] Evaluating fully-supervised baseline\"):\n", " # Load data\n", " if X_train is None:\n", " # Optionally subsample the training set here\n", " if args.training_docs > 0:\n", " cids_query = get_training_cids_query(\n", " sess,\n", " preprocess.CONTEXT_HIERARCHY,\n", " C,\n", " preprocess.CANDIDATE_CONTEXT,\n", " args.training_docs,\n", " training_docs_shuffle=args.training_docs_shuffle,\n", " verbose=args.verbose,\n", " training_splits=args.training_splits\n", " )\n", " X_train = sess.query(C)\\\n", " .filter(C.id == cids_query.subquery().c.id)\\\n", " .order_by(C.id)\\\n", " .all()\n", " else:\n", " X_train = sess.query(C).filter(C.split == 0)\\\n", " .order_by(C.id).all()\n", " assert len(X_train) > 0\n", " if X_dev is None:\n", " X_dev = sess.query(C).filter(C.split == DEV_SPLIT)\\\n", " .order_by(C.id).all()\n", " assert len(X_dev) > 0\n", " if Y_dev is None:\n", " Y_dev = load_gold_labels(sess,annotator_name='gold',split=DEV_SPLIT)\n", " assert Y_dev.nonzero()[0].shape[0] > 0\n", "\n", " # Load ground-truth training set labels\n", " # Note we load in {0,1} not {-1,1} format for binary\n", " # Optionally subsample the training set here\n", " if args.training_docs > 0:\n", " cids_query = get_training_cids_query(\n", " sess,\n", " preprocess.CONTEXT_HIERARCHY,\n", " C,\n", " preprocess.CANDIDATE_CONTEXT,\n", " args.training_docs,\n", " training_docs_shuffle=args.training_docs_shuffle,\n", " verbose=args.verbose,\n", " training_splits=args.training_splits\n", " )\n", " Y_train_gt = load_gold_labels(sess, cids_query=cids_query,\n", " zero_one=(C.cardinality == 2), load_as_array=True,\n", " annotator_name='gold')\n", " else:\n", " Y_train_gt = load_gold_labels(sess, split=0,\n", " zero_one=(C.cardinality == 2), load_as_array=True,\n", " annotator_name='gold')\n", " assert Y_train_gt.nonzero()[0].shape[0] > 0\n", "\n", " # If categorical, convert to one-hot\n", " if C.cardinality > 2:\n", " Y_train_gt = labels_to_one_hot(Y_train_gt, C.cardinality)\n", "\n", " # Train discriminative model and score\n", " disc_model = train_model(\n", " config['disc-model-class'],\n", " X_train,\n", " Y_train=Y_train_gt,\n", " X_dev=X_dev,\n", " Y_dev=Y_dev,\n", " cardinality=C.cardinality,\n", " search_size=args.disc_model_search_space,\n", " search_params=config['disc-params-range'],\n", " rand_seed=args.rand_seed,\n", " n_threads=args.n_threads,\n", " verbose=(args.verbose > 0),\n", " params_default=config['disc-params-default'],\n", " model_init_params=config['disc-init-params'],\n", " model_name=DISC_MODEL_NAME + \"_supervised\",\n", " save_dir=args.save_dir,\n", " eval_batch_size=config['disc-eval-batch-size']\n", " )\n", " np.random.seed(args.rand_seed)\n", " scores['Sup'] = score_marginals(disc_model.marginals(X_test,\n", " batch_size=config['disc-eval-batch-size']), Y_test)\n", "\n", " # Print and save final score report\n", " ks = list(scores.keys())\n", " if C.cardinality > 2:\n", " cols = ['Accuracy', 'Coverage']\n", " d = {\n", " 'Accuracy': Series(data=[scores[k][0] for k in ks], index=ks),\n", " 'Coverage': Series(data=[scores[k][1] for k in ks], index=ks),\n", " }\n", " else:\n", " cols = ['Precision', 'Recall', 'F1 Score']\n", " d = {\n", " 'Precision' : Series(data=[scores[k][0] for k in ks], index=ks),\n", " 'Recall' : Series(data=[scores[k][1] for k in ks], index=ks),\n", " 'F1 Score' : Series(data=[scores[k][2] for k in ks], index=ks),\n", " 'Coverage' : Series(data=[scores[k][3] for k in ks], index=ks)\n", " }\n", " df = DataFrame(data=d, index=ks)\n", " print(df)\n", "\n", " # Assemble the report, to be saved as a json file\n", " df_scores = df.to_dict()\n", " row_str = print_latex_table_row(args.exp, scores,\n", " cardinality=C.cardinality)\n", " if args.verbose > 0:\n", " print(row_str)\n", " report = {\n", " 'scores': df_scores,\n", " 'row-string': row_str,\n", " 'scuba-commit': git_commit_hash(),\n", " 'snorkel-commit': git_commit_hash(path=os.environ['SNORKELHOME']),\n", " 'args': vars(args)\n", " }\n", "\n", " # Save to file\n", " report_dir = os.path.join(args.reports_dir, strftime(\"%Y_%m_%d\"))\n", " report_name = '{0}_{1}.json'.format(args.exp, strftime(\"%H_%M_%S\"))\n", " if not os.path.exists(report_dir):\n", " os.makedirs(report_dir)\n", " with open(os.path.join(report_dir, report_name), 'wb') as f:\n", " json.dump(report, f, indent=2)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.14" } }, "nbformat": 4, "nbformat_minor": 2 }