db.py 17.3 KB
Newer Older
aknecht2's avatar
aknecht2 committed
1
from pymongo import MongoClient
2
import pymongo.errors
3
4
5
import gridfs
import sys
import traceback
6
import os
7
import itertools
8
import time
9
import chipathlon.conf
10
from pprint import pprint
11
import hashlib
aknecht2's avatar
aknecht2 committed
12

13

14
15
class MongoDB(object):

16
    def __init__(self, host, username, password, debug=False):
17
18
19
20
21
22
23
24
25
26
        """
        :param host: The host address of the MongoDB database.
        :type host: str
        :param username: The username of the account for the MongoDB database.
        :type username: str
        :param password: The password for the user.
        :type password: str
        :param debug: If true print out debug messages
        :type debug: bool
        """
27
        self.debug = debug
28
29
30
        self.host = host
        self.username = username
        self.password = password
31
32
        self.client = MongoClient(host)
        self.db = self.client.chipseq
33
        self.cache = {}
34
        try:
35
            self.db.authenticate(username, password, mechanism="SCRAM-SHA-1")
36
37
38
39
40
41
42
        except:
            print("Could not authenticate to db %s!" % (host,))
            print traceback.format_exc()
            sys.exit(1)
        self.gfs = gridfs.GridFS(self.db)
        return

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    def add_cache(self, function, key, data):
        """
        :param function: The function name
        :type function: str
        :param key: The key to add the cache entry under.
        :type key: Any hashable
        :param data: The data to add to the cache.
        :type data: Object
        """
        if function not in self.cache:
            self.cache[function] = {}
        self.cache[function][key] = data
        return

    def get_cache(self, function, key):
        """
        :param function: The function name
        :type function: str
        :param key: The key to get from the cache.
        :type key: Any hashable
        """
        if function in self.cache:
            if key in self.cache[function]:
                return self.cache[function][key]
        return None

69
    def delete_result(self, result, genome):
70
        """
71
72
73
74
        :param result: The result to delete
        :type result: :py:class:~chipathlon.result.Result
        :param genome: The genome to find information from.
        :type genome: :py:meth:~chipathlon.genome.Genome
75
76
77

        Deletes a result and it's corresponding gridfs entry.
        """
78
        result_id = self.get_reuslt_id(result, genome)
79
80
81
82
83
84
85
86
87
        cursor = self.db.results.find({
            "_id": result_id
        })
        if cursor.count() == 1:
            result = cursor.next()
            self.gfs.delete(result["gridfs_id"])
            self.db.results.delete_one({"_id": result["_id"]})
        else:
            print "result_id %s doesn't exist." % (result_id,)
88
89
        return

90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    def _get_result_query(self, result, genome):
        query = {
            "result_type": result.file_type,
            "assembly": genome.assembly,
            "timestamp": {"$exists": True},
            "file_name": result.full_name
        }
        # In the case that there are 0 samples we just want to check for existence.
        control_sample_accessions = result.get_accessions("control")
        signal_sample_accessions = result.get_accessions("signal")
        query["control_sample_accessions"] = {"$all": control_sample_accessions} if (len(control_sample_accessions) > 0) else {"$exists": True}
        query["signal_sample_accessions"] = {"$all": signal_sample_accessions} if (len(signal_sample_accessions) > 0) else {"$exists": True}
        for job in result.all_jobs:
            job_args = job.get_db_arguments()
            arg_keys = job_args.keys()
            if len(arg_keys) == 0:
                query[job.job_name] = {"$exists": True}
            else:
                for arg_name in arg_keys:
                    query[job.job_name + "." + arg_name] = job_args[arg_name]
110
111
        if self.debug:
            print "Result query: %s" % (query,)
112
113
114
        return query

    def result_exists(self, result, genome):
115
116
117
118
119
120
121
122
        """
        :param result: The result to check.
        :type result: :py:meth:~chipathlon.result.Result
        :param genome: The genome to find information from.
        :type genome: :py:meth:~chipathlon.genome.Genome

        Check if a result exists.
        """
123
124
125
126
127
128
129
        try:
            cursor = self.db.results.find(self._get_result_query(result, genome))
            return cursor.count() > 0
        except pymongo.errors.OperationFailure as e:
            print "Error checking result [%s]: %s" % (file_name, e)
        return False

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    def get_result_id(self, result, genome):
        """
        :param result: The result to check.
        :type result: :py:meth:~chipathlon.result.Result
        :param genome: The genome to find information from.
        :type genome: :py:meth:~chipathlon.genome.Genome

        Get the id of a result.
        """
        try:
            cursor = self.db.results.find(self._get_result_query(result, genome))
            if cursor.count() == 1:
                return cursor._id
        except pymongo.errors.OperationFailure as e:
            print "Error getting result id [%s]: %s" % (file_name, e)
        return None
146
147

    def get_result(self, result, genome):
148
149
150
151
152
153
154
155
156
        """
        :param result: The result to check.
        :type result: :py:meth:~chipathlon.result.Result
        :param genome: The genome to find information from.
        :type genome: :py:meth:~chipathlon.genome.Genome

        Get the metadata for the result from the database.  If multiple results
        exist, the most recently saved result is returned.
        """
157
        try:
158
            cursor = self.db.results.find(self._get_result_query(result, genome))
159
160
            if cursor.count() > 0:
                return cursor.sort("timestamp", pymongo.DESCENDING).next()
161
        except pymongo.errors.OperationFailure as e:
162
            print "Error checking result [%s]: %s" % (file_name, e)
163
        return None
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    def save_result(self, output_file, control_sample_accessions, signal_sample_accessions, result_type, additional_data = {}, gfs_attributes = {}):
        """
        :param output_file: The path to the result to save.
        :type output_file: str
        :param control_sample_accessions: A list of control accessions.
        :type control_sample_accessions: list
        :param signal_sample_accessions: A list of signal accessions.
        :type signal_sample_accessions: list
        :param result_type: Useful file type info
        :type result_type: str
        :param additional_data: Additional metadata to store in mongo.
        :type additional_data: dict
        :param gfs_attributes: Additional metadata to store in gridfs.
        :type gfs_attributes: dict

        Saves a result file into mongodb and also creates the corresponding
        gridfs file.
        """
183
184
        # Make sure output_file exists
        if os.path.isfile(output_file):
185
            # Make sure that all control_sample_accessions & signal_sample_accessions are valid
186
            # REMEMBER, these are ids for control & experiment SAMPLES
187
188
            valid_controls = [self.is_valid_sample(cid) for cid in control_sample_accessions]
            valid_experiments = [self.is_valid_sample(eid) for eid in signal_sample_accessions]
189
            if all(valid_controls) and all(valid_experiments):
190
                gfs_attributes["file_type"] = result_type
191
192
193
194
195
196
197
                # First, we load the output file into gfs
                with open(output_file, "r") as rh:
                    # Calling put returns the gfs id
                    gridfs_id = self.gfs.put(rh, filename=os.path.basename(output_file), **gfs_attributes)
                # Now, we create the actual result entry by combining all necessary info
                result_entry = {
                    "gridfs_id": gridfs_id,
198
199
                    "control_sample_accessions": control_sample_accessions,
                    "signal_sample_accessions": signal_sample_accessions,
200
201
202
                    "result_type": result_type,
                    "file_name": output_file,
                    "timestamp": time.time()
203
204
205
206
207
208
209
                }
                # Add additional attributes into the result_entry
                result_entry.update(additional_data)
                # Insert the entry into the database, and return the id
                result = self.db.results.insert_one(result_entry)
                return (True, "Result created successfully.", result.inserted_id)
            else:
210
                msg = "Not all input ids are valid.  The following are invalid: "
211
                for id_list, valid_list in zip([control_sample_accessions, signal_sample_accessions], [valid_controls, valid_experiments]):
212
                    msg += ", ".join([id_list[i] for i, valid in enumerate(valid_list) if not valid])
213
214
215
216
        else:
            msg = "Specified output_file %s does not exist." % (output_file,)
        return (False, msg, None)

217
    def is_valid_sample(self, sample_accession):
218
219
220
221
222
223
        """
        :param sample_accession: The accession number to check.
        :type sample_accession: str

        Ensures that a sample with the accession specified actually exists.
        """
224
225
226
227
228
229
230
231
232
233
        try:
            cursor = self.db.samples.find({
                "accession": sample_accession
            })
            if cursor.count() == 1:
                return True
        except pymongo.errors.OperationFailure as e:
            print "Error with sample_accession %s: %s" % (sample_accession, e)
        return False

234
235
236
237
238
239
240
    def is_valid_experiment(self, experiment_accession):
        """
        :param experiment_accession: The accession number to check.
        :type experiment_accession: str

        Ensures that an experiment with the accession specified actually exists.
        """
241
242
243
        try:
            cursor = self.db.experiments.find({
                "target": {"$exists": True},
244
                "@id": "/experiments/%s/" % (experiment_accession,)
245
246
247
248
            })
            if cursor.count() == 1:
                return True
        except pymongo.errors.OperationFailure as e:
249
            print "Error with experiment_accession %s: %s" % (experiment_accession, e)
250
251
        return False

Adam Caprez's avatar
Adam Caprez committed
252
    def fetch_from_gridfs(self, gridfs_id, filename, checkmd5=True):
253
254
255
256
257
        """
        :param gridfs_id: GridFS _id of file to get.
        :type gridfs_id: bson.objectid.ObjectId
        :param filename: Filename to save file to.
        :type filename: str
258
259
        :param checkmd5: Whether or not to validate the md5 of the result
        :type checkmd5: bool
260

261
262
263
        Fetch the file with the corresponding id and save it under the
        specified 'filename'.  If checkmd5 is specified, validate that the saved
        file has a correct md5 value.
264
        """
265
266
267
268
269
270
271
        try:
            gridfs_file = self.gfs.get(gridfs_id)
            gridfs_md5 = gridfs_file.md5
        except gridfs.errors.NoFile as e:
            print "Error fetching file from GridFS!\nNo file with ID '%s'" % (gridfs_id)
            print e
            sys.exit(1)
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

        try:
            output_fh = open(filename,'wb')
        except IOError as e:
            print "Error creating GridFS output file '%s':" % (filename)
            print (e.errno,e.strerror)
            sys.exit(1)

        hash_md5 = hashlib.md5()
        for chunk in gridfs_file:
            output_fh.write(chunk)
            hash_md5.update(chunk)

        output_fh.close()
        gridfs_file.close()

Adam Caprez's avatar
Adam Caprez committed
288
289
290
291
292
293
        if checkmd5:
            if gridfs_md5 == hash_md5.hexdigest():
                return True
            else:
                print "MD5 mismatch saving file from GridFS to '%s'" % (filename)
                return False
294
        else:
Adam Caprez's avatar
Adam Caprez committed
295
            return True
296

297
298
299
300
301
302
303
304
305
306
307
308
    def get_sample(self, accession, file_type):
        """
        :param accession: The accession number of the target sample
        :type accession: string
        :param file_type: The file type of the target sample should be [fastq|bam]
        :type file_type: string

        Gets the associated sample based on accession number and file_type
        """
        valid = True
        msg = ""
        data = {}
309
310
311
312
        check_cache = self.get_cache("get_sample", accession)
        if check_cache is not None:
            msg = "Retrieved data from cache."
            data = check_cache
313
        else:
314
315
316
317
318
319
320
321
322
323
324
325
326
327
            cursor = self.db.samples.find({
                "accession": accession,
                "file_type": file_type
            })
            if cursor.count() == 1:
                data = cursor.next()
                self.add_cache("get_sample", accession, data)
            else:
                valid = False
                msg = "Found %s files with accession: %s, file_type: %s. Should only be 1." % (
                    cursor.count(),
                    accession,
                    file_type
                )
328
329
        return (valid, msg, data)

330
331
332
333
334
335
336
337
338
    def get_samples(self, experiment_accession, file_type):
        """
        :param experiment_accession: Accession number of the experiment to grab samples from.
        :type experiment_accession: str
        :param file_type: File type of samples to grab usually fastq or bam
        :type file_type: str

        Validates and gets samples for the given experiment.  Experiments must
        have control and signal samples of the provided file_type to be
aknecht2's avatar
aknecht2 committed
339
340
341
342
        considered valid.  Returns a tuple with three values (valid, msg, data)
        valid -- Whether or not the accession / file_type combo is a valid exp
        msg -- Why it is or is not valid
        data -- A dictionary containing a list of all control / sample documents.
343

aknecht2's avatar
aknecht2 committed
344
345
346
        The data dictionary has two keys, "control" and "signal", each one containing
        a list of all metadata related to the experiment samples.  The sample metadata
        is taken directly from Mongo.
347
        """
348
349
350
        valid = True
        msg = ""
        data = {}
351
        # First, check to make sure the target experiment is valid
352
        if self.is_valid_experiment(experiment_accession):
353
354
            # Next, we check that there is a least 1 possible control
            check3 = self.db.experiments.find({
355
                "target": {"$exists": True},
356
                "possible_controls.0": {"$exists": True},
357
                "@id": "/experiments/%s/" % (experiment_accession,)
358
            })
359
360
361
362
363
364
365
366
367
368
369
370
371
372
            if check3.count() == 1:
                # Complicated aggregtaion pipeline does the following steps:
                # 1. Find the experiment that matches the given id
                # 2. Join samples into the collection by exp_id
                # 3. Iterate through possible_controls
                # 4. Join possible_control data into control_exps
                # 5. Iterate through control_exps
                # 6. Join samples into the control_exps by exp_id
                # 7. Re-aggregate all data into arrays
                pipeline = [
                    {
                        "$match": {
                            "target": {"$exists": True},
                            "possible_controls.0": {"$exists": True},
373
                            "@id": "/experiments/%s/" % (experiment_accession,)
374
                        }
375
376
377
378
379
                    },
                    {
                        "$lookup": {
                            "from": "samples",
                            "localField": "uuid",
380
                            "foreignField": "experiment_id",
381
                            "as": "samples"
382
                        }
383
384
385
386
387
388
389
390
                    },
                    {
                        "$unwind": "$possible_controls"
                    },
                    {
                        "$lookup": {
                            "from": "samples",
                            "localField": "possible_controls.uuid",
391
                            "foreignField": "experiment_id",
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                            "as": "possible_controls.samples"
                        }
                    },
                    {
                        "$group": {
                            "_id": "$_id",
                            "possible_controls": {"$push": "$possible_controls"},
                            "samples": {"$push": "$samples"}
                        }
                    }
                ]
                cursor = self.db.experiments.aggregate(pipeline)
                # We should have only 1 document
                document = cursor.next()
406
                control_inputs = [sample for control in document["possible_controls"] for sample in control["samples"] if ("file_type" in sample and sample["file_type"] == file_type)]
407
408
                signal_inputs = [sample for sample in document["samples"][0] if ("file_type" in sample and sample["file_type"] == file_type)]
                if (len(control_inputs) > 0 and len(signal_inputs) > 0):
409
                    msg = "Succesfully retrieved input files for experiment with id '%s'.\n" % (experiment_accession,)
410
411
                    data = {
                        "control": control_inputs,
412
                        "signal": signal_inputs
413
                    }
414
415
                else:
                    valid = False
416
                    msg = "Experiment with id '%s' has %s possible control inputs, and %s possible signal inputs.\n" % (experiment_accession, len(control_inputs), len(signal_inputs))
417
418
            else:
                valid = False
419
                msg = "Experiment with id '%s' does not have possible_controls.\n" % (experiment_accession,)
420
421
        else:
            valid = False
422
            msg = "Experiment with id '%s' is not valid!  It may not exist, or it may be missing required metadata.\n" % (experiment_accession,)
423
        return (valid, msg, data)