collections.defaultdict

Here are the examples of the python api collections.defaultdict taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: AGOUTI
Source File: agouti_denoise.py
View license
def denoise_joining_pairs(dContigPairs, dGFFs, vertex2Name,
						  outDir, prefix, minSupport,
						  debug=0):

	moduleName = os.path.basename(__file__).split('.')[0].upper()
	moduleOutDir = os.path.join(outDir, "agouti_denoise")
	if not os.path.exists(moduleOutDir):
		os.makedirs(moduleOutDir)

	progressLogFile = os.path.join(moduleOutDir, "%s.agouti_denoise.progressMeter" %(prefix))
	agDENOISEProgress = agLOG.PROGRESS_METER(moduleName)
	agDENOISEProgress.add_file_handler(progressLogFile)

	debugLogFile = ""
	if debug:
		debugLogFile = os.path.join(moduleOutDir, "%s.agouti_denoise.debug" %(prefix))
		global agDENOISEDebug
		agDENOISEDebug = agLOG.DEBUG(moduleName, debugLogFile)

	agDENOISEProgress.logger.info("[BEGIN] Denoising joining pairs")
	startTime = time.clock()
	dCtgPair2GenePair = collections.defaultdict()
	dCtgPairDenoise = collections.defaultdict()
	dMappedPos = collections.defaultdict()
	daddedModels = collections.defaultdict(list)
	nFail4Combination = 0
	nFailGeneModel = 0
	nFailK = 0
	outDenoiseJPFile = os.path.join(moduleOutDir, "%s.agouti.join_pairs.noise_free.txt" %(prefix))
	fOUT = open(outDenoiseJPFile, 'w')
	for ctgPair, pairInfo in dContigPairs.items():
		if len(pairInfo) < minSupport:
			nFailK += 1
			del dContigPairs[ctgPair]
			continue
		ctgA = ctgPair[0]
		ctgB = ctgPair[1]
		if debug:
			agDENOISEDebug.debugger.debug("DENOISE_MAIN\t>contigA - %s - contigB - %s"
										  %(ctgA, ctgB))
		pairToRemove = []
		mapIntervalsA = []
		mapIntervalsB = []
		pairs = []
		senses = []
		keep = 0
		for i in xrange(len(pairInfo)):
			startA, startB, stopA, stopB, senseA, senseB, readID = pairInfo[i]
			mapIntervalsA += [(startA, stopA)]
			mapIntervalsB += [(startB, stopB)]
			pairs += [(startA, stopA, startB, stopB)]
			senses += [(senseA, senseB)]
		genePair = get_genePair_for_contigPair(dGFFs, ctgA, ctgB,
											   mapIntervalsA,
											   mapIntervalsB, senses,
											   debug)
		geneModelsA = dGFFs[ctgA]
		geneModelsB = dGFFs[ctgB]
		if genePair is None:
			nFailGeneModel += 1
			if debug:
				agDENOISEDebug.debugger.debug("DENOISE_MAIN\tFail to find a pair of gene models")
				agDENOISEDebug.debugger.debug("DENOISE_MAIN\t----------------------------------")
		else:
			geneIndexA, geneIndexB, endA, endB, intervalsA, intervalsB, senses = genePair
			sensesCounter = collections.Counter(senses)
			if debug:
				agDENOISEDebug.debugger.debug("DENOISE_MAIN\tsensesCounter: %s" %(str(sensesCounter)))
			if geneIndexB != 0:
				# create gene model according to endB using intervalsB
				if geneIndexB == -1 and (endB == 5 or endB == 0):
					dGFFs[ctgB] = create_fake_genes(geneModelsB, 0, ctgB, intervalsB, debug)
					geneIndexB = 0
					endB = 5
				elif geneIndexB == 1 and (endB == 3 or endB == 0):
					dGFFs[ctgB] = create_fake_genes(geneModelsB, len(geneModelsB), ctgB, intervalsB, debug)
					geneIndexB = len(dGFFs[ctgB]) - 1
					endB = 3
			else:
				if endB == 0:
					endB = 5
				elif endB == 3:
					geneIndexB = len(dGFFs[ctgB])-1
			if geneIndexA != 0:
				# create gene model according to endA using intervalsA
				if geneIndexA == -1 and (endA == 5 or endA == 0):
					dGFFs[ctgA] = create_fake_genes(geneModelsA, 0, ctgA, intervalsA, debug)
					geneIndexA = 0
					endA = 5
				elif geneIndexA == 1 and (endA == 3 or endA == 0):
					dGFFs[ctgA] = create_fake_genes(geneModelsA, len(geneModelsA), ctgA, intervalsA, debug)
					geneIndexA = len(dGFFs[ctgA]) - 1
					endA = 3
			else:
				if endA == 0:
					endA = 3
				elif endA == 3:
					geneIndexA = len(dGFFs[ctgA])-1
			if debug:
				agDENOISEDebug.debugger.debug("DENOISE_MAIN\tgenePair: %s" %(str(genePair)))
				agDENOISEDebug.debugger.debug("DENOISE_MAIN\t# models on ctgA - %d - # models on ctgB - %d"
											  %(len(dGFFs[ctgA]), len(dGFFs[ctgB])))
				agDENOISEDebug.debugger.debug("DENOISE_MAIN\tgeneIndexA - %d - endA - %d - geneIndexB - %d - endB - %d"
											  %(geneIndexA, endA, geneIndexB, endB))
			sense = sorted(sensesCounter.items(), key=operator.itemgetter(1), reverse=True)[0][0]
			if debug:
				agDENOISEDebug.debugger.debug("DENOISE_MAIN\tsensePair - %s" %(str(sense)))
			if (geneIndexA == len(dGFFs[ctgA])-1 and endA == 3) and \
			   (geneIndexB == 0 and endB == 5) and sense == ('+', '-'):
					# FR + 3'-5'
					keep = 1
			elif (geneIndexA == 0 and endA == 5) and \
				 (geneIndexB == 0 and endB == 5) and sense == ('-', '-'):
					# RR + 5'-5'
					keep = 1
			elif (geneIndexA == len(dGFFs[ctgA])-1 and endA == 3) and \
				 (geneIndexB == len(dGFFs[ctgB])-1 and endB == 3) and \
				 sense == ('+', '+'):
					# FF + 3'-3'
					keep = 1
			elif (geneIndexA == 0 and endA == 5) and \
				 (geneIndexB == len(dGFFs[ctgB])-1 and endB == 3) and \
				 sense == ('-', '+'):
					# RF + 5'-3'
					keep = 1
			elif (geneIndexA == 0 and (endA == 0 or endA == 3)) and \
				 (geneIndexB == 0 and (endB == 0 or endB == 5)) and \
				 sense == ('+', '-'):
					# only one gene on the contig
					# it doesn't matter which end
					keep = 1
			if keep:
				geneA = dGFFs[ctgA][geneIndexA]
				geneB = dGFFs[ctgB][geneIndexB]
				dCtgPair2GenePair[vertex2Name.index(ctgA), vertex2Name.index(ctgB)] = [geneA, geneB]
				if debug:
					agDENOISEDebug.debugger.debug("DENOISE_MAIN\tNOISE-FREE")
					agDENOISEDebug.debugger.debug("DENOISE_MAIN\tgeneA ID - %s - startA - %d - stopA = %d"
												  %(geneA.geneID, geneA.geneStart, geneA.geneStop))
					agDENOISEDebug.debugger.debug("DENOISE_MAIN\tgeneB ID - %s - startB - %d - stopB = %d"
												  %(geneB.geneID, geneB.geneStart, geneB.geneStop))
					agDENOISEDebug.debugger.debug("DENOISE_MAIN\t----------------------------------")
				senseA = sense[0]
				senseB = sense[1]
				weight = 0
				for i in xrange(len(pairInfo)):
					startA, startB, stopA, stopB, _, _, readID = pairInfo[i]
					intervalA = (startA, stopA)
					intervalB = (startB, stopB)
					#print "intervalA", intervalA, "intervalB", intervalB
					if len(intervalsA) == 0:
						if len(intervalsB) == 0:
							#print "use all"
							fOUT.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\n" %(readID, ctgA, startA, senseA, ctgB, startB, senseB))
							weight += 1
						else:
							#print "use all A, not all B"
							overlap = find_overlap(intervalB, (geneB.geneStart, geneB.geneStop))
							if overlap == 0:
								fOUT.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\n" %(readID, ctgA, startA, senseA, ctgB, startB, senseB))
								weight += 1
					else:
						if len(intervalsB) == 0:
							#print "use all B, not all A"
							overlap = find_overlap(intervalA, (geneA.geneStart, geneA.geneStop))
							if overlap == 0:
								fOUT.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\n" %(readID, ctgA, startA, senseA, ctgB, startB, senseB))
								weight += 1
						else:
							#print "not all Both"
							overlapA = find_overlap(intervalA, (geneA.geneStart, geneA.geneStop))
							overlapB = find_overlap(intervalB, (geneB.geneStart, geneB.geneStop))
							if overlapA == 0 and overlapB == 0:
								fOUT.write("%s\t%s\t%s\t%s\t%s\t%s\t%s\n" %(readID, ctgA, startA, senseA, ctgB, startB, senseB))
								weight += 1
				dCtgPairDenoise[vertex2Name.index(ctgA), vertex2Name.index(ctgB)] = [weight, (senseA, senseB)]
			else:
				nFail4Combination += 1
#			if len(sensesCounter) == 1:
#				sense = sensesCounter.keys()[0]
#			else:
#				print "multiple sense pairs"
#				senses = sorted(sensesCounter.items(), key=operator.itemgetter(1), reverse=True)[0:2]
#				print "senses", senses
#				ratio = float(senses[0][1])/(senses[0][1]+senses[1][1])
#				print "ratio", ratio
	fOUT.close()
	agDENOISEProgress.logger.info("Succeeded")
	agDENOISEProgress.logger.info("Denoise took in %.2f min CPU time" %((time.clock()-startTime)/60))
	agDENOISEProgress.logger.info("%d contig pairs filtered for spanning across >1 gene models"
								  %(nFailGeneModel))
	agDENOISEProgress.logger.info("%d contig pairs filtered for not being one of the four combinations"
								  %(nFail4Combination))
	agDENOISEProgress.logger.info("%d contig pairs filtered for less support"
								  %(nFailK))
	agDENOISEProgress.logger.info("%d contig pairs for scaffolding"
								  %(len(dCtgPairDenoise)))
	return dCtgPair2GenePair, dCtgPairDenoise

Example 2

Project: AGOUTI
Source File: agouti_update.py
View license
def agouti_update(agoutiPaths, dSeqs, seqNames,
				  dSenses, dGFFs,
				  dCtgPair2GenePair, outDir, prefix,
				  nFills=1000, debug=0, no_update_gff=0):

	moduleName = os.path.basename(__file__).split('.')[0].upper()
	moduleOutDir = os.path.join(outDir, "agouti_update")
	if not os.path.exists(moduleOutDir):
		os.makedirs(moduleOutDir)

	progressLogFile = os.path.join(moduleOutDir, "%s.agouti_update.progressMeter" %(prefix))
	global agUPDATEProgress
	agUPDATEProgress = agLOG.PROGRESS_METER(moduleName)
	agUPDATEProgress.add_file_handler(progressLogFile)
	if debug:
		debugLogFile = os.path.join(moduleOutDir, "%s.agouti_update.debug" %(prefix))
		global agUPDATEDebug
		agUPDATEDebug = agLOG.DEBUG(moduleName, debugLogFile)

	if not no_update_gff:
		agUPDATEProgress.logger.info("[BEGIN] Updating gene models")

	outFasta = os.path.join(outDir, "%s.agouti.fasta" %(prefix))
	fFASTA = open(outFasta, 'w')
	dUpdateGFFs = collections.defaultdict(list)
	dMergedGene2Ctgs = collections.defaultdict(list)
	dMergedGene2Genes = collections.defaultdict(list)
	numMergedGene = 0
	nCtgScaffolded = 0
	scaffoldedCtgs = {}
	seqLens = []
	dScafGaps = {}
	dScafStats = {}
	scafID = 0
	mergedGenes = []
	for i in range(len(agoutiPaths)):
		path = agoutiPaths[i]
		scafID += 1
		scafName = prefix + "_scaf_%d" %(scafID)
		dScafStats[scafName] = 0
		dScafGaps[scafName] = []

		curVertex = path[0]
		sequence = dSeqs[curVertex]
		curSense = "+"
		curCtg = seqNames[curVertex]
		preCtg = ""
		scafPath = []
		preGeneID, curGeneID = "", ""
		mergedGene = agGFF.AGOUTI_GFF()
		preMergedGene = agGFF.AGOUTI_GFF()
		gapStart, gapStop = 0, 0
		offset = 0
		orientation = ""
		updatedGeneIDs = []
		mergedGenesPerPath = []
		excludeGeneIDs = []
		for nextVertex in path[1:]:
			nextCtg = seqNames[nextVertex]

			if preCtg == "":
				if debug:
					agUPDATEDebug.debugger.debug("UPDATE_MAIN\t>scaf_%d - path - %s"
												 %(scafID,
												  str([seqNames[vertex] for vertex in path])))
			if debug:
				agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tcurVertex - %d - %s - nextVertex - %d - %s"
											 %(curVertex, curCtg, nextVertex, nextCtg))

			if not no_update_gff:
				#curGene, nextGene = ctgpair2genepair(dCtgPair2GenePair, curCtg, nextCtg)
				curGene, nextGene = ctgpair2genepair(dCtgPair2GenePair, curVertex, nextVertex)
				#!!! I should not break here, should continue#
				if curGene is None and nextGene is None:
					agUPDATEProgress.logger.error("%s - %s found no gene models joining them"
										   %(curCtg, nextCtg))
					agUPDATEProgress.logger.error("This is NOT EXPECTED, REPORT!")
					sys.exit(1)
				curGeneID = curGene.geneID
				excludeGeneIDs = [preGeneID] + [curGeneID]
				if debug:
					agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tpreGene - %s - curGene - %s - nextGene - %s"
												 %(preGeneID, curGene.geneID, nextGene.geneID))

			if debug:
				agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tscafName - %s" %(scafName))
			FR, FF, RR, RF = get_orientation_counts(curVertex, nextVertex, dSenses)
			if curSense == "-":
				temp1 = FR
				temp2 = FF
				FR = RR
				FF = RF
				RR = temp1
				RF = temp2
			orientation = decide_orientation(FR, FF, RR, RF)

			gapStart = gapStop + len(dSeqs[curVertex])
			gapStop = gapStart + nFills - 1
			dScafGaps[scafName].append((gapStart+1, gapStop+1))
			if debug:
				agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tcurSense=%s" %(curSense))
				agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tFR=%d - FF=%d - RF=%d - RR=%d"
											 %(FR, FF, RF, RR))
				agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\toffset - %d - curCtgLen - %d"
											 %(offset, len(dSeqs[curVertex])))
				agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tgapstart - %d - gapstop - %d"
											 %(gapStart, gapStop+1))
			valid = 0
			if orientation == "FR":
				if not no_update_gff:
					if curGeneID !=  preGeneID:
						numMergedGene += 1
						mergedGene = merge_gene_model(curGene, nextGene, scafName,
													  numMergedGene, offset, gapStop,
													  debug)
						dMergedGene2Ctgs[mergedGene.geneID] += [curCtg, nextCtg]
						#if curGene.geneStop != 0:
						#	dMergedGene2Genes[mergedGene.geneID] += [curGeneID]
						#if nextGene.geneStop != 0:
						#	dMergedGene2Genes[mergedGene.geneID] += [nextGene.geneID]
						if mergedGene.geneStop != 0:
							dMergedGene2Genes[mergedGene.geneID] += [curGeneID, nextGene.geneID]
						dUpdateGFFs[scafName], updatedGeneIDs = update_gene_model(dGFFs[curCtg], dUpdateGFFs[scafName],
																				  scafName, offset, excludeGeneIDs,
																				  debug, mergedGene)
					else:
						mergedGene = merge_gene_model(preMergedGene, nextGene, scafName,
													  numMergedGene, 0, gapStop, debug)
						dMergedGene2Ctgs[mergedGene.geneID] += [nextCtg]
						if nextGene.geneStop != 0:
							dMergedGene2Genes[mergedGene.geneID] += [nextGene.geneID]
						indexMerged = updatedGeneIDs.index(mergedGene.geneID)
						dUpdateGFFs[scafName][indexMerged] = mergedGene
					preMergedGene = mergedGene
				sequence += 'N'*nFills + dSeqs[nextVertex]
				curSense = "+"
			elif orientation == "FF":
				if not no_update_gff:
					#nextGene = reverse_gene_model(nextGene, len(dSeqs[nextVertex]), debug)
					dGFFs[nextCtg] = reverse_gene_models(dGFFs[nextCtg], len(dSeqs[nextVertex]), debug)
					if curGeneID !=  preGeneID:
						numMergedGene += 1
						mergedGene = merge_gene_model(curGene, nextGene, scafName,
													  numMergedGene, offset, gapStop,
													  debug)
						dMergedGene2Ctgs[mergedGene.geneID] += [curCtg, nextCtg]
						if mergedGene.geneStop != 0:
							dMergedGene2Genes[mergedGene.geneID] += [curGeneID, nextGene.geneID]
						dUpdateGFFs[scafName], updatedGeneIDs = update_gene_model(dGFFs[curCtg], dUpdateGFFs[scafName],
																				  scafName, offset, excludeGeneIDs,
																				  debug, mergedGene)
					else:
						mergedGene = merge_gene_model(preMergedGene, nextGene, scafName,
													  numMergedGene, 0, gapStop, debug)
						dMergedGene2Ctgs[mergedGene.geneID] += [nextCtg]
						dMergedGene2Genes[mergedGene.geneID] += [nextGene.geneID]
						indexMerged = updatedGeneIDs.index(mergedGene.geneID)
						dUpdateGFFs[scafName][indexMerged] = mergedGene
					preMergedGene = mergedGene
				sequence += 'N'*nFills + agSeq.rc_seq(dSeqs[nextVertex])
				curSense = "-"
			elif orientation == "RR":
				if not no_update_gff:
					if curGene.geneID != preGeneID:
						dGFFs[curCtg] = reverse_gene_models(dGFFs[curCtg], len(dSeqs[curVertex]), debug)
						numMergedGene += 1
						mergedGene = merge_gene_model(curGene, nextGene, scafName,
													  numMergedGene, offset, gapStop,
													  debug)
						dMergedGene2Ctgs[mergedGene.geneID] += [curCtg, nextCtg]
						if mergedGene.geneStop != 0:
							dMergedGene2Genes[mergedGene.geneID] += [curGeneID, nextGene.geneID]
						dUpdateGFFs[scafName], updatedGeneIDs = update_gene_model(dGFFs[curCtg], dUpdateGFFs[scafName],
																				  scafName, offset, excludeGeneIDs,
																				  debug, mergedGene)
					else:
						dUpdateGFFs[scafName], updatedGeneIDs = update_gene_model(dGFFs[curCtg], dUpdateGFFs[scafName],
																				  scafName, offset, excludeGeneIDs,
																				  debug)
						dUpdateGFFs[scafName] = reverse_gene_models(dUpdateGFFs[scafName], gapStart-1, debug)
						mergedGene = merge_gene_model(preMergedGene, nextGene, scafName,
													  numMergedGene, 0, gapStop,
													  debug)
						dMergedGene2Ctgs[mergedGene.geneID] += [nextCtg]
						dMergedGene2Genes[mergedGene.geneID] += [nextGene.geneID]
						indexMerged = updatedGeneIDs.index(mergedGene.geneID)
						dUpdateGFFs[scafName][indexMerged] = mergedGene
					preMergedGene = mergedGene
				sequence = agSeq.rc_seq(sequence) + \
						   'N'*nFills + dSeqs[nextVertex]
				curSense = "+"
			elif orientation == "RF":
				if not no_update_gff:
					dGFFs[nextCtg] = reverse_gene_models(dGFFs[nextCtg], len(dSeqs[nextVertex]), debug)
					if curGene.geneID != preGeneID:
						dGFFs[curCtg] = reverse_gene_models(dGFFs[curCtg], len(dSeqs[curVertex]), debug)
						#nextGene = reverse_gene_model(nextGene, len(dSeqs[nextVertex]), debug)
						numMergedGene += 1
						mergedGene = merge_gene_model(curGene, nextGene, scafName,
													  numMergedGene, offset, gapStop,
													  debug)
						dMergedGene2Ctgs[mergedGene.geneID] += [curCtg, nextCtg]
						if mergedGene.geneStop != 0:
							dMergedGene2Genes[mergedGene.geneID] += [curGeneID, nextGene.geneID]
						dUpdateGFFs[scafName], updatedGeneIDs = update_gene_model(dGFFs[curCtg], dUpdateGFFs[scafName],
																				  scafName, offset, excludeGeneIDs,
																				  debug, mergedGene)
					else:
						dUpdateGFFs[scafName], updatedGeneIDs = update_gene_model(dGFFs[curCtg], dUpdateGFFs[scafName],
																				  scafName, offset, excludeGeneIDs,
																				  debug)
						dUpdateGFFs[scafName] = reverse_gene_models(dUpdateGFFs[scafName],
																	gapStop+len(dSeqs[curVertex]),
																	debug)
						mergedGene = merge_gene_model(preMergedGene, nextGene, scafName,
													  numMergedGene, 0, gapStop,
													  debug)
						dMergedGene2Ctgs[mergedGene.geneID] += [nextCtg]
						dMergedGene2Genes[mergedGene.geneID] += [nextGene.geneID]
						indexMerged = updatedGeneIDs.index(mergedGene.geneID)
						dUpdateGFFs[scafName][indexMerged] = mergedGene
					preMergedGene = mergedGene
				sequence = agSeq.rc_seq(sequence) + \
						   'N'*nFills + \
						   agSeq.rc_seq(dSeqs[nextVertex])
				curSense = "-"
			scafPath.append(curCtg)
			if debug:
				agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tscafPath updates- %s"
											 %(str(scafPath)))
				agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tdMergedGene2Gene - %s"
											 %(str(dMergedGene2Genes[mergedGene.geneID])))
			if not no_update_gff:
				mergedGenesPerPath.append(mergedGene.geneID)
				preGeneID = nextGene.geneID
			offset = gapStop
			preCtg = curCtg
			curVertex = nextVertex
			curCtg = seqNames[curVertex]

		scafPath.append(curCtg)
		if debug:
			agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tappend last curCtg - %s" %(curCtg))
		if debug:
			agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tscafPath - %s"
										 %(str(scafPath)))
		if not no_update_gff:
			excludeGeneIDs = [preGeneID]
			mergedGenes.append(mergedGenesPerPath)
			dUpdateGFFs[scafName], updatedGeneIDs = update_gene_model(dGFFs[curCtg],
																	  dUpdateGFFs[scafName],
																	  scafName, offset,
																	  excludeGeneIDs, debug)
		fFASTA.write(">%s |%dbp |%s\n%s\n"
						%(scafName, len(sequence), ",".join(scafPath), sequence))
		dScafStats[scafName] = len(sequence)
		seqLens.append(len(sequence))
		#agPaths.append(scafPath)
		nCtgScaffolded += len(scafPath)
		scaffoldedCtgs.update(dict((contig, 1) for contig in scafPath))
		if debug:
			agUPDATEDebug.debugger.debug("UPDATE_MAIN\t\tmergedGenesPerPath - %s"
										 %(str(mergedGenesPerPath)))
			agUPDATEDebug.debugger.debug("UPDATE_MAIN\t-------------------------------------")

	# other contigs need to be output
	agUPDATEProgress.logger.info("Finalizing sequences")
	numLeft = 0
	for vertex in dSeqs:
		if seqNames[vertex] in scaffoldedCtgs:
			continue
		numLeft += 1
		fFASTA.write(">%s\n%s\n" % (seqNames[vertex], dSeqs[vertex]))
		dScafStats[seqNames[vertex]] = len(dSeqs[vertex])
		seqLens.append(len(dSeqs[vertex]))
	fFASTA.close()
	n50 = agSeq.get_assembly_NXX(seqLens)

	agUPDATEProgress.logger.info("Outputting updated Gene Moddels")
	for vertex in dSeqs:
		if seqNames[vertex] in scaffoldedCtgs:
			if seqNames[vertex] in dGFFs:
				del dGFFs[seqNames[vertex]]
	if not no_update_gff:
		dFinalGFFs = dict(dGFFs, **dUpdateGFFs)
		numGenes = output_gff(dFinalGFFs, dMergedGene2Ctgs, dMergedGene2Genes,
							  dScafStats, dScafGaps, outDir, prefix)
		agUPDATEProgress.logger.info("Summarizing AGOUTI gene paths")
		summarize_gene_path(dMergedGene2Genes, dMergedGene2Ctgs,
							outDir, prefix)

	agUPDATEProgress.logger.info("-----------Summary-----------")
	agUPDATEProgress.logger.info("number of contigs scaffoled: %d" %(nCtgScaffolded))
	agUPDATEProgress.logger.info("number of scaffolds: %d" %(scafID))
#	agUPDATEProgress.logger.info("number of contigs found no links: %d" %(numLeft))
	agUPDATEProgress.logger.info("number of contigs in the final assembly: %d" %(len(seqLens)))
	agUPDATEProgress.logger.info("Final assembly N50: %d" %(n50))
	if not no_update_gff:
		agUPDATEProgress.logger.info("Final number of genes: %d" %(numGenes))
	agUPDATEProgress.logger.info("Succeeded")

Example 3

Project: nesoni
Source File: clip.py
View license
    def run(self):
        log = self.log
        
        #quality_cutoff, args = grace.get_option_value(args, '--quality', int, 10)
        #qoffset, args = grace.get_option_value(args, '--qoffset', int, None)
        #clip_ambiguous, args = grace.get_option_value(args, '--clip-ambiguous', grace.as_bool, True)
        #length_cutoff, args = grace.get_option_value(args, '--length', int, 24)
        #adaptor_cutoff, args = grace.get_option_value(args, '--match', int, 10)
        #max_error, args = grace.get_option_value(args, '--max-errors', int, 1)
        #adaptor_set, args = grace.get_option_value(args, '--adaptors', str, 'truseq-adapter,truseq-srna,genomic,multiplexing,pe,srna')
        #disallow_homopolymers, args = grace.get_option_value(args, '--homopolymers', grace.as_bool, False)
        #reverse_complement, args = grace.get_option_value(args, '--revcom', grace.as_bool, False)
        #trim_start, args = grace.get_option_value(args, '--trim-start', int, 0)
        #trim_end, args = grace.get_option_value(args, '--trim-end', int, 0)
        #output_fasta, args = grace.get_option_value(args, '--fasta', grace.as_bool, False)
        #use_gzip, args = grace.get_option_value(args, '--gzip', grace.as_bool, True)
        #output_rejects, args = grace.get_option_value(args, '--rejects', grace.as_bool, False)
        #grace.expect_no_further_options(args)
        
        prefix = self.prefix
        log_name = os.path.split(prefix)[1]
        
        quality_cutoff = self.quality
        qoffset = self.qoffset
        clip_ambiguous = self.clip_ambiguous
        length_cutoff = self.length
        adaptor_cutoff = self.match
        max_error = self.max_errors
        disallow_homopolymers = self.homopolymers
        reverse_complement = self.revcom
        trim_start = self.trim_start
        trim_end = self.trim_end
        output_fasta = self.fasta
        use_gzip = self.gzip
        output_rejects = self.rejects
    
        iterators = [ ]        
        filenames = [ ]
        any_paired = False
        
        for filename in self.reads:
            filenames.append(filename)
            iterators.append(itertools.izip(
                 io.read_sequences(filename, qualities=True)
            ))
        
        for pair_filenames in self.pairs:
            assert len(pair_filenames) == 2, 'Expected a pair of files for "pairs" section.'
            filenames.extend(pair_filenames)
            any_paired = True
            iterators.append(itertools.izip(
                io.read_sequences(pair_filenames[0], qualities=True),
                io.read_sequences(pair_filenames[1], qualities=True)
            ))
        
        for filename in self.interleaved:
            filenames.append(filename)
            any_paired = True
            iterators.append(deinterleave(
                io.read_sequences(filename, qualities=True)
            ))
        
        fragment_reads = (2 if any_paired else 1)
        read_in_fragment_names = [ 'read-1', 'read-2' ] if any_paired else [ 'read' ]
        
        assert iterators, 'Nothing to clip'
        
        io.check_name_uniqueness(self.reads, self.pairs, self.interleaved)
    
        if qoffset is None:
            #guesses = [ io.guess_quality_offset(filename) for filename in filenames ]
            #assert len(set(guesses)) == 1, 'Conflicting quality offset guesses, please specify manually.'
            #qoffset = guesses[0]
            qoffset = io.guess_quality_offset(*filenames)
            log.log('FASTQ offset seems to be %d\n' % qoffset)    
    
        quality_cutoff_char = chr(qoffset + quality_cutoff)
        
        #log.log('Minimum quality:        %d (%s)\n' % (quality_cutoff, quality_cutoff_char))
        #log.log('Clip ambiguous bases:   %s\n' % (grace.describe_bool(clip_ambiguous)))
        #log.log('Minimum adaptor match:  %d bases, %d errors\n' % (adaptor_cutoff, max_error))
        #log.log('Minimum length:         %d bases\n' % length_cutoff)
        
        adaptor_seqs = [ ]
        adaptor_names = [ ]
        if self.adaptor_clip:
            if self.adaptor_file:
                adaptor_iter = io.read_sequences(self.adaptor_file)
            else:
                adaptor_iter = ADAPTORS
            for name, seq in adaptor_iter:
                seq = seq.upper().replace('U','T')
                adaptor_seqs.append(seq)
                adaptor_names.append(name)
                adaptor_seqs.append(bio.reverse_complement(seq))
                adaptor_names.append(name)

        matcher = Matcher(adaptor_seqs, adaptor_names, max_error)
        
        start_clips = [ collections.defaultdict(list) for i in xrange(fragment_reads) ]
        end_clips = [ collections.defaultdict(list) for i in xrange(fragment_reads) ]
    
        if output_fasta:
            write_sequence = io.write_fasta_single_line
        else:
            write_sequence = io.write_fastq
    
        f_single = io.open_possibly_compressed_writer(self.reads_output_filenames()[0])
        if fragment_reads == 2:
            names = self.pairs_output_filenames()[0] if self.out_separate else self.interleaved_output_filenames()
            f_paired = map(io.open_possibly_compressed_writer, names)
        if output_rejects:
            f_reject = io.open_possibly_compressed_writer(self.rejects_output_filenames()[0])
        
        n_single = 0
        n_paired = 0
        
        n_in_single = 0
        n_in_paired = 0
        total_in_length = [ 0 ] * fragment_reads
        
        n_out = [ 0 ] * fragment_reads
        n_q_clipped = [ 0 ] * fragment_reads
        n_a_clipped = [ 0 ] * fragment_reads
        n_homopolymers = [ 0 ] * fragment_reads
        total_out_length = [ 0 ] * fragment_reads
        
        #log.attach(open(prefix + '_log.txt', 'wb'))
        
        for iterator in iterators:
          for fragment in iterator:
            if (n_in_single+n_in_paired) % 10000 == 0:
                grace.status('Clipping fragment %s' % grace.pretty_number(n_in_single+n_in_paired))
        
            if len(fragment) == 1:
                n_in_single += 1
            else:
                n_in_paired += 1
            
            graduates = [ ]
            rejects = [ ]
            for i, (name, seq, qual) in enumerate(fragment):
                seq = seq.upper()
                total_in_length[i] += len(seq)
                
                if self.trim_to:
                    seq = seq[:self.trim_to]
                    qual = qual[:self.trim_to]
                
                start = trim_start
                best_start = 0
                best_len = 0
                for j in xrange(len(seq)-trim_end):
                    if qual[j] < quality_cutoff_char or \
                       (clip_ambiguous and seq[j] not in 'ACGT'):
                        if best_len < j-start:
                            best_start = start
                            best_len = j-start
                        start = j + 1
                j = len(seq)-trim_end
                if best_len < j-start:
                    best_start = start
                    best_len = j-start
        
                clipped_seq = seq[best_start:best_start+best_len]
                clipped_qual = qual[best_start:best_start+best_len]
                if len(clipped_seq) < length_cutoff:
                    n_q_clipped[i] += 1
                    rejects.append( (name,seq,qual,'quality') ) 
                    continue
        
                match = matcher.match(clipped_seq)
                if match and match[0] >= adaptor_cutoff:
                    clipped_seq = clipped_seq[match[0]:]
                    clipped_qual = clipped_qual[match[0]:]
                    start_clips[i][match[0]].append( match[1][0] )
                    if len(clipped_seq) < length_cutoff:
                        n_a_clipped[i] += 1 
                        rejects.append( (name,seq,qual,'adaptor') ) 
                        continue
            
                match = matcher.match(bio.reverse_complement(clipped_seq))
                if match and match[0] >= adaptor_cutoff:
                    clipped_seq = clipped_seq[: len(clipped_seq)-match[0] ]    
                    clipped_qual = clipped_qual[: len(clipped_qual)-match[0] ]    
                    end_clips[i][match[0]].append( match[1][0] )
                    if len(clipped_seq) < length_cutoff:
                        n_a_clipped[i] += 1 
                        rejects.append( (name,seq,qual,'adaptor') ) 
                        continue
    
                if disallow_homopolymers and len(set(clipped_seq)) <= 1:
                    n_homopolymers[i] += 1
                    rejects.append( (name,seq,qual,'homopolymer') ) 
                    continue
        
                graduates.append( (name, clipped_seq, clipped_qual) )
                n_out[i] += 1
                total_out_length[i] += len(clipped_seq)
    
            if output_rejects:
                for name,seq,qual,reason in rejects:
                    write_sequence(f_reject, name + ' ' + reason, seq, qual)
             
            if graduates:
                if reverse_complement:
                    graduates = [
                        (name, bio.reverse_complement(seq), qual[::-1])
                        for name, seq, qual in graduates
                    ]
            
                if len(graduates) == 1:
                    n_single += 1

                    (name, seq, qual) = graduates[0]
                    write_sequence(f_single, name, seq, qual)
                else:
                    assert len(graduates) == 2
                    n_paired += 1

                    # Write the pair to an interleaved file or separate l/r files
                    for (lr,(name, seq, qual)) in enumerate(graduates):
                        write_sequence(f_paired[lr%len(f_paired)], name, seq, qual)
                
        
        grace.status('')
        
        if output_rejects:
            f_reject.close()
        if fragment_reads == 2:
            map(lambda f: f.close(), f_paired)
        f_single.close()
        
        def summarize_clips(name, location, clips):
            total = 0
            for i in clips:
                total += len(clips[i])
            log.datum(log_name, name + ' adaptors clipped at ' + location, total) 
            
            if not clips:
                return
    
            for i in xrange(min(clips), max(clips)+1):
                item = clips[i]
                log.quietly_log('%3d bases: %10d ' % (i, len(item)))
                if item:
                    avg_errors = float(sum( item2[0] for item2 in item )) / len(item)
                    log.quietly_log(' avg errors: %5.2f  ' % avg_errors)
                    
                    counts = collections.defaultdict(int)
                    for item2 in item: counts[item2[1]] += 1
                    #print counts
                    for no in sorted(counts,key=lambda item2:counts[item2],reverse=True)[:2]:
                        log.quietly_log('%dx%s ' % (counts[no], matcher.names[no]))
                    if len(counts) > 2: log.quietly_log('...')
                    
                log.quietly_log('\n')
            log.quietly_log('\n')


        if n_in_paired:
            log.datum(log_name,'read-pairs', n_in_paired)                      
        if n_in_single:
            log.datum(log_name,'single reads', n_in_single)                      
        
        for i in xrange(fragment_reads):
            if start_clips:
                summarize_clips(read_in_fragment_names[i], 'start', start_clips[i])
        
            if end_clips:
                summarize_clips(read_in_fragment_names[i], 'end', end_clips[i])

                prefix = read_in_fragment_names[i]
                
            log.datum(log_name, prefix + ' too short after quality clip', n_q_clipped[i])
            log.datum(log_name, prefix + ' too short after adaptor clip', n_a_clipped[i])
            if disallow_homopolymers:
                log.datum(log_name, prefix + ' homopolymers', n_homopolymers[i])
            if fragment_reads > 1:
                log.datum(log_name, prefix + ' kept', n_out[i])
            log.datum(log_name, prefix + ' average input length',  float(total_in_length[i]) / (n_in_single+n_in_paired))                     
            if n_out[i]:
                log.datum(log_name, prefix + ' average output length', float(total_out_length[i]) / n_out[i])                     
        
        if fragment_reads == 2:
            log.datum(log_name,'pairs kept after clipping', n_paired)                      
        log.datum(log_name, 'reads kept after clipping', n_single)

Example 4

Project: nesoni
Source File: fill_scaffolds.py
View license
def fill_scaffolds(args):
    max_filler_length, args = grace.get_option_value(args, '--max-filler', int, 4000)
    
    if len(args) < 2:
        print USAGE
        return 1
    
    (output_dir, graph_dir), args = args[:2], args[2:]

    scaffolds = [ ]
    
    def scaffold(args):
        circular, args = grace.get_option_value(args, '--circular', grace.as_bool, False)
        
        scaffold = [ ]
        for item in args:
            scaffold.append( ('contig', int(item)) )
            scaffold.append( ('gap', None) )
        
        if not circular: scaffold = scaffold[:-1]
        
        name = 'custom_scaffold_%d' % (len(scaffolds)+1)
        scaffolds.append( (name, scaffold) )
            
    grace.execute(args, [scaffold])
    
    custom_scaffolds = (len(scaffolds) != 0)    
    
    sequences = dict( 
        (a.split()[0], b.upper()) 
          for a,b in 
            io.read_sequences(os.path.join(
              graph_dir, '454AllContigs.fna')))
    
    sequence_names = sorted(sequences)
    sequence_ids = dict(zip(sequence_names, xrange(1,len(sequence_names)+1)))
    
    contexts = { }
    context_names = { }
    context_depths = { }
    for i in xrange(1,len(sequence_names)+1):
        seq = sequences[sequence_names[i-1]]
        contexts[ i ] = seq
        context_names[ i ] = sequence_names[i-1]+'-fwd'
        contexts[ -i ] = bio.reverse_complement(seq)
        context_names[ -i ] = sequence_names[i-1]+'-rev'
    
    links = collections.defaultdict(list)
    
    for line in open(
      os.path.join(graph_dir, '454ContigGraph.txt'),
      'rU'):
        parts = line.rstrip('\n').split('\t')
        
        if parts[0].isdigit():
            seq = sequence_ids[parts[1]]
            context_depths[ seq] = float(parts[3])
            context_depths[-seq] = float(parts[3])
        
        if parts[0] == 'C':    
            name1 = 'contig%05d' % int(parts[1])
            dir1 = {"3'" : 1, "5'" : -1 }[parts[2]]
            name2 = 'contig%05d' % int(parts[3])
            dir2 = {"5'" : 1, "3'" : -1 }[parts[4]]
            depth = int(parts[5])
            #print name1, dir1, name2, dir2, depth
            
            links[ sequence_ids[name1] * dir1 ].append( (depth, sequence_ids[name2] * dir2) )
            links[ sequence_ids[name2] * -dir2 ].append( (depth, sequence_ids[name1] * -dir1) )
    
        if parts[0] == 'S' and not custom_scaffolds:  
            name = 'scaffold%05d' % int(parts[2])  
            components = parts[3].split(';')
            scaffold = [ ]
            for component in components:
                a,b = component.split(':')
                if a == 'gap':
                    scaffold.append( ('gap',int(b)) )
                else:
                    strand = { '+': +1, '-': -1 }[ b ]
                    scaffold.append( ('contig', sequence_ids['contig%05d'%int(a)] * strand) )
            scaffolds.append( (name, scaffold) )
    
    
    
    #paths = { }
    #
    #todo = [ ]
    #for i in contexts:
    #    for depth_left, neg_left in links[-i]:
    #        left = -neg_left
    #        for depth_right, right in links[i]:
    #            todo.append( ( max(-depth_left,-depth_right,-context_depths[i]), left, right, (i,)) )
    #
    #heapq.heapify(todo)
    #while todo:
    #    score, source, dest, path = heapq.heappop(todo)
    #    if (source,dest) in paths: continue
    #    
    #    paths[(source,dest)] = path
    #    
    #    if len(contexts[dest]) > max_filler_length: continue
    #    
    #    for depth, next in links[dest]:
    #        heapq.heappush(todo,
    #            ( max(score,-depth,-context_depths[dest]), source, next, path+(dest,))
    #        )
    
    
    path_source_dest = collections.defaultdict(dict) # source -> dest -> next
    path_dest_source = collections.defaultdict(dict) # dest -> source -> next
    
    
    # Use links, in order to depth of coverage, to construct paths between contigs
    # Thus: paths have maximum minimum depth
    #       subsections of paths also have this property
    
    todo = [ ]
    for i in contexts:    
        for depth_link, right in links[i]:
            todo.append( ( depth_link, i, right) )
    todo.sort(reverse=True)
    for score, left, right in todo:
        if right in path_source_dest[left]: continue
        
        sources = [(left,right)]
        if len(contexts[left]) <= max_filler_length:
            sources += path_dest_source[left].items()
        destinations = [right]
        if len(contexts[right]) <= max_filler_length:
            destinations += path_source_dest[right].keys()
        
        for source, next in sources:
            for dest in destinations:
                if dest in path_source_dest[source]: continue
                path_source_dest[source][dest] = next
                path_dest_source[dest][source] = next
    
    
    workspace = io.Workspace(output_dir)
    scaffold_f = workspace.open('scaffolds.fa','wb')
    
    #comments = [ ]
    features = [ ]
    
    used = set()
    previous_total = 0
    
    for i, (name, scaffold) in enumerate(scaffolds):
        result = '' # Inefficient. Meh.
        n_filled = 0
        n_failed = 0
        for j, item in enumerate(scaffold):
            if item[0] == 'contig':
                result += contexts[item[1]]
                used.add(abs(item[1]))
            else:
                left = scaffold[j-1]
                right = scaffold[ (j+1) % len(scaffold) ] #If gap at end, assume circular
                assert left[0] == 'contig'
                assert right[0] == 'contig'
                
                gap_start = len(result)
    
                can_fill = right[1] in path_source_dest[left[1]]
                if can_fill:
                    n = 0
                    k = path_source_dest[left[1]][right[1]]
                    while k != right[1]:
                        n += len(contexts[k])
                        result += contexts[k].lower()
                        used.add(abs(k))
                        
                        k = path_source_dest[k][right[1]]
                    
                    n_filled += 1
                        
                    if item[1] is not None and max(n,item[1]) > min(n,item[1])*4:
                        print >> sys.stderr, 'Warning: gap size changed from %d to %d in scaffold %d' % (item[1],n,i+1)
                else:
                    n_failed += 1
                    
                    #print >> sys.stderr, 'Warning: No path to fill a gap in scaffold %d' % (i+1)
                    result += 'n' * (9 if item[1] is None else item[1])
    
                gap_end = len(result)
                
                #features.append( '%s\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' % (
                #    'all-scaffolds',
                #    'fill-scaffolds',
                #    'gap',
                #    previous_total + gap_start+1,
                #    previous_total + max(gap_end, gap_start+1), #Allow for zeroed out gaps. Hmm.
                #    '.', #score
                #    '+', #strand
                #    '.', #frame
                #    '' #properties
                #))
                features.append( '%s\t%s\t%s\t%d\t%d\t%s\t%s\t%s\t%s' % (
                    name,
                    'fill-scaffolds',
                    'gap',
                    gap_start+1,
                    max(gap_end, gap_start+1), #Allow for zeroed out gaps. Hmm.
                    '.', #score
                    '+', #strand
                    '.', #frame
                    '' #properties
                ))
                    
    
        io.write_fasta(scaffold_f, name, result)
        previous_total += len(result)
        #comments.append('##sequence-region    %s %d %d' % (name, 1, len(result)))
        print >> sys.stderr, 'Scaffold%05d: %d gaps filled, %d could not be filled' % (i+1, n_filled, n_failed)
    
    scaffold_f.close()
    
    gff_f = workspace.open('scaffolds.gff', 'wb')
    #print >>gff_f, '##gff-version    3'
    #for comment in comments:
    #    print >>gff_f, comment
    for feature in features:
        print >>gff_f, feature
    gff_f.close()
    
    
    leftovers_f = workspace.open('leftovers.fa', 'wb')
    for name in sequence_names:
        if sequence_ids[name] not in used:
            io.write_fasta(leftovers_f, name, sequences[name])
    leftovers_f.close()
    
    ends = { }
    for i, (name, scaffold) in enumerate(scaffolds):
        if scaffold[-1][0] == 'gap': continue
        ends[ '%s start' % name ] = scaffold[-1][1]
        ends[ '%s end  ' % name ] = -scaffold[0][1] 
    
    for end1 in sorted(ends):
        options = [ end2 for end2 in ends if -ends[end2] in path_source_dest[ends[end1]] ]
        if len(options) == 1:
            print >> sys.stderr, 'Note: from', end1, 'only', options[0], 'is reachable'

Example 5

Project: frescobaldi
Source File: vocal.py
View license
    def build(self, data, builder):
        # normalize voicing
        staves = self.voicing.currentText().upper()
        # remove unwanted characters
        staves = re.sub(r'[^SATB-]+', '', staves)
        # remove double hyphens, and from begin and end
        staves = re.sub('-+', '-', staves).strip('-')
        if not staves:
            return
        
        splitStaves = staves.split('-')
        numStaves = len(splitStaves)
        staffCIDs = collections.defaultdict(int)    # number same-name staff Context-IDs
        voiceCounter = collections.defaultdict(int) # dict to number same voice types
        maxNumVoices = max(map(len, splitStaves))   # largest number of voices
        numStanzas = self.stanzas.value()
        lyrics = collections.defaultdict(list)      # lyrics grouped by stanza number
        pianoReduction = collections.defaultdict(list)
        rehearsalMidis = []
        
        p = ly.dom.ChoirStaff()
        choir = ly.dom.Sim(p)
        data.nodes.append(p)
        
        # print main instrumentName if there are more choirs, and we
        # have more than one staff.
        if numStaves > 1 and data.num:
            builder.setInstrumentNames(p,
                builder.instrumentName(lambda _: _("Choir"), data.num),
                builder.instrumentName(lambda _: _("abbreviation for Choir", "Ch."), data.num))
        
        # get the preferred way of adding lyrics
        lyrAllSame, lyrEachSame, lyrEachDiff, lyrSpread = (
            self.lyrics.currentIndex() == i for i in range(4))
        lyrEach = lyrEachSame or lyrEachDiff
        
        # stanzas to print (0 = don't print stanza number):
        if numStanzas == 1:
            allStanzas = [0]
        else:
            allStanzas = list(range(1, numStanzas + 1))
        
        # Which stanzas to print where:
        if lyrSpread and numStanzas > 1 and numStaves > 2:
            spaces = numStaves - 1
            count, rest = divmod(max(numStanzas, spaces), spaces)
            stanzaSource = itertools.cycle(allStanzas)
            stanzaGroups = (itertools.islice(stanzaSource, num)
                            for num in itertools.chain(
                                itertools.repeat(count + 1, rest),
                                itertools.repeat(count, numStaves - rest)))
        else:
            stanzaGroups = itertools.repeat(allStanzas, numStaves)
        
        # a function to set staff affinity (in LilyPond 2.13.4 and above):
        if builder.lyVersion >= (2, 13, 4):
            def setStaffAffinity(context, affinity):
                ly.dom.Line("\\override VerticalAxisGroup "
                     "#'staff-affinity = #" + affinity, context.getWith())
        else:
            def setStaffAffinity(lyricsContext, affinity):
                pass
        
        # a function to make a column markup:
        if builder.lyVersion >= (2, 11, 57):
            columnCommand = 'center-column'
        else:
            columnCommand = 'center-align'
        def makeColumnMarkup(names):
            node = ly.dom.Markup()
            column = ly.dom.MarkupEnclosed(columnCommand, node)
            for name in names:
                ly.dom.QuotedString(name, column)
            return node
        
        stavesLeft = numStaves
        for staff, stanzas in zip(splitStaves, stanzaGroups):
            # are we in the last staff?
            stavesLeft -= 1
            # the number of voices in this staff
            numVoices = len(staff)
            # sort the letters in order SATB
            staff = ''.join(i * staff.count(i) for i in 'SATB')
            # Create the staff for the voices
            s = ly.dom.Staff(parent=choir)
            builder.setMidiInstrument(s, self.midiInstrument)
            
            # Build a list of the voices in this staff.
            # Each entry is a tuple(name, num).
            # name is one of 'S', 'A', 'T', or 'B'
            # num is an integer: 0 when a voice occurs only once, or >= 1 when
            # there are more voices of the same type (e.g. Soprano I and II)
            voices = []
            for voice in staff:
                if staves.count(voice) > 1:
                    voiceCounter[voice] += 1
                voices.append((voice, voiceCounter[voice]))
            
            # Add the instrument names to the staff:
            if numVoices == 1:
                voice, num = voices[0]
                longName = builder.instrumentName(voice2Voice[voice].title, num)
                shortName = builder.instrumentName(voice2Voice[voice].short, num)
                builder.setInstrumentNames(s, longName, shortName)
            else:
                # stack instrument names (long and short) in a markup column.
                # long names
                longNames = makeColumnMarkup(
                    builder.instrumentName(voice2Voice[voice].title, num) for voice, num in voices)
                shortNames = makeColumnMarkup(
                    builder.instrumentName(voice2Voice[voice].short, num) for voice, num in voices)
                builder.setInstrumentNames(s, longNames, shortNames)
            
            # Make the { } or << >> holder for this staff's children.
            # If *all* staves have only one voice, addlyrics is used.
            # In that case, don't remove the braces.
            staffMusic = (ly.dom.Seq if lyrEach and maxNumVoices == 1 else
                          ly.dom.Seqr if numVoices == 1 else ly.dom.Simr)(s)
            
            # Set the clef for this staff:
            if 'B' in staff:
                ly.dom.Clef('bass', staffMusic)
            elif 'T' in staff:
                ly.dom.Clef('treble_8', staffMusic)

            # Determine voice order (\voiceOne, \voiceTwo etc.)
            if numVoices == 1:
                order = (0,)
            elif numVoices == 2:
                order = 1, 2
            elif staff in ('SSA', 'TTB'):
                order = 1, 3, 2
            elif staff in ('SAA', 'TBB'):
                order = 1, 2, 4
            elif staff in ('SSAA', 'TTBB'):
                order = 1, 3, 2, 4
            else:
                order = range(1, numVoices + 1)
            
            # What name would the staff get if we need to refer to it?
            # If a name (like 's' or 'sa') is already in use in this part,
            # just add a number ('ss2' or 'sa2', etc.)
            staffCIDs[staff] += 1
            cid = ly.dom.Reference(staff.lower() +
                str(staffCIDs[staff] if staffCIDs[staff] > 1 else ""))
            
            # Create voices and their lyrics:
            for (voice, num), voiceNum in zip(voices, order):
                name = voice2id[voice]
                if num:
                    name += ly.util.int2text(num)
                a = data.assignMusic(name, voice2Voice[voice].octave)
                lyrName = name + 'Verse' if lyrEachDiff else 'verse'
            
                # Use \addlyrics if all staves have exactly one voice.
                if lyrEach and maxNumVoices == 1:
                    for verse in stanzas:
                        lyrics[verse].append((ly.dom.AddLyrics(s), lyrName))
                    ly.dom.Identifier(a.name, staffMusic)
                else:
                    voiceName = voice2id[voice] + str(num or '')
                    v = ly.dom.Voice(voiceName, parent=staffMusic)
                    voiceMusic = ly.dom.Seqr(v)
                    if voiceNum:
                        ly.dom.Text('\\voice' + ly.util.int2text(voiceNum), voiceMusic)
                    ly.dom.Identifier(a.name, voiceMusic)
                    
                    if stanzas and (lyrEach or (voiceNum <= 1 and
                                    (stavesLeft or numStaves == 1))):
                        # Create the lyrics. If they should be above the staff,
                        # give the staff a suitable name, and use alignAbove-
                        # Context to align the Lyrics above the staff.
                        above = voiceNum & 1 if lyrEach else False
                        if above and s.cid is None:
                            s.cid = cid

                        for verse in stanzas:
                            l = ly.dom.Lyrics(parent=choir)
                            if above:
                                l.getWith()['alignAboveContext'] = cid
                                setStaffAffinity(l, "DOWN")
                            elif not lyrEach and stavesLeft:
                                setStaffAffinity(l, "CENTER")
                            lyrics[verse].append((ly.dom.LyricsTo(voiceName, l), lyrName))

                # Add ambitus:
                if self.ambitus.isChecked():
                    ambitusContext = (s if numVoices == 1 else v).getWith()
                    ly.dom.Line('\\consists "Ambitus_engraver"', ambitusContext)
                    if voiceNum > 1:
                        ly.dom.Line("\\override Ambitus #'X-offset = #{0}".format(
                                 (voiceNum - 1) * 2.0), ambitusContext)
            
                pianoReduction[voice].append(a.name)
                rehearsalMidis.append((voice, num, a.name, lyrName))
            
        # Assign the lyrics, so their definitions come after the note defs.
        # (These refs are used again below in the midi rehearsal routine.)
        refs = {}
        for verse in allStanzas:
            for node, name in lyrics[verse]:
                if (name, verse) not in refs:
                    refs[(name, verse)] = self.assignLyrics(data, name, verse).name
                ly.dom.Identifier(refs[(name, verse)], node)

        # Create the piano reduction if desired
        if self.pianoReduction.isChecked():
            a = data.assign('pianoReduction')
            data.nodes.append(ly.dom.Identifier(a.name))
            piano = ly.dom.PianoStaff(parent=a)
            
            sim = ly.dom.Sim(piano)
            rightStaff = ly.dom.Staff(parent=sim)
            leftStaff = ly.dom.Staff(parent=sim)
            right = ly.dom.Seq(rightStaff)
            left = ly.dom.Seq(leftStaff)
            
            # Determine the ordering of voices in the staves
            upper = pianoReduction['S'] + pianoReduction['A']
            lower = pianoReduction['T'] + pianoReduction['B']
            
            preferUpper = 1
            if not upper:
                # Male choir
                upper = pianoReduction['T']
                lower = pianoReduction['B']
                ly.dom.Clef("treble_8", right)
                ly.dom.Clef("bass", left)
                preferUpper = 0
            elif not lower:
                # Female choir
                upper = pianoReduction['S']
                lower = pianoReduction['A']
            else:
                ly.dom.Clef("bass", left)

            # Otherwise accidentals can be confusing
            ly.dom.Line("#(set-accidental-style 'piano)", right)
            ly.dom.Line("#(set-accidental-style 'piano)", left)
            
            # Move voices if unevenly spread
            if abs(len(upper) - len(lower)) > 1:
                voices = upper + lower
                half = (len(voices) + preferUpper) // 2
                upper = voices[:half]
                lower = voices[half:]
            
            for staff, voices in (ly.dom.Simr(right), upper), (ly.dom.Simr(left), lower):
                if voices:
                    for v in voices[:-1]:
                        ly.dom.Identifier(v, staff)
                        ly.dom.VoiceSeparator(staff).after = 1
                    ly.dom.Identifier(voices[-1], staff)

            # Make the piano part somewhat smaller
            ly.dom.Line("fontSize = #-1", piano.getWith())
            ly.dom.Line("\\override StaffSymbol #'staff-space = #(magstep -1)",
                piano.getWith())
            
            # Nice to add Mark engravers
            ly.dom.Line('\\consists "Mark_engraver"', rightStaff.getWith())
            ly.dom.Line('\\consists "Metronome_mark_engraver"', rightStaff.getWith())
            
            # Keep piano reduction out of the MIDI output
            if builder.midi:
                ly.dom.Line('\\remove "Staff_performer"', rightStaff.getWith())
                ly.dom.Line('\\remove "Staff_performer"', leftStaff.getWith())
        
        # Create MIDI files if desired
        if self.rehearsalMidi.isChecked():
            a = data.assign('rehearsalMidi')
            rehearsalMidi = a.name
            
            func = ly.dom.SchemeList(a)
            func.pre = '#\n(' # hack
            ly.dom.Text('define-music-function', func)
            ly.dom.Line('(parser location name midiInstrument lyrics) '
                 '(string? string? ly:music?)', func)
            choir = ly.dom.Sim(ly.dom.Command('unfoldRepeats', ly.dom.SchemeLily(func)))
            
            data.afterblocks.append(ly.dom.Comment(_("Rehearsal MIDI files:")))
            
            for voice, num, ref, lyrName in rehearsalMidis:
                # Append voice to the rehearsalMidi function
                name = voice2id[voice] + str(num or '')
                seq = ly.dom.Seq(ly.dom.Voice(name, parent=ly.dom.Staff(name, parent=choir)))
                if builder.lyVersion < (2, 18, 0):
                    ly.dom.Text('<>\\f', seq) # add one dynamic
                ly.dom.Identifier(ref, seq) # add the reference to the voice
                
                book = ly.dom.Book()
                
                # Append score to the aftermath (stuff put below the main score)
                suffix = "choir{0}-{1}".format(data.num, name) if data.num else name
                if builder.lyVersion < (2, 12, 0):
                    data.afterblocks.append(
                        ly.dom.Line('#(define output-suffix "{0}")'.format(suffix)))
                else:
                    ly.dom.Line('\\bookOutputSuffix "{0}"'.format(suffix), book)
                data.afterblocks.append(book)
                data.afterblocks.append(ly.dom.BlankLine())
                score = ly.dom.Score(book)
                
                # TODO: make configurable
                midiInstrument = voice2Midi[voice]

                cmd = ly.dom.Command(rehearsalMidi, score)
                ly.dom.QuotedString(name, cmd)
                ly.dom.QuotedString(midiInstrument, cmd)
                ly.dom.Identifier(refs[(lyrName, allStanzas[0])], cmd)
                ly.dom.Midi(score)
            
            ly.dom.Text("\\context Staff = $name", choir)
            seq = ly.dom.Seq(choir)
            ly.dom.Line("\\set Score.midiMinimumVolume = #0.5", seq)
            ly.dom.Line("\\set Score.midiMaximumVolume = #0.5", seq)
            ly.dom.Line("\\set Score.tempoWholesPerMinute = #" + data.scoreProperties.schemeMidiTempo(), seq)
            ly.dom.Line("\\set Staff.midiMinimumVolume = #0.8", seq)
            ly.dom.Line("\\set Staff.midiMaximumVolume = #1.0", seq)
            ly.dom.Line("\\set Staff.midiInstrument = $midiInstrument", seq)
            lyr = ly.dom.Lyrics(parent=choir)
            lyr.getWith()['alignBelowContext'] = ly.dom.Text('$name')
            ly.dom.Text("\\lyricsto $name $lyrics", lyr)

Example 6

View license
    @save
    def onSuccess(self, results, config):
        data = self.new_data()

        datasource_by_pid = {}
        metrics_by_component = collections.defaultdict(
            lambda: collections.defaultdict(list))

        # Used for process restart checking.
        if not hasattr(self, 'previous_pids_by_component'):
            self.previous_pids_by_component = collections.defaultdict(set)

        pids_by_component = collections.defaultdict(set)

        sorted_datasource = sorted(
            config.datasources,
            key=lambda x: x.params.get('sequence', 0))

        # Win32_Process: Counts and correlation to performance table.
        process_key = [x for x in results if 'Win32_Process' in x.wql][0]
        for item in results[process_key]:
            processText = get_processText(item)

            for datasource in sorted_datasource:
                regex = re.compile(datasource.params['regex'])

                # Zenoss 4.2 2013-10-15 RPS style.
                if 'replacement' in datasource.params:
                    matcher = OSProcessDataMatcher(
                        includeRegex=datasource.params['includeRegex'],
                        excludeRegex=datasource.params['excludeRegex'],
                        replaceRegex=datasource.params['replaceRegex'],
                        replacement=datasource.params['replacement'],
                        primaryUrlPath=datasource.params['primaryUrlPath'],
                        generatedId=datasource.params['generatedId'])

                    if not matcher.matches(processText):
                        continue

                # Zenoss 4.2 intermediate style
                elif hasattr(OSProcess, 'matchRegex'):
                    excludeRegex = re.compile(
                        datasource.params['excludeRegex'])

                    basic_match = OSProcess.matchRegex(
                        regex, excludeRegex, processText)

                    if not basic_match:
                        continue

                    capture_match = OSProcess.matchNameCaptureGroups(
                        regex, processText, datasource.component)

                    if not capture_match:
                        continue

                # Zenoss 4.1-4.2 style.
                else:
                    if datasource.params['ignoreParameters']:
                        processText = item.ExecutablePath or item.Name

                    name, args = get_processNameAndArgs(item)
                    if datasource.params['ignoreParameters']:
                        proc_id = getProcessIdentifier(name, None)
                    else:
                        proc_id = getProcessIdentifier(name, args)

                    if datasource.component != prepId(proc_id):
                        continue

                datasource_by_pid[item.ProcessId] = datasource
                pids_by_component[datasource.component].add(item.ProcessId)

                # Track process count. Append 1 each time we find a
                # match because the generic aggregator below will sum
                # them up to the total count.
                metrics_by_component[datasource.component][COUNT_DATAPOINT].append(1)

                # Don't continue matching once a match is found.
                break

        # Send process status events.
        for datasource in config.datasources:
            component = datasource.component

            if COUNT_DATAPOINT in metrics_by_component[component]:
                severity = 0
                summary = 'matching processes running'

                # Process restart checking.
                previous_pids = self.previous_pids_by_component.get(component)
                current_pids = pids_by_component.get(component)

                # No restart if there are no current or previous PIDs.
                # previous PIDs.
                if previous_pids and current_pids:

                    # Only consider PID changes a restart if all PIDs
                    # matching the process changed.
                    if current_pids.isdisjoint(previous_pids):
                        summary = 'matching processes restarted'

                        # If the process is configured to alert on
                        # restart, the first "up" won't be a clear.
                        if datasource.params['alertOnRestart']:
                            severity = datasource.params['severity']

            else:
                severity = datasource.params['severity']
                summary = 'no matching processes running'

                # Add a 0 count for process that aren't running.
                metrics_by_component[component][COUNT_DATAPOINT].append(0)

            data['events'].append({
                'device': datasource.device,
                'component': component,
                'eventClass': datasource.eventClass,
                'eventGroup': 'Process',
                'summary': summary,
                'severity': severity,
                })

        # Prepare for next cycle's restart check by merging current
        # process PIDs with previous. This is to catch restarts that
        # stretch across more than subsequent cycles.
        self.previous_pids_by_component.update(
            (c, p) for c, p in pids_by_component.iteritems() if p)

        # Win32_PerfFormattedData_PerfProc_Process: Datapoints.
        perf_keys = [x for x in results if 'Win32_Perf' in x.wql]
        if perf_keys:
            for item in results[perf_keys[0]]:
                if item.IDProcess not in datasource_by_pid:
                    continue
                datasource = datasource_by_pid[item.IDProcess]
                for point in datasource.points:
                    if point.id == COUNT_DATAPOINT:
                        continue

                    try:
                        value = int(getattr(item, point.id))
                    except (TypeError, ValueError):
                        LOG.warn(
                            "%s %s %s: Couldn't convert %r to integer",
                            datasource.device, datasource.component, point.id,
                            value)
                    except AttributeError:
                        LOG.warn(
                            "%s %s: %s not in result",
                            datasource.device, datasource.component, point.id)
                    else:
                        metrics_by_component[datasource.component][point.id].append(value)

        # Aggregate and store datapoint values.
        for component, points in metrics_by_component.iteritems():
            for point, values in points.iteritems():
                if point in NON_AGGREGATED_DATAPOINTS:
                    value = values[0]
                else:
                    value = sum(values)

                data['values'][component][point] = (value, 'N')

        # Send overall clear.
        data['events'].append({
            'device': config.id,
            'severity': Event.Clear,
            'eventClass': Status_OSProcess,
            'summary': 'process scan successful',
            })

        return data

Example 7

Project: cgat
Source File: annotator_distance.py
View license
def main(argv=sys.argv):

    parser = E.OptionParser(
        version="%prog version: $Id: annotator_distance.py 2861 2010-02-23 17:36:32Z andreas $", usage=globals()["__doc__"])

    parser.add_option("-a", "--annotations-tsv-file", dest="filename_annotations", type="string",
                      help="filename mapping gene ids to annotations (a tab-separated table with two-columns) [default=%default].")

    parser.add_option("-r", "--resolution", dest="resolution", type="int",
                      help="resolution of count vector [default=%default].")

    parser.add_option("-b", "--num-bins", dest="num_bins", type="int",
                      help="number of bins in count vector [default=%default].")

    parser.add_option("-i", "--num-samples", dest="num_samples", type="int",
                      help="sample size to compute [default=%default].")

    parser.add_option("-w", "--workspace-bed-file", dest="filename_workspace", type="string",
                      help="filename with workspace information [default=%default].")

    parser.add_option("--workspace-builder", dest="workspace_builder", type="choice",
                      choices=(
                          "gff", "gtf-intergenic", "gtf-intronic", "gtf-genic"),
                      help="given a gff/gtf file build a workspace [default=%default].")

    parser.add_option("--workspace-labels", dest="workspace_labels", type="choice",
                      choices=("none", "direction", "annotation"),
                      help="labels to use for the workspace workspace [default=%default].")

    parser.add_option("--sampler", dest="sampler", type="choice",
                      choices=("permutation", "gaps"),
                      help="sampler to use. The sampler determines the null model of how segments are distributed in the workspace  [default=%default]")

    parser.add_option("--counter", dest="counters", type="choice", action="append",
                      choices=(
                          "transcription", "closest-distance", "all-distances"),
                      help="counter to use. The counter computes the quantity of interest [default=%default]")

    parser.add_option("--analysis", dest="analysis", type="choice", action="append",
                      choices=("proximity", "area-under-curve"),
                      help="analysis to perform [default=%default]")

    parser.add_option("--transform-counts", dest="transform_counts", type="choice",
                      choices=("raw", "cumulative"),
                      help="cumulate counts [default=%default].")

    parser.add_option("-s", "--segments", dest="filename_segments", type="string",
                      help="filename with segment information [default=%default].")

    parser.add_option("--xrange", dest="xrange", type="string",
                      help="xrange to plot [default=%default]")

    parser.add_option("-o", "--logscale", dest="logscale", type="string",
                      help="use logscale on x, y or xy [default=%default]")

    parser.add_option("-p", "--plot", dest="plot", action="store_true",
                      help="output plots [default=%default]")

    parser.add_option("--hardcopy", dest="hardcopy", type="string",
                      help="output hardcopies to file [default=%default]")

    parser.add_option("--no-fdr", dest="do_fdr", action="store_false",
                      help="do not compute FDR rates [default=%default]")

    parser.add_option("--segments-format", dest="segments_format", type="choice",
                      choices=("gtf", "bed"),
                      help="format of segments file [default=%default].")

    parser.add_option("--truncate", dest="truncate", action="store_true",
                      help="truncate segments extending beyond a workspace [default=%default]")

    parser.add_option("--remove-overhangs", dest="remove_overhangs", action="store_true",
                      help="remove segments extending beyond a workspace[default=%default]")

    parser.add_option("--keep-ambiguous", dest="keep_ambiguous", action="store_true",
                      help="keep segments extending to more than one workspace [default=%default]")

    parser.set_defaults(
        filename_annotations=None,
        filename_workspace="workspace.gff",
        filename_segments="FastDown.gtf",
        filename_annotations_gtf="../data/tg1_territories.gff",
        workspace_builder="gff",
        workspace_labels="none",
        sampler="permutation",
        truncate=False,
        num_bins=10000,
        num_samples=10,
        resolution=100,
        plot_samples=False,
        plot_envelope=True,
        counters=[],
        transform_counts="raw",
        xrange=None,
        plot=False,
        logscale=None,
        output_all=False,
        do_test=False,
        analysis=[],
        do_fdr=True,
        hardcopy="%s.png",
        segments_format="gtf",
        remove_overhangs=False,
    )

    (options, args) = E.Start(parser, argv=argv, add_output_options=True)

    ###########################################
    # setup options
    if options.sampler == "permutation":
        sampler = SamplerPermutation
    elif options.sampler == "gaps":
        sampler = SamplerGaps

    if options.xrange:
        options.xrange = list(map(float, options.xrange.split(",")))

    if len(options.counters) == 0:
        raise ValueError("please specify at least one counter.")

    if len(options.analysis) == 0:
        raise ValueError("please specify at least one analysis.")

    if options.workspace_labels == "annotation" and not options.filename_annotations:
        raise ValueError(
            "please specify --annotations-tsv-file is --workspace-labels=annotations.")

    ###########################################
    # read data
    if options.workspace_labels == "annotation":
        def constant_factory(value):
            return itertools.repeat(value).__next__

        def dicttype():
            return collections.defaultdict(constant_factory(("unknown",)))

        map_id2annotations = IOTools.readMultiMap(open(options.filename_annotations, "r"),
                                                  dtype=dicttype)
    else:
        map_id2annotations = {}

    workspace = readWorkspace(open(options.filename_workspace, "r"),
                              options.workspace_builder,
                              options.workspace_labels,
                              map_id2annotations)

    E.info("read workspace for %i contigs" % (len(workspace)))

    indexed_workspace = indexIntervals(workspace, with_values=True)
    segments = readSegments(open(options.filename_segments, "r"), indexed_workspace,
                            format=options.segments_format,
                            keep_ambiguous=options.keep_ambiguous,
                            truncate=options.truncate,
                            remove_overhangs=options.remove_overhangs)

    nsegments = 0
    for contig, vv in segments.items():
        nsegments += len(vv)

    E.info("read %i segments for %i contigs" % (nsegments, len(workspace)))
    indexed_segments = indexIntervals(segments, with_values=False)

    if nsegments == 0:
        E.warn("no segments read - no computation done.")
        E.Stop()
        return

    # build labels
    labels = collections.defaultdict(int)
    for contig, vv in workspace.items():
        for start, end, v in vv:
            for l in v[0]:
                labels[l] += 1
            for l in v[1]:
                labels[l] += 1

    E.info("found %i workspace labels" % len(labels))

    ###########################################
    # setup counting containers
    counters = []
    for cc in options.counters:

        if cc == "transcription":
            counter = CounterTranscription
        elif cc == "closest-distance":
            counter = CounterClosestDistance
        elif cc == "all-distances":
            counter = CounterAllDistances

        if nsegments < 256:
            dtype = numpy.uint8
        elif nsegments < 65536:
            dtype = numpy.uint16
        elif nsegments < 4294967296:
            dtype = numpy.uint32
        else:
            dtype = numpy.int

        E.debug("choosen dtype %s" % str(dtype))

        E.info("samples space is %i bases: %i bins at %i resolution" %
               (options.num_bins * options.resolution,
                options.num_bins,
                options.resolution,
                ))

        E.info("allocating counts: %i bytes (%i labels, %i samples, %i bins)" %
               (options.num_bins * len(labels) * dtype().itemsize * (options.num_samples + 1),
                len(labels),
                options.num_samples,
                options.num_bins,
                ))

        c = CountingResults(labels)
        c.mObservedCounts = counter(
            labels, options.num_bins, options.resolution, dtype=dtype)

        simulated_counts = []
        for x in range(options.num_samples):
            simulated_counts.append(
                counter(labels, options.num_bins, options.resolution, dtype=dtype))
        c.mSimulatedCounts = simulated_counts
        c.mName = c.mObservedCounts.mName

        counters.append(c)

        E.info("allocated memory successfully")

    segments_per_workspace = []
    segment_sizes = []
    segments_per_label = collections.defaultdict(int)
    workspaces_per_label = collections.defaultdict(int)

    ############################################
    # get observed and simpulated counts
    nworkspaces, nempty_workspaces, nempty_contigs, nmiddle = 0, 0, 0, 0
    iteration2 = 0
    for contig, vv in workspace.items():

        iteration2 += 1
        E.info("counting %i/%i: %s %i segments" %
               (iteration2,
                len(workspace),
                contig,
                len(vv)))

        if len(vv) == 0:
            continue

        iteration1 = 0
        for work_start, work_end, v in vv:

            left_labels, right_labels = v[0], v[1]

            iteration1 += 1

            # ignore empty segments
            if contig not in indexed_segments:
                nempty_contigs += 1
                continue

            r = indexed_segments[contig].find(work_start, work_end)
            segments_per_workspace.append(len(r))

            if not r:
                nempty_workspaces += 1
                continue

            # collect segments and stats
            nworkspaces += 1
            observed = [(x.start, x.end) for x in r]
            observed.sort()
            segments_per_workspace.append(len(observed))
            segment_sizes.extend([x[1] - x[0] for x in observed])

            # collect basic counts
            for label in list(left_labels) + list(right_labels):
                workspaces_per_label[label] += 1
                segments_per_label[label] += len(observed)

            # add observed counts
            for counter in counters:
                counter.mObservedCounts.addCounts(
                    observed, work_start, work_end, left_labels, right_labels)

            # create sampler
            s = sampler(observed, work_start, work_end)

            # add simulated counts
            for iteration in range(options.num_samples):
                simulated = s.sample()
                for counter in counters:
                    counter.mSimulatedCounts[iteration].addCounts(
                        simulated, work_start, work_end, left_labels, right_labels)

    E.info("counting finished")
    E.info("nworkspaces=%i, nmiddle=%i, nempty_workspaces=%i, nempty_contigs=%i" %
           (nworkspaces, nmiddle, nempty_workspaces, nempty_contigs))

    ######################################################
    # transform counts

    if options.transform_counts == "cumulative":
        transform = cumulative_transform
    elif options.transform_counts == "raw":
        transform = normalize_transform

    ####################################################
    # analysis

    if "proximity" in options.analysis:
        outfile_proximity = E.openOutputFile("proximity")
        outfile_proximity.write("\t".join(("label", "observed", "pvalue",
                                           "expected", "CIlower", "CIupper", "qvalue", "segments", "workspaces")) + "\n")
    else:
        outfile_proximity = None

    if "area-under-curve" in options.analysis:
        outfile_auc = E.openOutputFile("auc")
        outfile_auc.write("label\tobserved\texpected\tCIlower\tCIupper\n")
    else:
        outfile_auc = None

    # qvalue: expected false positives at p-value
    # qvalue = expected false positives /
    if options.do_fdr:
        E.info("computing pvalues for fdr")
        for counter in counters:
            for label in labels:
                E.info("working on counter:%s label:%s" % (counter, label))

                # collect all P-Values of simulated results to compute FDR
                sim_pvalues = []
                medians = counter.getMedians(label)

                for median in medians:
                    pvalue = float(
                        scipy.stats.percentileofscore(medians, median)) / 100.0
                    sim_pvalues.append(pvalue)

        sim_pvalues.sort()
    else:
        sim_pvalues = []

    # compute observed p-values
    for counter in counters:
        counter.update()

    obs_pvalues = []
    for counter in counters:
        for label in labels:
            obs_pvalues.append(counter.mStats[label].pvalue)
        obs_pvalues.sort()

    # compute observed p-values
    if options.do_fdr:
        for counter in counters:
            counter.updateFDR(obs_pvalues, sim_pvalues)

    for counter in counters:

        outofbounds_sim, totals_sim = 0, 0
        outofbounds_obs, totals_obs = 0, 0
        for label in labels:
            for sample in range(options.num_samples):
                if counter.mSimulatedCounts[sample].mOutOfBounds[label]:
                    E.debug("out of bounds: sample %i, label %s, counts=%i" %
                            (sample, label, counter.mSimulatedCounts[sample].mOutOfBounds[label]))
                    outofbounds_sim += counter.mSimulatedCounts[
                        sample].mOutOfBounds[label]
                totals_sim += counter.mSimulatedCounts[sample].mTotals[label]

            outofbounds_obs += counter.mObservedCounts.mOutOfBounds[label]
            totals_obs += counter.mObservedCounts.mTotals[label]

        E.info("out of bounds observations: observed=%i/%i (%5.2f%%), simulations=%i/%i (%5.2f%%)" %
               (outofbounds_obs, totals_obs,
                100.0 * outofbounds_obs / totals_obs,
                outofbounds_sim, totals_sim,
                100.0 * outofbounds_sim / totals_sim,
                ))

        for label in labels:

            if outfile_auc:
                mmin, mmax, mmean = counter.getEnvelope(
                    label, transform=normalize_transform)
                obs = normalize_transform(
                    counter.mObservedCounts[label], counter.mObservedCounts.mOutOfBounds[label])

                def block_iterator(a1, a2, a3, num_bins):
                    x = 0
                    while x < num_bins:
                        while x < num_bins and a1[x] <= a2[x]:
                            x += 1
                        start = x
                        while x < options.num_bins and a1[x] > a2[x]:
                            x += 1
                        end = x
                        total_a1 = a1[start:end].sum()
                        total_a3 = a3[start:end].sum()
                        if total_a1 > total_a3:
                            yield (total_a1 - total_a3, start, end, total_a1, total_a3)

                blocks = list(
                    block_iterator(obs, mmax, mmean, options.num_bins))

                if options.output_all:
                    for delta, start, end, total_obs, total_mean in blocks:
                        if end - start <= 1:
                            continue
                        outfile_auc.write("%s\t%i\t%i\t%i\t%f\t%f\t%f\t%f\t%f\n" %
                                          (label,
                                           start * options.resolution,
                                           end * options.resolution,
                                           (end - start) * options.resolution,
                                           total_obs,
                                           total_mean,
                                           delta,
                                           total_obs / total_mean,
                                           100.0 * (total_obs / total_mean - 1.0)))

                # output best block
                blocks.sort()
                delta, start, end, total_obs, total_mean = blocks[-1]

                outfile_auc.write("%s\t%i\t%i\t%i\t%f\t%f\t%f\t%f\t%f\n" %
                                  (label,
                                   start * options.resolution,
                                   end * options.resolution,
                                   (end - start) * options.resolution,
                                   total_obs,
                                   total_mean,
                                   delta,
                                   total_obs / total_mean,
                                   100.0 * (total_obs / total_mean - 1.0)))

            if outfile_proximity:

                # find error bars at median
                st = counter.mStats[label]
                outfile_proximity.write("%s\t%i\t%f\t%i\t%i\t%i\t%s\t%i\t%i\n" %
                                        (label,
                                         st.observed *
                                         options.resolution,
                                         st.pvalue,
                                         st.expected *
                                         options.resolution,
                                         st.ci95lower *
                                         options.resolution,
                                         st.ci95upper *
                                         options.resolution,
                                         IOTools.val2str(st.qvalue),
                                         segments_per_label[label],
                                         workspaces_per_label[label],
                                         ))

    if options.plot:

        for counter in counters:
            plotCounts(counter, options, transform)

        # plot summary stats
        plt.figure()
        plt.title("distribution of workspace length")
        data = []
        for contig, segs in workspace.items():
            if len(segs) == 0:
                continue
            data.extend([x[1] - x[0] for x in segs])

        vals, bins = numpy.histogram(
            data, bins=numpy.arange(0, max(data), 100), new=True)

        t = float(sum(vals))
        plt.plot(bins[:-1], numpy.cumsum(vals) / t)
        plt.gca().set_xscale('log')
        plt.legend()
        t = float(sum(vals))
        plt.xlabel("size of workspace")
        plt.ylabel("cumulative relative frequency")
        if options.hardcopy:
            plt.savefig(
                os.path.expanduser(options.hardcopy % "workspace_size"))

        plt.figure()
        plt.title("segments per block")
        vals, bins = numpy.histogram(segments_per_workspace, bins=numpy.arange(
            0, max(segments_per_workspace), 1), new=True)
        plt.plot(bins[:-1], vals)
        plt.xlabel("segments per block")
        plt.ylabel("absolute frequency")
        if options.hardcopy:
            plt.savefig(
                os.path.expanduser(options.hardcopy % "segments_per_block"))

        plt.figure()
        plt.title("workspaces per label")
        plt.barh(
            list(range(0, len(labels))), [workspaces_per_label[x] for x in labels], height=0.5)
        plt.yticks(list(range(0, len(labels))), labels)
        plt.ylabel("workspaces per label")
        plt.xlabel("absolute frequency")
        plt.gca().set_xscale('log')

        if options.hardcopy:
            plt.savefig(
                os.path.expanduser(options.hardcopy % "workspaces_per_label"))

        plt.figure()
        plt.title("segments per label")
        plt.barh(list(range(0, len(labels))), [segments_per_label[x]
                                               for x in labels], height=0.5)
        plt.yticks(list(range(0, len(labels))), labels)
        plt.ylabel("segments per label")
        plt.xlabel("absolute frequency")
        plt.xticks(list(range(0, len(labels))), labels)
        if options.hardcopy:
            plt.savefig(
                os.path.expanduser(options.hardcopy % "segments_per_label"))

        if not options.hardcopy:
            plt.show()

    E.Stop()

Example 8

Project: cgat
Source File: gff2gff.py
View license
def main(argv=None):

    if argv is None:
        argv = sys.argv

    parser = E.OptionParser(version="%prog version: $Id: gff2gff.py$",
                            usage=globals()["__doc__"])

    parser.add_option("-m", "--method", dest="method", type="choice",
                      choices=(
                          "add-flank",
                          "add-upstream-flank",
                          "add-downstream-flank",
                          "crop",
                          "crop-unique",
                          "complement-groups",
                          "combine-groups",
                          "filter-range",
                          "join-features",
                          "merge-features",
                          "sanitize",
                          "to-forward-coordinates",
                          "to-forward-strand"),
                      help="method to apply [%default]")

    parser.add_option(
        "--ignore-strand", dest="ignore_strand",
        help="ignore strand information.", action="store_true")

    parser.add_option("--is-gtf", dest="is_gtf", action="store_true",
                      help="input will be treated as gtf [default=%default].")

    parser.add_option(
        "-c", "--contigs-tsv-file", dest="input_filename_contigs",
        type="string",
        help="filename with contig lengths.")

    parser.add_option(
        "--agp-file", dest="input_filename_agp", type="string",
        help="agp file to map coordinates from contigs to scaffolds.")

    parser.add_option(
        "-g", "--genome-file", dest="genome_file", type="string",
        help="filename with genome.")

    parser.add_option(
        "--crop-gff-file", dest="filename_crop_gff", type="string",
        help="GFF/GTF file to crop against.")

    parser.add_option(
        "--group-field", dest="group_field", type="string",
        help="""gff field/attribute to group by such as gene_id, "
        "transcript_id, ... [%default].""")

    parser.add_option(
        "--filter-range", dest="filter_range", type="string",
        help="extract all elements overlapping a range. A range is "
        "specified by eithor 'contig:from..to', 'contig:+:from..to', "
        "or 'from,to' .")

    parser.add_option(
        "--sanitize-method", dest="sanitize_method", type="choice",
        choices=("ucsc", "ensembl", "genome"),
        help="method to use for sanitizing chromosome names. "
        "[%default].")

    parser.add_option(
        "--flank-method", dest="flank_method", type="choice",
        choices=("add", "extend"),
        help="method to use for adding flanks. ``extend`` will "
        "extend existing features, while ``add`` will add new features. "
        "[%default].")

    parser.add_option(
        "--skip-missing", dest="skip_missing", action="store_true",
        help="skip entries on missing contigs. Otherwise an "
        "exception is raised [%default].")

    parser.add_option(
        "--contig-pattern", dest="contig_pattern", type="string",
        help="a comma separated list of regular expressions specifying "
        "contigs to be removed when running method sanitize [%default].")

    parser.add_option(
        "--extension-upstream", dest="extension_upstream", type="float",
        help="extension for upstream end [%default].")

    parser.add_option(
        "--extension-downstream", dest="extension_downstream", type="float",
        help="extension for downstream end [%default].")

    parser.add_option(
        "--min-distance", dest="min_distance", type="int",
        help="minimum distance of features to merge/join [%default].")

    parser.add_option(
        "--max-distance", dest="max_distance", type="int",
        help="maximum distance of features to merge/join [%default].")

    parser.add_option(
        "--min-features", dest="min_features", type="int",
        help="minimum number of features to merge/join [%default].")

    parser.add_option(
        "--max-features", dest="max_features", type="int",
        help="maximum number of features to merge/join [%default].")

    parser.set_defaults(
        input_filename_contigs=False,
        filename_crop_gff=None,
        input_filename_agp=False,
        genome_file=None,
        add_up_flank=None,
        add_down_flank=None,
        complement_groups=False,
        crop=None,
        crop_unique=False,
        ignore_strand=False,
        filter_range=None,
        min_distance=0,
        max_distance=0,
        min_features=1,
        max_features=0,
        extension_upstream=1000,
        extension_downstream=1000,
        sanitize_method="ucsc",
        flank_method="add",
        output_format="%06i",
        skip_missing=False,
        is_gtf=False,
        group_field=None,
        contig_pattern=None,
    )

    (options, args) = E.Start(parser, argv=argv)

    contigs = None
    genome_fasta = None
    if options.input_filename_contigs:
        contigs = Genomics.readContigSizes(
            IOTools.openFile(options.input_filename_contigs, "r"))

    if options.genome_file:
        genome_fasta = IndexedFasta.IndexedFasta(options.genome_file)
        contigs = genome_fasta.getContigSizes()

    if options.method in ("forward_coordinates", "forward_strand",
                          "add-flank", "add-upstream-flank",
                          "add-downstream-flank") \
       and not contigs:
        raise ValueError("inverting coordinates requires genome file")

    if options.input_filename_agp:
        agp = AGP.AGP()
        agp.readFromFile(IOTools.openFile(options.input_filename_agp, "r"))
    else:
        agp = None

    gffs = GTF.iterator(options.stdin)

    if options.method in ("add-upstream-flank",
                          "add-downstream-flank",
                          "add-flank"):

        add_upstream_flank = "add-upstream-flank" == options.method
        add_downstream_flank = "add-downstream-flank" == options.method
        if options.method == "add-flank":
            add_upstream_flank = add_downstream_flank = True

        upstream_flank = int(options.extension_upstream)
        downstream_flank = int(options.extension_downstream)
        extend_flank = options.flank_method == "extend"

        if options.is_gtf:
            iterator = GTF.flat_gene_iterator(gffs)
        else:
            iterator = GTF.joined_iterator(gffs, options.group_field)

        for chunk in iterator:
            is_positive = Genomics.IsPositiveStrand(chunk[0].strand)
            chunk.sort(key=lambda x: (x.contig, x.start))
            lcontig = contigs[chunk[0].contig]

            if extend_flank:
                if add_upstream_flank:
                    if is_positive:
                        chunk[0].start = max(
                            0, chunk[0].start - upstream_flank)
                    else:
                        chunk[-1].end = min(
                            lcontig,
                            chunk[-1].end + upstream_flank)
                if add_downstream_flank:
                    if is_positive:
                        chunk[-1].end = min(lcontig,
                                            chunk[-1].end + downstream_flank)
                    else:
                        chunk[0].start = max(
                            0, chunk[0].start - downstream_flank)
            else:
                if add_upstream_flank:
                    gff = GTF.Entry()
                    if is_positive:
                        gff.copy(chunk[0])
                        gff.end = gff.start
                        gff.start = max(0, gff.start - upstream_flank)
                        chunk.insert(0, gff)
                    else:
                        gff.copy(chunk[-1])
                        gff.start = gff.end
                        gff.end = min(lcontig, gff.end + upstream_flank)
                        chunk.append(gff)
                    gff.feature = "5-Flank"
                    gff.mMethod = "gff2gff"
                if add_downstream_flank:
                    gff = GTF.Entry()
                    if is_positive:
                        gff.copy(chunk[-1])
                        gff.start = gff.end
                        gff.end = min(lcontig, gff.end + downstream_flank)
                        chunk.append(gff)
                    else:
                        gff.copy(chunk[0])
                        gff.end = gff.start
                        gff.start = max(0, gff.start - downstream_flank)
                        chunk.insert(0, gff)
                    gff.feature = "3-Flank"
                    gff.mMethod = "gff2gff"

            if not is_positive:
                chunk.reverse()

            for gff in chunk:
                options.stdout.write(str(gff) + "\n")

    elif options.method == "complement-groups":

        iterator = GTF.joined_iterator(gffs,
                                       group_field=options.group_field)

        for chunk in iterator:
            if options.is_gtf:
                chunk = [x for x in chunk if x.feature == "exon"]
                if len(chunk) == 0:
                    continue
            chunk.sort(key=lambda x: (x.contig, x.start))
            x = GTF.Entry()
            x.copy(chunk[0])
            x.start = x.end
            x.feature = "intron"
            for c in chunk[1:]:
                x.end = c.start
                options.stdout.write(str(x) + "\n")
                x.start = c.end

    elif options.method == "combine-groups":

        iterator = GTF.joined_iterator(gffs,
                                       group_field=options.group_field)

        for chunk in iterator:
            chunk.sort(key=lambda x: (x.contig, x.start))
            x = GTF.Entry()
            x.copy(chunk[0])
            x.end = chunk[-1].end
            x.feature = "segment"
            options.stdout.write(str(x) + "\n")

    elif options.method == "join-features":
        for gff in combineGFF(gffs,
                              min_distance=options.min_distance,
                              max_distance=options.max_distance,
                              min_features=options.min_features,
                              max_features=options.max_features,
                              merge=False,
                              output_format=options.output_format):
            options.stdout.write(str(gff) + "\n")

    elif options.method == "merge-features":
        for gff in combineGFF(gffs,
                              min_distance=options.min_distance,
                              max_distance=options.max_distance,
                              min_features=options.min_features,
                              max_features=options.max_features,
                              merge=True,
                              output_format=options.output_format):
            options.stdout.write(str(gff) + "\n")

    elif options.method == "crop":
        for gff in cropGFF(gffs, options.filename_crop_gff):
            options.stdout.write(str(gff) + "\n")

    elif options.method == "crop-unique":
        for gff in cropGFFUnique(gffs):
            options.stdout.write(str(gff) + "\n")

    elif options.method == "filter-range":

        contig, strand, interval = None, None, None
        try:
            contig, strand, start, sep, end = re.match(
                "(\S+):(\S+):(\d+)(\.\.|-)(\d+)",
                options.filter_range).groups()
        except AttributeError:
            pass

        if not contig:
            try:
                contig, start, sep, end = re.match(
                    "(\S+):(\d+)(\.\.|-)(\d+)", options.filter_range).groups()
                strand = None
            except AttributeError:
                pass

        if not contig:
            try:
                start, end = re.match(
                    "(\d+)(\.\.|\,|\-)(\d+)", options.filter_range).groups()
            except AttributeError:
                raise "can not parse range %s" % options.filter_range
            contig = None
            strand = None

        if start:
            interval = (int(start), int(end))
        else:
            interval = None

        E.debug("filter: contig=%s, strand=%s, interval=%s" %
                (str(contig), str(strand), str(interval)))

        for gff in GTF.iterator_filtered(gffs, contig=contig,
                                         strand=strand,
                                         interval=interval):
            options.stdout.write(str(gff) + "\n")

    elif options.method == "sanitize":

        def toUCSC(id):
            if not id.startswith("contig") and not id.startswith("chr"):
                id = "chr%s" % id
            return id

        def toEnsembl(id):
            if id.startswith("contig"):
                return id[len("contig"):]
            if id.startswith("chr"):
                return id[len("chr"):]
            return id

        if options.sanitize_method == "genome":
            if genome_fasta is None:
                raise ValueError(
                    "please specify --genome-file= when using "
                    "--sanitize-method=genome")
            f = genome_fasta.getToken
        elif options.sanitize_method == "ucsc":
            f = toUCSC
        elif options.sanitize_method == "ensembl":
            f = toEnsembl

        skipped_contigs = collections.defaultdict(int)
        outofrange_contigs = collections.defaultdict(int)
        filtered_contigs = collections.defaultdict(int)

        for gff in gffs:
            try:
                gff.contig = f(gff.contig)
            except KeyError:
                if options.skip_missing:
                    skipped_contigs[gff.contig] += 1
                    continue
                else:
                    raise
                    
            if genome_fasta:
                lcontig = genome_fasta.getLength(gff.contig)
                if lcontig < gff.end:
                    outofrange_contigs[gff.contig] += 1
                    continue

            if options.contig_pattern:
                to_remove = [re.compile(x)
                             for x in options.contig_pattern.split(",")]
                if any([x.search(gff.contig) for x in to_remove]):
                    filtered_contigs[gff.contig] += 1
                    continue

            options.stdout.write(str(gff) + "\n")

        if skipped_contigs:
            E.info("skipped %i entries on %i contigs: %s" %
                   (sum(skipped_contigs.values()),
                    len(list(skipped_contigs.keys(
                    ))),
                    str(skipped_contigs)))

        if outofrange_contigs:
            E.warn("skipped %i entries on %i contigs because they are out of range: %s" %
                   (sum(outofrange_contigs.values()),
                    len(list(outofrange_contigs.keys())),
                    str(outofrange_contigs)))

        if filtered_contigs:
            E.info("filtered out %i entries on %i contigs: %s" %
                   (sum(filtered_contigs.values()),
                    len(list(filtered_contigs.keys())),
                    str(filtered_contigs)))

    else:

        for gff in gffs:

            if options.method == "forward_coordinates":
                gff.invert(contigs[gff.contig])

            if options.method == "forward_strand":
                gff.invert(contigs[gff.contig])
                gff.strand = "+"

            if agp:
                # note: this works only with forward coordinates
                gff.contig, gff.start, gff.end = agp.mapLocation(
                    gff.contig, gff.start, gff.end)

            options.stdout.write(str(gff) + "\n")

    E.Stop()

Example 9

Project: cgat
Source File: gff2annotator.py
View license
def main(argv=None):
    """script main.

    parses command line options in sys.argv, unless *argv* is given.
    """

    if argv is None:
        argv = sys.argv

    parser = E.OptionParser(
        version="%prog version: $Id: gff2annotator2tsv.py 2861 2010-02-23 17:36:32Z andreas $", usage=globals()["__doc__"])

    parser.add_option("-g", "--genome-file", dest="genome_file", type="string",
                      help="filename with genome.")

    parser.add_option("-f", "--features", dest="features", type="string",
                      help="feature to collect [default=None].")

    parser.add_option("-i", "--files", dest="files", action="append",
                      help="use multiple annotations [default=None].")

    parser.add_option("-a", "--annotations", dest="annotations", type="string",
                      help="aggregate name for annotations if only single file is provided from STDIN [default=None].")

    parser.add_option("--input-filename-map", dest="input_filename_map", type="string",
                      help="filename with a map of gene_ids to categories [default=None].")

    parser.add_option("--output-filename-synonyms", dest="output_filename_synonyms", type="string",
                      help="output filename for synonyms. For workspace building, the gff source will be used as the id (instead of the contig) [default=None].")

    parser.add_option("-m", "--max-length", dest="max_length", type="string",
                      help="maximum segment length [default=None].")

    parser.add_option("-s", "--section", dest="section", type="choice",
                      choices=("segments", "annotations", "annotations-genes",
                               "annotations-go", "workspace", "annotations-gff"),
                      help="annotator section [default=None].")

    parser.add_option("--subset", dest="subsets", type="string", action="append",
                      help="add filenames to delimit subsets within the gff files. The syntax is filename.gff,label,filename.ids [default=None].")

    parser.add_option("--remove-regex", dest="remove_regex", type="string",
                      help="regular expression of contigs to remove [default=None].")

    parser.set_defaults(
        genome_file=None,
        feature=None,
        section="segments",
        annotations="annotations",
        max_length=100000,
        files=[],
        subsets=[],
        input_filename_map=None,
        output_filename_synonyms=None,
        input_format="gff",
        remove_regex=None,
    )

    (options, args) = E.Start(parser)

    options.files += args
    if len(options.files) == 0:
        options.files.append("-")
    options.files = list(
        itertools.chain(*[re.split("[,; ]+", x) for x in options.files]))

    if options.subsets:
        subsets = collections.defaultdict(list)
        for s in options.subsets:
            filename_gff, label, filename_ids = s.split(",")
            subsets[filename_gff].append((label, filename_ids))
        options.subsets = subsets

    if options.genome_file:
        fasta = IndexedFasta.IndexedFasta(options.genome_file)
    else:
        fasta = None

    if options.section == "segments":
        prefix = "##Segs"
    elif options.section.startswith("annotations"):
        prefix = "##Id"
    elif options.section == "workspace":
        prefix = "##Work"
    else:
        raise ValueError("unknown section %s" % options.section)

    ninput, ncontigs, nsegments, ndiscarded = 0, 0, 0, 0

    if options.remove_regex:
        options.remove_regex = re.compile(options.remove_regex)

    if options.section in ("segments", "workspace"):

        iterator = GTF.iterator_filtered(GFF.iterator(options.stdin),
                                         feature=options.feature)

        if options.output_filename_synonyms:
            outfile_synonyms = open(options.output_filename_synonyms, "w")
            with_records = True
        else:
            outfile_synonyms = None
            with_records = False

        intervals = GTF.readAsIntervals(iterator, with_records=with_records)
        ninput, nsegments, ndiscarded, ncontigs = \
            PipelineEnrichment.outputSegments(options.stdout,
                                              intervals,
                                              options.section,
                                              outfile_synonyms=outfile_synonyms,
                                              max_length=options.max_length,
                                              remove_regex=options.remove_regex)

        if outfile_synonyms:
            outfile_synonyms.close()

    elif options.section == "annotations-go":

        assert options.input_filename_map, "please supply option --input-filename-map"

        iterator = GTF.iterator_filtered(GTF.iterator(options.stdin),
                                         feature=options.feature)

        geneid2categories = IOTools.readMultiMap(
            open(options.input_filename_map, "r"))

        category2segments = collections.defaultdict(list)

        for contig, gffs in GTF.readAsIntervals(iterator, with_gene_id=True).items():
            if options.remove_regex and options.remove_regex.search(contig):
                continue

            for start, end, geneid in gffs:
                if geneid not in geneid2categories:
                    continue
                for category in geneid2categories[geneid]:
                    category2segments[category].append(nsegments)

                options.stdout.write(
                    "%s\t%i\t%s\t(%i,%i)\n" % (prefix, nsegments, contig, start, end))
                nsegments += 1

        for category, segments in category2segments.iteritems():
            options.stdout.write(
                "##Ann\t%s\t%s\n" % (category, "\t".join(["%i" % x for x in segments])))
            E.info("set %s annotated with %i segments" %
                   (category, len(segments)))

    elif options.section == "annotations":

        for filename in options.files:

            E.info("adding filename %s" % filename)

            start = nsegments
            is_gtf = False

            if filename == "-":
                iterator = GTF.iterator_filtered(GFF.iterator(sys.stdin),
                                                 feature=options.feature)
                filename = options.annotations
            elif filename.endswith(".gtf"):
                is_gtf = True
                with open(filename, "r") as infile:
                    iterator = GTF.iterator_filtered(GTF.iterator(infile),
                                                     feature=options.feature)

            else:
                with open(filename, "r") as infile:
                    iterator = GTF.iterator_filtered(GFF.iterator(infile),
                                                     feature=options.feature)

            E.debug("processing %s" % (filename))

            if not options.subsets or filename not in options.subsets:
                for contig, gffs in GTF.readAsIntervals(iterator).items():
                    if options.remove_regex and options.remove_regex.search(contig):
                        continue

                    for x in gffs:
                        options.stdout.write(
                            "%s\t%i\t%s\t(%i,%i)\n" % (prefix, nsegments, contig, x[0], x[1]))
                        nsegments += 1

                options.stdout.write("##Ann\t%s\t%s\n" % (
                    filename, "\t".join(["%i" % x for x in range(start, nsegments)])))
                E.info("set %s annotated with %i segments" %
                       (filename, nsegments - start))

            else:
                raise ValueError("don't know how to filter %s" % filename)

    elif options.section == "annotations-gff":

        for filename in options.files:
            if filename == "-":
                iterator = GTF.iterator(sys.stdin)
            else:
                iterator = GTF.iterator_filtered(
                    GFF.iterator(open(filename, "r")))

            segments = collections.defaultdict(list)
            for gff in iterator:
                segments[":".join((gff.source, gff.feature))].append(
                    (gff.contig, gff.start, gff.end))

            feature2segments = {}

            for feature, s in segments.iteritems():
                s.sort()

                s1 = nsegments

                for contig, start, end in s:
                    if options.remove_regex and options.remove_regex.search(contig):
                        continue

                    options.stdout.write(
                        "%s\t%i\t%s\t(%i,%i)\n" % (prefix, nsegments, contig, start, end))
                    nsegments += 1

                feature2segments[feature] = (s1, nsegments)

        for feature, id_range in feature2segments.iteritems():
            start, end = id_range
            options.stdout.write("##Ann\t%s\t%s\n" % (
                feature, "\t".join(["%i" % x for x in xrange(start, end)])))
            E.info("set %s annotated with %i segments" %
                   (feature, end - start))

    elif options.section == "annotations-genes":

        for filename in options.files:

            E.info("adding filename %s" % filename)

            start = nsegments

            assert filename.endswith(".gtf") or filename.endswith(".gtf.gz"), \
                "requiring .gtf files for gene list filtering, received %s" % filename

            infile = IOTools.openFile(filename)
            iterator = GTF.iterator_filtered(GTF.iterator(infile),
                                             feature=options.feature)

            E.debug("processing %s" % (filename))

            if not options.subsets or filename not in options.subsets:
                # output all
                for contig, gffs in GTF.readAsIntervals(iterator).items():
                    if options.remove_regex and options.remove_regex.search(contig):
                        continue

                    for x in gffs:
                        options.stdout.write(
                            "%s\t%i\t%s\t(%i,%i)\n" % (prefix, nsegments, contig, x[0], x[1]))
                        nsegments += 1

                options.stdout.write("##Ann\t%s\t%s\n" % (
                    filename, "\t".join(["%i" % x for x in range(start, nsegments)])))
                E.info("set %s annotated with %i segments" %
                       (filename, nsegments - start))

            else:
                # create subsets
                E.debug("applying subsets for %s" % filename)
                geneid2label, label2segments = collections.defaultdict(
                    list), {}
                for label, filename_ids in options.subsets[filename]:
                    gene_ids = IOTools.readList(open(filename_ids, "r"))
                    for gene_id in gene_ids:
                        geneid2label[gene_id].append(label)
                    label2segments[label] = []

                for contig, gffs in GTF.readAsIntervals(iterator, with_gene_id=True).items():

                    if options.remove_regex and options.remove_regex.search(contig):
                        continue

                    for start, end, gene_id in gffs:
                        if gene_id not in geneid2label:
                            continue
                        for label in geneid2label[gene_id]:
                            label2segments[label].append(nsegments)

                        options.stdout.write(
                            "%s\t%i\t%s\t(%i,%i)\n" % (prefix, nsegments, contig, start, end))
                        nsegments += 1

                for label, segments in label2segments.iteritems():
                    options.stdout.write(
                        "##Ann\t%s\t%s\n" % (label, "\t".join(["%i" % x for x in segments])))
                    E.info("set %s (%s) annotated with %i segments" %
                           (label, filename, len(segments)))

    E.info("ninput=%i, ncontigs=%i, nsegments=%i, ndiscarded=%i" %
           (ninput, ncontigs, nsegments, ndiscarded))

    E.Stop()

Example 10

Project: UMI-tools
Source File: dedup.py
View license
def get_bundles(insam, ignore_umi=False, subset=None, quality_threshold=0,
                paired=False, chrom=None, spliced=False, soft_clip_threshold=0,
                per_contig=False, whole_contig=False, read_length=False,
                detection_method="MAPQ"):
    ''' Returns a dictionary of dictionaries, representing the unique reads at
    a position/spliced/strand combination. The key to the top level dictionary
    is a umi. Each dictionary contains a "read" entry with the best read, and a
    count entry with the number of reads with that position/spliced/strand/umi
    combination'''

    last_pos = 0
    last_chr = ""
    reads_dict = collections.defaultdict(
        lambda: collections.defaultdict(
            lambda: collections.defaultdict(dict)))
    read_counts = collections.defaultdict(
        lambda: collections.defaultdict(dict))

    if chrom:
        inreads = insam.fetch(reference=chrom)
    else:
        inreads = insam.fetch()

    for read in inreads:

        if subset:
            if random.random() >= subset:
                continue

        if quality_threshold:
            if read.mapq < quality_threshold:
                continue

        if read.is_unmapped:
            continue

        if read.mate_is_unmapped and paired:
            continue

        if read.is_read2:
            continue

        # TS - some methods require deduping on a per contig
        # (gene for transcriptome) basis, e.g Soumillon et al 2014
        # to fit in with current workflow, simply assign pos and key as contig
        if per_contig:

            pos = read.tid
            key = read.tid
            if not read.tid == last_chr:

                out_keys = reads_dict.keys()

                for p in out_keys:
                    for bundle in reads_dict[p].values():
                        yield bundle
                    del reads_dict[p]
                    del read_counts[p]

                last_chr = read.tid

        else:

            is_spliced = False

            if read.is_reverse:
                pos = read.aend
                if read.cigar[-1][0] == 4:
                    pos = pos + read.cigar[-1][1]
                start = read.pos

                if ('N' in read.cigarstring or
                    (read.cigar[0][0] == 4 and
                     read.cigar[0][1] > soft_clip_threshold)):
                    is_spliced = True
            else:
                pos = read.pos
                if read.cigar[0][0] == 4:
                    pos = pos - read.cigar[0][1]
                start = pos

                if ('N' in read.cigarstring or
                    (read.cigar[-1][0] == 4 and
                     read.cigar[-1][1] > soft_clip_threshold)):
                    is_spliced = True

            if whole_contig:
                do_output = not read.tid == last_chr
            else:
                do_output = start > (last_pos+1000) or not read.tid == last_chr

            if do_output:

                out_keys = [x for x in reads_dict.keys() if x <= start-1000]

                for p in out_keys:
                    for bundle in reads_dict[p].values():
                        yield bundle
                    del reads_dict[p]
                    del read_counts[p]

                last_pos = start
                last_chr = read.tid

            if read_length:
                r_length = read.query_length
            else:
                r_length = 0

            key = (read.is_reverse, spliced & is_spliced,
                   paired*read.tlen, r_length)

        if ignore_umi:
            umi = ""
        else:
            umi = read.qname.split("_")[-1]

        try:
            reads_dict[pos][key][umi]["count"] += 1
        except KeyError:
            reads_dict[pos][key][umi]["read"] = read
            reads_dict[pos][key][umi]["count"] = 1
            read_counts[pos][key][umi] = 0
        else:
            if reads_dict[pos][key][umi]["read"].mapq > read.mapq:
                continue

            if reads_dict[pos][key][umi]["read"].mapq < read.mapq:
                reads_dict[pos][key][umi]["read"] = read
                read_counts[pos][key][umi] = 0
                continue

            # TS: implemented different checks for multimapping here
            if detection_method in ["NH", "X0"]:
                tag = detection_method
                if reads_dict[pos][key][umi]["read"].opt(tag) < read.opt(tag):
                    continue
                elif reads_dict[pos][key][umi]["read"].opt(tag) > read.opt(tag):
                    reads_dict[pos][key][umi]["read"] = read
                    read_counts[pos][key][umi] = 0

            elif detection_method == "XT":
                if reads_dict[pos][key][umi]["read"].opt("XT") == "U":
                    continue
                elif read.opt("XT") == "U":
                    reads_dict[pos][key][umi]["read"] = read
                    read_counts[pos][key][umi] = 0

            read_counts[pos][key][umi] += 1
            prob = 1.0/read_counts[pos][key][umi]

            if random.random() < prob:
                reads_dict[pos][key][umi]["read"] = read

    # yield remaining bundles
    for p in reads_dict:
        for bundle in reads_dict[p].values():
            yield bundle

Example 11

Project: networking-brocade
Source File: config.py
View license
def connect_setup_commands(vrouter, iface, svc, conn, resources):
        LOG.info(_LI('Vyatta vRouter: _connect_setup_commands:  '))
        SCmd = vyatta_client.SetCmd
        batch = list()

        if resources.iface_alloc(conn, iface):
            batch.append(
                SCmd('vpn/ipsec/ipsec-interfaces/interface/{0}'.format(iface)))

        ike_name, need_create = resources.ike_group_alloc(conn)
        if need_create:
            batch.extend(ike_setup_commands(conn, ike_name))

        esp_name, need_create = resources.esp_group_alloc(conn)
        if need_create:
            batch.extend(esp_setup_commands(conn, esp_name))

        link_id = [uuid.UUID(x).get_hex() for x in svc['id'], conn['id']]
        link_id.insert(0, 'os-id')
        link_id = ':'.join(link_id)
        remote_peer = conn['peer_address']
        p = 'vpn/ipsec/site-to-site/peer/{0}'.format(remote_peer)
        batch.extend([
            SCmd('{0}/description/{1}'.format(p, link_id)),
            SCmd('{0}/authentication/mode/pre-shared-secret'.format(p)),
            SCmd('{0}/authentication/pre-shared-secret/{1}'.format(
                p, conn['psk'])),
            SCmd('{0}/ike-group/{1}'.format(p, ike_name)),
            SCmd('{0}/default-esp-group/{1}'.format(p, esp_name)),
            SCmd('{0}/local-address/{1}'.format(
                p, urllib.quote_plus(svc['external_ip'])))])
        for remote_cidr in conn['peer_cidrs']:
            idx = resources.tunnel_alloc(conn, remote_peer)
            batch.append(SCmd(
                '{0}/tunnel/{1}/allow-public-networks/enable'.format(p, idx)))
            batch.append(SCmd('{0}/tunnel/{1}/local/prefix/{2}'.format(
                p, idx, urllib.quote_plus(svc['subnet']['cidr']))))
            batch.append(SCmd('{0}/tunnel/{1}/remote/prefix/{2}'.format(
                p, idx, urllib.quote_plus(remote_cidr))))
            LOG.info(_LI('Vyatta vRouter: _connect_setup_commands: '
                         'add snat-exclude for remote_cidr %s'), remote_cidr)
            vrouter.add_snat_exclude_rule(batch, iface,
                svc['subnet']['cidr'], remote_cidr)

        # TODO(dbogun): static routing for remote networks
        return batch


def ike_setup_commands(conn, name):
        policy = conn[_KEY_IKEPOLICY]

        SCmd = vyatta_client.SetCmd

        ike_prefix = 'vpn/ipsec/ike-group/{0}'.format(urllib.quote_plus(name))
        return [
            SCmd('{0}/proposal/1'.format(ike_prefix)),
            SCmd('{0}/proposal/1/encryption/{1}'.format(
                ike_prefix, policy['encryption_algorithm'])),
            SCmd('{0}/proposal/1/hash/{1}'.format(
                ike_prefix, policy['auth_algorithm'])),
            SCmd('{0}/lifetime/{1}'.format(
                ike_prefix, policy['lifetime_value'])),
            SCmd('{0}/dead-peer-detection/action/{1}'.format(
                ike_prefix, conn['dpd_action'])),
            SCmd('{0}/dead-peer-detection/interval/{1}'.format(
                ike_prefix, conn['dpd_interval'])),
            SCmd('{0}/dead-peer-detection/timeout/{1}'.format(
                ike_prefix, conn['dpd_timeout']))
        ]


def esp_setup_commands(conn, name):
        policy = conn[_KEY_ESPPOLICY]

        SCmd = vyatta_client.SetCmd

        esp_prefix = 'vpn/ipsec/esp-group/{0}'.format(urllib.quote_plus(name))
        return [
            SCmd('{0}/proposal/1'.format(esp_prefix)),
            SCmd('{0}/proposal/1/encryption/{1}'.format(
                esp_prefix, policy['encryption_algorithm'])),
            SCmd('{0}/proposal/1/hash/{1}'.format(
                esp_prefix, policy['auth_algorithm'])),
            SCmd('{0}/lifetime/{1}'.format(
                esp_prefix, policy['lifetime_value'])),
            SCmd('{0}/pfs/{1}'.format(
                esp_prefix, policy['pfs'])),
            SCmd('{0}/mode/{1}'.format(
                esp_prefix, policy['encapsulation_mode']))
        ]


def connect_remove_commands(vrouter, iface, svc, conn, resources):
        LOG.info(_LI('Vyatta vRouter: _connect_setup_commands:  '))

        DCmd = vyatta_client.DeleteCmd

        batch = [
            DCmd('vpn/ipsec/site-to-site/peer/{0}'.format(
                conn['peer_address']))]

        name, need_remove = resources.ike_group_release(conn)
        if need_remove:
            batch.append(
                DCmd('vpn/ipsec/ike-group/{0}'.format(
                    urllib.quote_plus(name))))
        name, need_remove = resources.esp_group_release(conn)
        if need_remove:
            batch.append(
                DCmd('vpn/ipsec/esp-group/{0}'.format(
                    urllib.quote_plus(name))))

        if resources.iface_release(conn, iface):
            # FIXME(dbogun): vrouter failed to complete this command
            # batch.append(
            #     DCmd('vpn/ipsec/ipsec-interfaces/interface/{0}'.format(
            #         iface)))
            pass

        for remote_cidr in conn['peer_cidrs']:
            LOG.info(_LI('Vyatta vRouter: _connect_setup_commands: '
                         'delete snat-exclude for remote_cidr %s'),
                     remote_cidr)
            vrouter.delete_snat_exclude_rule(batch, iface,
                svc['subnet']['cidr'], remote_cidr)

        resources.tunnel_release(conn, svc['external_ip'])
        return batch


def compare_vpn_services(vrouter, gw_iface, old, new):
    if old.get('_reversed', False):
        old_conn_by_id = dict((x['id'], x) for x in old[_KEY_CONNECTIONS])
        for conn_new in new[_KEY_CONNECTIONS]:
            try:
                old_conn = old_conn_by_id[conn_new['id']]
            except KeyError:
                continue
            old_conn['psk'] = conn_new['psk']

    batch_old = set()
    batch_new = set()
    for svc, batch in (
            (old, batch_old), (new, batch_new)):
        for conn in svc[_KEY_CONNECTIONS]:
            patch = RouterResources(
                svc['id']).make_patch()
            commands = connect_setup_commands(
                vrouter, gw_iface, svc, conn, patch)
            commands = tuple((x.make_url('_dummy_') for x in commands))
            batch.add(commands)
    return batch_old == batch_new


VPN_STATE_MAP = {
    'up': True,
    'down': False
}


def parse_vpn_connections(ipsec_sa, resources):
    parser = v_parsers.TableParser()
    parser(ipsec_sa)
    assert not (len(parser) % 2)

    conn_status = collections.defaultdict(list)
    parser_iter = iter(parser)
    for endpoints, tunnels in itertools.izip(parser_iter, parser_iter):
        endpoints = iter(endpoints)
        try:
            peer = next(endpoints)
            peer = peer.cell_by_idx(0)
        except StopIteration:
            raise ValueError('Invalid VPN IPSec status report.')

        for tunn in tunnels:
            tunn_idx = tunn.cell_by_name('Tunnel')
            try:
                tunn_idx = int(tunn_idx)
            except ValueError:
                raise v_exc.InvalidResponseFormat(
                    details=('expect integer on place of tunnel index '
                             'got {!r}').format(tunn_idx))

            try:
                conn_id = resources.get_connect_by_tunnel(
                    peer, tunn_idx)
            except v_exc.ResourceNotFound:
                continue

            state = tunn.cell_by_name('State')
            state = state.lower()
            try:
                state = VPN_STATE_MAP[state]
            except KeyError:
                raise v_exc.InvalidResponseFormat(
                    details='unsupported tunnel state {!r}'.format(
                        state))
            conn_status[conn_id].append(state)

    conn_ok = set(key for key, value in six.iteritems(conn_status)
                  if all(value))

    return conn_ok


def validate_svc_connection(conn):
    # NOTE(asaprykin): Maybe it's better to move code into
    # separate function to avoid long try..except blocks
    try:
        map_encryption = {
            '3des': '3des',
            'aes-128': 'aes128',
            'aes-256': 'aes256'}
        map_pfs = {
            'group2': 'dh-group2',
            'group5': 'dh-group5'}

        allowed_pfs = map_pfs.values() + ['enable', 'disable']

        if conn['dpd_action'] not in ('hold', 'clear', 'restart'):
            raise ValueError('invalid dpd_action {0}'.format(
                conn['dpd_action']))

        ike_policy = conn[_KEY_IKEPOLICY]
        ike_policy['encryption_algorithm'] = \
            map_encryption[ike_policy['encryption_algorithm']]
        if ike_policy['lifetime_units'] != 'seconds':
            raise ValueError(
                'invalid "lifetime_units"=="{}" in ike_policy'.format(
                    ike_policy['lifetime_units']))

        esp_policy = conn[_KEY_ESPPOLICY]
        esp_policy['encryption_algorithm'] = \
            map_encryption[esp_policy['encryption_algorithm']]
        if esp_policy['lifetime_units'] != 'seconds':
            raise ValueError(
                'invalid "lifetime_units"=="{}" in esp_policy'.format(
                    esp_policy['lifetime_units']))
        if esp_policy['transform_protocol'] != 'esp':
            raise ValueError(
                'invalid "transform_protocol"=="{}" in esp_policy'.format(
                    esp_policy['transform_protocol']))
        pfs = esp_policy['pfs']
        esp_policy['pfs'] = map_pfs.get(pfs, pfs)
        if esp_policy['pfs'] not in allowed_pfs:
            raise ValueError(
                'invalid "pfs"=="{}" in esp_policy'.format(pfs))
        if esp_policy['encapsulation_mode'] not in ('tunnel', 'transport'):
            raise ValueError(
                'invalid "encapsulation_mode"=="{}" in esp_policy'.format(
                    esp_policy['encapsulation_mode']))
    except (ValueError, KeyError):
        with excutils.save_and_reraise_exception():
            raise v_exc.InvalidVPNServiceError()


# -- Tools ------------------------
class RouterResources(object):
    def __init__(self, router_id):
        self.router_id = router_id

        self.iface_to_conn = collections.defaultdict(set)
        self.ike_to_conn = collections.defaultdict(set)
        self.esp_to_conn = collections.defaultdict(set)
        self.conn_to_tunn = collections.defaultdict(set)
        self.peer_to_conn = dict()
        self.tunn_idx_factory = collections.defaultdict(itertools.count)

    def make_patch(self):
        return _RouteResourcePatch(self)

    def key_for_conn(self, conn):
        return ('connect', conn['id'])

    def key_for_tunnel(self, conn, remote_peer):
        key = list(self.key_for_conn(conn))
        key[0] = 'tunn'
        key.append(remote_peer)
        return tuple(key)

    def get_connect_by_tunnel(self, remote_peer, idx):
        key = (remote_peer, idx)
        try:
            conn_id = self.peer_to_conn[key]
        except KeyError:
            raise v_exc.ResourceNotFound
        return conn_id


class _RouteResourcePatch(object):
    NAME_LENGTH_LIMIT = 251

    def __init__(self, owner):
        self._owner = owner
        self._actions = list()
        self._tunn_factory_overlay = dict()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_val is not None:
            return

        self._apply()

    def iface_alloc(self, conn, name):
        key = self._owner.key_for_conn(conn)
        self._actions.append(_PatchSetAdd(
            self._owner.iface_to_conn[key], conn['id']))
        return not len(self._owner.iface_to_conn[key])

    def iface_release(self, conn, name):
        key = self._owner.key_for_conn(conn)
        self._actions.append(_PatchSetDel(
            self._owner.iface_to_conn[key], conn['id']))
        return self._owner.iface_to_conn[key] == {conn['id']}

    def ike_group_alloc(self, conn):
        return self._group_alloc(self._owner.ike_to_conn, conn, _KEY_IKEPOLICY)

    def esp_group_alloc(self, conn):
        return self._group_alloc(self._owner.esp_to_conn, conn, _KEY_ESPPOLICY)

    def _group_alloc(self, target, conn, policy_key):
        name = self._make_entity_name(conn[policy_key])
        self._actions.append(_PatchSetAdd(target[name], conn['id']))
        return name, not len(target[name])

    def ike_group_release(self, conn):
        return self._group_release(
            self._owner.ike_to_conn, conn, _KEY_IKEPOLICY)

    def esp_group_release(self, conn):
        return self._group_release(
            self._owner.esp_to_conn, conn, _KEY_ESPPOLICY)

    def _group_release(self, target, conn, policy_key):
        name = self._make_entity_name(conn[policy_key])
        self._actions.append(_PatchSetDel(target[name], conn['id']))
        return name, target[name] == {conn['id']}

    def tunnel_alloc(self, conn, remote_peer, idx=None):
        conn_key = self._owner.key_for_conn(conn)
        tunn_key = self._owner.key_for_tunnel(conn, remote_peer)
        if idx is None:
            idx = next(self._owner.tunn_idx_factory[tunn_key])
        else:
            try:
                max_idx, overlay = self._tunn_factory_overlay[tunn_key]
                max_idx = max(idx, max_idx)
                overlay.value = itertools.count(max_idx + 1)
            except KeyError:
                max_idx = idx
                overlay = _PatchDictAdd(
                    self._owner.tunn_idx_factory, tunn_key,
                    itertools.count(max_idx + 1))
                self._actions.append(overlay)
            self._tunn_factory_overlay[tunn_key] = max_idx, overlay

        self._actions.append(_PatchDictAdd(
            self._owner.peer_to_conn, (remote_peer, idx), conn['id']))
        self._actions.append(_PatchSetAdd(
            self._owner.conn_to_tunn[conn_key], idx))
        return idx

    def tunnel_release(self, conn, remote_peer, idx=None):
        key = self._owner.key_for_conn(conn)

        if idx is None:
            idx_seq = self._owner.conn_to_tunn[key]
        else:
            idx_seq = (idx, )

        for idx in idx_seq:
            self._actions.append(_PatchDictDel(
                self._owner.peer_to_conn, (remote_peer, idx)))
            self._actions.append(_PatchSetDel(
                self._owner.conn_to_tunn[key], idx))

    def _apply(self):
        for action in self._actions:
            action()

    def _make_entity_name(self, data):
        idnr = uuid.UUID(data['id'])
        name = data['name'].lower()
        name = ''.join((x if x.isalnum() else '') for x in name)
        name = '{0}-{1}'.format(name, idnr.get_hex())
        if self.NAME_LENGTH_LIMIT < len(name):
            raise v_exc.InvalidParameter(
                cause=('Can\'t make vyatta resource identifier, result exceed '
                       'length limit'))
        return name


@six.add_metaclass(abc.ABCMeta)
class _PatchActionAbstract(object):
    def __init__(self, target):
        self.target = target

    @abc.abstractmethod
    def __call__(self):
        pass


class _PatchDictAdd(_PatchActionAbstract):
    def __init__(self, target, key, value):
        _PatchActionAbstract.__init__(self, target)
        self.key = key
        self.value = value

    def __call__(self):
        self.target[self.key] = self.value


class _PatchDictDel(_PatchActionAbstract):
    def __init__(self, target, key, allow_missing=True):
        _PatchActionAbstract.__init__(self, target)
        self.key = key
        self.allow_missing = allow_missing

    def __call__(self):
        try:
            del self.target[self.key]
        except KeyError:
            if not self.allow_missing:
                raise v_exc.ResourceNotFound


class _PatchSetAdd(_PatchActionAbstract):
    def __init__(self, target, value):
        _PatchActionAbstract.__init__(self, target)
        self.value = value

    def __call__(self):
        self.target.add(self.value)


class _PatchSetDel(_PatchActionAbstract):
    def __init__(self, target, value, allow_missing=True):
        _PatchActionAbstract.__init__(self, target)
        self.value = value
        self.allow_missing = allow_missing

    def __call__(self):
        try:
            self.target.remove(self.value)
        except KeyError:
            if not self.allow_missing:
                raise v_exc.ResourceNotFound


def parse_vrouter_config(config, resources):
    try:
        config = config['vpn']
        config = config['ipsec']
        interfaces = config['ipsec-interfaces']
        site_to_site = config['site-to-site']
    except KeyError:
        return tuple()

    interfaces = interfaces.values()
    svc_set = collections.defaultdict(dict)

    key_prefix = 'peer '
    for peer in site_to_site:
        conn = site_to_site[peer]
        if not peer.startswith(key_prefix):
            continue
        peer = peer[len(key_prefix):]

        try:
            conn_data, svc_data = _parse_ipsec_site_to_site(
                peer, conn, config, resources)
        except v_exc.InvalidResponseFormat as e:
            LOG.error(
                _LE('process vRouter ipsec configuration: {0}').format(e))
            continue

        svc_id = svc_data['id']
        svc = svc_set[svc_id]
        svc.setdefault('_reversed', True)
        svc.update(svc_data)
        svc.setdefault('ipsec_site_connections', list()).append(conn_data)

        resources.ike_group_alloc(conn_data)
        resources.esp_group_alloc(conn_data)
        for iface in interfaces:
            resources.iface_alloc(conn_data, iface)
        for idx, remote in conn_data.pop('_tunn_and_cidr'):
            resources.tunnel_alloc(conn_data, remote, idx)

    return svc_set.values()


def _parse_ipsec_site_to_site(peer, conn, config, resources):
    result = dict()
    svc_upd = dict()

    try:
        idnr_set = conn['description'].split(':')
        if idnr_set.pop(0) != 'os-id':
            raise ValueError

        svc_id, conn_id = [str(uuid.UUID(x)) for x in idnr_set]
    except (KeyError, TypeError, ValueError):
        raise v_exc.InvalidResponseFormat(
            details='vpn connection does not contain neutron connection id')

    try:
        auth = conn['authentication']

        result['id'] = conn_id
        result['peer_address'] = peer
        result['psk'] = auth['pre-shared-secret']

        svc_upd['id'] = svc_id
        svc_upd['external_ip'] = conn['local-address']
    except KeyError:
        raise v_exc.InvalidResponseFormat(
            details='incomplete connection config')

    svc_upd['subnet'] = subnet = dict()

    key_prefix = 'tunnel '
    peer_cidr = list()
    for key in (x for x in conn if x.startswith(key_prefix)):
        idx, local, remote = _parse_ipsec_tunnel(key, conn)
        peer_cidr.append((idx, remote))
        subnet['cidr'] = local
    peer_cidr.sort(key=lambda x: x[0])
    result['_tunn_and_cidr'] = peer_cidr
    result['peer_cidrs'] = [x[1] for x in peer_cidr]

    ike, upd = _parse_ipsec_ike_group(conn, config, resources)
    result['ikepolicy'] = ike
    result.update(upd)

    esp, upd = _parse_ipsec_esp_group(conn, config, resources)
    result['ipsecpolicy'] = esp
    result.update(upd)

    return result, svc_upd


def _parse_ipsec_ike_group(conn, config, resources):
    result = dict()
    conn_upd = dict()

    try:
        name = conn['ike-group']
        policy = config['ike-group ' + name]
    except KeyError:
        raise v_exc.InvalidResponseFormat(details='ike group missing')

    name, idnr = _unpack_ipsec_group_name(name)

    try:
        dpd = policy['dead-peer-detection']
        proposal = policy['proposal 1']

        result['id'] = idnr
        result['name'] = name
        result['encryption_algorithm'] = proposal['encryption']
        result['auth_algorithm'] = proposal['hash']
        result['lifetime_value'] = policy['lifetime']

        conn_upd['dpd_action'] = dpd['action']
        conn_upd['dpd_interval'] = dpd['interval']
        conn_upd['dpd_timeout'] = dpd['timeout']
    except KeyError:
        raise v_exc.InvalidResponseFormat(
            details='incomplete IKE group config')

    return result, conn_upd


def _parse_ipsec_esp_group(conn, config, resources):
    result = dict()
    try:
        name = conn['default-esp-group']
        policy = config['esp-group ' + name]
    except KeyError:
        raise v_exc.InvalidResponseFormat(details='eps group missing')

    name, idnr = _unpack_ipsec_group_name(name)

    try:
        proposal = policy['proposal 1']

        result['id'] = idnr
        result['name'] = name
        result['encryption_algorithm'] = proposal['encryption']
        result['auth_algorithm'] = proposal['hash']
        result['lifetime_value'] = policy['lifetime']
        result['pfs'] = policy['pfs']
        result['encapsulation_mode'] = policy['mode']
    except KeyError:
        raise v_exc.InvalidResponseFormat(
            details='incomplete ESP group config')

    return result, dict()


def _parse_ipsec_tunnel(key, conn):
    try:
        peer = conn[key]
        idx = key.rsplit(' ', 1)[1]
        idx = int(idx)
    except (KeyError, IndexError, ValueError):
        raise v_exc.InvalidResponseFormat(
            details='invalid tunnel section "{0}"'.format(key))

    try:
        local = peer['local']
        local = local['prefix']

        remote = peer['remote']
        remote = remote['prefix']
    except KeyError:
        raise v_exc.InvalidResponseFormat(
            details='incomplete peer config')

    return idx, local, remote


def _unpack_ipsec_group_name(raw):
    try:
        name, idnr = raw.rsplit('-', 1)
        idnr = uuid.UUID(idnr)
        idnr = str(idnr)
    except (TypeError, ValueError):
        raise v_exc.InvalidResponseFormat(
            details='can\'t parse group name "{0}"'.format(raw))
    return name, idnr

Example 12

Project: oppia
Source File: stats_jobs_continuous.py
View license
    @staticmethod
    def reduce(key, stringified_values):
        """Updates statistics for the given (exploration, version) and list of
        events and creates batch model(ExplorationAnnotationsModel) for storing
        this output.

        Args:
            key: str. The exploration id and version of the exploration in the
                form 'exploration_id.version'.
            stringified_values: list(str). A list of stringified values
                associated with the given key. It includes information depending
                on the type of event. If type of event is 'event',
                an element of stringified_values would be:
                '{
                    'type': 'event',
                    'event_type': Type of event triggered,
                    'session_id': ID of current student's session,
                    'state_name': Name of current state,
                    'created_on': How many milliseconds ago the exploration was
                        created,
                    'exploration_id': The ID of the exploration,
                    'version': Version of exploration}'
                If type is 'counter', then an element of stringified_values
                would be of the form:
                '{
                    'type': 'counter',
                    'exploration_id': The ID of the exploration,
                    'version': Version of the exploration,
                    'state_name': Name of current state,
                    'first_entry_count': Number of times the state was entered
                        for the first time in a reader session,
                    'subsequent_entries_count': Number of times the state was
                        entered for the second time or later in a reader
                        session,
                    'resolved_answer_count': Number of times an answer
                        submitted for this state was subsequently resolved by
                        an exploration admin and removed from the answer logs,
                    'active_answer_count': Number of times an answer was entered
                        for this state and was not subsequently resolved by an
                        exploration admin}'
        """
        exploration = None
        exp_id, version = key.split(':')
        try:
            if version == VERSION_NONE:
                exploration = exp_services.get_exploration_by_id(exp_id)

                # Rewind to the last commit before the transition from
                # StateCounterModel.
                current_version = exploration.version
                while (exploration.last_updated > _STATE_COUNTER_CUTOFF_DATE
                       and current_version > 1):
                    current_version -= 1
                    exploration = exp_models.ExplorationModel.get_version(
                        exp_id, current_version)
            elif version == VERSION_ALL:
                exploration = exp_services.get_exploration_by_id(exp_id)
            else:
                exploration = exp_services.get_exploration_by_id(
                    exp_id, version=version)

        except base_models.BaseModel.EntityNotFoundError:
            return

        # Number of times exploration was started
        new_models_start_count = 0
        # Number of times exploration was completed
        new_models_complete_count = 0
        # Session ids that have completed this state
        new_models_end_sessions = set()
        # {session_id: (created-on timestamp of last known maybe leave event,
        # state_name)}
        session_id_to_latest_leave_evt = collections.defaultdict(
            lambda: (0, ''))
        old_models_start_count = 0
        old_models_complete_count = 0

        # {state_name: {'total_entry_count': ...,
        #               'first_entry_count': ...,
        #               'no_answer_count': ...}}
        state_hit_counts = collections.defaultdict(
            lambda: collections.defaultdict(int))
        for state_name in exploration.states:
            state_hit_counts[state_name] = {
                'total_entry_count': 0,
                'first_entry_count': 0,
                'no_answer_count': 0,
            }

        # {state_name: set(session ids that have reached this state)}
        state_session_ids = collections.defaultdict(set)
        for state_name in exploration.states:
            state_session_ids[state_name] = set([])

        # Iterate over and process each event for this exploration.
        for value_str in stringified_values:
            value = ast.literal_eval(value_str)

            state_name = value['state_name']

            # Convert the state name to unicode, if necessary.
            # Note: sometimes, item.state_name is None for
            # StateHitEventLogEntryModel.
            # TODO(sll): Track down the reason for this, and fix it.
            if (state_name is not None and
                    not isinstance(state_name, unicode)):
                state_name = state_name.decode('utf-8')

            if (value['type'] ==
                    StatisticsMRJobManager._TYPE_STATE_COUNTER_STRING):
                if state_name == exploration.init_state_name:
                    old_models_start_count = value['first_entry_count']
                if state_name == OLD_END_DEST:
                    old_models_complete_count = value['first_entry_count']
                else:
                    state_hit_counts[state_name]['no_answer_count'] += (
                        value['first_entry_count']
                        + value['subsequent_entries_count']
                        - value['resolved_answer_count']
                        - value['active_answer_count'])
                    state_hit_counts[state_name]['first_entry_count'] += (
                        value['first_entry_count'])
                    state_hit_counts[state_name]['total_entry_count'] += (
                        value['first_entry_count']
                        + value['subsequent_entries_count'])
                continue

            event_type = value['event_type']
            created_on = value['created_on']
            session_id = value['session_id']

            # If this is a start event, increment start count.
            if event_type == feconf.EVENT_TYPE_START_EXPLORATION:
                new_models_start_count += 1
            elif event_type == feconf.EVENT_TYPE_COMPLETE_EXPLORATION:
                new_models_complete_count += 1
                # Track that we have seen a 'real' end for this session id
                new_models_end_sessions.add(session_id)
            elif event_type == feconf.EVENT_TYPE_MAYBE_LEAVE_EXPLORATION:
                # Identify the last learner event for this session.
                latest_timestamp_so_far, _ = (
                    session_id_to_latest_leave_evt[session_id])
                if latest_timestamp_so_far < created_on:
                    latest_timestamp_so_far = created_on
                    session_id_to_latest_leave_evt[session_id] = (
                        created_on, state_name)
            # If this is a state hit, increment the total count and record that
            # we have seen this session id.
            elif event_type == feconf.EVENT_TYPE_STATE_HIT:
                state_hit_counts[state_name]['total_entry_count'] += 1
                state_session_ids[state_name].add(session_id)

        # After iterating through all events, take the size of the set of
        # session ids as the first entry count.
        for state_name in state_session_ids:
            state_hit_counts[state_name]['first_entry_count'] += len(
                state_session_ids[state_name])

        # Get the set of session ids that left without completing. This is
        # determined as the set of session ids with maybe-leave events at
        # intermediate states, minus the ones that have a maybe-leave event
        # at the END state.
        leave_states = set(session_id_to_latest_leave_evt.keys()).difference(
            new_models_end_sessions)
        for session_id in leave_states:
            # Grab the state name of the state they left on and count that as a
            # 'no answer' for that state.
            (_, state_name) = session_id_to_latest_leave_evt[session_id]
            state_hit_counts[state_name]['no_answer_count'] += 1

        num_starts = (
            old_models_start_count + new_models_start_count)
        num_completions = (
            old_models_complete_count + new_models_complete_count)

        stats_models.ExplorationAnnotationsModel.create(
            exp_id, str(version), num_starts, num_completions,
            state_hit_counts)

Example 13

Project: python-netsnmpagent
Source File: netsnmpagent.py
View license
	def __init__(self, **args):
		"""Initializes a new netsnmpAgent instance.
		
		"args" is a dictionary that can contain the following
		optional parameters:
		
		- AgentName     : The agent's name used for registration with net-snmp.
		- MasterSocket  : The transport specification of the AgentX socket of
		                  the running snmpd instance to connect to (see the
		                  "LISTENING ADDRESSES" section in the snmpd(8) manpage).
		                  Change this if you want to use eg. a TCP transport or
		                  access a custom snmpd instance, eg. as shown in
		                  run_simple_agent.sh, or for automatic testing.
		- PersistenceDir: The directory to use to store persistence information.
		                  Change this if you want to use a custom snmpd
		                  instance, eg. for automatic testing.
		- MIBFiles      : A list of filenames of MIBs to be loaded. Required if
		                  the OIDs, for which variables will be registered, do
		                  not belong to standard MIBs and the custom MIBs are not
		                  located in net-snmp's default MIB path
		                  (/usr/share/snmp/mibs).
		- UseMIBFiles   : Whether to use MIB files at all or not. When False,
		                  the parser for MIB files will not be initialized, so
		                  neither system-wide MIB files nor the ones provided
		                  in the MIBFiles argument will be in use.
		- LogHandler    : An optional Python function that will be registered
		                  with net-snmp as a custom log handler. If specified,
		                  this function will be called for every log message
		                  net-snmp itself generates, with parameters as follows:
		                  1. a string indicating the message's priority: one of
		                  "Emergency", "Alert", "Critical", "Error", "Warning",
		                  "Notice", "Info" or "Debug".
		                  2. the actual log message. Note that heading strings
		                  such as "Warning: " and "Error: " will be stripped off
		                  since the priority level is explicitly known and can
		                  be used to prefix the log message, if desired.
		                  Trailing linefeeds will also have been stripped off.
		                  If undefined, log messages will be written to stderr
		                  instead. """

		# Default settings
		defaults = {
			"AgentName"     : os.path.splitext(os.path.basename(sys.argv[0]))[0],
			"MasterSocket"  : None,
			"PersistenceDir": None,
			"UseMIBFiles"   : True,
			"MIBFiles"      : None,
			"LogHandler"    : None,
		}
		for key in defaults:
			setattr(self, key, args.get(key, defaults[key]))
		if self.UseMIBFiles and self.MIBFiles is not None and type(self.MIBFiles) not in (list, tuple):
			self.MIBFiles = (self.MIBFiles,)

		# Initialize status attribute -- until start() is called we will accept
		# SNMP object registrations
		self._status = netsnmpAgentStatus.REGISTRATION

		# Unfortunately net-snmp does not give callers of init_snmp() (used
		# in the start() method) any feedback about success or failure of
		# connection establishment. But for AgentX clients this information is
		# quite essential, thus we need to implement some more or less ugly
		# workarounds.

		# For net-snmp 5.7.x, we can derive success and failure from the log
		# messages it generates. Normally these go to stderr, in the absence
		# of other so-called log handlers. Alas we define a callback function
		# that we will register with net-snmp as a custom log handler later on,
		# hereby effectively gaining access to the desired information.
		def _py_log_handler(majorID, minorID, serverarg, clientarg):
			# "majorID" and "minorID" are the callback IDs with which this
			# callback function was registered. They are useful if the same
			# callback was registered multiple times.
			# Both "serverarg" and "clientarg" are pointers that can be used to
			# convey information from the calling context to the callback
			# function: "serverarg" gets passed individually to every call of
			# snmp_call_callbacks() while "clientarg" was initially passed to
			# snmp_register_callback().

			# In this case, "majorID" and "minorID" are always the same (see the
			# registration code below). "serverarg" needs to be cast back to
			# become a pointer to a "snmp_log_message" C structure (passed by
			# net-snmp's log_handler_callback() in snmplib/snmp_logging.c) while
			# "clientarg" will be None (see the registration code below).
			logmsg = ctypes.cast(serverarg, snmp_log_message_p)

			# Generate textual description of priority level
			priorities = {
				LOG_EMERG: "Emergency",
				LOG_ALERT: "Alert",
				LOG_CRIT: "Critical",
				LOG_ERR: "Error",
				LOG_WARNING: "Warning",
				LOG_NOTICE: "Notice",
				LOG_INFO: "Info",
				LOG_DEBUG: "Debug"
			}
			msgprio = priorities[logmsg.contents.priority]

			# Strip trailing linefeeds and in addition "Warning: " and "Error: "
			# from msgtext as these conditions are already indicated through
			# msgprio
			msgtext = re.sub(
				"^(Warning|Error): *",
				"",
				u(logmsg.contents.msg.rstrip(b"\n"))
			)

			# Intercept log messages related to connection establishment and
			# failure to update the status of this netsnmpAgent object. This is
			# really an ugly hack, introducing a dependency on the particular
			# text of log messages -- hopefully the net-snmp guys won't
			# translate them one day.
			if  msgprio == "Warning" \
			or  msgprio == "Error" \
			and re.match("Failed to .* the agentx master agent.*", msgtext):
				# If this was the first connection attempt, we consider the
				# condition fatal: it is more likely that an invalid
				# "MasterSocket" was specified than that we've got concurrency
				# issues with our agent being erroneously started before snmpd.
				if self._status == netsnmpAgentStatus.FIRSTCONNECT:
					self._status = netsnmpAgentStatus.CONNECTFAILED

					# No need to log this message -- we'll generate our own when
					# throwing a netsnmpAgentException as consequence of the
					# ECONNECT
					return 0

				# Otherwise we'll stay at status RECONNECTING and log net-snmp's
				# message like any other. net-snmp code will keep retrying to
				# connect.
			elif msgprio == "Info" \
			and  re.match("AgentX subagent connected", msgtext):
				self._status = netsnmpAgentStatus.CONNECTED
			elif msgprio == "Info" \
			and  re.match("AgentX master disconnected us.*", msgtext):
				self._status = netsnmpAgentStatus.RECONNECTING

			# If "LogHandler" was defined, call it to take care of logging.
			# Otherwise print all log messages to stderr to resemble net-snmp
			# standard behavior (but add log message's associated priority in
			# plain text as well)
			if self.LogHandler:
				self.LogHandler(msgprio, msgtext)
			else:
				print("[{0}] {1}".format(msgprio, msgtext))

			return 0

		# We defined a Python function that needs a ctypes conversion so it can
		# be called by C code such as net-snmp. That's what SNMPCallback() is
		# used for. However we also need to store the reference in "self" as it
		# will otherwise be lost at the exit of this function so that net-snmp's
		# attempt to call it would end in nirvana...
		self._log_handler = SNMPCallback(_py_log_handler)

		# Now register our custom log handler with majorID SNMP_CALLBACK_LIBRARY
		# and minorID SNMP_CALLBACK_LOGGING.
		if libnsa.snmp_register_callback(
			SNMP_CALLBACK_LIBRARY,
			SNMP_CALLBACK_LOGGING,
			self._log_handler,
			None
		) != SNMPERR_SUCCESS:
			raise netsnmpAgentException(
				"snmp_register_callback() failed for _netsnmp_log_handler!"
			)

		# Finally the net-snmp logging system needs to be told to enable
		# logging through callback functions. This will actually register a
		# NETSNMP_LOGHANDLER_CALLBACK log handler that will call out to any
		# callback functions with the majorID and minorID shown above, such as
		# ours.
		libnsa.snmp_enable_calllog()

		# Unfortunately our custom log handler above is still not enough: in
		# net-snmp 5.4.x there were no "AgentX master disconnected" log
		# messages yet. So we need another workaround to be able to detect
		# disconnects for this release. Both net-snmp 5.4.x and 5.7.x support
		# a callback mechanism using the "majorID" SNMP_CALLBACK_APPLICATION and
		# the "minorID" SNMPD_CALLBACK_INDEX_STOP, which we can abuse for our
		# purposes. Again, we start by defining a callback function.
		def _py_index_stop_callback(majorID, minorID, serverarg, clientarg):
			# For "majorID" and "minorID" see our log handler above.
			# "serverarg" is a disguised pointer to a "netsnmp_session"
			# structure (passed by net-snmp's subagent_open_master_session() and
			# agentx_check_session() in agent/mibgroup/agentx/subagent.c). We
			# can ignore it here since we have a single session only anyway.
			# "clientarg" will be None again (see the registration code below).

			# We only care about SNMPD_CALLBACK_INDEX_STOP as our custom log
			# handler above already took care of all other events.
			if minorID == SNMPD_CALLBACK_INDEX_STOP:
				self._status = netsnmpAgentStatus.RECONNECTING

			return 0

		# Convert it to a C callable function and store its reference
		self._index_stop_callback = SNMPCallback(_py_index_stop_callback)

		# Register it with net-snmp
		if libnsa.snmp_register_callback(
			SNMP_CALLBACK_APPLICATION,
			SNMPD_CALLBACK_INDEX_STOP,
			self._index_stop_callback,
			None
		) != SNMPERR_SUCCESS:
			raise netsnmpAgentException(
				"snmp_register_callback() failed for _netsnmp_index_callback!"
			)

		# No enabling necessary here

		# Make us an AgentX client
		if libnsa.netsnmp_ds_set_boolean(
			NETSNMP_DS_APPLICATION_ID,
			NETSNMP_DS_AGENT_ROLE,
			1
		) != SNMPERR_SUCCESS:
			raise netsnmpAgentException(
				"netsnmp_ds_set_boolean() failed for NETSNMP_DS_AGENT_ROLE!"
			)

		# Use an alternative transport specification to connect to the master?
		# Defaults to "/var/run/agentx/master".
		# (See the "LISTENING ADDRESSES" section in the snmpd(8) manpage)
		if self.MasterSocket:
			if libnsa.netsnmp_ds_set_string(
				NETSNMP_DS_APPLICATION_ID,
				NETSNMP_DS_AGENT_X_SOCKET,
				b(self.MasterSocket)
			) != SNMPERR_SUCCESS:
				raise netsnmpAgentException(
					"netsnmp_ds_set_string() failed for NETSNMP_DS_AGENT_X_SOCKET!"
				)

		# Use an alternative persistence directory?
		if self.PersistenceDir:
			if libnsa.netsnmp_ds_set_string(
				NETSNMP_DS_LIBRARY_ID,
				NETSNMP_DS_LIB_PERSISTENT_DIR,
				b(self.PersistenceDir)
			) != SNMPERR_SUCCESS:
				raise netsnmpAgentException(
					"netsnmp_ds_set_string() failed for NETSNMP_DS_LIB_PERSISTENT_DIR!"
				)

		# Initialize net-snmp library (see netsnmp_agent_api(3))
		if libnsa.init_agent(b(self.AgentName)) != 0:
			raise netsnmpAgentException("init_agent() failed!")

		# Initialize MIB parser
		if self.UseMIBFiles:
			libnsa.netsnmp_init_mib()

		# If MIBFiles were specified (ie. MIBs that can not be found in
		# net-snmp's default MIB directory /usr/share/snmp/mibs), read
		# them in so we can translate OID strings to net-snmp's internal OID
		# format.
		if self.UseMIBFiles and self.MIBFiles:
			for mib in self.MIBFiles:
				if libnsa.read_mib(b(mib)) == 0:
					raise netsnmpAgentException("netsnmp_read_module({0}) " +
					                            "failed!".format(mib))

		# Initialize our SNMP object registry
		self._objs = defaultdict(dict)

Example 14

View license
def _fetch_distribution(container_root,  # pylint:disable=R0913
                        proot_distro,
                        details):
    """Lazy-initialize distribution and return it."""
    path_to_distro_folder = get_dir_for_distro(container_root,
                                               details)

    def _download_distro(details, path_to_distro_folder):
        """Download distribution and untar it in container root."""
        distro_arch = details["arch"]
        download_url = details["url"].format(arch=distro_arch)
        with tempdir.TempDir() as download_dir:
            with directory.Navigation(download_dir):
                with TemporarilyDownloadedFile(download_url) as archive_file:
                    _extract_distro_archive(archive_file,
                                            path_to_distro_folder)

    def _minimize_ubuntu(cont):
        """Reduce the install footprint of ubuntu as much as possible."""
        required_packages = {
            "precise": set([
                "apt",
                "base-files",
                "base-passwd",
                "bash",
                "bsdutils",
                "coreutils",
                "dash",
                "debconf",
                "debianutils",
                "diffutils",
                "dpkg",
                "findutils",
                "gcc-4.6-base",
                "gnupg",
                "gpgv",
                "grep",
                "gzip",
                "libacl1",
                "libapt-pkg4.12",
                "libattr1",
                "libbz2-1.0",
                "libc-bin",
                "libc6",
                "libdb5.1",
                "libffi6",
                "libgcc1",
                "liblzma5",
                "libpam-modules",
                "libpam-modules-bin",
                "libpam-runtime",
                "libpam0g",
                "libreadline6",
                "libselinux1",
                "libstdc++6",
                "libtinfo5",
                "libusb-0.1-4",
                "makedev",
                "mawk",
                "multiarch-support",
                "perl-base",
                "readline-common",
                "sed",
                "sensible-utils",
                "tar",
                "tzdata",
                "ubuntu-keyring",
                "xz-utils",
                "zlib1g"
            ]),
            "trusty": set([
                "apt",
                "base-files",
                "base-passwd",
                "bash",
                "bsdutils",
                "coreutils",
                "dash",
                "debconf",
                "debianutils",
                "diffutils",
                "dh-python",
                "dpkg",
                "findutils",
                "gcc-4.8-base",
                "gcc-4.9-base",
                "gnupg",
                "gpgv",
                "grep",
                "gzip",
                "libacl1",
                "libapt-pkg4.12",
                "libaudit1",
                "libaudit-common",
                "libattr1",
                "libbz2-1.0",
                "libc-bin",
                "libc6",
                "libcap2",
                "libdb5.3",
                "libdebconfclient0",
                "libexpat1",
                "libmpdec2",
                "libffi6",
                "libgcc1",
                "liblzma5",
                "libncursesw5",
                "libpcre3",
                "libpam-modules",
                "libpam-modules-bin",
                "libpam-runtime",
                "libpam0g",
                "libpython3-stdlib",
                "libpython3.4-stdlib",
                "libpython3",
                "libpython3-minimal",
                "libpython3.4",
                "libpython3.4-minimal",
                "libreadline6",
                "libselinux1",
                "libssl1.0.0",
                "libstdc++6",
                "libsqlite3-0",
                "libtinfo5",
                "libusb-0.1-4",
                "lsb-release",
                "makedev",
                "mawk",
                "mime-support",
                "multiarch-support",
                "perl-base",
                "python3",
                "python3-minimal",
                "python3.4",
                "python3.4-minimal",
                "readline-common",
                "sed",
                "sensible-utils",
                "tar",
                "tzdata",
                "ubuntu-keyring",
                "xz-utils",
                "zlib1g"
            ])
        }

        os.environ["SUDO_FORCE_REMOVE"] = "yes"
        os.environ["DEBIAN_FRONTEND"] = "noninteractive"

        pkgs = set(cont.execute(["dpkg-query",
                                 "-Wf",
                                 "${Package}\n"])[1].split("\n"))
        release = details["release"]
        remove = [l for l in list(pkgs ^ required_packages[release]) if len(l)]

        if len(remove):
            cont.execute_success(["dpkg",
                                  "--purge",
                                  "--force-all"] + remove,
                                 minimal_bind=True)

        with open(os.path.join(get_dir_for_distro(container_root,
                                                  details),
                               "etc",
                               "apt",
                               "apt.conf.d",
                               "99container"), "w") as apt_config:
            apt_config.write("\n".join([
                "APT::Install-Recommends \"0\";",
                "APT::Install-Suggests \"0\";"
            ]))

    # Container isn't safe to use until we've either verified that the
    # path to the distro folder exists or we've downloaded a distro into it
    linux_cont = LinuxContainer(proot_distro,
                                path_to_distro_folder,
                                details["release"],
                                details["arch"],
                                details["pkgsys"])

    try:
        os.stat(path_to_distro_folder)
        use_existing_msg = ("""\N{check mark} Using existing folder for """
                            """proot distro """
                            """{distro} {release} {arch}\n""")
        printer.unicode_safe(colored.green(use_existing_msg.format(**details),
                                           bold=True))
    except OSError:
        # Download the distribution tarball in the distro dir
        _download_distro(details, path_to_distro_folder)

        # Minimize the installed distribution, but only when it
        # was just initially downloaded
        minimize_actions = defaultdict(lambda: lambda c: None,
                                       Ubuntu=_minimize_ubuntu)
        minimize_actions[details["distro"]](linux_cont)

    return linux_cont

Example 15

Project: powerline
Source File: __init__.py
View license
def check(paths=None, debug=False, echoerr=echoerr, require_ext=None):
	'''Check configuration sanity

	:param list paths:
		Paths from which configuration should be loaded.
	:param bool debug:
		Determines whether some information useful for debugging linter should 
		be output.
	:param function echoerr:
		Function that will be used to echo the error(s). Should accept four 
		optional keyword parameters: ``problem`` and ``problem_mark``, and 
		``context`` and ``context_mark``.
	:param str require_ext:
		Require configuration for some extension to be present.

	:return:
		``False`` if user configuration seems to be completely sane and ``True`` 
		if some problems were found.
	'''
	hadproblem = False

	register_common_names()
	search_paths = paths or get_config_paths()
	find_config_files = generate_config_finder(lambda: search_paths)

	logger = logging.getLogger('powerline-lint')
	logger.setLevel(logging.DEBUG if debug else logging.ERROR)
	logger.addHandler(logging.StreamHandler())

	ee = EchoErr(echoerr, logger)

	if require_ext:
		used_main_spec = main_spec.copy()
		try:
			used_main_spec['ext'][require_ext].required()
		except KeyError:
			used_main_spec['ext'][require_ext] = ext_spec()
	else:
		used_main_spec = main_spec

	lhadproblem = [False]
	load_json_config = generate_json_config_loader(lhadproblem)

	config_loader = ConfigLoader(run_once=True, load=load_json_config)

	lists = {
		'colorschemes': set(),
		'themes': set(),
		'exts': set(),
	}
	found_dir = {
		'themes': False,
		'colorschemes': False,
	}
	config_paths = defaultdict(lambda: defaultdict(dict))
	loaded_configs = defaultdict(lambda: defaultdict(dict))
	for d in chain(
		find_all_ext_config_files(search_paths, 'colorschemes'),
		find_all_ext_config_files(search_paths, 'themes'),
	):
		if d['error']:
			hadproblem = True
			ee(problem=d['error'])
			continue
		if d['hadproblem']:
			hadproblem = True
		if d['ext']:
			found_dir[d['type']] = True
			lists['exts'].add(d['ext'])
			if d['name'] == '__main__':
				pass
			elif d['name'].startswith('__') or d['name'].endswith('__'):
				hadproblem = True
				ee(problem='File name is not supposed to start or end with “__”: {0}'.format(
					d['path']))
			else:
				lists[d['type']].add(d['name'])
			config_paths[d['type']][d['ext']][d['name']] = d['path']
			loaded_configs[d['type']][d['ext']][d['name']] = d['config']
		else:
			config_paths[d['type']][d['name']] = d['path']
			loaded_configs[d['type']][d['name']] = d['config']

	for typ in ('themes', 'colorschemes'):
		if not found_dir[typ]:
			hadproblem = True
			ee(problem='Subdirectory {0} was not found in paths {1}'.format(typ, ', '.join(search_paths)))

	diff = set(config_paths['colorschemes']) - set(config_paths['themes'])
	if diff:
		hadproblem = True
		for ext in diff:
			typ = 'colorschemes' if ext in config_paths['themes'] else 'themes'
			if not config_paths['top_' + typ] or typ == 'themes':
				ee(problem='{0} extension {1} not present in {2}'.format(
					ext,
					'configuration' if (
						ext in loaded_configs['themes'] and ext in loaded_configs['colorschemes']
					) else 'directory',
					typ,
				))

	try:
		main_config = load_config('config', find_config_files, config_loader)
	except IOError:
		main_config = {}
		ee(problem='Configuration file not found: config.json')
		hadproblem = True
	except MarkedError as e:
		main_config = {}
		ee(problem=str(e))
		hadproblem = True
	else:
		if used_main_spec.match(
			main_config,
			data={'configs': config_paths, 'lists': lists},
			context=Context(main_config),
			echoerr=ee
		)[1]:
			hadproblem = True

	import_paths = [os.path.expanduser(path) for path in main_config.get('common', {}).get('paths', [])]

	try:
		colors_config = load_config('colors', find_config_files, config_loader)
	except IOError:
		colors_config = {}
		ee(problem='Configuration file not found: colors.json')
		hadproblem = True
	except MarkedError as e:
		colors_config = {}
		ee(problem=str(e))
		hadproblem = True
	else:
		if colors_spec.match(colors_config, context=Context(colors_config), echoerr=ee)[1]:
			hadproblem = True

	if lhadproblem[0]:
		hadproblem = True

	top_colorscheme_configs = dict(loaded_configs['top_colorschemes'])
	data = {
		'ext': None,
		'top_colorscheme_configs': top_colorscheme_configs,
		'ext_colorscheme_configs': {},
		'colors_config': colors_config
	}
	for colorscheme, config in loaded_configs['top_colorschemes'].items():
		data['colorscheme'] = colorscheme
		if top_colorscheme_spec.match(config, context=Context(config), data=data, echoerr=ee)[1]:
			hadproblem = True

	ext_colorscheme_configs = dict2(loaded_configs['colorschemes'])
	for ext, econfigs in ext_colorscheme_configs.items():
		data = {
			'ext': ext,
			'top_colorscheme_configs': top_colorscheme_configs,
			'ext_colorscheme_configs': ext_colorscheme_configs,
			'colors_config': colors_config,
		}
		for colorscheme, config in econfigs.items():
			data['colorscheme'] = colorscheme
			if ext == 'vim':
				spec = vim_colorscheme_spec
			elif ext == 'shell':
				spec = shell_colorscheme_spec
			else:
				spec = colorscheme_spec
			if spec.match(config, context=Context(config), data=data, echoerr=ee)[1]:
				hadproblem = True

	colorscheme_configs = {}
	for ext in lists['exts']:
		colorscheme_configs[ext] = {}
		for colorscheme in lists['colorschemes']:
			econfigs = ext_colorscheme_configs[ext]
			ecconfigs = econfigs.get(colorscheme)
			mconfigs = (
				top_colorscheme_configs.get(colorscheme),
				econfigs.get('__main__'),
				ecconfigs,
			)
			if not (mconfigs[0] or mconfigs[2]):
				continue
			config = None
			for mconfig in mconfigs:
				if not mconfig:
					continue
				if config:
					config = mergedicts_copy(config, mconfig)
				else:
					config = mconfig
			colorscheme_configs[ext][colorscheme] = config

	theme_configs = dict2(loaded_configs['themes'])
	top_theme_configs = dict(loaded_configs['top_themes'])
	for ext, configs in theme_configs.items():
		data = {
			'ext': ext,
			'colorscheme_configs': colorscheme_configs,
			'import_paths': import_paths,
			'main_config': main_config,
			'top_themes': top_theme_configs,
			'ext_theme_configs': configs,
			'colors_config': colors_config
		}
		for theme, config in configs.items():
			data['theme'] = theme
			if theme == '__main__':
				data['theme_type'] = 'main'
				spec = main_theme_spec
			else:
				data['theme_type'] = 'regular'
				spec = theme_spec
			if spec.match(config, context=Context(config), data=data, echoerr=ee)[1]:
				hadproblem = True

	for top_theme, config in top_theme_configs.items():
		data = {
			'ext': None,
			'colorscheme_configs': colorscheme_configs,
			'import_paths': import_paths,
			'main_config': main_config,
			'theme_configs': theme_configs,
			'ext_theme_configs': None,
			'colors_config': colors_config
		}
		data['theme_type'] = 'top'
		data['theme'] = top_theme
		if top_theme_spec.match(config, context=Context(config), data=data, echoerr=ee)[1]:
			hadproblem = True

	return hadproblem

Example 16

Project: pretix
Source File: invoices.py
View license
def _invoice_generate_german(invoice, f):
    _invoice_register_fonts()
    styles = _invoice_get_stylesheet()
    pagesize = pagesizes.A4

    def on_page(canvas, doc):
        canvas.saveState()
        canvas.setFont('OpenSans', 8)
        canvas.drawRightString(pagesize[0] - 20 * mm, 10 * mm, _("Page %d") % (doc.page,))

        for i, line in enumerate(invoice.footer_text.split('\n')[::-1]):
            canvas.drawCentredString(pagesize[0] / 2, 25 + (3.5 * i) * mm, line.strip())

        canvas.restoreState()

    def on_first_page(canvas, doc):
        canvas.setCreator('pretix.eu')
        canvas.setTitle(pgettext('invoice', 'Invoice {num}').format(num=invoice.number))

        canvas.saveState()
        canvas.setFont('OpenSans', 8)
        canvas.drawRightString(pagesize[0] - 20 * mm, 10 * mm, _("Page %d") % (doc.page,))

        for i, line in enumerate(invoice.footer_text.split('\n')[::-1]):
            canvas.drawCentredString(pagesize[0] / 2, 25 + (3.5 * i) * mm, line.strip())

        textobject = canvas.beginText(25 * mm, (297 - 15) * mm)
        textobject.setFont('OpenSansBd', 8)
        textobject.textLine(pgettext('invoice', 'Invoice from').upper())
        textobject.moveCursor(0, 5)
        textobject.setFont('OpenSans', 10)
        textobject.textLines(invoice.invoice_from.strip())
        canvas.drawText(textobject)

        textobject = canvas.beginText(25 * mm, (297 - 50) * mm)
        textobject.setFont('OpenSansBd', 8)
        textobject.textLine(pgettext('invoice', 'Invoice to').upper())
        textobject.moveCursor(0, 5)
        textobject.setFont('OpenSans', 10)
        textobject.textLines(invoice.invoice_to.strip())
        canvas.drawText(textobject)

        textobject = canvas.beginText(125 * mm, (297 - 50) * mm)
        textobject.setFont('OpenSansBd', 8)
        if invoice.is_cancellation:
            textobject.textLine(pgettext('invoice', 'Cancellation number').upper())
            textobject.moveCursor(0, 5)
            textobject.setFont('OpenSans', 10)
            textobject.textLine(invoice.number)
            textobject.moveCursor(0, 5)
            textobject.setFont('OpenSansBd', 8)
            textobject.textLine(pgettext('invoice', 'Original invoice').upper())
            textobject.moveCursor(0, 5)
            textobject.setFont('OpenSans', 10)
            textobject.textLine(invoice.refers.number)
        else:
            textobject.textLine(pgettext('invoice', 'Invoice number').upper())
            textobject.moveCursor(0, 5)
            textobject.setFont('OpenSans', 10)
            textobject.textLine(invoice.number)
        textobject.moveCursor(0, 5)

        if invoice.is_cancellation:
            textobject.setFont('OpenSansBd', 8)
            textobject.textLine(pgettext('invoice', 'Cancellation date').upper())
            textobject.moveCursor(0, 5)
            textobject.setFont('OpenSans', 10)
            textobject.textLine(date_format(invoice.date, "DATE_FORMAT"))
            textobject.moveCursor(0, 5)
            textobject.setFont('OpenSansBd', 8)
            textobject.textLine(pgettext('invoice', 'Original invoice date').upper())
            textobject.moveCursor(0, 5)
            textobject.setFont('OpenSans', 10)
            textobject.textLine(date_format(invoice.refers.date, "DATE_FORMAT"))
            textobject.moveCursor(0, 5)
        else:
            textobject.setFont('OpenSansBd', 8)
            textobject.textLine(pgettext('invoice', 'Invoice date').upper())
            textobject.moveCursor(0, 5)
            textobject.setFont('OpenSans', 10)
            textobject.textLine(date_format(invoice.date, "DATE_FORMAT"))
            textobject.moveCursor(0, 5)

        canvas.drawText(textobject)

        textobject = canvas.beginText(165 * mm, (297 - 50) * mm)
        textobject.setFont('OpenSansBd', 8)
        textobject.textLine(_('Order code').upper())
        textobject.moveCursor(0, 5)
        textobject.setFont('OpenSans', 10)
        textobject.textLine(invoice.order.full_code)
        textobject.moveCursor(0, 5)
        textobject.setFont('OpenSansBd', 8)
        textobject.textLine(_('Order date').upper())
        textobject.moveCursor(0, 5)
        textobject.setFont('OpenSans', 10)
        textobject.textLine(date_format(invoice.order.datetime, "DATE_FORMAT"))
        canvas.drawText(textobject)

        textobject = canvas.beginText(125 * mm, (297 - 15) * mm)
        textobject.setFont('OpenSansBd', 8)
        textobject.textLine(_('Event').upper())
        textobject.moveCursor(0, 5)
        textobject.setFont('OpenSans', 10)
        textobject.textLine(str(invoice.event.name))
        if invoice.event.settings.show_date_to:
            textobject.textLines(
                _('{from_date}\nuntil {to_date}').format(from_date=invoice.event.get_date_from_display(),
                                                         to_date=invoice.event.get_date_to_display()))
        else:
            textobject.textLine(invoice.event.get_date_from_display())
        canvas.drawText(textobject)

        canvas.restoreState()

    doc = BaseDocTemplate(f.name, pagesize=pagesizes.A4,
                          leftMargin=25 * mm, rightMargin=20 * mm,
                          topMargin=20 * mm, bottomMargin=15 * mm)

    footer_length = 3.5 * len(invoice.footer_text.split('\n')) * mm
    frames_p1 = [
        Frame(doc.leftMargin, doc.bottomMargin, doc.width, doc.height - 75 * mm,
              leftPadding=0, rightPadding=0, topPadding=0, bottomPadding=footer_length,
              id='normal')
    ]
    frames = [
        Frame(doc.leftMargin, doc.bottomMargin, doc.width, doc.height,
              leftPadding=0, rightPadding=0, topPadding=0, bottomPadding=footer_length,
              id='normal')
    ]
    doc.addPageTemplates([
        PageTemplate(id='FirstPage', frames=frames_p1, onPage=on_first_page, pagesize=pagesize),
        PageTemplate(id='OtherPages', frames=frames, onPage=on_page, pagesize=pagesize)
    ])
    story = [
        NextPageTemplate('FirstPage'),
        Paragraph(pgettext('invoice', 'Invoice')
                  if not invoice.is_cancellation
                  else pgettext('invoice', 'Cancellation'),
                  styles['Heading1']),
        Spacer(1, 5 * mm),
        NextPageTemplate('OtherPages'),
    ]

    if invoice.introductory_text:
        story.append(Paragraph(invoice.introductory_text, styles['Normal']))
        story.append(Spacer(1, 10 * mm))

    taxvalue_map = defaultdict(Decimal)
    grossvalue_map = defaultdict(Decimal)

    tstyledata = [
        ('ALIGN', (1, 0), (-1, -1), 'RIGHT'),
        ('FONTNAME', (0, 0), (-1, 0), 'OpenSansBd'),
        ('FONTNAME', (0, -1), (-1, -1), 'OpenSansBd'),
        ('LEFTPADDING', (0, 0), (0, -1), 0),
        ('RIGHTPADDING', (-1, 0), (-1, -1), 0),
    ]
    tdata = [(
        pgettext('invoice', 'Description'),
        pgettext('invoice', 'Tax rate'),
        pgettext('invoice', 'Net'),
        pgettext('invoice', 'Gross'),
    )]
    total = Decimal('0.00')
    for line in invoice.lines.all():
        tdata.append((
            line.description,
            lformat("%.2f", line.tax_rate) + " %",
            lformat("%.2f", line.net_value) + " " + invoice.event.currency,
            lformat("%.2f", line.gross_value) + " " + invoice.event.currency,
        ))
        taxvalue_map[line.tax_rate] += line.tax_value
        grossvalue_map[line.tax_rate] += line.gross_value
        total += line.gross_value

    tdata.append([pgettext('invoice', 'Invoice total'), '', '', lformat("%.2f", total) + " " + invoice.event.currency])
    colwidths = [a * doc.width for a in (.55, .15, .15, .15)]
    table = Table(tdata, colWidths=colwidths, repeatRows=1)
    table.setStyle(TableStyle(tstyledata))
    story.append(table)

    story.append(Spacer(1, 15 * mm))

    if invoice.payment_provider_text:
        story.append(Paragraph(invoice.payment_provider_text, styles['Normal']))

    if invoice.additional_text:
        story.append(Paragraph(invoice.additional_text, styles['Normal']))
        story.append(Spacer(1, 15 * mm))

    tstyledata = [
        ('SPAN', (1, 0), (-1, 0)),
        ('ALIGN', (2, 1), (-1, -1), 'RIGHT'),
        ('LEFTPADDING', (0, 0), (0, -1), 0),
        ('RIGHTPADDING', (-1, 0), (-1, -1), 0),
        ('FONTSIZE', (0, 0), (-1, -1), 8),
    ]
    tdata = [('', pgettext('invoice', 'Included taxes'), '', '', ''),
             ('', pgettext('invoice', 'Tax rate'),
              pgettext('invoice', 'Net value'), pgettext('invoice', 'Gross value'), pgettext('invoice', 'Tax'))]

    for rate, gross in grossvalue_map.items():
        if line.tax_rate == 0:
            continue
        tax = taxvalue_map[rate]
        tdata.append((
            '',
            lformat("%.2f", rate) + " %",
            lformat("%.2f", (gross - tax)) + " " + invoice.event.currency,
            lformat("%.2f", gross) + " " + invoice.event.currency,
            lformat("%.2f", tax) + " " + invoice.event.currency,
        ))

    if len(tdata) > 2:
        colwidths = [a * doc.width for a in (.45, .10, .15, .15, .15)]
        table = Table(tdata, colWidths=colwidths, repeatRows=2)
        table.setStyle(TableStyle(tstyledata))
        story.append(table)

    doc.build(story)
    return doc

Example 17

Project: bayeslite
Source File: cgpm_metamodel.py
View license
def _create_schema(bdb, generator_id, schema_ast, **kwargs):
    # Get some parameters.
    population_id = core.bayesdb_generator_population(bdb, generator_id)
    table = core.bayesdb_population_table(bdb, population_id)

    # State.
    variables = []
    variable_dist = {}
    latents = {}
    cgpm_composition = []
    modelled = set()
    default_modelled = set()
    subsample = None
    deferred_input = defaultdict(lambda: [])
    deferred_output = dict()

    # Error-reporting state.
    duplicate = set()
    unknown = set()
    needed = set()
    existing_latent = set()
    must_exist = []
    unknown_stattype = {}

    # XXX Convert all Foreign.exposed lists to Latent clauses.
    # Retrieve Foreign clauses with exposed variables.
    foreign_clauses = [
        c for c in schema_ast
        if isinstance(c, cgpm_schema.parse.Foreign) and len(c.exposed) > 0
    ]
    # Add the exposed variables to Foreign.outputs
    # Note that this assumes if there are K exposed variables, then they are
    # necessarily the last K outputs of the fc.outputs.
    for fc in foreign_clauses:
        fc.outputs.extend([e[0] for e in fc.exposed])

    # Convert exposed entries into Latent clauses.
    latent_vars = list(itertools.chain.from_iterable(
        c.exposed for c in foreign_clauses))
    latent_clauses = [cgpm_schema.parse.Latent(v,s) for (v,s) in latent_vars]
    # Append the Latent clauses to the ast.
    schema_ast.extend(latent_clauses)

    # XXX Convert the baseline to a Foreign clause.
    # Currently the baselines do not accept a schema, and will fail if
    # `schema_ast` has any entries.
    baseline = kwargs.get('baseline', None)
    if baseline is not None and casefold(baseline.name) != 'crosscat':
        if schema_ast:
            raise BQLError(bdb,
                'Cannot accept schema with baseline: %s.' % schema_ast)
        # Retrieve all variable names in the population
        outputs = core.bayesdb_variable_names(bdb, population_id, None)
        # Convert the LITERAL namedtuples to their raw values.
        ps, vs = zip(*baseline.params)
        vs_new = [v.value for v in vs]
        params = zip(ps, vs_new)
        # Create the clause.
        clause = cgpm_schema.parse.Foreign(
            outputs, [], [], baseline.name, params)
        # And add append it to the schema_ast.
        schema_ast.append(clause)

    # Process each clause one by one.
    for clause in schema_ast:

        if isinstance(clause, cgpm_schema.parse.Basic):
            # Basic Crosscat component model: one variable to be put
            # into Crosscat views.
            var = clause.var
            dist = clause.dist
            params = dict(clause.params) # XXX error checking

            # Reject if the variable does not exist.
            if not core.bayesdb_has_variable(bdb, population_id, None, var):
                unknown.add(var)
                continue

            # Reject if the variable has already been modelled.
            if var in modelled:
                duplicate.add(var)
                continue

            # Reject if the variable is latent.
            if core.bayesdb_has_latent(bdb, population_id, var):
                existing_latent.add(var)
                continue

            # Get the column number.
            colno = core.bayesdb_variable_number(bdb, population_id, None, var)
            assert 0 <= colno

            # Add it to the list and mark it modelled by default.
            stattype = core.bayesdb_variable_stattype(
                bdb, population_id, colno)
            variables.append([var, stattype, dist, params])
            assert var not in variable_dist
            variable_dist[var] = (stattype, dist, params)
            modelled.add(var)
            default_modelled.add(var)

        elif isinstance(clause, cgpm_schema.parse.Latent):
            var = clause.name
            stattype = clause.stattype

            # Reject if the variable has already been modelled by the
            # default model.
            if var in default_modelled:
                duplicate.add(var)
                continue

            # Reject if the variable even *exists* in the population
            # at all yet.
            if core.bayesdb_has_variable(bdb, population_id, None, var):
                duplicate.add(var)
                continue

            # Reject if the variable is already latent, from another
            # generator.
            if core.bayesdb_has_latent(bdb, population_id, var):
                existing_latent.add(var)
                continue

            # Reject if we've already processed it.
            if var in latents:
                duplicate.add(var)
                continue

            # Add it to the set of latent variables.
            latents[var] = stattype

        elif isinstance(clause, cgpm_schema.parse.Foreign):
            # Foreign model: some set of output variables is to be
            # modelled by foreign logic, possibly conditional on some
            # set of input variables.
            #
            # Gather up the state for a cgpm_composition record, which
            # we may have to do incrementally because it must refer to
            # the distribution types of variables we may not have
            # seen.
            name = clause.name
            outputs = clause.outputs
            inputs = clause.inputs

            output_stattypes = []
            output_statargs = []
            input_stattypes = []
            input_statargs = []
            distargs = {
                'inputs': {
                    'stattypes': input_stattypes,
                    'statargs': input_statargs
                },
                'outputs': {
                    'stattypes': output_stattypes,
                    'statargs': output_statargs,
                }
            }
            kwds = {'distargs': distargs}
            kwds.update(clause.params)

            # First make sure all the output variables exist and have
            # not yet been modelled.
            for var in outputs:
                must_exist.append(var)
                if var in modelled:
                    duplicate.add(var)
                    continue
                modelled.add(var)
                # Add the output statistical type and its parameters.
                i = len(output_stattypes)
                assert i == len(output_statargs)
                output_stattypes.append(None)
                output_statargs.append(None)
                deferred_output[var] = (output_stattypes, output_statargs, i)

            # Next make sure all the input variables exist, mark them
            # needed, and record where to put their distribution type
            # and parameters.
            for var in inputs:
                must_exist.append(var)
                needed.add(var)
                i = len(input_stattypes)
                assert i == len(input_statargs)
                input_stattypes.append(None)
                input_statargs.append(None)
                deferred_input[var].append((input_stattypes, input_statargs, i))

            # Finally, add a cgpm_composition record.
            cgpm_composition.append({
                'name': name,
                'inputs': inputs,
                'outputs': outputs,
                'kwds': kwds,
            })

        elif isinstance(clause, cgpm_schema.parse.Subsample):
            if subsample is not None:
                raise BQLError(bdb, 'Duplicate subsample: %r' % (clause.n,))
            subsample = clause.n

        else:
            raise BQLError(bdb, 'Unknown clause: %r' % (clause,))

    # Make sure all the outputs and inputs exist, either in the
    # population or as latents in this generator.
    for var in must_exist:
        if core.bayesdb_has_variable(bdb, population_id, None, var):
            continue
        if var in latents:
            continue
        unknown.add(var)

    # Raise an exception if there were duplicates or unknown
    # variables.
    if duplicate:
        raise BQLError(bdb,
            'Duplicate model variables: %r' % (sorted(duplicate),))
    if existing_latent:
        raise BQLError(bdb,
            'Latent variables already defined: %r' % (sorted(existing_latent),))
    if unknown:
        raise BQLError(bdb,
            'Unknown model variables: %r' % (sorted(unknown),))

    def default_dist(var, stattype):
        stattype = casefold(stattype)
        if stattype not in _DEFAULT_DIST:
            if var in unknown_stattype:
                assert unknown_stattype[var] == stattype
            else:
                unknown_stattype[var] = stattype
            return None
        dist, params = _DEFAULT_DIST[stattype](bdb, generator_id, var)
        return dist, params

    # Use the default distribution for any variables that remain to be
    # modelled, excluding any that are latent or that have statistical
    # types we don't know about.
    for var in core.bayesdb_variable_names(bdb, population_id, None):
        if var in modelled:
            continue
        colno = core.bayesdb_variable_number(bdb, population_id, None, var)
        assert 0 <= colno
        stattype = core.bayesdb_variable_stattype(bdb, population_id, colno)
        distparams = default_dist(var, stattype)
        if distparams is None:
            continue
        dist, params = distparams
        variables.append([var, stattype, dist, params])
        assert var not in variable_dist
        variable_dist[var] = (stattype, dist, params)
        modelled.add(var)

    # Fill in the deferred_input statistical type assignments.
    for var in sorted(deferred_input.iterkeys()):
        # Check whether the variable is modelled.  If not, skip -- we
        # will fail later because this variable is guaranteed to also
        # be in needed.
        if var not in modelled:
            assert var in needed
            continue

        # Determine (possibly fictitious) distribution and parameters.
        if var in default_modelled:
            # Manifest variable modelled by default Crosscat model.
            assert var in variable_dist
            stattype, dist, params = variable_dist[var]
        else:
            # Modelled by a foreign model.  Assign a fictitious
            # default distribution because the 27B/6 of CGPM requires
            # this.
            if var in latents:
                # Latent variable modelled by a foreign model.  Use
                # the statistical type specified for it.
                stattype = latents[var]
            else:
                # Manifest variable modelled by a foreign model.  Use
                # the statistical type in the population.
                assert core.bayesdb_has_variable(bdb, population_id, None, var)
                colno = core.bayesdb_variable_number(
                    bdb, population_id, None, var)
                stattype = core.bayesdb_variable_stattype(
                    bdb, population_id, colno)
            distparams = default_dist(var, stattype)
            if distparams is None:
                continue
            dist, params = distparams

        # Assign the distribution and parameters.
        for cctypes, ccargs, i in deferred_input[var]:
            assert cctypes[i] is None
            assert ccargs[i] is None
            cctypes[i] = dist
            ccargs[i] = params

    # Fill in the deferred_output statistical type assignments. The need to be
    # in the form NUMERICAL or CATEGORICAL.
    for var in deferred_output:
        if var in latents:
            # Latent variable modelled by a foreign model.  Use
            # the statistical type specified for it.
            var_stattype = casefold(latents[var])
            if var_stattype not in _DEFAULT_DIST:
                if var in unknown_stattype:
                    assert unknown_stattype[var] == var_stattype
                else:
                    unknown_stattype[var] = var_stattype
            # XXX Cannot specify statargs for a latent variable. Trying to using
            # default_dist might lookup the counts for unique values of the
            # categorical in the base table causing a failure.
            var_statargs = {}
        else:
            # Manifest variable modelled by a foreign model.  Use
            # the statistical type and arguments from the population.
            assert core.bayesdb_has_variable(bdb, population_id, None, var)
            colno = core.bayesdb_variable_number(bdb, population_id, None, var)
            var_stattype = core.bayesdb_variable_stattype(
                bdb, population_id, colno)
            distparams = default_dist(var, var_stattype)
            if distparams is None:
                continue
            _, var_statargs = distparams

        stattypes, statargs, i = deferred_output[var]
        assert stattypes[i] is None
        assert statargs[i] is None
        stattypes[i] = var_stattype
        statargs[i] = var_statargs

    if unknown_stattype:
        raise BQLError(bdb,
            'Unknown statistical types for variables: %r' %
            (sorted(unknown_stattype.iteritems(),)))

    # If there remain any variables that we needed to model, because
    # others are conditional on them, fail.
    needed -= modelled
    if needed:
        raise BQLError(bdb, 'Unmodellable variables: %r' % (needed,))

    # Finally, create a CGPM schema.
    return {
        'variables': variables,
        'cgpm_composition': cgpm_composition,
        'subsample': subsample,
        'latents': latents,
    }

Example 18

View license
def retrieve_matadata(validation_directory, metric, configuration_space,
                      cutoff=0, num_runs=1, only_best=False):
    # This looks weird! The dictionaries contain the following information
    # {dataset_1: (configuration: best_value), dataset_2: (configuration: # best_value)}
    outputs = defaultdict(list)
    configurations = dict()
    configurations_to_ids = dict()
    possible_experiment_directories = os.listdir(validation_directory)

    for ped in possible_experiment_directories:
        dataset_name = ped

        # This is hacky, replace by pySMAC!
        ped = os.path.join(validation_directory, ped, ped)

        if not os.path.exists(ped) or not os.path.isdir(ped):
            continue

        print("Going through directory %s" % ped)

        smac_output_dir = ped
        validation_files = []
        validation_configuration_files = []
        validation_run_results_files = []

        # Configurations from smac-validate from a trajectory file
        for seed in [num_run * 1000 for num_run in range(num_runs)]:
            validation_file = os.path.join(smac_output_dir,
                                           'validationResults-detailed-traj-run-%d-walltime.csv' % seed)
            validation_configuration_file = os.path.join(smac_output_dir,
                                                         'validationCallStrings-detailed-traj-run-%d-walltime.csv' % seed)
            validation_run_results_file = os.path.join(smac_output_dir,
                                                       'validationRunResultLineMatrix-detailed-traj-run-%d-walltime.csv'
                                                       % seed)

            if os.path.exists(validation_file) and os.path.exists(
                    validation_configuration_file) and os.path.exists(
                    validation_run_results_file):
                validation_files.append(validation_file)
                validation_configuration_files.append(
                    validation_configuration_file)
                validation_run_results_files.append(
                    validation_run_results_file)

        # Configurations from smac-validate from a configurations file
        validation_file = os.path.join(smac_output_dir,
                                           'validationResults-configurations-walltime.csv')
        validation_configuration_file = os.path.join(smac_output_dir,
                                                     'validationCallStrings-configurations-walltime.csv')
        validation_run_results_file = os.path.join(smac_output_dir,
                                                   'validationRunResultLineMatrix-configurations-walltime.csv')
        if os.path.exists(validation_file) and os.path.exists(
                validation_configuration_file) and os.path.exists(
                validation_run_results_file):
            validation_files.append(validation_file)
            validation_configuration_files.append(
                validation_configuration_file)
            validation_run_results_files.append(
                validation_run_results_file)

        for validation_file, validation_configuration_file, validation_run_results_file in \
                zip(validation_files, validation_configuration_files,
                    validation_run_results_files):

            print("\t%s" % validation_file)

            configuration_to_time = dict()
            with open(validation_file) as fh:
                reader = csv.reader(fh)
                reader.next()
                for row in reader:
                    current_time = float(row[0])
                    validation_configuration_id = int(row[4])
                    configuration_to_time[
                        validation_configuration_id] = current_time

            best = []
            with open(validation_run_results_file) as fh:
                reader = csv.reader(fh)
                reader.next()
                for row in reader:
                    seed = int(float(row[1]))
                    results = row[2:]
                    for i, result in enumerate(results):
                        result = result.split(",")[-1]
                        if not ";" in result:
                            continue
                        result = result.split(";")
                        for result_ in result:
                            metric_, value = result_.split(":")
                            metric_ = metric_.replace(":", "").strip()
                            value = value.strip()

                            try:
                                if int(metric_) == STRING_TO_METRIC[metric]:
                                    value = float(value)
                                    best.append((value, i + 1))
                            except ValueError:
                                pass

            best.sort()
            for test_performance, validation_configuration_id in best:
                if cutoff > 0 and \
                        configuration_to_time[validation_configuration_id] > \
                        cutoff:
                    continue
                stop = False
                with open(validation_configuration_file) as fh:
                    reader = csv.reader(fh)
                    reader.next()
                    for row in reader:
                        if int(row[0]) == validation_configuration_id:
                            configuration = row[1]
                            configuration = configuration.split()
                            configuration = {configuration[i]:
                                                 configuration[i + 1]
                                             for i in
                                             range(0, len(configuration),
                                                   2)}
                            for key in configuration.keys():
                                value = configuration[key]
                                hp_name = key[1:]
                                try:
                                    hyperparameter = \
                                        configuration_space.get_hyperparameter(
                                            hp_name)
                                    del configuration[key]
                                except KeyError:
                                    break

                                value = value.strip("'")

                                if isinstance(hyperparameter,
                                              IntegerHyperparameter):
                                    value = int(float(value))
                                elif isinstance(hyperparameter,
                                                FloatHyperparameter):
                                    value = float(value)
                                elif isinstance(hyperparameter,
                                                CategoricalHyperparameter):
                                    # Implementation tailored to the PCS
                                    # parser
                                    value = str(value)
                                elif isinstance(hyperparameter, Constant):
                                    if isinstance(hyperparameter.value, float):
                                        value = float(value)
                                    elif isinstance(hyperparameter.value, int):
                                        value = int(value)
                                    else:
                                        value = value
                                elif hyperparameter is None:
                                    value = ''
                                else:
                                    raise ValueError((hp_name, value,
                                                      hyperparameter,
                                                      type(hyperparameter),
                                                      configuration, configuration_space))

                                configuration[hp_name] = value

                            try:
                                configuration = Configuration(
                                    configuration_space, configuration)
                            except Exception as e:
                                print("Configuration %s not applicable " \
                                      "because of %s!" \
                                      % (configuration, e))
                                break

                            if str(configuration) in \
                                    configurations_to_ids:
                                global_configuration_id = \
                                    configurations_to_ids[
                                        str(configuration)]
                            else:
                                global_configuration_id = len(configurations)
                                configurations[
                                    global_configuration_id] = configuration
                                configurations_to_ids[str(configuration)] = \
                                    global_configuration_id

                            if global_configuration_id is not None:
                                outputs[dataset_name].append(
                                    (global_configuration_id, test_performance))

                            if only_best:
                                stop = True
                                break
                            else:
                                pass

                if stop is True:
                    break

    return outputs, configurations

Example 19

Project: pycalphad
Source File: tdb.py
View license
def write_tdb(dbf, fd, groupby='subsystem'):
    """
    Write a TDB file from a pycalphad Database object.

    Parameters
    ----------
    dbf : Database
        A pycalphad Database.
    fd : file-like
        File descriptor.
    groupby : ['subsystem', 'phase'], optional
        Desired grouping of parameters in the file.
    """
    writetime = datetime.datetime.now()
    maxlen = 78
    output = ""
    # Comment header block
    # Import here to prevent circular imports
    from pycalphad import __version__
    output += ("$" * maxlen) + "\n"
    output += "$ Date: {}\n".format(writetime.strftime("%Y-%m-%d %H:%M"))
    output += "$ Components: {}\n".format(', '.join(sorted(dbf.elements)))
    output += "$ Phases: {}\n".format(', '.join(sorted(dbf.phases.keys())))
    output += "$ Generated by {} (pycalphad {})\n".format(getpass.getuser(), __version__)
    output += ("$" * maxlen) + "\n\n"
    for element in sorted(dbf.elements):
        output += "ELEMENT {0} BLANK 0 0 0 !\n".format(element.upper())
    if len(dbf.elements) > 0:
        output += "\n"
    for species in sorted(dbf.species):
        output += "SPECIES {0} !\n".format(species.upper())
    if len(dbf.species) > 0:
        output += "\n"
    # Write FUNCTION block
    for name, expr in sorted(dbf.symbols.items()):
        if not isinstance(expr, Piecewise):
            # Non-piecewise exprs need to be wrapped to print
            # Otherwise TC's TDB parser will complain
            expr = Piecewise((expr, And(v.T >= 1, v.T < 10000)))
        expr = TCPrinter().doprint(expr).upper()
        if ';' not in expr:
            expr += '; N'
        output += "FUNCTION {0} {1} !\n".format(name.upper(), expr)
    output += "\n"
    # Boilerplate code
    output += "TYPE_DEFINITION % SEQ * !\n"
    output += "DEFINE_SYSTEM_DEFAULT ELEMENT 2 !\n"
    default_elements = [i.upper() for i in sorted(dbf.elements) if i.upper() == 'VA' or i.upper() == '/-']
    if len(default_elements) > 0:
        output += 'DEFAULT_COMMAND DEFINE_SYSTEM_ELEMENT {} !\n'.format(' '.join(default_elements))
    output += "\n"
    typedef_chars = list("^&*()'ABCDEFGHIJKLMNOPQSRTUVWXYZ")[::-1]
    #  Write necessary TYPE_DEF based on model hints
    typedefs = defaultdict(lambda: ["%"])
    for name, phase_obj in sorted(dbf.phases.items()):
        model_hints = phase_obj.model_hints.copy()
        if ('ordered_phase' in model_hints.keys()) and (model_hints['ordered_phase'] == name):
            new_char = typedef_chars.pop()
            typedefs[name].append(new_char)
            typedefs[model_hints['disordered_phase']].append(new_char)
            output += 'TYPE_DEFINITION {} GES AMEND_PHASE_DESCRIPTION {} DISORDERED_PART {} !\n'\
                .format(new_char, model_hints['ordered_phase'].upper(),
                        model_hints['disordered_phase'].upper())
            del model_hints['ordered_phase']
            del model_hints['disordered_phase']
        if ('disordered_phase' in model_hints.keys()) and (model_hints['disordered_phase'] == name):
            # We handle adding the correct typedef when we write the ordered phase
            del model_hints['ordered_phase']
            del model_hints['disordered_phase']
        if 'ihj_magnetic_afm_factor' in model_hints.keys():
            new_char = typedef_chars.pop()
            typedefs[name].append(new_char)
            output += 'TYPE_DEFINITION {} GES AMEND_PHASE_DESCRIPTION {} MAGNETIC {} {} !\n'\
                .format(new_char, name.upper(), model_hints['ihj_magnetic_afm_factor'],
                        model_hints['ihj_magnetic_structure_factor'])
            del model_hints['ihj_magnetic_afm_factor']
            del model_hints['ihj_magnetic_structure_factor']
        if len(model_hints) > 0:
            # Some model hints were not properly consumed
            raise ValueError('Not all model hints are supported: {}'.format(model_hints))
    # Perform a second loop now that all typedefs / model hints are consistent
    for name, phase_obj in sorted(dbf.phases.items()):
        output += "PHASE {0} {1}  {2} {3} !\n".format(name.upper(), ''.join(typedefs[name]),
                                                      len(phase_obj.sublattices),
                                                      ' '.join([str(i) for i in phase_obj.sublattices]))
        constituents = ':'.join([','.join(sorted(subl)) for subl in phase_obj.constituents])
        output += "CONSTITUENT {0} :{1}: !\n".format(name.upper(), constituents)
        output += "\n"

    # PARAMETERs by subsystem
    param_sorted = defaultdict(lambda: list())
    paramtuple = namedtuple('ParamTuple', ['phase_name', 'parameter_type', 'complexity', 'constituent_array',
                                           'parameter_order', 'diffusing_species', 'parameter', 'reference'])
    for param in dbf._parameters.all():
        if groupby == 'subsystem':
            components = set()
            for subl in param['constituent_array']:
                components |= set(subl)
            if param['diffusing_species'] is not None:
                components |= {param['diffusing_species']}
            # Wildcard operator is not a component
            components -= {'*'}
            # Remove vacancy if it's not the only component (pure vacancy endmember)
            if len(components) > 1:
                components -= {'VA'}
            components = tuple(sorted([c.upper() for c in components]))
            grouping = components
        elif groupby == 'phase':
            grouping = param['phase_name'].upper()
        else:
            raise ValueError('Unknown groupby attribute \'{}\''.format(groupby))
        # We use the complexity parameter to help with sorting the parameters logically
        param_sorted[grouping].append(paramtuple(param['phase_name'], param['parameter_type'],
                                                 sum([len(i) for i in param['constituent_array']]),
                                                 param['constituent_array'], param['parameter_order'],
                                                 param['diffusing_species'], param['parameter'],
                                                 param['reference']))

    def write_parameter(param_to_write):
        constituents = ':'.join([','.join(sorted([i.upper() for i in subl]))
                         for subl in param_to_write.constituent_array])
        # TODO: Handle references
        paramx = param_to_write.parameter
        if not isinstance(paramx, Piecewise):
            # Non-piecewise parameters need to be wrapped to print correctly
            # Otherwise TC's TDB parser will fail
            paramx = Piecewise((paramx, And(v.T >= 1, v.T < 10000)))
        exprx = TCPrinter().doprint(paramx).upper()
        if ';' not in exprx:
            exprx += '; N'
        if param_to_write.diffusing_species is not None:
            ds = "&" + param_to_write.diffusing_species
        else:
            ds = ""
        return "PARAMETER {}({}{},{};{}) {} !\n".format(param_to_write.parameter_type.upper(),
                                                        param_to_write.phase_name.upper(),
                                                        ds,
                                                        constituents,
                                                        param_to_write.parameter_order,
                                                        exprx)
    if groupby == 'subsystem':
        for num_elements in range(1, 5):
            subsystems = list(itertools.combinations(sorted([i.upper() for i in dbf.elements]), num_elements))
            for subsystem in subsystems:
                parameters = sorted(param_sorted[subsystem])
                if len(parameters) > 0:
                    output += "\n\n"
                    output += "$" * maxlen + "\n"
                    output += "$ {}".format('-'.join(sorted(subsystem)).center(maxlen, " ")[2:-1]) + "$\n"
                    output += "$" * maxlen + "\n"
                    output += "\n"
                    for parameter in parameters:
                        output += write_parameter(parameter)
        # Don't generate combinatorics for multi-component subsystems or we'll run out of memory
        if len(dbf.elements) > 4:
            subsystems = [k for k in param_sorted.keys() if len(k) > 4]
            for subsystem in subsystems:
                parameters = sorted(param_sorted[subsystem])
                for parameter in parameters:
                    output += write_parameter(parameter)
    elif groupby == 'phase':
        for phase_name in sorted(dbf.phases.keys()):
            parameters = sorted(param_sorted[phase_name])
            if len(parameters) > 0:
                output += "\n\n"
                output += "$" * maxlen + "\n"
                output += "$ {}".format(phase_name.upper().center(maxlen, " ")[2:-1]) + "$\n"
                output += "$" * maxlen + "\n"
                output += "\n"
                for parameter in parameters:
                    output += write_parameter(parameter)
    else:
        raise ValueError('Unknown groupby attribute {}'.format(groupby))
    # Reflow text to respect character limit per line
    fd.write(reflow_text(output, linewidth=maxlen))

Example 20

Project: pychess
Source File: PgnImport.py
View license
    def do_import(self, filename, info=None, progressbar=None):
        DB_MAXINT_SHIFT = get_maxint_shift(self.engine)
        self.progressbar = progressbar

        orig_filename = filename
        count_source = self.conn.execute(self.count_source.where(source.c.name == orig_filename)).scalar()
        if count_source > 0:
            print("%s is already imported" % filename)
            return

        # collect new names not in they dict yet
        self.event_data = []
        self.site_data = []
        self.player_data = []
        self.annotator_data = []
        self.source_data = []

        # collect new games and commit them in big chunks for speed
        self.game_data = []
        self.bitboard_data = []
        self.stat_ins_data = []
        self.stat_upd_data = []
        self.tag_game_data = []

        if filename.startswith("http"):
            filename = download_file(filename, progressbar=progressbar)
            if filename is None:
                return
        else:
            if not os.path.isfile(filename):
                print("Can't open %s" % filename)
                return

        if filename.lower().endswith(".zip") and zipfile.is_zipfile(filename):
            zf = zipfile.ZipFile(filename, "r")
            files = [f for f in zf.namelist() if f.lower().endswith(".pgn")]
        else:
            zf = None
            files = [filename]

        for pgnfile in files:
            basename = os.path.basename(pgnfile)
            if progressbar is not None:
                GLib.idle_add(progressbar.set_text, "Reading %s ..." % basename)
            else:
                print("Reading %s ..." % pgnfile)

            if zf is None:
                size = os.path.getsize(pgnfile)
                handle = protoopen(pgnfile)
            else:
                size = zf.getinfo(pgnfile).file_size
                handle = io.TextIOWrapper(zf.open(pgnfile), encoding=PGN_ENCODING, newline='')

            cf = PgnBase(handle, [])

            # estimated game count
            all_games = max(size / 840, 1)
            self.CHUNK = 1000 if all_games > 5000 else 100

            get_id = self.get_id
            # use transaction to avoid autocommit slowness
            trans = self.conn.begin()
            try:
                i = 0
                for tagtext, movetext in read_games(handle):
                    tags = defaultdict(str, tagre.findall(tagtext))
                    if not tags:
                        print("Empty game #%s" % (i + 1))
                        continue

                    if self.cancel:
                        trans.rollback()
                        return

                    fenstr = tags.get("FEN")

                    variant = tags.get("Variant")
                    if variant:
                        if "fischer" in variant.lower() or "960" in variant:
                            variant = "Fischerandom"
                        else:
                            variant = variant.lower().capitalize()

                    # Fixes for some non statndard Chess960 .pgn
                    if fenstr and variant == "Fischerandom":
                        parts = fenstr.split()
                        parts[0] = parts[0].replace(".", "/").replace("0", "")
                        if len(parts) == 1:
                            parts.append("w")
                            parts.append("-")
                            parts.append("-")
                        fenstr = " ".join(parts)

                    if variant:
                        if variant not in name2variant:
                            print("Unknown variant: %s" % variant)
                            continue
                        variant = name2variant[variant].variant
                        if variant == NORMALCHESS:
                            # lichess uses tag [Variant "Standard"]
                            variant = 0
                            board = START_BOARD.clone()
                        else:
                            board = LBoard(variant)
                    elif fenstr:
                        variant = 0
                        board = LBoard()
                    else:
                        variant = 0
                        board = START_BOARD.clone()

                    if fenstr:
                        try:
                            board.applyFen(fenstr)
                        except SyntaxError as e:
                            print(_(
                                "The game #%s can't be loaded, because of an error parsing FEN")
                                % (i + 1), e.args[0])
                            continue
                    elif variant:
                        board.applyFen(FEN_START)

                    movelist = array("H")
                    comments = []
                    cf.error = None

                    # First we try to use simple_parse_movetext()
                    # assuming most games in .pgn contains only moves
                    # without any comments/variations
                    simple = False
                    if not fenstr and not variant:
                        bitboards = []
                        simple = cf.simple_parse_movetext(movetext, board, movelist, bitboards)

                        if cf.error is not None:
                            print("ERROR in %s game #%s" % (pgnfile, i + 1), cf.error.args[0])
                            continue

                    # If simple_parse_movetext() find any comments/variations
                    # we restart parsing with full featured parse_movetext()
                    if not simple:
                        movelist = array("H")
                        bitboards = None

                        # in case simple_parse_movetext failed we have to reset our lboard
                        if not fenstr and not variant:
                            board = START_BOARD.clone()

                        # parse movetext to create boards tree structure
                        boards = [board]
                        boards = cf.parse_movetext(movetext, boards[0], -1, pgn_import=True)

                        if cf.error is not None:
                            print("ERROR in %s game #%s" % (pgnfile, i + 1), cf.error.args[0])
                            continue

                        # create movelist and comments from boards tree
                        walk(boards[0], movelist, comments)

                    white = tags.get('White')
                    black = tags.get('Black')

                    if not movelist:
                        if (not comments) and (not white) and (not black):
                            print("Empty game #%s" % (i + 1))
                            continue

                    event_id = get_id(tags.get('Event'), event, EVENT)

                    site_id = get_id(tags.get('Site'), site, SITE)

                    game_date = tags.get('Date').strip()
                    try:
                        if game_date and '?' not in game_date:
                            ymd = game_date.split('.')
                            if len(ymd) == 3:
                                game_year, game_month, game_day = map(int, ymd)
                            else:
                                game_year, game_month, game_day = int(game_date[:4]), None, None
                        elif game_date and '?' not in game_date[:4]:
                            game_year, game_month, game_day = int(game_date[:4]), None, None
                        else:
                            game_year, game_month, game_day = None, None, None
                    except:
                        game_year, game_month, game_day = None, None, None

                    game_round = tags.get('Round')

                    white_fide_id = tags.get('WhiteFideId')
                    black_fide_id = tags.get('BlackFideId')

                    white_id = get_id(unicode(white), player, PLAYER, fide_id=white_fide_id)
                    black_id = get_id(unicode(black), player, PLAYER, fide_id=black_fide_id)

                    result = tags.get("Result")
                    if result in pgn2Const:
                        result = pgn2Const[result]
                    else:
                        print("Invalid Result tag in game #%s: %s" % (i + 1, result))
                        continue

                    white_elo = tags.get('WhiteElo')
                    white_elo = int(white_elo) if white_elo and white_elo.isdigit() else None

                    black_elo = tags.get('BlackElo')
                    black_elo = int(black_elo) if black_elo and black_elo.isdigit() else None

                    time_control = tags.get("TimeControl")

                    eco = tags.get("ECO")
                    eco = eco[:3] if eco else None

                    fen = tags.get("FEN")

                    board_tag = tags.get("Board")

                    annotator_id = get_id(tags.get("Annotator"), annotator, ANNOTATOR)

                    source_id = get_id(unicode(orig_filename), source, SOURCE, info=info)

                    game_id = self.next_id[GAME]
                    self.next_id[GAME] += 1

                    # annotated game
                    if bitboards is None:
                        for ply, board in enumerate(boards):
                            if ply == 0:
                                continue
                            bb = board.friends[0] | board.friends[1]
                            # Avoid to include mate in x .pgn collections and similar in opening tree
                            if fen and "/pppppppp/8/8/8/8/PPPPPPPP/" not in fen:
                                ply = -1
                            self.bitboard_data.append({
                                'game_id': game_id,
                                'ply': ply,
                                'bitboard': bb - DB_MAXINT_SHIFT,
                            })

                            if ply <= STAT_PLY_MAX:
                                self.stat_ins_data.append({
                                    'ply': ply,
                                    'bitboard': bb - DB_MAXINT_SHIFT,
                                    'count': 0,
                                    'whitewon': 0,
                                    'blackwon': 0,
                                    'draw': 0,
                                    'white_elo_count': 0,
                                    'black_elo_count': 0,
                                    'white_elo': 0,
                                    'black_elo': 0,
                                })
                                self.stat_upd_data.append({
                                    '_ply': ply,
                                    '_bitboard': bb - DB_MAXINT_SHIFT,
                                    '_count': 1,
                                    '_whitewon': 1 if result == WHITEWON else 0,
                                    '_blackwon': 1 if result == BLACKWON else 0,
                                    '_draw': 1 if result == DRAW else 0,
                                    '_white_elo_count': 1 if white_elo is not None else 0,
                                    '_black_elo_count': 1 if black_elo is not None else 0,
                                    '_white_elo': white_elo if white_elo is not None else 0,
                                    '_black_elo': black_elo if black_elo is not None else 0,
                                })

                    # simple game
                    else:
                        for ply, bb in enumerate(bitboards):
                            if ply == 0:
                                continue
                            self.bitboard_data.append({
                                'game_id': game_id,
                                'ply': ply,
                                'bitboard': bb - DB_MAXINT_SHIFT,
                            })

                            if ply <= STAT_PLY_MAX:
                                self.stat_ins_data.append({
                                    'ply': ply,
                                    'bitboard': bb - DB_MAXINT_SHIFT,
                                    'count': 0,
                                    'whitewon': 0,
                                    'blackwon': 0,
                                    'draw': 0,
                                    'white_elo_count': 0,
                                    'black_elo_count': 0,
                                    'white_elo': 0,
                                    'black_elo': 0,
                                })
                                self.stat_upd_data.append({
                                    '_ply': ply,
                                    '_bitboard': bb - DB_MAXINT_SHIFT,
                                    '_count': 1,
                                    '_whitewon': 1 if result == WHITEWON else 0,
                                    '_blackwon': 1 if result == BLACKWON else 0,
                                    '_draw': 1 if result == DRAW else 0,
                                    '_white_elo_count': 1 if white_elo is not None else 0,
                                    '_black_elo_count': 1 if black_elo is not None else 0,
                                    '_white_elo': white_elo if white_elo is not None else 0,
                                    '_black_elo': black_elo if black_elo is not None else 0,
                                })

                    ply_count = tags.get("PlyCount")
                    if not ply_count and not fen:
                        ply_count = len(bitboards) if bitboards is not None else len(boards)

                    self.game_data.append({
                        'event_id': event_id,
                        'site_id': site_id,
                        'date_year': game_year,
                        'date_month': game_month,
                        'date_day': game_day,
                        'round': game_round,
                        'white_id': white_id,
                        'black_id': black_id,
                        'result': result,
                        'white_elo': white_elo,
                        'black_elo': black_elo,
                        'ply_count': ply_count,
                        'eco': eco,
                        'fen': fen,
                        'variant': variant,
                        'board': board_tag,
                        'time_control': time_control,
                        'annotator_id': annotator_id,
                        'source_id': source_id,
                        'movelist': movelist.tostring(),
                        'comments': unicode("|".join(comments)),
                    })

                    i += 1

                    if len(self.game_data) >= self.CHUNK:
                        if self.event_data:
                            self.conn.execute(self.ins_event, self.event_data)
                            self.event_data = []

                        if self.site_data:
                            self.conn.execute(self.ins_site, self.site_data)
                            self.site_data = []

                        if self.player_data:
                            self.conn.execute(self.ins_player,
                                              self.player_data)
                            self.player_data = []

                        if self.annotator_data:
                            self.conn.execute(self.ins_annotator,
                                              self.annotator_data)
                            self.annotator_data = []

                        if self.source_data:
                            self.conn.execute(self.ins_source, self.source_data)
                            self.source_data = []

                        self.conn.execute(self.ins_game, self.game_data)
                        self.game_data = []

                        if self.bitboard_data:
                            self.conn.execute(self.ins_bitboard, self.bitboard_data)
                            self.bitboard_data = []

                            self.conn.execute(self.ins_stat, self.stat_ins_data)
                            self.conn.execute(self.upd_stat, self.stat_upd_data)
                            self.stat_ins_data = []
                            self.stat_upd_data = []

                        if progressbar is not None:
                            GLib.idle_add(progressbar.set_fraction, i / float(all_games))
                            GLib.idle_add(progressbar.set_text, "%s games from %s imported" % (i, basename))
                        else:
                            print(pgnfile, i)

                if self.event_data:
                    self.conn.execute(self.ins_event, self.event_data)
                    self.event_data = []

                if self.site_data:
                    self.conn.execute(self.ins_site, self.site_data)
                    self.site_data = []

                if self.player_data:
                    self.conn.execute(self.ins_player, self.player_data)
                    self.player_data = []

                if self.annotator_data:
                    self.conn.execute(self.ins_annotator, self.annotator_data)
                    self.annotator_data = []

                if self.source_data:
                    self.conn.execute(self.ins_source, self.source_data)
                    self.source_data = []

                if self.game_data:
                    self.conn.execute(self.ins_game, self.game_data)
                    self.game_data = []

                if self.bitboard_data:
                    self.conn.execute(self.ins_bitboard, self.bitboard_data)
                    self.bitboard_data = []

                    self.conn.execute(self.ins_stat, self.stat_ins_data)
                    self.conn.execute(self.upd_stat, self.stat_upd_data)
                    self.stat_ins_data = []
                    self.stat_upd_data = []

                if progressbar is not None:
                    GLib.idle_add(progressbar.set_fraction, i / float(all_games))
                    GLib.idle_add(progressbar.set_text, "%s games from %s imported" % (i, basename))
                else:
                    print(pgnfile, i)
                trans.commit()

            except SQLAlchemyError as e:
                trans.rollback()
                print("Importing %s failed! \n%s" % (pgnfile, e))

Example 21

Project: rapidpro
Source File: models.py
View license
    @classmethod
    def get_filtered_value_summary(cls, ruleset=None, contact_field=None, filters=None, return_contacts=False, filter_contacts=None):
        """
        Return summary results for the passed in values, optionally filtering by a passed in filter on the contact.

        This will try to aggregate results based on the values found.

        Filters expected in the following formats:
            { ruleset: rulesetId, categories: ["Red", "Blue", "Yellow"] }
            { groups: 12,124,15 }
            { location: 1515, boundary: "f1551" }
            { contact_field: fieldId, values: ["UK", "RW"] }
        """
        from temba.flows.models import RuleSet, FlowStep
        from temba.contacts.models import Contact

        start = time.time()

        # caller may identify either a ruleset or contact field to summarize
        if (not ruleset and not contact_field) or (ruleset and contact_field):
            raise ValueError("Must define either a RuleSet or ContactField to summarize values for")

        if ruleset:
            (categories, uuid_to_category) = ruleset.build_uuid_to_category_map()

        org = ruleset.flow.org if ruleset else contact_field.org

        # this is for the case when we are filtering across our own categories, we build up the category uuids we will
        # pay attention then filter before we grab the actual values
        self_filter_uuids = []

        org_contacts = Contact.objects.filter(org=org, is_test=False, is_active=True)

        if filters:
            if filter_contacts is None:
                contacts = org_contacts
            else:
                contacts = Contact.objects.filter(pk__in=filter_contacts)

            for contact_filter in filters:
                # empty filters are no-ops
                if not contact_filter:
                    continue

                # we are filtering by another rule
                if 'ruleset' in contact_filter:
                    # load the ruleset for this filter
                    filter_ruleset = RuleSet.objects.get(pk=contact_filter['ruleset'])
                    (filter_cats, filter_uuids) = filter_ruleset.build_uuid_to_category_map()

                    uuids = []
                    for (uuid, category) in filter_uuids.items():
                        if category in contact_filter['categories']:
                            uuids.append(uuid)

                    contacts = contacts.filter(values__rule_uuid__in=uuids)

                    # this is a self filter, save the uuids for later filtering
                    if ruleset and ruleset.pk == filter_ruleset.pk:
                        self_filter_uuids = uuids

                # we are filtering by one or more groups
                elif 'groups' in contact_filter:
                    # filter our contacts by that group
                    for group_id in contact_filter['groups']:
                        contacts = contacts.filter(all_groups__pk=group_id)

                # we are filtering by one or more admin boundaries
                elif 'boundary' in contact_filter:
                    boundaries = contact_filter['boundary']
                    if not isinstance(boundaries, list):
                        boundaries = [boundaries]

                    # filter our contacts by those that are in that location boundary
                    contacts = contacts.filter(values__contact_field__id=contact_filter['location'],
                                               values__location_value__osm_id__in=boundaries)

                # we are filtering by a contact field
                elif 'contact_field' in contact_filter:
                    contact_query = Q()

                    # we can't use __in as we want case insensitive matching
                    for value in contact_filter['values']:
                        contact_query |= Q(values__contact_field__id=contact_filter['contact_field'],
                                           values__string_value__iexact=value)

                    contacts = contacts.filter(contact_query)

                else:
                    raise ValueError("Invalid filter definition, must include 'group', 'ruleset', 'contact_field' or 'boundary'")

            contacts = set([c['id'] for c in contacts.values('id')])

        else:
            # no filter, default either to all contacts or our filter contacts
            if filter_contacts:
                contacts = filter_contacts
            else:
                contacts = set([c['id'] for c in org_contacts.values('id')])

        # we are summarizing a flow ruleset
        if ruleset:
            filter_uuids = set(self_filter_uuids)

            # grab all the flow steps for this ruleset, this gets us the most recent run for each contact
            steps = [fs for fs in FlowStep.objects.filter(step_uuid=ruleset.uuid)
                                                  .values('arrived_on', 'rule_uuid', 'contact')
                                                  .order_by('-arrived_on')]

            # this will build up sets of contacts for each rule uuid
            seen_contacts = set()
            value_contacts = defaultdict(set)
            for step in steps:
                contact = step['contact']
                if contact in contacts:
                    if contact not in seen_contacts:
                        value_contacts[step['rule_uuid']].add(contact)
                        seen_contacts.add(contact)

            results = defaultdict(set)
            for uuid, contacts in value_contacts.items():
                if uuid and (not filter_uuids or uuid in filter_uuids):
                    category = uuid_to_category.get(uuid, None)
                    if category:
                        results[category] |= contacts

            # now create an ordered array of our results
            set_contacts = set()
            for category in categories:
                contacts = results.get(category['label'], set())
                if return_contacts:
                    category['contacts'] = contacts

                category['count'] = len(contacts)
                set_contacts |= contacts

            # how many runs actually entered a response?
            set_contacts = set_contacts
            unset_contacts = value_contacts[None]

        # we are summarizing based on contact field
        else:
            values = Value.objects.filter(contact_field=contact_field)

            if contact_field.value_type == Value.TYPE_TEXT:
                values = values.values('string_value', 'contact')
                categories, set_contacts = cls._filtered_values_to_categories(contacts, values, 'string_value',
                                                                              return_contacts=return_contacts)

            elif contact_field.value_type == Value.TYPE_DECIMAL:
                values = values.values('decimal_value', 'contact')
                categories, set_contacts = cls._filtered_values_to_categories(contacts, values, 'decimal_value',
                                                                              formatter=format_decimal,
                                                                              return_contacts=return_contacts)

            elif contact_field.value_type == Value.TYPE_DATETIME:
                values = values.extra({'date_value': "date_trunc('day', datetime_value)"}).values('date_value', 'contact')
                categories, set_contacts = cls._filtered_values_to_categories(contacts, values, 'date_value',
                                                                              return_contacts=return_contacts)

            elif contact_field.value_type in [Value.TYPE_STATE, Value.TYPE_DISTRICT, Value.TYPE_WARD]:
                values = values.values('location_value__osm_id', 'contact')
                categories, set_contacts = cls._filtered_values_to_categories(contacts, values, 'location_value__osm_id',
                                                                              return_contacts=return_contacts)

            else:
                raise ValueError(_("Summary of contact fields with value type of %s is not supported" % contact_field.get_value_type_display()))

            set_contacts = contacts & set_contacts
            unset_contacts = contacts - set_contacts

        print "RulesetSummary [%f]: %s contact_field: %s with filters: %s" % (time.time() - start, ruleset, contact_field, filters)

        if return_contacts:
            return (set_contacts, unset_contacts, categories)
        else:
            return (len(set_contacts), len(unset_contacts), categories)

Example 22

Project: headphones
Source File: bluelet.py
View license
def run(root_coro):
    """Schedules a coroutine, running it to completion. This
    encapsulates the Bluelet scheduler, which the root coroutine can
    add to by spawning new coroutines.
    """
    # The "threads" dictionary keeps track of all the currently-
    # executing and suspended coroutines. It maps coroutines to their
    # currently "blocking" event. The event value may be SUSPENDED if
    # the coroutine is waiting on some other condition: namely, a
    # delegated coroutine or a joined coroutine. In this case, the
    # coroutine should *also* appear as a value in one of the below
    # dictionaries `delegators` or `joiners`.
    threads = {root_coro: ValueEvent(None)}

    # Maps child coroutines to delegating parents.
    delegators = {}

    # Maps child coroutines to joining (exit-waiting) parents.
    joiners = collections.defaultdict(list)

    def complete_thread(coro, return_value):
        """Remove a coroutine from the scheduling pool, awaking
        delegators and joiners as necessary and returning the specified
        value to any delegating parent.
        """
        del threads[coro]

        # Resume delegator.
        if coro in delegators:
            threads[delegators[coro]] = ValueEvent(return_value)
            del delegators[coro]

        # Resume joiners.
        if coro in joiners:
            for parent in joiners[coro]:
                threads[parent] = ValueEvent(None)
            del joiners[coro]

    def advance_thread(coro, value, is_exc=False):
        """After an event is fired, run a given coroutine associated with
        it in the threads dict until it yields again. If the coroutine
        exits, then the thread is removed from the pool. If the coroutine
        raises an exception, it is reraised in a ThreadException. If
        is_exc is True, then the value must be an exc_info tuple and the
        exception is thrown into the coroutine.
        """
        try:
            if is_exc:
                next_event = coro.throw(*value)
            else:
                next_event = coro.send(value)
        except StopIteration:
            # Thread is done.
            complete_thread(coro, None)
        except:
            # Thread raised some other exception.
            del threads[coro]
            raise ThreadException(coro, sys.exc_info())
        else:
            if isinstance(next_event, types.GeneratorType):
                # Automatically invoke sub-coroutines. (Shorthand for
                # explicit bluelet.call().)
                next_event = DelegationEvent(next_event)
            threads[coro] = next_event

    def kill_thread(coro):
        """Unschedule this thread and its (recursive) delegates.
        """
        # Collect all coroutines in the delegation stack.
        coros = [coro]
        while isinstance(threads[coro], Delegated):
            coro = threads[coro].child
            coros.append(coro)

        # Complete each coroutine from the top to the bottom of the
        # stack.
        for coro in reversed(coros):
            complete_thread(coro, None)

    # Continue advancing threads until root thread exits.
    exit_te = None
    while threads:
        try:
            # Look for events that can be run immediately. Continue
            # running immediate events until nothing is ready.
            while True:
                have_ready = False
                for coro, event in list(threads.items()):
                    if isinstance(event, SpawnEvent):
                        threads[event.spawned] = ValueEvent(None)  # Spawn.
                        advance_thread(coro, None)
                        have_ready = True
                    elif isinstance(event, ValueEvent):
                        advance_thread(coro, event.value)
                        have_ready = True
                    elif isinstance(event, ExceptionEvent):
                        advance_thread(coro, event.exc_info, True)
                        have_ready = True
                    elif isinstance(event, DelegationEvent):
                        threads[coro] = Delegated(event.spawned)  # Suspend.
                        threads[event.spawned] = ValueEvent(None)  # Spawn.
                        delegators[event.spawned] = coro
                        have_ready = True
                    elif isinstance(event, ReturnEvent):
                        # Thread is done.
                        complete_thread(coro, event.value)
                        have_ready = True
                    elif isinstance(event, JoinEvent):
                        threads[coro] = SUSPENDED  # Suspend.
                        joiners[event.child].append(coro)
                        have_ready = True
                    elif isinstance(event, KillEvent):
                        threads[coro] = ValueEvent(None)
                        kill_thread(event.child)
                        have_ready = True

                # Only start the select when nothing else is ready.
                if not have_ready:
                    break

            # Wait and fire.
            event2coro = dict((v, k) for k, v in threads.items())
            for event in _event_select(threads.values()):
                # Run the IO operation, but catch socket errors.
                try:
                    value = event.fire()
                except socket.error as exc:
                    if isinstance(exc.args, tuple) and \
                            exc.args[0] == errno.EPIPE:
                        # Broken pipe. Remote host disconnected.
                        pass
                    else:
                        traceback.print_exc()
                    # Abort the coroutine.
                    threads[event2coro[event]] = ReturnEvent(None)
                else:
                    advance_thread(event2coro[event], value)

        except ThreadException as te:
            # Exception raised from inside a thread.
            event = ExceptionEvent(te.exc_info)
            if te.coro in delegators:
                # The thread is a delegate. Raise exception in its
                # delegator.
                threads[delegators[te.coro]] = event
                del delegators[te.coro]
            else:
                # The thread is root-level. Raise in client code.
                exit_te = te
                break

        except:
            # For instance, KeyboardInterrupt during select(). Raise
            # into root thread and terminate others.
            threads = {root_coro: ExceptionEvent(sys.exc_info())}

    # If any threads still remain, kill them.
    for coro in threads:
        coro.close()

    # If we're exiting with an exception, raise it in the client.
    if exit_te:
        exit_te.reraise()

Example 23

Project: reviewboard
Source File: detail.py
View license
    def query_data_post_etag(self):
        """Perform remaining queries for the page.

        This method will populate everything else needed for the display of the
        review request page other than that which was required to compute the
        ETag.
        """
        self.reviews_by_id = self._build_id_map(self.reviews)

        self.body_top_replies = defaultdict(list)
        self.body_bottom_replies = defaultdict(list)
        self.latest_timestamps_by_review_id = defaultdict(lambda: 0)

        for r in self.reviews:
            r._body_top_replies = []
            r._body_bottom_replies = []

            if r.body_top_reply_to_id is not None:
                self.body_top_replies[r.body_top_reply_to_id].append(r)

            if r.body_bottom_reply_to_id is not None:
                self.body_bottom_replies[r.body_bottom_reply_to_id].append(r)

            # Find the latest reply timestamp for each top-level review.
            parent_id = r.base_reply_to_id

            if parent_id is not None:
                self.latest_timestamps_by_review_id[parent_id] = max(
                    r.timestamp.replace(tzinfo=utc).ctime(),
                    self.latest_timestamps_by_review_id[parent_id])

        # Link up all the review body replies.
        for reply_id, replies in six.iteritems(self.body_top_replies):
            self.reviews_by_id[reply_id]._body_top_replies = reversed(replies)

        for reply_id, replies in six.iteritems(self.body_bottom_replies):
            self.reviews_by_id[reply_id]._body_bottom_replies = \
                reversed(replies)

        self.review_request_details = self.draft or self.review_request

        # Get all the file attachments and screenshots.
        #
        # Note that we fetch both active and inactive file attachments and
        # screenshots. We do this because even though they've been removed,
        # they still will be rendered in change descriptions.
        self.active_file_attachments = \
            list(self.review_request_details.get_file_attachments())
        self.all_file_attachments = (
            self.active_file_attachments +
            list(self.review_request_details.get_inactive_file_attachments()))
        self.file_attachments_by_id = \
            self._build_id_map(self.all_file_attachments)

        for attachment in self.all_file_attachments:
            attachment._comments = []

        self.active_screenshots = \
            list(self.review_request_details.get_screenshots())
        self.all_screenshots = (
            self.active_screenshots +
            list(self.review_request_details.get_inactive_screenshots()))
        self.screenshots_by_id = self._build_id_map(self.all_screenshots)

        for screenshot in self.all_screenshots:
            screenshot._comments = []

        review_ids = self.reviews_by_id.keys()

        # Get all status updates.
        if status_updates_feature.is_enabled(request=self.request):
            self.status_updates = list(
                self.review_request.status_updates.all()
                .select_related('review'))

        self.comments = []
        self.issues = []
        self.issue_counts = {
            'total': 0,
            'open': 0,
            'resolved': 0,
            'dropped': 0,
        }

        for model, key, ordering in (
            (Comment, 'diff_comments', ('comment__filediff',
                                        'comment__first_line',
                                        'comment__timestamp')),
            (ScreenshotComment, 'screenshot_comments', None),
            (FileAttachmentComment, 'file_attachment_comments', None),
            (GeneralComment, 'general_comments', None)):
            # Due to mistakes in how we initially made the schema, we have a
            # ManyToManyField in between comments and reviews, instead of
            # comments having a ForeignKey to the review. This makes it
            # difficult to easily go from a comment to a review ID.
            #
            # The solution to this is to not query the comment objects, but
            # rather the through table. This will let us grab the review and
            # comment in one go, using select_related.
            related_field = model.review.related.field
            comment_field_name = related_field.m2m_reverse_field_name()
            through = related_field.rel.through
            q = through.objects.filter(review__in=review_ids).select_related()

            if ordering:
                q = q.order_by(*ordering)

            objs = list(q)

            # We do two passes. One to build a mapping, and one to actually
            # process comments.
            comment_map = {}

            for obj in objs:
                comment = getattr(obj, comment_field_name)
                comment._type = key
                comment._replies = []
                comment_map[comment.pk] = comment

            for obj in objs:
                comment = getattr(obj, comment_field_name)

                self.comments.append(comment)

                # Short-circuit some object fetches for the comment by setting
                # some internal state on them.
                assert obj.review_id in self.reviews_by_id
                review = self.reviews_by_id[obj.review_id]
                comment.review_obj = review
                comment._review_request = self.review_request

                # If the comment has an associated object (such as a file
                # attachment) that we've already fetched, attach it to prevent
                # future queries.
                if isinstance(comment, FileAttachmentComment):
                    attachment_id = comment.file_attachment_id
                    f = self.file_attachments_by_id[attachment_id]
                    comment.file_attachment = f
                    f._comments.append(comment)

                    diff_against_id = comment.diff_against_file_attachment_id

                    if diff_against_id is not None:
                        f = self.file_attachments_by_id[diff_against_id]
                        comment.diff_against_file_attachment = f
                elif isinstance(comment, ScreenshotComment):
                    screenshot = self.screenshots_by_id[comment.screenshot_id]
                    comment.screenshot = screenshot
                    screenshot._comments.append(comment)

                # We've hit legacy database cases where there were entries that
                # weren't a reply, and were just orphaned. Ignore them.
                if review.is_reply() and comment.is_reply():
                    replied_comment = comment_map[comment.reply_to_id]
                    replied_comment._replies.append(comment)

                if review.public and comment.issue_opened:
                    status_key = \
                        comment.issue_status_to_string(comment.issue_status)
                    self.issue_counts[status_key] += 1
                    self.issue_counts['total'] += 1
                    self.issues.append(comment)

Example 24

Project: django-adminactions
Source File: mass_update.py
View license
def mass_update(modeladmin, request, queryset):  # noqa
    """
        mass update queryset
    """

    def not_required(field, **kwargs):
        """ force all fields as not required"""
        kwargs['required'] = False
        return field.formfield(**kwargs)

    def _doit():
        errors = {}
        updated = 0
        for record in queryset:
            for field_name, value_or_func in list(form.cleaned_data.items()):
                if callable(value_or_func):
                    old_value = getattr(record, field_name)
                    setattr(record, field_name, value_or_func(old_value))
                else:
                    setattr(record, field_name, value_or_func)
            if clean:
                record.clean()
            record.save()
            updated += 1
        if updated:
            messages.info(request, _("Updated %s records") % updated)

        if len(errors):
            messages.error(request, "%s records not updated due errors" % len(errors))
        adminaction_end.send(sender=modeladmin.model,
                             action='mass_update',
                             request=request,
                             queryset=queryset,
                             modeladmin=modeladmin,
                             form=form,
                             errors=errors,
                             updated=updated)

    opts = modeladmin.model._meta
    perm = "{0}.{1}".format(opts.app_label, get_permission_codename('adminactions_massupdate', opts))
    if not request.user.has_perm(perm):
        messages.error(request, _('Sorry you do not have rights to execute this action'))
        return

    try:
        adminaction_requested.send(sender=modeladmin.model,
                                   action='mass_update',
                                   request=request,
                                   queryset=queryset,
                                   modeladmin=modeladmin)
    except ActionInterrupted as e:
        messages.error(request, str(e))
        return

    # Allows to specified a custom mass update Form in the ModelAdmin
    mass_update_form = getattr(modeladmin, 'mass_update_form', MassUpdateForm)

    MForm = modelform_factory(modeladmin.model, form=mass_update_form,
                              exclude=('pk',),
                              formfield_callback=not_required)
    grouped = defaultdict(lambda: [])
    selected_fields = []
    initial = {'_selected_action': request.POST.getlist(helpers.ACTION_CHECKBOX_NAME),
               'select_across': request.POST.get('select_across') == '1',
               'action': 'mass_update'}

    if 'apply' in request.POST:
        form = MForm(request.POST)
        if form.is_valid():
            try:
                adminaction_start.send(sender=modeladmin.model,
                                       action='mass_update',
                                       request=request,
                                       queryset=queryset,
                                       modeladmin=modeladmin,
                                       form=form)
            except ActionInterrupted as e:
                messages.error(request, str(e))
                return HttpResponseRedirect(request.get_full_path())

            # need_transaction = form.cleaned_data.get('_unique_transaction', False)
            validate = form.cleaned_data.get('_validate', False)
            clean = form.cleaned_data.get('_clean', False)

            if validate:
                with compat.atomic():
                    _doit()

            else:
                values = {}
                for field_name, value in list(form.cleaned_data.items()):
                    if isinstance(form.fields[field_name], ModelMultipleChoiceField):
                        messages.error(request, "Unable no mass update ManyToManyField without 'validate'")
                        return HttpResponseRedirect(request.get_full_path())
                    elif callable(value):
                        messages.error(request, "Unable no mass update using operators without 'validate'")
                        return HttpResponseRedirect(request.get_full_path())
                    elif field_name not in ['_selected_action', '_validate', 'select_across', 'action',
                                            '_unique_transaction', '_clean']:
                        values[field_name] = value
                queryset.update(**values)

            return HttpResponseRedirect(request.get_full_path())
    else:
        initial.update({'action': 'mass_update', '_validate': 1})
        # form = MForm(initial=initial)
        prefill_with = request.POST.get('prefill-with', None)
        prefill_instance = None
        try:
            # Gets the instance directly from the queryset for data security
            prefill_instance = queryset.get(pk=prefill_with)
        except ObjectDoesNotExist:
            pass

        form = MForm(initial=initial, instance=prefill_instance)

    for el in queryset.all()[:10]:
        for f in modeladmin.model._meta.fields:
            if f.name not in form._no_sample_for:
                if hasattr(f, 'flatchoices') and f.flatchoices:
                    grouped[f.name] = list(dict(getattr(f, 'flatchoices')).values())
                elif hasattr(f, 'choices') and f.choices:
                    grouped[f.name] = list(dict(getattr(f, 'choices')).values())
                elif isinstance(f, df.BooleanField):
                    grouped[f.name] = [True, False]
                else:
                    value = getattr(el, f.name)
                    if value is not None and value not in grouped[f.name]:
                        grouped[f.name].append(value)
                    initial[f.name] = initial.get(f.name, value)

    adminForm = helpers.AdminForm(form, modeladmin.get_fieldsets(request), {}, [], model_admin=modeladmin)
    media = modeladmin.media + adminForm.media
    dthandler = lambda obj: obj.isoformat() if isinstance(obj, datetime.date) else str(obj)
    tpl = 'adminactions/mass_update.html'
    ctx = {'adminform': adminForm,
           'form': form,
           'action_short_description': mass_update.short_description,
           'title': u"%s (%s)" % (
               mass_update.short_description.capitalize(),
               smart_text(modeladmin.opts.verbose_name_plural),
           ),
           'grouped': grouped,
           'fieldvalues': json.dumps(grouped, default=dthandler),
           'change': True,
           'selected_fields': selected_fields,
           'is_popup': False,
           'save_as': False,
           'has_delete_permission': False,
           'has_add_permission': False,
           'has_change_permission': True,
           'opts': modeladmin.model._meta,
           'app_label': modeladmin.model._meta.app_label,
           # 'action': 'mass_update',
           # 'select_across': request.POST.get('select_across')=='1',
           'media': mark_safe(media),
           'selection': queryset}
    if django.VERSION[:2] > (1, 7):
        ctx.update(modeladmin.admin_site.each_context(request))
    else:
        ctx.update(modeladmin.admin_site.each_context())

    if django.VERSION[:2] > (1, 8):
        return render(request, tpl, context=ctx)
    else:
        return render_to_response(tpl, RequestContext(request, ctx))

Example 25

Project: rex
Source File: fuzzing_type_2.py
View license
    def analyze_bytes(self, byte_indices):
        if any(i in self.skip_bytes for i in byte_indices):
            return False
        if frozenset(set(byte_indices)) in self.skip_sets:
            return False
        if len(byte_indices) == 1:
            l.info("fuzzing byte %d", byte_indices[0])
        else:
            l.info("fuzzing bytes %s", byte_indices)
        bytes_to_regs = dict()

        bytes_that_change_crash = set()
        bytes_that_dont_crash = set()
        bytes_that_dont_affect_regs = set()
        bytes_that_affect_regs = set()

        # run on the prefilter
        binary_input_bytes = []
        for i in _PREFILTER_BYTES:
            test_input = self._replace_indices(self.crash, chr(i), byte_indices)
            binary_input_bytes.append((self.binary, test_input, chr(i)))
        it = self.pool.imap_unordered(_get_reg_vals, binary_input_bytes)
        for c, reg_vals in it:
            if reg_vals is not None:
                reg_vals = self._fix_reg_vals(reg_vals)
                bytes_to_regs[c] = reg_vals
            else:
                bytes_that_dont_crash.add(c)

        possible_sets = defaultdict(set)
        for c in sorted(bytes_to_regs.keys()):
            reg_vals = bytes_to_regs[c]
            num_diff = 0
            for r in reg_vals.keys():
                if r not in self.reg_deps:
                    continue
                possible_sets[r].add(reg_vals[r])
                if reg_vals[r] != self.orig_regs[r]:
                    num_diff += 1

            if num_diff == 0:
                bytes_that_dont_affect_regs.add(c)
            elif reg_vals["eip"] != self.orig_regs["eip"]:
                bytes_that_change_crash.add(c)
            else:
                bytes_that_affect_regs.add(c)
        if len(bytes_that_affect_regs) == 0:
            return False
        if all(len(possible_sets[r]) <= 2 for r in possible_sets):
            return False

        for i in xrange(256):
            if i in _PREFILTER_BYTES:
                continue
            test_input = self._replace_indices(self.crash, chr(i), byte_indices)
            binary_input_bytes.append((self.binary, test_input, chr(i)))
        it = self.pool.imap_unordered(_get_reg_vals, binary_input_bytes, chunksize=4)
        for c, reg_vals in it:
            if reg_vals is not None:
                reg_vals = self._fix_reg_vals(reg_vals)
                bytes_to_regs[c] = reg_vals
            else:
                bytes_that_dont_crash.add(c)

        ip_counts = defaultdict(int)
        for reg_vals in bytes_to_regs.values():
            ip_counts[reg_vals["eip"]] += 1

        # if multiple registers change we might've found a different crash
        for c in sorted(bytes_to_regs.keys()):
            reg_vals = bytes_to_regs[c]
            num_diff = 0
            for r in reg_vals.keys():
                if reg_vals[r] != self.orig_regs[r]:
                    num_diff += 1

            if num_diff == 0:
                bytes_that_dont_affect_regs.add(c)
            elif reg_vals["eip"] != self.orig_regs["eip"]:
                bytes_that_change_crash.add(c)
            else:
                bytes_that_affect_regs.add(c)

        l.debug("%d bytes don't crash, %d bytes don't affect regs",
                len(bytes_that_dont_crash), len(bytes_that_dont_affect_regs))

        # the goal here is to find which bits of regs are contolled here
        all_reg_vals = defaultdict(set)
        for c in bytes_that_affect_regs:
            reg_vals = bytes_to_regs[c]
            for reg in reg_vals.keys():
                all_reg_vals[reg].add(reg_vals[reg])

        byte_analysis = ByteAnalysis()
        for i in byte_indices:
            self.byte_analysis[i] = byte_analysis
        byte_analysis.valid_bytes = set(bytes_to_regs.keys())

        found_interesting = False

        for reg in all_reg_vals.keys():
            if reg not in self.reg_deps:
                continue
            possible_vals = all_reg_vals[reg]
            bits_that_can_be_set = 0
            bits_that_can_be_unset = 0
            for val in possible_vals:
                bits_that_can_be_set |= val
                bits_that_can_be_unset |= ((~val) & 0xffffffff)
            controlled_bits = bits_that_can_be_set & bits_that_can_be_unset
            while controlled_bits != 0:
                number_bits = bin(controlled_bits).count("1")
                bit_indices = []
                for i, c in enumerate(bin(controlled_bits).replace("0b", "").rjust(32, "0")):
                    if c == "1":
                        bit_indices.append(31-i)
                if number_bits > 8:
                    if self.analyze_complex(byte_indices, reg, bytes_to_regs):
                        return True
                    else:
                        return False

                # might want to check for impossible bit patterns
                if controlled_bits != 0:
                    # check that all bitmasks are possible for those bits

                    # now map the patterns were not possible
                    all_patterns = self._get_bit_patterns(number_bits, bit_indices)

                    byte_analysis.register_pattern_maps[reg] = dict()

                    impossible_patterns = set(all_patterns)
                    for c in bytes_to_regs.keys():
                        reg_val = bytes_to_regs[c][reg]
                        pattern = reg_val & controlled_bits
                        byte_analysis.register_pattern_maps[reg][pattern] = c
                        impossible_patterns.discard(pattern)

                    # now we want to find a minimum set of bits
                    if len(impossible_patterns) > 0:
                        l.warning("not all patterns viable, decreasing bit patterns")
                        # remove a bit with the least variety
                        possible_patterns = all_patterns - impossible_patterns
                        bit_counts = dict()
                        for bit in bit_indices:
                            bit_counts[bit] = 0
                        for pattern in possible_patterns:
                            for bit in bit_indices:
                                if pattern & (1 << bit) != 0:
                                    bit_counts[bit] += 1
                                else:
                                    bit_counts[bit] -= 1
                        bit_to_remove = max(bit_counts.items(), key=lambda x: abs(x[1]))[0]
                        l.info("removing bit %d", bit_to_remove)
                        controlled_bits &= (~(1 << bit_to_remove))
                    else:
                        break

            if controlled_bits != 0:
                l.info("Register %s has the following bitmask %s for bytes %s of the input",
                       reg, hex(controlled_bits), byte_indices)
                byte_analysis.register_bitmasks[reg] = controlled_bits
                found_interesting = True
                byte_analysis.reg_vals = bytes_to_regs

            # todo remove conflicts

        return found_interesting

Example 26

Project: SickGear
Source File: base.py
View license
def _as_declarative(cls, classname, dict_):
    from .api import declared_attr

    # dict_ will be a dictproxy, which we can't write to, and we need to!
    dict_ = dict(dict_)

    column_copies = {}
    potential_columns = {}

    mapper_args_fn = None
    table_args = inherited_table_args = None
    tablename = None

    declarative_props = (declared_attr, util.classproperty)

    for base in cls.__mro__:
        _is_declarative_inherits = hasattr(base, '_decl_class_registry')

        if '__declare_last__' in base.__dict__:
            @event.listens_for(mapper, "after_configured")
            def go():
                cls.__declare_last__()
        if '__declare_first__' in base.__dict__:
            @event.listens_for(mapper, "before_configured")
            def go():
                cls.__declare_first__()
        if '__abstract__' in base.__dict__:
            if (base is cls or
                (base in cls.__bases__ and not _is_declarative_inherits)
            ):
                return

        class_mapped = _declared_mapping_info(base) is not None

        for name, obj in vars(base).items():
            if name == '__mapper_args__':
                if not mapper_args_fn and (
                                        not class_mapped or
                                        isinstance(obj, declarative_props)
                                    ):
                    # don't even invoke __mapper_args__ until
                    # after we've determined everything about the
                    # mapped table.
                    mapper_args_fn = lambda: cls.__mapper_args__
            elif name == '__tablename__':
                if not tablename and (
                                        not class_mapped or
                                        isinstance(obj, declarative_props)
                                    ):
                    tablename = cls.__tablename__
            elif name == '__table_args__':
                if not table_args and (
                                        not class_mapped or
                                        isinstance(obj, declarative_props)
                                    ):
                    table_args = cls.__table_args__
                    if not isinstance(table_args, (tuple, dict, type(None))):
                        raise exc.ArgumentError(
                                "__table_args__ value must be a tuple, "
                                "dict, or None")
                    if base is not cls:
                        inherited_table_args = True
            elif class_mapped:
                if isinstance(obj, declarative_props):
                    util.warn("Regular (i.e. not __special__) "
                            "attribute '%s.%s' uses @declared_attr, "
                            "but owning class %s is mapped - "
                            "not applying to subclass %s."
                            % (base.__name__, name, base, cls))
                continue
            elif base is not cls:
                # we're a mixin.
                if isinstance(obj, Column):
                    if getattr(cls, name) is not obj:
                        # if column has been overridden
                        # (like by the InstrumentedAttribute of the
                        # superclass), skip
                        continue
                    if obj.foreign_keys:
                        raise exc.InvalidRequestError(
                        "Columns with foreign keys to other columns "
                        "must be declared as @declared_attr callables "
                        "on declarative mixin classes. ")
                    if name not in dict_ and not (
                            '__table__' in dict_ and
                            (obj.name or name) in dict_['__table__'].c
                            ) and name not in potential_columns:
                        potential_columns[name] = \
                                column_copies[obj] = \
                                obj.copy()
                        column_copies[obj]._creation_order = \
                                obj._creation_order
                elif isinstance(obj, MapperProperty):
                    raise exc.InvalidRequestError(
                        "Mapper properties (i.e. deferred,"
                        "column_property(), relationship(), etc.) must "
                        "be declared as @declared_attr callables "
                        "on declarative mixin classes.")
                elif isinstance(obj, declarative_props):
                    dict_[name] = ret = \
                            column_copies[obj] = getattr(cls, name)
                    if isinstance(ret, (Column, MapperProperty)) and \
                        ret.doc is None:
                        ret.doc = obj.__doc__

    # apply inherited columns as we should
    for k, v in potential_columns.items():
        dict_[k] = v

    if inherited_table_args and not tablename:
        table_args = None

    clsregistry.add_class(classname, cls)
    our_stuff = util.OrderedDict()

    for k in list(dict_):

        # TODO: improve this ?  all dunders ?
        if k in ('__table__', '__tablename__', '__mapper_args__'):
            continue

        value = dict_[k]
        if isinstance(value, declarative_props):
            value = getattr(cls, k)

        elif isinstance(value, QueryableAttribute) and \
                value.class_ is not cls and \
                value.key != k:
            # detect a QueryableAttribute that's already mapped being
            # assigned elsewhere in userland, turn into a synonym()
            value = synonym(value.key)
            setattr(cls, k, value)


        if (isinstance(value, tuple) and len(value) == 1 and
            isinstance(value[0], (Column, MapperProperty))):
            util.warn("Ignoring declarative-like tuple value of attribute "
                      "%s: possibly a copy-and-paste error with a comma "
                      "left at the end of the line?" % k)
            continue
        if not isinstance(value, (Column, MapperProperty)):
            if not k.startswith('__'):
                dict_.pop(k)
                setattr(cls, k, value)
            continue
        if k == 'metadata':
            raise exc.InvalidRequestError(
                "Attribute name 'metadata' is reserved "
                "for the MetaData instance when using a "
                "declarative base class."
            )
        prop = clsregistry._deferred_relationship(cls, value)
        our_stuff[k] = prop

    # set up attributes in the order they were created
    our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)

    # extract columns from the class dict
    declared_columns = set()
    name_to_prop_key = collections.defaultdict(set)
    for key, c in list(our_stuff.items()):
        if isinstance(c, (ColumnProperty, CompositeProperty)):
            for col in c.columns:
                if isinstance(col, Column) and \
                    col.table is None:
                    _undefer_column_name(key, col)
                    if not isinstance(c, CompositeProperty):
                        name_to_prop_key[col.name].add(key)
                    declared_columns.add(col)
        elif isinstance(c, Column):
            _undefer_column_name(key, c)
            name_to_prop_key[c.name].add(key)
            declared_columns.add(c)
            # if the column is the same name as the key,
            # remove it from the explicit properties dict.
            # the normal rules for assigning column-based properties
            # will take over, including precedence of columns
            # in multi-column ColumnProperties.
            if key == c.key:
                del our_stuff[key]

    for name, keys in name_to_prop_key.items():
        if len(keys) > 1:
            util.warn(
                "On class %r, Column object %r named directly multiple times, "
                "only one will be used: %s" %
                (classname, name, (", ".join(sorted(keys))))
            )

    declared_columns = sorted(
        declared_columns, key=lambda c: c._creation_order)
    table = None

    if hasattr(cls, '__table_cls__'):
        table_cls = util.unbound_method_to_callable(cls.__table_cls__)
    else:
        table_cls = Table

    if '__table__' not in dict_:
        if tablename is not None:

            args, table_kw = (), {}
            if table_args:
                if isinstance(table_args, dict):
                    table_kw = table_args
                elif isinstance(table_args, tuple):
                    if isinstance(table_args[-1], dict):
                        args, table_kw = table_args[0:-1], table_args[-1]
                    else:
                        args = table_args

            autoload = dict_.get('__autoload__')
            if autoload:
                table_kw['autoload'] = True

            cls.__table__ = table = table_cls(
                tablename, cls.metadata,
                *(tuple(declared_columns) + tuple(args)),
                **table_kw)
    else:
        table = cls.__table__
        if declared_columns:
            for c in declared_columns:
                if not table.c.contains_column(c):
                    raise exc.ArgumentError(
                        "Can't add additional column %r when "
                        "specifying __table__" % c.key
                    )

    if hasattr(cls, '__mapper_cls__'):
        mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
    else:
        mapper_cls = mapper

    for c in cls.__bases__:
        if _declared_mapping_info(c) is not None:
            inherits = c
            break
    else:
        inherits = None

    if table is None and inherits is None:
        raise exc.InvalidRequestError(
            "Class %r does not have a __table__ or __tablename__ "
            "specified and does not inherit from an existing "
            "table-mapped class." % cls
            )
    elif inherits:
        inherited_mapper = _declared_mapping_info(inherits)
        inherited_table = inherited_mapper.local_table
        inherited_mapped_table = inherited_mapper.mapped_table

        if table is None:
            # single table inheritance.
            # ensure no table args
            if table_args:
                raise exc.ArgumentError(
                    "Can't place __table_args__ on an inherited class "
                    "with no table."
                    )
            # add any columns declared here to the inherited table.
            for c in declared_columns:
                if c.primary_key:
                    raise exc.ArgumentError(
                        "Can't place primary key columns on an inherited "
                        "class with no table."
                        )
                if c.name in inherited_table.c:
                    if inherited_table.c[c.name] is c:
                        continue
                    raise exc.ArgumentError(
                        "Column '%s' on class %s conflicts with "
                        "existing column '%s'" %
                        (c, cls, inherited_table.c[c.name])
                    )
                inherited_table.append_column(c)
                if inherited_mapped_table is not None and \
                    inherited_mapped_table is not inherited_table:
                    inherited_mapped_table._refresh_for_new_column(c)

    defer_map = hasattr(cls, '_sa_decl_prepare')
    if defer_map:
        cfg_cls = _DeferredMapperConfig
    else:
        cfg_cls = _MapperConfig
    mt = cfg_cls(mapper_cls,
                       cls, table,
                       inherits,
                       declared_columns,
                       column_copies,
                       our_stuff,
                       mapper_args_fn)
    if not defer_map:
        mt.map()

Example 27

Project: SickRage
Source File: base.py
View license
def _as_declarative(cls, classname, dict_):
    from .api import declared_attr

    # dict_ will be a dictproxy, which we can't write to, and we need to!
    dict_ = dict(dict_)

    column_copies = {}
    potential_columns = {}

    mapper_args_fn = None
    table_args = inherited_table_args = None
    tablename = None

    declarative_props = (declared_attr, util.classproperty)

    for base in cls.__mro__:
        _is_declarative_inherits = hasattr(base, '_decl_class_registry')

        if '__declare_last__' in base.__dict__:
            @event.listens_for(mapper, "after_configured")
            def go():
                cls.__declare_last__()
        if '__declare_first__' in base.__dict__:
            @event.listens_for(mapper, "before_configured")
            def go():
                cls.__declare_first__()
        if '__abstract__' in base.__dict__:
            if (base is cls or
                (base in cls.__bases__ and not _is_declarative_inherits)
            ):
                return

        class_mapped = _declared_mapping_info(base) is not None

        for name, obj in vars(base).items():
            if name == '__mapper_args__':
                if not mapper_args_fn and (
                                        not class_mapped or
                                        isinstance(obj, declarative_props)
                                    ):
                    # don't even invoke __mapper_args__ until
                    # after we've determined everything about the
                    # mapped table.
                    mapper_args_fn = lambda: cls.__mapper_args__
            elif name == '__tablename__':
                if not tablename and (
                                        not class_mapped or
                                        isinstance(obj, declarative_props)
                                    ):
                    tablename = cls.__tablename__
            elif name == '__table_args__':
                if not table_args and (
                                        not class_mapped or
                                        isinstance(obj, declarative_props)
                                    ):
                    table_args = cls.__table_args__
                    if not isinstance(table_args, (tuple, dict, type(None))):
                        raise exc.ArgumentError(
                                "__table_args__ value must be a tuple, "
                                "dict, or None")
                    if base is not cls:
                        inherited_table_args = True
            elif class_mapped:
                if isinstance(obj, declarative_props):
                    util.warn("Regular (i.e. not __special__) "
                            "attribute '%s.%s' uses @declared_attr, "
                            "but owning class %s is mapped - "
                            "not applying to subclass %s."
                            % (base.__name__, name, base, cls))
                continue
            elif base is not cls:
                # we're a mixin.
                if isinstance(obj, Column):
                    if getattr(cls, name) is not obj:
                        # if column has been overridden
                        # (like by the InstrumentedAttribute of the
                        # superclass), skip
                        continue
                    if obj.foreign_keys:
                        raise exc.InvalidRequestError(
                        "Columns with foreign keys to other columns "
                        "must be declared as @declared_attr callables "
                        "on declarative mixin classes. ")
                    if name not in dict_ and not (
                            '__table__' in dict_ and
                            (obj.name or name) in dict_['__table__'].c
                            ) and name not in potential_columns:
                        potential_columns[name] = \
                                column_copies[obj] = \
                                obj.copy()
                        column_copies[obj]._creation_order = \
                                obj._creation_order
                elif isinstance(obj, MapperProperty):
                    raise exc.InvalidRequestError(
                        "Mapper properties (i.e. deferred,"
                        "column_property(), relationship(), etc.) must "
                        "be declared as @declared_attr callables "
                        "on declarative mixin classes.")
                elif isinstance(obj, declarative_props):
                    dict_[name] = ret = \
                            column_copies[obj] = getattr(cls, name)
                    if isinstance(ret, (Column, MapperProperty)) and \
                        ret.doc is None:
                        ret.doc = obj.__doc__

    # apply inherited columns as we should
    for k, v in potential_columns.items():
        dict_[k] = v

    if inherited_table_args and not tablename:
        table_args = None

    clsregistry.add_class(classname, cls)
    our_stuff = util.OrderedDict()

    for k in list(dict_):

        # TODO: improve this ?  all dunders ?
        if k in ('__table__', '__tablename__', '__mapper_args__'):
            continue

        value = dict_[k]
        if isinstance(value, declarative_props):
            value = getattr(cls, k)

        elif isinstance(value, QueryableAttribute) and \
                value.class_ is not cls and \
                value.key != k:
            # detect a QueryableAttribute that's already mapped being
            # assigned elsewhere in userland, turn into a synonym()
            value = synonym(value.key)
            setattr(cls, k, value)


        if (isinstance(value, tuple) and len(value) == 1 and
            isinstance(value[0], (Column, MapperProperty))):
            util.warn("Ignoring declarative-like tuple value of attribute "
                      "%s: possibly a copy-and-paste error with a comma "
                      "left at the end of the line?" % k)
            continue
        if not isinstance(value, (Column, MapperProperty)):
            if not k.startswith('__'):
                dict_.pop(k)
                setattr(cls, k, value)
            continue
        if k == 'metadata':
            raise exc.InvalidRequestError(
                "Attribute name 'metadata' is reserved "
                "for the MetaData instance when using a "
                "declarative base class."
            )
        prop = clsregistry._deferred_relationship(cls, value)
        our_stuff[k] = prop

    # set up attributes in the order they were created
    our_stuff.sort(key=lambda key: our_stuff[key]._creation_order)

    # extract columns from the class dict
    declared_columns = set()
    name_to_prop_key = collections.defaultdict(set)
    for key, c in list(our_stuff.items()):
        if isinstance(c, (ColumnProperty, CompositeProperty)):
            for col in c.columns:
                if isinstance(col, Column) and \
                    col.table is None:
                    _undefer_column_name(key, col)
                    if not isinstance(c, CompositeProperty):
                        name_to_prop_key[col.name].add(key)
                    declared_columns.add(col)
        elif isinstance(c, Column):
            _undefer_column_name(key, c)
            name_to_prop_key[c.name].add(key)
            declared_columns.add(c)
            # if the column is the same name as the key,
            # remove it from the explicit properties dict.
            # the normal rules for assigning column-based properties
            # will take over, including precedence of columns
            # in multi-column ColumnProperties.
            if key == c.key:
                del our_stuff[key]

    for name, keys in name_to_prop_key.items():
        if len(keys) > 1:
            util.warn(
                "On class %r, Column object %r named directly multiple times, "
                "only one will be used: %s" %
                (classname, name, (", ".join(sorted(keys))))
            )

    declared_columns = sorted(
        declared_columns, key=lambda c: c._creation_order)
    table = None

    if hasattr(cls, '__table_cls__'):
        table_cls = util.unbound_method_to_callable(cls.__table_cls__)
    else:
        table_cls = Table

    if '__table__' not in dict_:
        if tablename is not None:

            args, table_kw = (), {}
            if table_args:
                if isinstance(table_args, dict):
                    table_kw = table_args
                elif isinstance(table_args, tuple):
                    if isinstance(table_args[-1], dict):
                        args, table_kw = table_args[0:-1], table_args[-1]
                    else:
                        args = table_args

            autoload = dict_.get('__autoload__')
            if autoload:
                table_kw['autoload'] = True

            cls.__table__ = table = table_cls(
                tablename, cls.metadata,
                *(tuple(declared_columns) + tuple(args)),
                **table_kw)
    else:
        table = cls.__table__
        if declared_columns:
            for c in declared_columns:
                if not table.c.contains_column(c):
                    raise exc.ArgumentError(
                        "Can't add additional column %r when "
                        "specifying __table__" % c.key
                    )

    if hasattr(cls, '__mapper_cls__'):
        mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__)
    else:
        mapper_cls = mapper

    for c in cls.__bases__:
        if _declared_mapping_info(c) is not None:
            inherits = c
            break
    else:
        inherits = None

    if table is None and inherits is None:
        raise exc.InvalidRequestError(
            "Class %r does not have a __table__ or __tablename__ "
            "specified and does not inherit from an existing "
            "table-mapped class." % cls
            )
    elif inherits:
        inherited_mapper = _declared_mapping_info(inherits)
        inherited_table = inherited_mapper.local_table
        inherited_mapped_table = inherited_mapper.mapped_table

        if table is None:
            # single table inheritance.
            # ensure no table args
            if table_args:
                raise exc.ArgumentError(
                    "Can't place __table_args__ on an inherited class "
                    "with no table."
                    )
            # add any columns declared here to the inherited table.
            for c in declared_columns:
                if c.primary_key:
                    raise exc.ArgumentError(
                        "Can't place primary key columns on an inherited "
                        "class with no table."
                        )
                if c.name in inherited_table.c:
                    if inherited_table.c[c.name] is c:
                        continue
                    raise exc.ArgumentError(
                        "Column '%s' on class %s conflicts with "
                        "existing column '%s'" %
                        (c, cls, inherited_table.c[c.name])
                    )
                inherited_table.append_column(c)
                if inherited_mapped_table is not None and \
                    inherited_mapped_table is not inherited_table:
                    inherited_mapped_table._refresh_for_new_column(c)

    defer_map = hasattr(cls, '_sa_decl_prepare')
    if defer_map:
        cfg_cls = _DeferredMapperConfig
    else:
        cfg_cls = _MapperConfig
    mt = cfg_cls(mapper_cls,
                       cls, table,
                       inherits,
                       declared_columns,
                       column_copies,
                       our_stuff,
                       mapper_args_fn)
    if not defer_map:
        mt.map()

Example 28

Project: SickRage
Source File: cli.py
View license
@subliminal.command()
@click.option('-l', '--language', type=LANGUAGE, required=True, multiple=True, help='Language as IETF code, '
              'e.g. en, pt-BR (can be used multiple times).')
@click.option('-p', '--provider', type=PROVIDER, multiple=True, help='Provider to use (can be used multiple times).')
@click.option('-r', '--refiner', type=REFINER, multiple=True, help='Refiner to use (can be used multiple times).')
@click.option('-a', '--age', type=AGE, help='Filter videos newer than AGE, e.g. 12h, 1w2d.')
@click.option('-d', '--directory', type=click.STRING, metavar='DIR', help='Directory where to save subtitles, '
              'default is next to the video file.')
@click.option('-e', '--encoding', type=click.STRING, metavar='ENC', help='Subtitle file encoding, default is to '
              'preserve original encoding.')
@click.option('-s', '--single', is_flag=True, default=False, help='Save subtitle without language code in the file '
              'name, i.e. use .srt extension. Do not use this unless your media player requires it.')
@click.option('-f', '--force', is_flag=True, default=False, help='Force download even if a subtitle already exist.')
@click.option('-hi', '--hearing-impaired', is_flag=True, default=False, help='Prefer hearing impaired subtitles.')
@click.option('-m', '--min-score', type=click.IntRange(0, 100), default=0, help='Minimum score for a subtitle '
              'to be downloaded (0 to 100).')
@click.option('-w', '--max-workers', type=click.IntRange(1, 50), default=None, help='Maximum number of threads to use.')
@click.option('-z/-Z', '--archives/--no-archives', default=True, show_default=True, help='Scan archives for videos '
              '(supported extensions: %s).' % ', '.join(ARCHIVE_EXTENSIONS))
@click.option('-v', '--verbose', count=True, help='Increase verbosity.')
@click.argument('path', type=click.Path(), required=True, nargs=-1)
@click.pass_obj
def download(obj, provider, refiner, language, age, directory, encoding, single, force, hearing_impaired, min_score,
             max_workers, archives, verbose, path):
    """Download best subtitles.

    PATH can be an directory containing videos, a video file path or a video file name. It can be used multiple times.

    If an existing subtitle is detected (external or embedded) in the correct language, the download is skipped for
    the associated video.

    """
    # process parameters
    language = set(language)

    # scan videos
    videos = []
    ignored_videos = []
    errored_paths = []
    with click.progressbar(path, label='Collecting videos', item_show_func=lambda p: p or '') as bar:
        for p in bar:
            logger.debug('Collecting path %s', p)

            # non-existing
            if not os.path.exists(p):
                try:
                    video = Video.fromname(p)
                except:
                    logger.exception('Unexpected error while collecting non-existing path %s', p)
                    errored_paths.append(p)
                    continue
                if not force:
                    video.subtitle_languages |= set(search_external_subtitles(video.name, directory=directory).values())
                refine(video, episode_refiners=refiner, movie_refiners=refiner, embedded_subtitles=not force)
                videos.append(video)
                continue

            # directories
            if os.path.isdir(p):
                try:
                    scanned_videos = scan_videos(p, age=age, archives=archives)
                except:
                    logger.exception('Unexpected error while collecting directory path %s', p)
                    errored_paths.append(p)
                    continue
                for video in scanned_videos:
                    if not force:
                        video.subtitle_languages |= set(search_external_subtitles(video.name,
                                                                                  directory=directory).values())
                    if check_video(video, languages=language, age=age, undefined=single):
                        refine(video, episode_refiners=refiner, movie_refiners=refiner, embedded_subtitles=not force)
                        videos.append(video)
                    else:
                        ignored_videos.append(video)
                continue

            # other inputs
            try:
                video = scan_video(p)
            except:
                logger.exception('Unexpected error while collecting path %s', p)
                errored_paths.append(p)
                continue
            if not force:
                video.subtitle_languages |= set(search_external_subtitles(video.name, directory=directory).values())
            if check_video(video, languages=language, age=age, undefined=single):
                refine(video, episode_refiners=refiner, movie_refiners=refiner, embedded_subtitles=not force)
                videos.append(video)
            else:
                ignored_videos.append(video)

    # output errored paths
    if verbose > 0:
        for p in errored_paths:
            click.secho('%s errored' % p, fg='red')

    # output ignored videos
    if verbose > 1:
        for video in ignored_videos:
            click.secho('%s ignored - subtitles: %s / age: %d day%s' % (
                os.path.split(video.name)[1],
                ', '.join(str(s) for s in video.subtitle_languages) or 'none',
                video.age.days,
                's' if video.age.days > 1 else ''
            ), fg='yellow')

    # report collected videos
    click.echo('%s video%s collected / %s video%s ignored / %s error%s' % (
        click.style(str(len(videos)), bold=True, fg='green' if videos else None),
        's' if len(videos) > 1 else '',
        click.style(str(len(ignored_videos)), bold=True, fg='yellow' if ignored_videos else None),
        's' if len(ignored_videos) > 1 else '',
        click.style(str(len(errored_paths)), bold=True, fg='red' if errored_paths else None),
        's' if len(errored_paths) > 1 else '',
    ))

    # exit if no video collected
    if not videos:
        return

    # download best subtitles
    downloaded_subtitles = defaultdict(list)
    with AsyncProviderPool(max_workers=max_workers, providers=provider, provider_configs=obj['provider_configs']) as p:
        with click.progressbar(videos, label='Downloading subtitles',
                               item_show_func=lambda v: os.path.split(v.name)[1] if v is not None else '') as bar:
            for v in bar:
                scores = get_scores(v)
                subtitles = p.download_best_subtitles(p.list_subtitles(v, language - v.subtitle_languages),
                                                      v, language, min_score=scores['hash'] * min_score / 100,
                                                      hearing_impaired=hearing_impaired, only_one=single)
                downloaded_subtitles[v] = subtitles

        if p.discarded_providers:
            click.secho('Some providers have been discarded due to unexpected errors: %s' %
                        ', '.join(p.discarded_providers), fg='yellow')

    # save subtitles
    total_subtitles = 0
    for v, subtitles in downloaded_subtitles.items():
        saved_subtitles = save_subtitles(v, subtitles, single=single, directory=directory, encoding=encoding)
        total_subtitles += len(saved_subtitles)

        if verbose > 0:
            click.echo('%s subtitle%s downloaded for %s' % (click.style(str(len(saved_subtitles)), bold=True),
                                                            's' if len(saved_subtitles) > 1 else '',
                                                            os.path.split(v.name)[1]))

        if verbose > 1:
            for s in saved_subtitles:
                matches = s.get_matches(v)
                score = compute_score(s, v)

                # score color
                score_color = None
                scores = get_scores(v)
                if isinstance(v, Movie):
                    if score < scores['title']:
                        score_color = 'red'
                    elif score < scores['title'] + scores['year'] + scores['release_group']:
                        score_color = 'yellow'
                    else:
                        score_color = 'green'
                elif isinstance(v, Episode):
                    if score < scores['series'] + scores['season'] + scores['episode']:
                        score_color = 'red'
                    elif score < scores['series'] + scores['season'] + scores['episode'] + scores['release_group']:
                        score_color = 'yellow'
                    else:
                        score_color = 'green'

                # scale score from 0 to 100 taking out preferences
                scaled_score = score
                if s.hearing_impaired == hearing_impaired:
                    scaled_score -= scores['hearing_impaired']
                scaled_score *= 100 / scores['hash']

                # echo some nice colored output
                click.echo('  - [{score}] {language} subtitle from {provider_name} (match on {matches})'.format(
                    score=click.style('{:5.1f}'.format(scaled_score), fg=score_color, bold=score >= scores['hash']),
                    language=s.language.name if s.language.country is None else '%s (%s)' % (s.language.name,
                                                                                             s.language.country.name),
                    provider_name=s.provider_name,
                    matches=', '.join(sorted(matches, key=scores.get, reverse=True))
                ))

    if verbose == 0:
        click.echo('Downloaded %s subtitle%s' % (click.style(str(total_subtitles), bold=True),
                                                 's' if total_subtitles > 1 else ''))

Example 29

Project: maestro-ng
Source File: entities.py
View license
    def __init__(self, name, ship, service, config=None, schema=None):
        """Create a new Container object.

        Args:
            name (string): the instance name (should be unique).
            ship (Ship): the Ship object representing the host this container
                is expected to be executed on.
            service (Service): the Service this container is an instance of.
            config (dict): the YAML-parsed dictionary containing this
                instance's configuration (ports, environment, volumes, etc.)
            schema (dict): Maestro schema versioning information.
        """
        Entity.__init__(self, name)
        config = config or {}

        self._status = None  # The container's status, cached.
        self._ship = ship
        self._service = service
        self._image = config.get('image', service.image)
        self._schema = schema

        # Register this instance container as being part of its parent service.
        self._service.register_container(self)

        # Get command
        # TODO(mpetazzoni): remove deprecated 'cmd' support
        self.command = config.get('command', config.get('cmd'))

        # Parse the port specs.
        self.ports = self._parse_ports(dict(self.service.ports, **config.get('ports', {})))

        # Gather environment variables.
        self.env = dict(service.env)
        self.env.update(config.get('env', {}))
        # Seed the service name, container name and host address as part of the
        # container's environment.
        self.env.update({
            'CONTAINER_NAME': self.name,
            'CONTAINER_HOST_ADDRESS': self.ship.ip,
            'DOCKER_IMAGE': self.image,
            'DOCKER_TAG': self.get_image_details()['tag'],
        })

        def env_list_expand(elt):
            return type(elt) != list and elt \
                or ' '.join(map(env_list_expand, elt))

        for k, v in self.env.items():
            if type(v) == list:
                self.env[k] = env_list_expand(v)

        self.volumes = self._parse_volumes(config.get('volumes', {}))
        self.container_volumes = config.get('container_volumes', [])
        if type(self.container_volumes) != list:
            self.container_volumes = [self.container_volumes]
        self.container_volumes = set(self.container_volumes)

        # Check for conflicts
        for volume in self.volumes.values():
            if volume['bind'] in self.container_volumes:
                raise exceptions.InvalidVolumeConfigurationException(
                        'Conflict in {} between bind-mounted volume '
                        'and container-only volume on {}'
                        .format(self.name, volume['bind']))

        # Contains the list of containers from which volumes should be mounted
        # in this container. Host-locality and volume conflicts are checked by
        # the conductor.
        self.volumes_from = config.get('volumes_from', [])
        if type(self.volumes_from) != list:
            self.volumes_from = [self.volumes_from]
        self.volumes_from = set(self.volumes_from)

        # Get links
        self.links = dict(
            (name, alias) for name, alias in
            config.get('links', {}).items())

        # Should this container run with -privileged?
        self.privileged = config.get('privileged', False)

        # Add or drop privileges
        self.cap_add = config.get('cap_add', None)
        self.cap_drop = config.get('cap_drop', None)

        # Add extra hosts
        self.extra_hosts = config.get('extra_hosts', None)

        # Network mode
        self.network_mode = config.get('net')

        # Restart policy
        self.restart_policy = self._parse_restart_policy(config.get('restart'))

        # DNS settings for the container, always as a list
        self.dns = config.get('dns')
        if isinstance(self.dns, six.string_types):
            self.dns = [self.dns]

        # Stop timeout
        self.stop_timeout = config.get('stop_timeout', 10)

        # Get limits
        limits = dict(self.service.limits, **config.get('limits', {}))
        self.cpu_shares = limits.get('cpu')
        self.mem_limit = self._parse_bytes(limits.get('memory'))
        self.memswap_limit = self._parse_bytes(limits.get('swap'))

        # Get logging config.
        self.log_config = self._parse_log_config(
            config.get('log_driver'), config.get('log_opt'))

        # Additional LXC configuration options. See the LXC documentation for a
        # reference of the available settings. Those are only supported if the
        # remote Docker daemon uses the lxc execution driver.
        self.lxc_conf = config.get('lxc_conf', {})

        # Work directory for the container
        self.workdir = config.get('workdir')

        # Reformat port structure
        ports = collections.defaultdict(list) if self.ports else None
        if ports is not None:
            for port in self.ports.values():
                ports[port['exposed']].append(
                    (port['external'][0], port['external'][1].split('/')[0]))

        # Security options
        self.security_opt = config.get('security_opt')

        # Ulimits options
        self.ulimits = self._parse_ulimits(config.get('ulimits', None))

        # host_config now contains all settings previously passed in container
        # start().
        self.host_config = self._ship.backend.create_host_config(
            log_config=self.log_config,
            mem_limit=self.mem_limit,
            memswap_limit=self.memswap_limit,
            binds=self.volumes,
            port_bindings=ports,
            lxc_conf=self.lxc_conf,
            privileged=self.privileged,
            cap_add=self.cap_add,
            cap_drop=self.cap_drop,
            extra_hosts=self.extra_hosts,
            network_mode=self.network_mode,
            restart_policy=self.restart_policy,
            dns=self.dns,
            links=self.links,
            ulimits=self.ulimits,
            volumes_from=list(self.volumes_from),
            security_opt=self.security_opt)

        # With everything defined, build lifecycle state helpers as configured
        lifecycle = dict(self.service.lifecycle)
        for state, checks in config.get('lifecycle', {}).items():
            if state not in lifecycle:
                lifecycle[state] = []
            lifecycle[state].extend(checks)
        self._lifecycle = self._parse_lifecycle(lifecycle)

Example 30

Project: autonetkit
Source File: graph.py
View license
    def add_edges_from(self, ebunch, bidirectional=False, retain=None,
                       warn=True, **kwargs):
        """Add edges. Unlike NetworkX, can only add an edge if both
        src and dst in graph already.
        If they are not, then they will not be added (silently ignored)


        Retains interface mappings if they are present (this is why ANK
            stores the interface reference on the edges, as it simplifies
            cross-layer access, as well as split, aggregate, etc retaining the
            interface bindings)_

        Bidirectional will add edge in both directions. Useful if going
        from an undirected graph to a
        directed, eg G_in to G_bgp
        #TODO: explain "retain" and ["retain"] logic

        if user wants to add from another overlay, first go g_x.edges()
        then add from the result

        allow (src, dst, ekey), (src, dst, ekey, data) for the ank utils
        """

        if not retain:
            retain = []
        try:
            retain.lower()
            retain = [retain]  # was a string, put into list
        except AttributeError:
            pass  # already a list

        if self.is_multigraph():
            #used_keys = self._graph.adj[u][v]
            from collections import defaultdict
            used_keys = defaultdict(dict)

        all_edges = []
        for in_edge in ebunch:
            """Edge could be one of:
            - NmEdge
            - (NmNode, NmNode)
            - (NmPort, NmPort)
            - (NmNode, NmPort)
            - (NmPort, NmNode)
            - (string, string)
            """
            # This is less efficient than nx add_edges_from, but cleaner logic
            # TODO: could put the interface data into retain?
            data = {'_ports': {}}  # to retain
            ekey = None  # default is None (nx auto-allocates next int)

            # convert input to a NmEdge
            src = dst = None
            if isinstance(in_edge, NmEdge):
                edge = in_edge  # simple case
                ekey = edge.ekey  # explictly set ekey
                src = edge.src.node_id
                dst = edge.dst.node_id

                # and copy retain data
                data = dict((key, edge.get(key)) for key in retain)
                ports = {k: v for k, v in edge.raw_interfaces.items()
                         if k in self._graph}  # only if exists in this overlay
                # TODO: debug log if skipping a binding?
                data['_ports'] = ports

                # this is the only case where copy across data
                # but want to copy attributes for all cases

            elif len(in_edge) == 2:
                in_a, in_b = in_edge[0], in_edge[1]

                if isinstance(in_a, NmNode) and isinstance(in_b, NmNode):
                    src = in_a.node_id
                    dst = in_b.node_id

                elif isinstance(in_a, NmPort) and isinstance(in_b, NmPort):
                    src = in_a.node.node_id
                    dst = in_b.node.node_id
                    ports = {}
                    if src in self:
                        ports[src] = in_a.interface_id
                    if dst in self:
                        ports[dst] = in_b.interface_id
                    data['_ports'] = ports

                elif isinstance(in_a, NmNode) and isinstance(in_b, NmPort):
                    src = in_a.node_id
                    dst = in_b.node.node_id
                    ports = {}
                    if dst in self:
                        ports[dst] = in_b.interface_id
                    data['_ports'] = ports

                elif isinstance(in_a, NmPort) and isinstance(in_b, NmNode):
                    src = in_a.node.node_id
                    dst = in_b.node_id
                    ports = {}
                    if src in self:
                        ports[src] = in_a.interface_id
                    data['_ports'] = ports

                elif in_a in self and in_b in self:
                    src = in_a
                    dst = in_b

            elif len(in_edge) == 3:
                # (src, dst, ekey) format
                # or (src, dst, data) format
                in_a, in_b, in_c = in_edge[0], in_edge[1], in_edge[2]
                if in_a in self and in_b in self:
                    src = in_a
                    dst = in_b
                    # TODO: document the following logic
                    if self.is_multigraph() and not isinstance(in_c, dict):
                        ekey = in_c
                    else:
                        data = in_c

            elif len(in_edge) == 4:
                # (src, dst, ekey, data) format
                in_a, in_b = in_edge[0], in_edge[1]
                if in_a in self and in_b in self:
                    src = in_a
                    dst = in_b
                    ekey = in_edge[2]
                    data = in_edge[3]

            # TODO: if edge not set at this point, give error/warn

            # TODO: add check that edge.src and edge.dst exist
            if (src is None or dst is None) and warn:
                log.warning("Unsupported edge %s" % str(in_edge))
            if not(src in self and dst in self):
                if warn:
                    self.log.debug("Not adding edge %s, src/dst not in overlay"
                                   % str(in_edge))
                continue

            # TODO: warn if not multigraph and edge already exists - don't
            # add/clobber
            #TODO: double check this logic + add test case
            data.update(**kwargs)
            if self.is_multigraph() and ekey is None:
                # specifically allocate a key
                if src in used_keys and dst in used_keys[src]:
                    pass # already established
                else:
                    try:
                        used_keys[src][dst] = self._graph.adj[src][dst].keys()
                    except KeyError:
                        # no edges exist
                        used_keys[src][dst] = []

                # now have the keys mapping
                ekey=len(used_keys[src][dst])
                while ekey in used_keys[src][dst]:
                    ekey+=1

                used_keys[src][dst].append(ekey)

            edges_to_add = []
            if self.is_multigraph():
                edges_to_add.append((src, dst, ekey, dict(data)))
                if bidirectional:
                    edges_to_add.append((dst, src, ekey, dict(data)))
            else:
                edges_to_add.append((src, dst, dict(data)))
                if bidirectional:
                    edges_to_add.append((dst, src, dict(data)))


            #TODO: warn if not multigraph

            self._graph.add_edges_from(edges_to_add)
            all_edges += edges_to_add

        if self.is_multigraph():
            return [
            NmEdge(self.anm, self._overlay_id, src, dst, ekey) if ekey
            else NmEdge(self.anm, self._overlay_id, src, dst) # default no ekey set
            for src, dst, ekey, _ in all_edges]
        else:
            return [NmEdge(self.anm, self._overlay_id, src, dst)
            for src, dst, _ in all_edges]

Example 31

Project: autonetkit
Source File: graphml.py
View license
def load_graphml(input_data, defaults = True):


    # TODO: allow default properties to be passed in as dicts

    try:
        graph = nx.read_graphml(input_data)
    except IOError, e:
        acceptable_errors = set([2, 36, 63])  # no such file or directory
                                              # input string too long for filename
                                              # input string too long for filename
        if e.errno in acceptable_errors:
            from xml.etree.cElementTree import ParseError

            # try as data string rather than filename string

            try:
                input_pseduo_fh = StringIO(input_data)  # load into filehandle to networkx
                graph = nx.read_graphml(input_pseduo_fh)
            except IOError:
                raise autonetkit.exception.AnkIncorrectFileFormat
            except IndexError:
                raise autonetkit.exception.AnkIncorrectFileFormat
            except ParseError:
                raise autonetkit.exception.AnkIncorrectFileFormat
            except ParseError:
                raise autonetkit.exception.AnkIncorrectFileFormat
        else:
            raise e

    if graph.is_multigraph():
        log.info('Input graph is multigraph. Converting to single-edge graph'
                 )
        if graph.is_directed():
            graph = nx.DiGraph(graph)
        else:
            graph = nx.Graph(graph)

    # TODO: need to support edge index keying for multi graphs
    graph.remove_edges_from(edge for edge in graph.selfloop_edges())

# TODO: if selfloops then log that are removing

    letters_single = (c for c in string.lowercase)  # a, b, c, ... z
    letters_double = ('%s%s' % (a, b) for (a, b) in
        itertools.product(string.lowercase, string.lowercase))  # aa, ab, ... zz
    letters = itertools.chain(letters_single, letters_double)  # a, b, c, .. z, aa, ab, ac, ... zz

# TODO: need to get set of current labels, and only return if not in this set

    # TODO: add cloud, host, etc
    # prefixes for unlabelled devices, ie router -> r_a

    label_prefixes = {'router': 'r', 'switch': 'sw', 'server': 'se'}

    current_labels = set(graph.node[node].get('label') for node in
                         graph.nodes_iter())
    unique_label = (letter for letter in letters if letter
        not in current_labels)

# TODO: make sure device label set

    ank_graph_defaults = settings['Graphml']['Graph Defaults']
    for (key, val) in ank_graph_defaults.items():
        if key not in graph.graph:
            graph.graph[key] = val

    # handle yEd exported booleans: if a boolean is set, then only the nodes marked true have the attribute. need to map the remainder to be false to allow ANK logic
    # for node in graph.nodes(data=True):
        # print node

    all_labels = dict((n, d.get('label')) for (n, d) in
                      graph.nodes(data=True))
    label_counts = defaultdict(list)
    for (node, label) in all_labels.items():
        label_counts[label].append(node)

    # set default name for blank labels to ensure unique

    try:
        blank_labels = [v for (k, v) in label_counts.items()
                        if not k].pop()  # strip outer list
    except IndexError:
        blank_labels = []  # no blank labels
    for (index, node) in enumerate(blank_labels):

        # TODO: log message that no label set, so setting default

        graph.node[node]['label'] = 'none___%s' % index

    duplicates = [(k, v) for (k, v) in label_counts.items() if k
                  and len(v) > 1]
    for (label, nodes) in duplicates:
        for node in nodes:

            # TODO: need to check they don't all have same ASN... if so then warn

            try:
                graph.node[node]['label'] = '%s_%s' \
                    % (graph.node[node]['label'], graph.node[node]['asn'
                       ])
            except KeyError:
                log.warning('Unable to set new label for duplicate node %s: %s'
                             % (node, graph.node[node].get('label')))

    boolean_attributes = set(k for (n, d) in graph.nodes(data=True)
                             for (k, v) in d.items() if isinstance(v,
                             bool))

    for node in graph:
        for attr in boolean_attributes:
            if attr not in graph.node[node]:
                graph.node[node][attr] = False

    boolean_attributes = set(k for (n1, d1) in graph.edge.items()
                             for (n2, d2) in d1.items() for (k, v) in
                             d2.items() if isinstance(v, bool))
    for (n1, d1) in graph.edge.items():
        for (n2, d2) in d1.items():
            for attr in boolean_attributes:
                if attr not in graph.edge[n1][n2]:
                    graph.edge[n1][n2][attr] = False

# TODO: store these in config file

    if defaults:
        ank_node_defaults = settings['Graphml']['Node Defaults']
        node_defaults = graph.graph['node_default']  # update with defaults from graphml
        for (key, val) in node_defaults.items():
            if val == 'False':
                node_defaults[key] = False

    # TODO: do a dict update before applying so only need to iterate nodes once

        for (key, val) in ank_node_defaults.items():
            if key not in node_defaults or node_defaults[key] == 'None':
                node_defaults[key] = val

        for node in graph:
            for (key, val) in node_defaults.items():
                if key not in graph.node[node]:
                    graph.node[node][key] = val

    # set address family

        graph.graph['address_family'] = 'v4'
        graph.graph['enable_routing'] = True

    # map lat/lon from zoo to crude x/y approximation

    if graph.graph.get('Creator') == 'Topology Zoo Toolset':
        all_lat = [graph.node[n].get('Latitude') for n in graph
                   if graph.node[n].get('Latitude')]
        all_lon = [graph.node[n].get('Longitude') for n in graph
                   if graph.node[n].get('Longitude')]

        lat_min = min(all_lat)
        lon_min = min(all_lon)
        lat_max = max(all_lat)
        lon_max = max(all_lon)
        lat_mean = (lat_max - lat_min) / 2
        lon_mean = (lon_max - lon_min) / 2
        lat_scale = 500 / (lat_max - lat_min)
        lon_scale = 500 / (lon_max - lon_min)
        for node in graph:
            lat = graph.node[node].get('Latitude') or lat_mean  # set default to be mean of min/max
            lon = graph.node[node].get('Longitude') or lon_mean  # set default to be mean of min/max
            graph.node[node]['y'] = -1 * lat * lat_scale
            graph.node[node]['x'] = lon * lon_scale

    if not (any(graph.node[n].get('x') for n in graph)
            and any(graph.node[n].get('y') for n in graph)):

# No x, y set, layout in a grid

        grid_length = int(math.ceil(math.sqrt(len(graph))))
        co_ords = [(x * 100, y * 100) for y in range(grid_length)
                   for x in range(grid_length)]

        # (0,0), (100, 0), (200, 0), (0, 100), (100, 100) ....

        for node in sorted(graph):
            (x, y) = co_ords.pop(0)
            graph.node[node]['x'] = x
            graph.node[node]['y'] = y

    # and ensure asn is integer, x and y are floats

    for node in sorted(graph):
        graph.node[node]['asn'] = int(graph.node[node]['asn'])
        if graph.node[node]['asn'] == 0:
            log.debug('Node %s has ASN set to 0. Setting to 1'
                      % graph.node[node]['label'])
            graph.node[node]['asn'] = 1
        try:
            x = float(graph.node[node]['x'])
        except KeyError:
            x = 0
        graph.node[node]['x'] = x
        try:
            y = float(graph.node[node]['y'])
        except KeyError:
            y = 0
        graph.node[node]['y'] = y
        try:
            graph.node[node]['label']
        except KeyError:
            device_type = graph.node[node]['device_type']
            graph.node[node]['label'] = '%s_%s' \
                % (label_prefixes[device_type], unique_label.next())

    if defaults:
        ank_edge_defaults = settings['Graphml']['Edge Defaults']
        edge_defaults = graph.graph['edge_default']
        for (key, val) in ank_edge_defaults.items():
            if key not in edge_defaults or edge_defaults[key] == 'None':
                edge_defaults[key] = val

        for (src, dst) in graph.edges():
            for (key, val) in edge_defaults.items():
                if key not in graph[src][dst]:
                    graph[src][dst][key] = val

# apply defaults
# relabel nodes
# other handling... split this into seperate module!
# relabel based on label: assume unique by now!
    # if graph.graph.get("Network") == "European NRENs":
        # TODO: test if non-unique labels, if so then warn and proceed with this logic
        # we need to map node ids to contain network to ensure unique labels
        # mapping = dict( (n, "%s__%s" % (d['label'], d['asn'])) for n, d in graph.nodes(data=True))


    mapping = dict((n, d['label']) for (n, d) in graph.nodes(data=True))  # TODO: use dict comprehension
    if not all(key == val for (key, val) in mapping.items()):
        nx.relabel_nodes(graph, mapping, copy=False)  # Networkx wipes data if remap with same labels

    graph.graph['file_type'] = 'graphml'

    selfloop_count = graph.number_of_selfloops()
    if selfloop_count > 0:
        log.warning("Self loops present: do multiple nodes have the same label?")
        selfloops = ", ".join(str(e) for e in graph.selfloop_edges())
        log.warning("Removing selfloops: %s" % selfloops)
        graph.remove_edges_from(edge for edge in graph.selfloop_edges())


    return graph

Example 32

Project: autonetkit
Source File: ipv4.py
View license
    def build(self, group_attr='asn'):
        """Builds tree from unallocated_nodes,
        groupby is the attribute to build subtrees from"""

        subgraphs = []

# if network final octet is .0 eg 10.0.0.0 or 192.168.0.0, then add extra "dummy" node, so don't have a loopback of 10.0.0.0
# Change strategy: if just hosts (ie loopbacks), then allocate as a large
# collision domain

        if not len(self.unallocated_nodes):

            # no nodes to allocate - eg could be no collision domains

            return

        unallocated_nodes = self.unallocated_nodes
        key_func = lambda x: x.get(group_attr)
        if all(isinstance(item, autonetkit.anm.NmPort)
               and item.is_loopback for item in unallocated_nodes):
            # interface, map key function to be the interface's node
            key_func = lambda x: x.node.get(group_attr)

        unallocated_nodes = sorted(unallocated_nodes, key=key_func)
        groupings = itertools.groupby(unallocated_nodes, key=key_func)
        prefixes_by_attr = {}

        for (attr_value, items) in groupings:

            # make subtree for each attr

            items = sorted(list(items))
            subgraph = nx.DiGraph()

            if all(isinstance(item, autonetkit.anm.NmPort)
                   for item in items):

                # interface

                if all(item.is_loopback for item in items):
                    parent_id = self.next_node_id
                    # group all loopbacks into single subnet
                    prefixlen = 32 - subnet_size(len(items))
                    subgraph.add_node(parent_id, prefixlen=prefixlen,
                                      loopback_group=True)
                    for item in sorted(items):

                        # subgraph.add_edge(node, child_a)

                        item_id = self.next_node_id
                        subgraph.add_node(item_id, prefixlen=32,
                                          host=item)
                        subgraph.add_edge(parent_id, item_id)

                    root_node = parent_id
                    subgraphs.append(subgraph)
                    subgraph.graph['root'] = root_node
                    subgraph.node[root_node]['group_attr'] = attr_value
                    subgraph.node[root_node]['prefixlen'] = 24
                    # finished for loopbacks, continue only for collision
                    # domains
                    continue

            if all(item.is_l3device() for item in items):

                # Note: only l3 devices are added for loopbacks: cds allocate
                # to edges not devices (for now) - will be fixed when move to
                # proper interface model

                parent_id = self.next_node_id
                # group all loopbacks into single subnet
                prefixlen = 32 - subnet_size(len(items))
                subgraph.add_node(parent_id, prefixlen=prefixlen,
                                  loopback_group=True)
                for item in sorted(items):

                    # subgraph.add_edge(node, child_a)

                    item_id = self.next_node_id
                    subgraph.add_node(item_id, prefixlen=32, host=item)
                    subgraph.add_edge(parent_id, item_id)

                root_node = parent_id
                subgraphs.append(subgraph)
                subgraph.graph['root'] = root_node
                subgraph.node[root_node]['group_attr'] = attr_value
                # finished for loopbacks, continue only for collision domains
                continue

            for item in sorted(items):
                if item.broadcast_domain:
                    subgraph.add_node(self.next_node_id, prefixlen=32
                                      - subnet_size(item.degree()), host=item)
                if item.is_l3device():
                    subgraph.add_node(self.next_node_id, prefixlen=32,
                                      host=item)

            # now group by levels

            level_counts = defaultdict(int)

            nodes_by_level = defaultdict(list)
            for node in subgraph.nodes():
                prefixlen = subgraph.node[node]['prefixlen']
                nodes_by_level[prefixlen].append(node)

            log.debug('Building IP subtree for %s %s' % (group_attr,
                                                         attr_value))

            for (level, nodes) in nodes_by_level.items():
                level_counts[level] = len(nodes)

            self.add_parent_nodes(subgraph, level_counts)

# test if min_level node is bound, if so then add a parent, so root for AS
# isn't a cd

            min_level = min(level_counts)
            min_level_nodes = [n for n in subgraph
                               if subgraph.node[n]['prefixlen']
                               == min_level]

            # test if bound

            if len(min_level_nodes) == 2:
                subgraph.add_node(self.next_node_id,
                                  {'prefixlen': min_level - 2})
                subgraph.add_node(self.next_node_id,
                                  {'prefixlen': min_level - 2})
                subgraph.add_node(self.next_node_id,
                                  {'prefixlen': min_level - 1})
            if len(min_level_nodes) == 1:
                subgraph.add_node(self.next_node_id,
                                  {'prefixlen': min_level - 1})

            # rebuild with parent nodes

            nodes_by_level = defaultdict(list)
            for node in sorted(subgraph.nodes()):
                prefixlen = subgraph.node[node]['prefixlen']
                nodes_by_level[prefixlen].append(node)

            root_node = self.build_tree(subgraph, level_counts,
                                        nodes_by_level)
            subgraphs.append(subgraph)

            subgraph.graph['root'] = root_node

# FOrce to be a /16 block
# TODO: document this

            subgraph.node[root_node]['prefixlen'] = 16
            subgraph.node[root_node]['group_attr'] = attr_value
            prefixes_by_attr[attr_value] = subgraph.node[
                root_node]['prefixlen']

        global_graph = nx.DiGraph()
        subgraphs = sorted(subgraphs, key=lambda x:
                           subgraph.node[subgraph.graph['root'
                                                        ]]['group_attr'])
        root_nodes = [subgraph.graph['root'] for subgraph in subgraphs]
        root_nodes = []
        for subgraph in subgraphs:
            root_node = subgraph.graph['root']
            root_nodes.append(root_node)
            global_graph.add_node(root_node, subgraph.node[root_node])

        nodes_by_level = defaultdict(list)
        for node in root_nodes:
            prefixlen = global_graph.node[node]['prefixlen']
            nodes_by_level[prefixlen].append(node)

        level_counts = defaultdict(int)
        for (level, nodes) in nodes_by_level.items():
            level_counts[level] = len(nodes)

        self.add_parent_nodes(global_graph, level_counts)

# rebuild nodes by level
# TODO: make this a function

        nodes_by_level = defaultdict(list)
        for node in global_graph:
            prefixlen = global_graph.node[node]['prefixlen']
            nodes_by_level[prefixlen].append(node)

        global_root = self.build_tree(global_graph, level_counts,
                                      nodes_by_level)
        global_root = TreeNode(global_graph, global_root)

        for subgraph in subgraphs:
            global_graph = nx.compose(global_graph, subgraph)

        # now allocate the IPs

        global_prefix_len = global_root.prefixlen

        # TODO: try/catch if the block is too small for prefix

        try:
            global_ip_block = \
                self.root_ip_block.subnet(global_prefix_len).next()
        except StopIteration:
            #message = ("Unable to allocate IPv4 subnets. ")
            formatted_prefixes = ", ".join(
                "AS%s: /%s" % (k, v) for k, v in sorted(prefixes_by_attr.items()))
            message = ("Cannot create requested number of /%s subnets from root block %s. Please specify a larger root IP block. (Requested subnet allocations are: %s)"
                       % (global_prefix_len, self.root_ip_block, formatted_prefixes))
            log.error(message)
            # TODO: throw ANK specific exception here
            raise AutoNetkitException(message)
        self.graph = global_graph

# add children of collision domains

        cd_nodes = [n for n in self if n.is_broadcast_domain()]
        for cd in sorted(cd_nodes):
            for edge in sorted(cd.host.edges()):

                # TODO: sort these

                child_id = self.next_node_id
                cd_id = cd.node
                global_graph.add_node(child_id, prefixlen=32,
                                      host=edge.dst_int)
                # cd -> neigh (cd is parent)
                global_graph.add_edge(cd_id, child_id)

# TODO: make allocate seperate step

        def allocate(node):

            # children = graph.successors(node)

            children = sorted(node.children())
            prefixlen = node.prefixlen + 1

            # workaround for clobbering attr subgraph root node with /16 if was
            # a /28

            subnet = node.subnet.subnet(prefixlen)

# handle case where children subnet

            # special case of single AS -> root is loopback_group
            if node.is_loopback_group() or node.is_broadcast_domain():

                # TODO: generalise this rather than repeated code with below
                # node.subnet = subnet.next() # Note: don't break into smaller
                # subnets if single-AS

                # ensures start at .1 rather than .0
                iterhosts = node.subnet.iter_hosts()
                sub_children = node.children()
                for sub_child in sorted(sub_children):

                    # TODO: tidy up this allocation to always record the subnet

                    if sub_child.is_interface() \
                            and sub_child.host.is_loopback:
                        if sub_child.host.is_loopback_zero:

                            # loopback zero, just store the ip address

                            sub_child.ip_address = iterhosts.next()
                        else:

                            # secondary loopback

                            sub_child.ip_address = iterhosts.next()
                            sub_child.subnet = node.subnet
                    elif sub_child.is_interface() \
                            and sub_child.host.is_physical:

                        # physical interface

                        sub_child.ip_address = iterhosts.next()
                        sub_child.subnet = node.subnet
                    else:
                        sub_child.subnet = iterhosts.next()

                return

            for child in sorted(children):

                # traverse the tree

                if child.is_broadcast_domain():
                    subnet = subnet.next()
                    child.subnet = subnet
                    # ensures start at .1 rather than .0
                    iterhosts = child.subnet.iter_hosts()
                    sub_children = child.children()
                    for sub_child in sorted(sub_children):
                        if sub_child.is_interface():
                            interface = sub_child.host
                            if interface.is_physical:

                                # physical interface

                                sub_child.ip_address = iterhosts.next()
                                sub_child.subnet = subnet
                            elif interface.is_loopback \
                                    and not interface.is_loopback_zero:

                                # secondary loopback interface

                                sub_child.ip_address = iterhosts.next()
                                sub_child.subnet = subnet
                        else:
                            sub_child.subnet = iterhosts.next()

                        #log.debug('Allocate sub_child to %s %s'% (sub_child, sub_child.subnet))
                elif child.is_host():
                    child.subnet = subnet.next()
                elif child.is_loopback_group():
                    child.subnet = subnet.next()
                    # ensures start at .1 rather than .0
                    iterhosts = child.subnet.iter_hosts()
                    sub_children = child.children()
                    for sub_child in sorted(sub_children):
                        if sub_child.is_interface() \
                                and not sub_child.host.is_loopback_zero:

                           # secondary loopback

                            sub_child.ip_address = iterhosts.next()
                            sub_child.subnet = child.subnet
                        else:
                            sub_child.subnet = iterhosts.next()
                else:
                    child.subnet = subnet.next()
                    allocate(child)  # continue down the tree

        global_root.subnet = global_ip_block

# TODO: fix this workaround where referring to the wrong graph

        global_root_id = global_root.node
        global_root = TreeNode(global_graph, global_root_id)
        allocate(global_root)

# check for parentless nodes

        self.graph = global_graph
        self.root_node = global_root

Example 33

Project: qgisSpaceSyntaxToolkit
Source File: gml.py
View license
def parse_gml_lines(lines, label, destringizer):
    """Parse GML into a graph.
    """
    def tokenize():
        patterns = [
            r'[A-Za-z][0-9A-Za-z]*\s+',  # keys
            r'[+-]?(?:[0-9]*\.[0-9]+|[0-9]+\.[0-9]*)(?:[Ee][+-]?[0-9]+)?',  # reals
            r'[+-]?[0-9]+',   # ints
            r'".*?"',         # strings
            r'\[',            # dict start
            r'\]',            # dict end
            r'#.*$|\s+'       # comments and whitespaces
            ]
        tokens = re.compile(
            '|'.join('(' + pattern + ')' for pattern in patterns))
        lineno = 0
        for line in lines:
            length = len(line)
            pos = 0
            while pos < length:
                match = tokens.match(line, pos)
                if match is not None:
                    for i in range(len(patterns)):
                        group = match.group(i + 1)
                        if group is not None:
                            if i == 0:    # keys
                                value = group.rstrip()
                            elif i == 1:  # reals
                                value = float(group)
                            elif i == 2:  # ints
                                value = int(group)
                            else:
                                value = group
                            if i != 6:    # comments and whitespaces
                                yield (i, value, lineno + 1, pos + 1)
                            pos += len(group)
                            break
                else:
                    raise NetworkXError('cannot tokenize %r at (%d, %d)' %
                                        (line[pos:], lineno + 1, pos + 1))
            lineno += 1
        yield (None, None, lineno + 1, 1)  # EOF

    def unexpected(curr_token, expected):
        type, value, lineno, pos = curr_token
        raise NetworkXError(
            'expected %s, found %s at (%d, %d)' %
            (expected, repr(value) if value is not None else 'EOF', lineno,
             pos))

    def consume(curr_token, type, expected):
        if curr_token[0] == type:
            return next(tokens)
        unexpected(curr_token, expected)

    def parse_kv(curr_token):
        dct = defaultdict(list)
        while curr_token[0] == 0:  # keys
            key = curr_token[1]
            curr_token = next(tokens)
            type = curr_token[0]
            if type == 1 or type == 2:  # reals or ints
                value = curr_token[1]
                curr_token = next(tokens)
            elif type == 3:  # strings
                value = unescape(curr_token[1][1:-1])
                if destringizer:
                    try:
                        value = destringizer(value)
                    except ValueError:
                        pass
                curr_token = next(tokens)
            elif type == 4:  # dict start
                curr_token, value = parse_dict(curr_token)
            else:
                unexpected(curr_token, "an int, float, string or '['")
            dct[key].append(value)
        dct = {key: (value if not isinstance(value, list) or len(value) != 1
                     else value[0]) for key, value in dct.items()}
        return curr_token, dct

    def parse_dict(curr_token):
        curr_token = consume(curr_token, 4, "'['")    # dict start
        curr_token, dct = parse_kv(curr_token)
        curr_token = consume(curr_token, 5, "']'")  # dict end
        return curr_token, dct

    def parse_graph():
        curr_token, dct = parse_kv(next(tokens))
        if curr_token[0] is not None:  # EOF
            unexpected(curr_token, 'EOF')
        if 'graph' not in dct:
            raise NetworkXError('input contains no graph')
        graph = dct['graph']
        if isinstance(graph, list):
            raise NetworkXError('input contains more than one graph')
        return graph

    tokens = tokenize()
    graph = parse_graph()

    directed = graph.pop('directed', False)
    multigraph = graph.pop('multigraph', False)
    if not multigraph:
        G = nx.DiGraph() if directed else nx.Graph()
    else:
        G = nx.MultiDiGraph() if directed else nx.MultiGraph()
    G.graph.update((key, value) for key, value in graph.items()
                   if key != 'node' and key != 'edge')

    def pop_attr(dct, type, attr, i):
        try:
            return dct.pop(attr)
        except KeyError:
            raise NetworkXError(
                "%s #%d has no '%s' attribute" % (type, i, attr))

    nodes = graph.get('node', [])
    mapping = {}
    labels = set()
    for i, node in enumerate(nodes if isinstance(nodes, list) else [nodes]):
        id = pop_attr(node, 'node', 'id', i)
        if id in G:
            raise NetworkXError('node id %r is duplicated' % (id,))
        if label != 'id':
            label = pop_attr(node, 'node', 'label', i)
            if label in labels:
                raise NetworkXError('node label %r is duplicated' % (label,))
            labels.add(label)
            mapping[id] = label
        G.add_node(id, node)

    edges = graph.get('edge', [])
    for i, edge in enumerate(edges if isinstance(edges, list) else [edges]):
        source = pop_attr(edge, 'edge', 'source', i)
        target = pop_attr(edge, 'edge', 'target', i)
        if source not in G:
            raise NetworkXError(
                'edge #%d has an undefined source %r' % (i, source))
        if target not in G:
            raise NetworkXError(
                'edge #%d has an undefined target %r' % (i, target))
        if not multigraph:
            if not G.has_edge(source, target):
                G.add_edge(source, target, edge)
            else:
                raise nx.NetworkXError(
                    'edge #%d (%r%s%r) is duplicated' %
                    (i, source, '->' if directed else '--', target))
        else:
            key = edge.pop('key', None)
            if key is not None and G.has_edge(source, target, key):
                raise nx.NetworkXError(
                    'edge #%d (%r%s%r, %r) is duplicated' %
                    (i, source, '->' if directed else '--', target, key))
            G.add_edge(source, target, key, edge)

    if label != 'id':
        G = nx.relabel_nodes(G, mapping)
        if 'name' in graph:
            G.graph['name'] = graph['name']
        else:
            del G.graph['name']
    return G

Example 34

Project: pyNastran
Source File: attributes.py
View license
    def __init_attributes(self):
        """
        Creates storage objects for the BDF object.
        This would be in the init but doing it this way allows for better
        inheritance

        References:
          1.  http://www.mscsoftware.com/support/library/conf/wuc87/p02387.pdf
        """
        self.bdf_filename = None
        self.punch = None
        self._encoding = None

        #: list of execive control deck lines
        self.executive_control_lines = []

        #: list of case control deck lines
        self.case_control_lines = []

        self._auto_reject = False
        self._solmap_to_value = {
            'NONLIN': 101,  # 66 -> 101 per Reference 1
            'SESTATIC': 101,
            'SESTATICS': 101,
            'SEMODES': 103,
            'BUCKLING': 105,
            'SEBUCKL': 105,
            'NLSTATIC': 106,
            'SEDCEIG': 107,
            'SEDFREQ': 108,
            'SEDTRAN': 109,
            'SEMCEIG': 110,
            'SEMFREQ': 111,
            'SEMTRAN': 112,
            'CYCSTATX': 114,
            'CYCMODE': 115,
            'CYCBUCKL': 116,
            'CYCFREQ': 118,
            'NLTRAN': 129,
            'AESTAT': 144,
            'FLUTTR': 145,
            'SEAERO': 146,
            'NLSCSH': 153,
            'NLTCSH': 159,
            'DBTRANS': 190,
            'DESOPT': 200,

            # guessing
            #'CTRAN' : 115,
            'CFREQ' : 118,

            # solution 200 names
            'STATICS': 101,
            'MODES': 103,
            'BUCK': 105,
            'DFREQ': 108,
            'MFREQ': 111,
            'MTRAN': 112,
            'DCEIG': 107,
            'MCEIG': 110,
            #'HEAT'     : None,
            #'STRUCTURE': None,
            #'DIVERGE'  : None,
            'FLUTTER': 145,
            'SAERO': 146,
        }

        self.rsolmap_to_str = {
            66: 'NONLIN',
            101: 'SESTSTATIC',  # linear static
            103: 'SEMODES',  # modal
            105: 'BUCKLING',  # buckling
            106: 'NLSTATIC',  # non-linear static
            107: 'SEDCEIG',  # direct complex frequency response
            108: 'SEDFREQ',  # direct frequency response
            109: 'SEDTRAN',  # direct transient response
            110: 'SEMCEIG',  # modal complex eigenvalue
            111: 'SEMFREQ',  # modal frequency response
            112: 'SEMTRAN',  # modal transient response
            114: 'CYCSTATX',
            115: 'CYCMODE',
            116: 'CYCBUCKL',
            118: 'CYCFREQ',
            129: 'NLTRAN',  # nonlinear transient
            144: 'AESTAT',  # static aeroelastic
            145: 'FLUTTR',  # flutter/aeroservoelastic
            146: 'SEAERO',  # dynamic aeroelastic
            153: 'NLSCSH',  # nonlinear static thermal
            159: 'NLTCSH',  # nonlinear transient thermal
            190: 'DBTRANS',
            200: 'DESOPT',  # optimization
        }

        # ------------------------ bad duplicates ----------------------------
        self._iparse_errors = 0
        self._nparse_errors = 0
        self._stop_on_parsing_error = True
        self._stop_on_duplicate_error = True
        self._stored_parse_errors = []

        self._duplicate_nodes = []
        self._duplicate_elements = []
        self._duplicate_properties = []
        self._duplicate_materials = []
        self._duplicate_masses = []
        self._duplicate_thermal_materials = []
        self._duplicate_coords = []
        self.values_to_skip = {}

        # ------------------------ structural defaults -----------------------
        #: the analysis type
        self._sol = None
        #: used in solution 600, method
        self.sol_method = None
        #: the line with SOL on it, marks ???
        self.sol_iline = None
        self.case_control_deck = None

        #: store the PARAM cards
        self.params = {}
        # ------------------------------- nodes -------------------------------
        # main structural block
        #: stores SPOINT, GRID cards
        self.nodes = {}
        #: stores POINT cards
        self.points = {}
        #self.grids = {}
        self.spoints = None
        self.epoints = None
        #: stores GRIDSET card
        self.gridSet = None

        #: stores elements (CQUAD4, CTRIA3, CHEXA8, CTETRA4, CROD, CONROD,
        #: etc.)
        self.elements = {}

        #: stores rigid elements (RBE2, RBE3, RJOINT, etc.)
        self.rigid_elements = {}
        #: stores PLOTELs
        self.plotels = {}

        #: store CONM1, CONM2, CMASS1,CMASS2, CMASS3, CMASS4, CMASS5
        self.masses = {}
        self.properties_mass = {} # PMASS

        #: stores LOTS of propeties (PBAR, PBEAM, PSHELL, PCOMP, etc.)
        self.properties = {}

        #: stores MAT1, MAT2, MAT3, MAT8, MAT10, MAT11
        self.materials = {}

        #: defines the MAT4, MAT5
        self.thermal_materials = {}

        #: defines the MATHE, MATHP
        self.hyperelastic_materials = {}

        #: stores MATSx
        self.MATS1 = {}
        self.MATS3 = {}
        self.MATS8 = {}

        #: stores MATTx
        self.MATT1 = {}
        self.MATT2 = {}
        self.MATT3 = {}
        self.MATT4 = {}
        self.MATT5 = {}
        self.MATT8 = {}
        self.MATT9 = {}

        #: stores the CREEP card
        self.creep_materials = {}

        # loads
        #: stores LOAD, FORCE, FORCE1, FORCE2, MOMENT, MOMENT1, MOMENT2,
        #: PLOAD, PLOAD2, PLOAD4, SLOAD
        #: GMLOAD, SPCD,
        #: QVOL
        self.loads = {}
        self.tics = {}

        # stores DLOAD entries.
        self.dloads = {}
        # stores ACSRCE, RLOAD1, RLOAD2, TLOAD1, TLOAD2, and ACSRCE entries.
        self.dload_entries = {}

        #self.gusts  = {} # Case Control GUST = 100
        #self.random = {} # Case Control RANDOM = 100

        #: stores coordinate systems
        origin = array([0., 0., 0.])
        zaxis = array([0., 0., 1.])
        xzplane = array([1., 0., 0.])
        coord = CORD2R(cid=0, rid=0, origin=origin, zaxis=zaxis, xzplane=xzplane)
        self.coords = {0 : coord}

        # --------------------------- constraints ----------------------------
        #: stores SUPORT1s
        #self.constraints = {} # suport1, anything else???
        self.suport = []
        self.suport1 = {}
        self.se_suport = []

        #: stores SPCADD, SPC, SPC1, SPCAX, GMSPC
        #self.spcObject = ConstraintObject()
        #: stores MPCADD,MPC
        #self.mpcObject = ConstraintObject()

        self.spcs = {}
        self.spcadds = {}

        self.mpcs = {}
        self.mpcadds = {}

        # --------------------------- dynamic ----------------------------
        #: stores DAREA
        self.dareas = {}
        self.dphases = {}

        self.pbusht = {}
        self.pdampt = {}
        self.pelast = {}

        #: frequencies
        self.frequencies = {}

        # ----------------------------------------------------------------
        #: direct matrix input - DMIG
        self.dmis = {}
        self.dmigs = {}
        self.dmijs = {}
        self.dmijis = {}
        self.dmiks = {}
        self._dmig_temp = defaultdict(list)

        # ----------------------------------------------------------------
        #: SETy
        self.sets = {}
        self.asets = []
        self.bsets = []
        self.csets = []
        self.qsets = []
        self.usets = {}

        #: SExSETy
        self.se_bsets = []
        self.se_csets = []
        self.se_qsets = []
        self.se_usets = {}
        self.se_sets = {}

        # ----------------------------------------------------------------
        #: tables
        self.tables = {}
        #: random_tables
        self.random_tables = {}
        #: TABDMP1
        self.tables_sdamping = {}

        # ----------------------------------------------------------------
        #: EIGB, EIGR, EIGRL methods
        self.methods = {}
        # EIGC, EIGP methods
        self.cMethods = {}

        # ---------------------------- optimization --------------------------
        # optimization
        self.dconadds = {}
        self.dconstrs = {}
        self.desvars = {}
        self.ddvals = {}
        self.dlinks = {}
        self.dresps = {}

        self.dtable = None
        self.dequations = {}

        #: stores DVPREL1, DVPREL2...might change to DVxRel
        self.dvprels = {}
        self.dvmrels = {}
        self.dvcrels = {}
        self.dvgrids = {}
        self.doptprm = None
        self.dscreen = {}

        # ------------------------- nonlinear defaults -----------------------
        #: stores NLPCI
        self.nlpcis = {}
        #: stores NLPARM
        self.nlparms = {}
        #: stores TSTEPs
        self.tsteps = {}
        #: stores TSTEPNL
        self.tstepnls = {}
        #: stores TF
        self.transfer_functions = {}
        #: stores DELAY
        self.delays = {}

        # --------------------------- aero defaults --------------------------
        # aero cards
        #: stores CAEROx
        self.caeros = {}
        #: stores PAEROx
        self.paeros = {}
        # stores MONPNT1
        self.monitor_points = []

        #: stores AECOMP
        self.aecomps = {}
        #: stores AEFACT
        self.aefacts = {}
        #: stores AELINK
        self.aelinks = {}
        #: stores AELIST
        self.aelists = {}
        #: stores AEPARAM
        self.aeparams = {}
        #: stores AESURF
        self.aesurf = {}
        #: stores AESURFS
        self.aesurfs = {}
        #: stores AESTAT
        self.aestats = {}
        #: stores CSSCHD
        self.csschds = {}

        #: store SPLINE1,SPLINE2,SPLINE4,SPLINE5
        self.splines = {}

        # ------ SOL 144 ------
        #: stores AEROS
        self.aeros = None

        #: stores TRIM
        self.trims = {}

        #: stores DIVERG
        self.divergs = {}

        # ------ SOL 145 ------
        #: stores AERO
        self.aero = None

        #: stores FLFACT
        self.flfacts = {}  #: .. todo:: can this be simplified ???
        #: stores FLUTTER
        self.flutters = {}
        #: mkaeros
        self.mkaeros = []

        # ------ SOL 146 ------
        #: stores GUST cards
        self.gusts = {}

        # ------------------------- thermal defaults -------------------------
        # BCs
        #: stores thermal boundary conditions - CONV,RADBC
        self.bcs = {}  # e.g. RADBC

        #: stores PHBDY
        self.phbdys = {}
        #: stores convection properties - PCONV, PCONVM ???
        self.convection_properties = {}
        #: stores TEMPD
        self.tempds = {}

        # -------------------------contact cards-------------------------------
        self.bcrparas = {}
        self.bctadds = {}
        self.bctparas = {}
        self.bctsets = {}
        self.bsurf = {}
        self.bsurfs = {}

        # ---------------------------------------------------------------------
        self._type_to_id_map = defaultdict(list)
        self._slot_to_type_map = {
            'params' : ['PARAM'],
            'nodes' : ['GRID', 'SPOINT', 'EPOINT'], # 'RINGAX',
            'points' : ['POINT'],
            'gridSet' : ['GRDSET'],
            #'POINT', 'POINTAX', 'RINGAX',

            # CMASS4 lies in the QRG
            'masses' : ['CONM1', 'CONM2', 'CMASS1', 'CMASS2', 'CMASS3', 'CMASS4'],

            'elements' : [
                'CELAS1', 'CELAS2', 'CELAS3', 'CELAS4',
                # 'CELAS5',
                'CBUSH', 'CBUSH1D', 'CBUSH2D',

                'CDAMP1', 'CDAMP2', 'CDAMP3', 'CDAMP4', 'CDAMP5',
                'CFAST',

                'CBAR', 'CROD', 'CTUBE', 'CBEAM', 'CBEAM3', 'CONROD', 'CBEND',
                'CTRIA3', 'CTRIA6', 'CTRIAR', 'CTRIAX', 'CTRIAX6',
                'CQUAD4', 'CQUAD8', 'CQUADR', 'CQUADX', 'CQUAD',
                'CPLSTN3', 'CPLSTN6', 'CPLSTN4', 'CPLSTN8',
                'CPLSTS3', 'CPLSTS6', 'CPLSTS4', 'CPLSTS8',

                'CTETRA', 'CPYRAM', 'CPENTA', 'CHEXA', 'CIHEX1',
                'CSHEAR', 'CVISC', 'CRAC2D', 'CRAC3D',
                'CGAP',

                # thermal
                'CHBDYE', 'CHBDYG', 'CHBDYP',
            ],
            'rigid_elements' : ['RBAR', 'RBAR1', 'RBE1', 'RBE2', 'RBE3', 'RROD', 'RSPLINE'],
            'plotels' : ['PLOTEL',],

            'properties_mass' : ['PMASS'],
            'properties' : [
                'PELAS', 'PGAP', 'PFAST', 'PLPLANE', 'PPLANE',
                'PBUSH', 'PBUSH1D',
                'PDAMP', 'PDAMP5',
                'PROD', 'PBAR', 'PBARL', 'PBEAM', 'PTUBE', 'PBEND', 'PBCOMP', 'PBRSECT', 'PBMSECT',
                'PBEAML',  # not fully supported
                # 'PBEAM3',

                'PSHELL', 'PCOMP', 'PCOMPG', 'PSHEAR',
                'PSOLID', 'PLSOLID', 'PVISC', 'PRAC2D', 'PRAC3D',
                'PIHEX', 'PCOMPS',
            ],
            'pdampt' : ['PDAMPT',],
            'pelast' : ['PELAST',],
            'pbusht' : ['PBUSHT',],

            # materials
            'materials' : ['MAT1', 'MAT2', 'MAT3', 'MAT8', 'MAT9', 'MAT10', 'MAT11'],
            'hyperelastic_materials' : ['MATHE', 'MATHP',],
            'creep_materials' : ['CREEP'],
            'MATT1' : ['MATT1'],
            'MATT2' : ['MATT2'],
            'MATT3' : ['MATT3'],
            'MATT4' : ['MATT4'], # thermal
            'MATT5' : ['MATT5'], # thermal
            'MATT8' : ['MATT8'],
            'MATT9' : ['MATT9'],
            'MATS1' : ['MATS1'],
            'MATS3' : ['MATS3'],
            'MATS8' : ['MATS8'],

            # 'MATHE'
            #'EQUIV', # testing only, should never be activated...

            # thermal materials
            'thermal_materials' : ['MAT4', 'MAT5',],

            # spc/mpc constraints - TODO: is this correct?
            'spcs' : ['SPC', 'SPC1', 'SPCAX', 'SPCADD', 'GMSPC'],
            #'spcadds' : ['SPCADD'],
            #'mpcadds' : ['MPCADD'],
            'mpcs' : ['MPC', 'MPCADD'],
            'suport' : ['SUPORT'],
            'suport1' : ['SUPORT1'],
            'se_suport' : ['SESUP'],

            # loads
            'loads' : [
                'LOAD', 'LSEQ', 'RANDPS',
                'FORCE', 'FORCE1', 'FORCE2',
                'MOMENT', 'MOMENT1', 'MOMENT2',
                'GRAV', 'ACCEL', 'ACCEL1',
                'PLOAD', 'PLOAD1', 'PLOAD2', 'PLOAD4',
                'PLOADX1', 'RFORCE', 'RFORCE1', 'SLOAD',
                'GMLOAD', 'SPCD', 'LOADCYN',

                # thermal
                'TEMP', 'QBDY1', 'QBDY2', 'QBDY3', 'QHBDY',
                'QVOL',
                ],
            'dloads' : ['DLOAD', ],
            # stores RLOAD1, RLOAD2, TLOAD1, TLOAD2, and ACSRCE entries.
            'dload_entries' : ['ACSRCE', 'TLOAD1', 'TLOAD2', 'RLOAD1', 'RLOAD2',],

            # aero cards
            'aero' : ['AERO'],
            'aeros' : ['AEROS'],
            'gusts' : ['GUST'],
            'flutters' : ['FLUTTER'],
            'flfacts' : ['FLFACT'],
            'mkaeros' : ['MKAERO1', 'MKAERO2'],
            'aecomps' : ['AECOMP'],
            'aefacts' : ['AEFACT'],
            'aelinks' : ['AELINK'],
            'aelists' : ['AELIST'],
            'aeparams' : ['AEPARM'],
            'aesurf' : ['AESURF'],
            'aesurfs' : ['AESURFS'],
            'aestats' : ['AESTAT'],
            'caeros' : ['CAERO1', 'CAERO2', 'CAERO3', 'CAERO4', 'CAERO5'],
            'paeros' : ['PAERO1', 'PAERO2', 'PAERO3', 'PAERO4', 'PAERO5'],
            'monitor_points' : ['MONPNT1'],
            'splines' : ['SPLINE1', 'SPLINE2', 'SPLINE4', 'SPLINE5',],
            'csschds' : ['CSSCHD',],
            #'SPLINE3', 'SPLINE6', 'SPLINE7',
            'trims' : ['TRIM',],
            'divergs' : ['DIVERG'],

            # coords
            'coords' : ['CORD1R', 'CORD1C', 'CORD1S',
                        'CORD2R', 'CORD2C', 'CORD2S',
                        'GMCORD'],

            # temperature cards
            'tempds' : ['TEMPD'],

            'phbdys' : ['PHBDY'],
            'convection_properties' : ['PCONV', 'PCONVM'],

            # stores thermal boundary conditions
            'bcs' : ['CONV', 'RADBC', 'RADM'],


            # dynamic cards
            'dareas' : ['DAREA'],
            'dphases' : ['DPHASE'],
            'nlparms' : ['NLPARM'],
            'nlpcis' : ['NLPCI'],
            'tsteps' : ['TSTEP'],
            'tstepnls' : ['TSTEPNL'],
            'transfer_functions' : ['TF'],
            'delays' : ['DELAY'],

            'frequencies' : ['FREQ', 'FREQ1', 'FREQ2', 'FREQ4'],

            # direct matrix input cards
            'dmigs' : ['DMIG'],
            'dmijs' : ['DMIJ'],
            'dmijis' : ['DMIJI'],
            'dmiks' : ['DMIK'],
            'dmis' : ['DMI'],

            # optimzation
            'dequations' : ['DEQATN'],
            'dtable' : ['DTABLE'],
            'dconstrs' : ['DCONSTR', 'DCONADD'],
            'desvars' : ['DESVAR'],
            'ddvals' : ['DDVAL'],
            'dlinks' : ['DLINK'],
            'dresps' : ['DRESP1', 'DRESP2', 'DRESP3',],
            'dvprels' : ['DVPREL1', 'DVPREL2'],
            'dvmrels' : ['DVMREL1', 'DVMREL2'],
            'dvcrels' : ['DVCREL1', 'DVCREL2'],
            'dvgrids' : ['DVGRID'],
            'doptprm' : ['DOPTPRM'],
            'dscreen' : ['DSCREEN'],


            # sets
            'asets' : ['ASET', 'ASET1'],
            'bsets' : ['BSET', 'BSET1',],
            'qsets' : ['QSET', 'QSET1'],
            'csets' : ['CSET', 'CSET1',],
            'usets' : ['USET', 'USET1',],
            'sets' : ['SET1', 'SET3',],

            # super-element sets
            'se_bsets' : ['SEBSET', 'SEBSET1'],
            'se_csets' : ['SECSET', 'SECSET1',],
            'se_qsets' : ['SEQSET', 'SEQSET1'],
            'se_usets' : ['SEUSET', 'SEQSET1'],
            'se_sets' : ['SESET'],
            # SEBSEP

            'tables' : [
                'TABLEHT', 'TABRNDG',
                'TABLED1', 'TABLED2', 'TABLED3', 'TABLED4',
                'TABLEM1', 'TABLEM2', 'TABLEM3', 'TABLEM4',
                'TABLES1', 'TABLEST',
                ],
            'tables_sdamping' : ['TABDMP1'],
            'random_tables' : ['TABRND1', 'TABRNDG',],

            # initial conditions - sid (set ID)
            ##'TIC',  (in bdf_tables.py)

            # methods
            'methods' : ['EIGB', 'EIGR', 'EIGRL',],

            # cMethods
            'cMethods' : ['EIGC', 'EIGP',],

            # contact
            'bctparas' : ['BCTPARA'],
            'bcrparas' : ['BCRPARA'],
            'bctadds' : ['BCTADD'],
            'bctsets' : ['BCTSET'],
            'bsurf' : ['BSURF'],
            'bsurfs' : ['BSURFS'],

            ## other
            #'INCLUDE',  # '='
            #'ENDDATA',
        }
        self._type_to_slot_map = self.get_rslot_map()

Example 35

Project: pyNastran
Source File: attributes.py
View license
    def __init__(self):
        """creates the attributes for the BDF"""
        self._nastran_format = 'msc'
        self.is_nx = False
        self.is_msc = True
        self.max_int = 100000000

        #----------------------------------------
        self._is_cards_dict = True
        self.punch = None
        self.reject_lines = []

        self.set_precision()
        #----------------------------------------
        self.grid = GRID(self)
        self.grdset = GRDSET(self)
        self.point = POINT(self)
        self.grdset = GRDSET(self)
        self.spoint = SPOINT(self)
        self.epoint = EPOINT(self)
        self.pointax = POINTAX(self)
        self.coords = Coord(self)
        #----------------------------------------

        # springs
        self.pelas = PELAS(self)
        self.celas1 = CELAS1(self)
        self.celas2 = CELAS2(self)
        self.celas3 = CELAS3(self)
        self.celas4 = CELAS4(self)
        self.elements_spring = ElementsSpring(self)

        # rods/tubes
        self.prod = PROD(self)
        self.crod = CROD(self)
        self.conrod = CONROD(self)
        self.ptube = PTUBE(self)
        self.ctube = CTUBE(self)

        # bars
        self.cbar = CBAR(self)
        #self.cbaror = CBAROR(self)
        self.pbar = PBAR(self)
        self.pbarl = PBARL(self)
        self.properties_bar = PropertiesBar(self)

        # beams
        self.cbeam = CBEAM(self)
        self.pbeam = PBEAM(self)
        self.pbeaml = PBEAML(self)
        #: stores PBEAM, PBEAML
        self.properties_beam = PropertiesBeam(self)

        # shear
        #: stores CSHEAR
        self.cshear = CSHEAR(self)
        #: stores PSHEAR
        self.pshear = PSHEAR(self)

        # shells
        self.pshell = PSHELL(self)
        self.pcomp = PCOMP(self)
        self.pcompg = PCOMPG(self)
        self.cquad4 = CQUAD4(self)
        self.ctria3 = CTRIA3(self)
        self.ctria6 = CTRIA6(self)
        self.cquad8 = CQUAD8(self)
        self.ctriax6 = CTRIAX6(self)
        #self.cquad = CQUAD(self)
        #: stores PSHELL, PCOMP, PCOMPG
        self.properties_shell = PropertiesShell(self)
        #: stores CTRIA3, CTRIA6, CQUAD4, CQUAD8
        self.elements_shell = ElementsShell(self)

        # solids
        self.psolid = PSOLID(self)
        self.plsolid = PLSOLID(self)
        self.ctetra4 = CTETRA4(self)
        self.ctetra10 = CTETRA10(self)
        self.cpyram5 = None
        self.cpyram13 = None
        self.cpenta6 = CPENTA6(self)
        self.cpenta15 = CPENTA15(self)
        self.chexa8 = CHEXA8(self)
        self.chexa20 = CHEXA20(self)
        #: stores CTETRA4, CPENTA6, CHEXA8, CTETRA10, CPENTA15, CHEXA20
        self.elements_solid = ElementsSolid(self)
        #: stores PSOLID, PLSOLID
        self.properties_solid = PropertiesSolid(self)

        #----------------------------------------
        # mass
        self.conm1 = CONM1(self)
        self.conm2 = CONM2(self)
        #self.pmass = PMASS(self)
        #self.cmass1 = CMASS1(self)
        #self.cmass2 = CMASS2(self)
        #self.cmass3 = CMASS3(self)
        #self.cmass4 = CMASS4(self)
        #self.cmass5 = CMASS5(self)
        self.pmass = None
        self.cmass1 = None
        self.cmass2 = None
        self.cmass3 = None
        self.cmass4 = None
        self.cmass5 = None
        self.mass = Mass(self)
        #----------------------------------------
        # b-list elements
        #self.rbe2 = None
        #self.rbe3 = None
        self.cbush = CBUSH(self)
        self.pbush = PBUSH(self)
        self.cbush1d = None
        self.pbush1d = None
        self.cbush2d = None
        self.pbush2d = None

        #----------------------------------------
        # control structure
        self.elements = Elements(self)
        self.properties = Properties(self)
        #----------------------------------------

        self.mat1 = MAT1(self)
        self.mats1 = MATS1(self)
        #self.mat2 = MAT2(self)
        #self.mat2 = MAT2(self)
        #self.mat4 = MAT4(self)
        #self.mat5 = MAT5(self)
        self.mat8 = MAT8(self)
        #self.mat10 = MAT10(self)
        #self.mat11 = MAT11(self)
        self.mathp = MATHP(self)

        self.materials = Materials(self)

        # ----------------------------------------------------------------

        self.load = LOADs(self)
        self.dload = LOADs(self)
        #self.dload = defaultdict(list)
        #self.loadset = LOADSET(model)

        self.force = FORCE(self)
        self.force1 = FORCE1(self)
        self.force2 = FORCE2(self)
        self.moment = MOMENT(self)
        self.moment1 = MOMENT1(self)
        self.moment2 = MOMENT2(self)
        self.grav = GRAV(self)
        self.rforce = RFORCE(self)

        self.pload = PLOAD(self)
        self.pload1 = PLOAD1(self)
        self.pload2 = PLOAD2(self)
        #self.pload3 = PLOAD3(self)
        self.pload4 = PLOAD4(self)
        self.ploadx1 = PLOADX1(self)

        self.tload1 = TLOAD1(self)
        self.tload2 = TLOAD2(self)
        self.delay = DELAY(self)

        self.rload1 = RLOAD1(self)
        #self.rload2 = RLOAD2(self)
        self.dphase = DPHASE(self)

        self.darea = DAREA(self)

        #: stores LOAD, FORCE, MOMENT, etc.
        self.loads = Loads(self)
        # ----------------------------------------------------------------
        self.tempp1 = TEMPP1(self)
        self.temp = TEMP(self)
        self.temps = TEMPs(self)

        # ----------------------------------------------------------------
        #self.spc1 = SPC1(self)
        #self.spcadd = SPCADD(self)
        self.spc = {} #class_obj_defaultdict(SPC, model)
        self.spcd = {} #class_obj_defaultdict(SPCD, model)
        self.spc1 = {} #class_obj_defaultdict(SPC1, model)
        self.spcadd = {}
        self.mpc = {}  # the new form, not added...
        self.mpcadd = {}

        # ----------------------------------------------------------------
        #: stores PARAMs
        self.params = {}

        #: stores rigid elements (RBE2, RBE3, RJOINT, etc.)
        self.rigid_elements = {}
        #: stores PLOTELs
        self.plotels = {}

        # --------------------------- dynamic ----------------------------
        #: stores DAREA
        #self.dareas = {}
        #self.dphases = {}

        self.pbusht = {}
        self.pdampt = {}
        self.pelast = {}

        #: frequencies
        self.frequencies = {}

        # ----------------------------------------------------------------
        #: direct matrix input - DMIG
        self.dmis = {}
        self.dmigs = {}
        self.dmijs = {}
        self.dmijis = {}
        self.dmiks = {}
        self._dmig_temp = defaultdict(list)

        # ----------------------------------------------------------------
        #: SETy
        self.sets = {} # SET1, SET3
        self.asets = []
        self.bsets = []
        self.csets = []
        self.qsets = []
        self.usets = {}

        #: SExSETy
        self.se_bsets = []
        self.se_csets = []
        self.se_qsets = []
        self.se_usets = {}
        self.se_sets = {}

        # ----------------------------------------------------------------
        #: tables
        self.tables = {}
        #: random_tables
        self.random_tables = {}
        #: TABDMP1
        self.tables_sdamping = {}

        # ----------------------------------------------------------------
        #: EIGB, EIGR, EIGRL methods
        self.methods = {}
        # EIGC, EIGP methods
        self.cMethods = {}

        # ---------------------------- optimization --------------------------
        # optimization
        self.dconadds = {}
        self.dconstrs = {}
        self.desvars = {}
        self.ddvals = {}
        self.dlinks = {}
        self.dresps = {}

        self.dtable = None
        self.dequations = {}

        #: stores DVPREL1, DVPREL2...might change to DVxRel
        self.dvprels = {}
        self.dvmrels = {}
        self.dvcrels = {}
        self.dvgrids = {}
        self.doptprm = None
        self.dscreen = {}

        # ------------------------- nonlinear defaults -----------------------
        #: stores NLPCI
        self.nlpcis = {}
        #: stores NLPARM
        self.nlparms = {}
        #: stores TSTEPs
        self.tsteps = {}
        #: stores TSTEPNL
        self.tstepnls = {}
        #: stores TF
        self.transfer_functions = {}
        #: stores DELAY
        self.delays = {}

        # --------------------------- aero defaults --------------------------
        # aero cards
        #: stores CAEROx
        self.caeros = {}
        #: stores PAEROx
        self.paeros = {}
        # stores MONPNT1
        self.monitor_points = []

        #: stores AECOMP
        self.aecomps = {}
        #: stores AEFACT
        self.aefacts = {}
        #: stores AELINK
        self.aelinks = {}
        #: stores AELIST
        self.aelists = {}
        #: stores AEPARAM
        self.aeparams = {}
        #: stores AESURF
        self.aesurf = {}
        #: stores AESURFS
        self.aesurfs = {}
        #: stores AESTAT
        self.aestats = {}
        #: stores CSSCHD
        self.csschds = {}

        #: store SPLINE1,SPLINE2,SPLINE4,SPLINE5
        self.splines = {}

        # ------ SOL 144 ------
        #: stores AEROS
        self.aeros = None

        #: stores TRIM
        self.trims = {}

        #: stores DIVERG
        self.divergs = {}

        # ------ SOL 145 ------
        #: stores AERO
        self.aero = None

        #: stores FLFACT
        self.flfacts = {}  #: .. todo:: can this be simplified ???
        #: stores FLUTTER
        self.flutters = {}
        #: mkaeros
        self.mkaeros = []

        # ------ SOL 146 ------
        #: stores GUST cards
        self.gusts = {}
        # ------------------------- thermal defaults -------------------------
        # BCs
        #: stores thermal boundary conditions - CONV,RADBC
        self.bcs = {}  # e.g. RADBC

        #: stores PHBDY
        self.phbdys = {}
        #: stores convection properties - PCONV, PCONVM ???
        self.convection_properties = {}
        #: stores TEMPD
        self.tempds = {}

        # -------------------------contact cards-------------------------------
        self.bcrparas = {}
        self.bctadds = {}
        self.bctparas = {}
        self.bctsets = {}
        self.bsurf = {}
        self.bsurfs = {}

        # ---------------------------------------------------------------------

        self._type_to_id_map = defaultdict(list)
        self._solmap_to_value = {
            'NONLIN': 101,  # 66 -> 101 per Reference 1
            'SESTATIC': 101,
            'SESTATICS': 101,
            'SEMODES': 103,
            'BUCKLING': 105,
            'SEBUCKL': 105,
            'NLSTATIC': 106,
            'SEDCEIG': 107,
            'SEDFREQ': 108,
            'SEDTRAN': 109,
            'SEMCEIG': 110,
            'SEMFREQ': 111,
            'SEMTRAN': 112,
            'CYCSTATX': 114,
            'CYCMODE': 115,
            'CYCBUCKL': 116,
            'CYCFREQ': 118,
            'NLTRAN': 129,
            'AESTAT': 144,
            'FLUTTR': 145,
            'SEAERO': 146,
            'NLSCSH': 153,
            'NLTCSH': 159,
            'DBTRANS': 190,
            'DESOPT': 200,

            # guessing
            #'CTRAN' : 115,
            'CFREQ' : 118,

            # solution 200 names
            'STATICS': 101,
            'MODES': 103,
            'BUCK': 105,
            'DFREQ': 108,
            'MFREQ': 111,
            'MTRAN': 112,
            'DCEIG': 107,
            'MCEIG': 110,
            #'HEAT'     : None,
            #'STRUCTURE': None,
            #'DIVERGE'  : None,
            'FLUTTER': 145,
            'SAERO': 146,
        }

        self.rsolmap_to_str = {
            66: 'NONLIN',
            101: 'SESTSTATIC',  # linear static
            103: 'SEMODES',  # modal
            105: 'BUCKLING',  # buckling
            106: 'NLSTATIC',  # non-linear static
            107: 'SEDCEIG',  # direct complex frequency response
            108: 'SEDFREQ',  # direct frequency response
            109: 'SEDTRAN',  # direct transient response
            110: 'SEMCEIG',  # modal complex eigenvalue
            111: 'SEMFREQ',  # modal frequency response
            112: 'SEMTRAN',  # modal transient response
            114: 'CYCSTATX',
            115: 'CYCMODE',
            116: 'CYCBUCKL',
            118: 'CYCFREQ',
            129: 'NLTRAN',  # nonlinear transient
            144: 'AESTAT',  # static aeroelastic
            145: 'FLUTTR',  # flutter/aeroservoelastic
            146: 'SEAERO',  # dynamic aeroelastic
            153: 'NLSCSH',  # nonlinear static thermal
            159: 'NLTCSH',  # nonlinear transient thermal
            190: 'DBTRANS',
            200: 'DESOPT',  # optimization
        }

Example 36

Project: pyNastran
Source File: ugrid_reader.py
View license
    def _write_faces(self, faces_filename):
        """writes an OpenFOAM faces file"""
        nhexas = self.hexas.shape[0]
        npenta6s = self.penta6s.shape[0]
        npenta5s = self.penta5s.shape[0]
        ntets = self.tets.shape[0]

        nquad_faces = nhexas * 6 + npenta5s + npenta6s * 3
        ntri_faces = ntets * 4 + npenta5s * 4 + npenta6s * 2
        nfaces = ntri_faces + nquad_faces
        assert nfaces > 0, nfaces

        #tri_face_to_eids = ones((nt, 2), dtype='int32')
        tri_face_to_eids = defaultdict(list)

        #quad_face_to_eids = ones((nq, 2), dtype='int32')
        quad_face_to_eids = defaultdict(list)

        tri_faces = zeros((ntri_faces, 3), dtype='int32')
        quad_faces = zeros((nquad_faces, 4), dtype='int32')

        with open(faces_filename, 'wb') as faces_file:
            faces_file.write('\n\n')
            #faces_file.write('%i\n' % (nnodes))
            faces_file.write('(\n')

            it_start = {}
            iq_start = {}
            min_eids = {}
            it = 0
            iq = 0
            eid = 1
            it_start[1] = it
            iq_start[1] = iq
            min_eids[eid] = self.tets
            for element in self.tets - 1:
                (n1, n2, n3, n4) = element
                face1 = [n3, n2, n1]
                face2 = [n1, n2, n4]
                face3 = [n4, n3, n1]
                face4 = [n2, n3, n4]

                tri_faces[it, :] = face1
                tri_faces[it+1, :] = face2
                tri_faces[it+2, :] = face3
                tri_faces[it+3, :] = face4

                face1.sort()
                face2.sort()
                face3.sort()
                face4.sort()
                tri_face_to_eids[tuple(face1)].append(eid)
                tri_face_to_eids[tuple(face2)].append(eid)
                tri_face_to_eids[tuple(face3)].append(eid)
                tri_face_to_eids[tuple(face4)].append(eid)
                it += 4
                eid += 1

            it_start[2] = it
            iq_start[2] = iq
            min_eids[eid] = self.hexas
            print('HEXA it=%s iq=%s' % (it, iq))
            for element in self.hexas-1:
                (n1, n2, n3, n4, n5, n6, n7, n8) = element

                face1 = [n1, n2, n3, n4]
                face2 = [n2, n6, n7, n3]
                face3 = [n6, n5, n8, n7]
                face4 = [n5, n1, n4, n8]
                face5 = [n4, n3, n7, n8]
                face6 = [n5, n6, n2, n1]

                quad_faces[iq, :] = face1
                quad_faces[iq+1, :] = face2
                quad_faces[iq+2, :] = face3
                quad_faces[iq+3, :] = face4
                quad_faces[iq+4, :] = face5
                quad_faces[iq+5, :] = face6

                face1.sort()
                face2.sort()
                face3.sort()
                face4.sort()
                face5.sort()
                face6.sort()

                quad_face_to_eids[tuple(face1)].append(eid)
                quad_face_to_eids[tuple(face2)].append(eid)
                quad_face_to_eids[tuple(face3)].append(eid)
                quad_face_to_eids[tuple(face4)].append(eid)
                quad_face_to_eids[tuple(face5)].append(eid)
                quad_face_to_eids[tuple(face6)].append(eid)
                iq += 6
                eid += 1

            it_start[3] = it
            iq_start[3] = iq
            min_eids[eid] = self.penta5s
            print('PENTA5 it=%s iq=%s' % (it, iq))
            for element in self.penta5s-1:
                (n1, n2, n3, n4, n5) = element

                face1 = [n2, n3, n5]
                face2 = [n1, n2, n5]
                face3 = [n4, n1, n5]
                face4 = [n5, n3, n4]
                face5 = [n4, n3, n2, n1]

                tri_faces[it, :] = face1
                tri_faces[it+1, :] = face2
                tri_faces[it+2, :] = face3
                tri_faces[it+3, :] = face4
                quad_faces[iq, :] = face5

                face1.sort()
                face2.sort()
                face3.sort()
                face4.sort()
                face5.sort()

                tri_face_to_eids[tuple(face1)].append(eid)
                tri_face_to_eids[tuple(face2)].append(eid)
                tri_face_to_eids[tuple(face3)].append(eid)
                tri_face_to_eids[tuple(face4)].append(eid)
                quad_face_to_eids[tuple(face5)].append(eid)

                it += 4
                iq += 1
                eid += 1

            it_start[4] = it
            iq_start[4] = iq
            min_eids[eid] = self.penta6s
            print('PENTA6 it=%s iq=%s' % (it, iq))
            for element in self.penta6s-1:
                (n1, n2, n3, n4, n5, n6) = element

                face1 = [n1, n2, n3]
                face2 = [n5, n4, n6]
                face3 = [n2, n5, n6, n3]
                face4 = [n4, n1, n3, n6]
                face5 = [n4, n5, n2, n1]

                tri_faces[it, :] = face1
                tri_faces[it+1, :] = face2
                quad_faces[iq, :] = face3
                quad_faces[iq+1, :] = face4
                quad_faces[iq+2, :] = face5

                face1.sort()
                face2.sort()
                face3.sort()
                face4.sort()
                face5.sort()

                tri_face_to_eids[tuple(face1)].append(eid)
                tri_face_to_eids[tuple(face2)].append(eid)
                quad_face_to_eids[tuple(face3)].append(eid)
                quad_face_to_eids[tuple(face4)].append(eid)
                quad_face_to_eids[tuple(face5)].append(eid)
                it += 2
                iq += 3
                eid += 1

            # find the unique faces
            tri_faces_sort = deepcopy(tri_faces)
            quad_faces_sort = deepcopy(quad_faces)
            #print('t0', tri_faces_sort[0, :])
            #print('t1', tri_faces_sort[1, :])

            print('nt=%s nq=%s' % (ntri_faces, nquad_faces))
            tri_faces_sort.sort(axis=1)
            #for i, tri in enumerate(tri_faces):
                #assert tri[2] > tri[0], 'i=%s tri=%s' % (i, tri)
            #print('*t0', tri_faces_sort[0, :])
            #print('*t1', tri_faces_sort[1, :])

            quad_faces_sort.sort(axis=1)
            #for i, quad in enumerate(quad_faces):
                #assert quad[3] > quad[0], 'i=%s quad=%s' % (i, quad)


            #iq_start_keys = iq_start.keys()
            #it_start_keys = it_start.keys()
            #iq_start_keys.sort()
            #it_start_keys.sort()

            face_to_eid = []

            eid_keys = min_eids.keys()
            eid_keys.sort()

            type_mapper = {
                1 : 'tets',
                2 : 'hexas',
                3 : 'penta5s',
                4 : 'penta6s',
            }
            print("eid_keys =", eid_keys)
            for face, eids in iteritems(tri_face_to_eids):
                if len(eids) == 1:
                    #if it's a boundary face, wer're fine, otherwise, error...
                    #print('*face=%s eids=%s' % (face, eids))
                    #pid = lookup from quads/tris
                    eid = eids[0]
                    owner = eid
                    neighbor = -1
                    continue
                    #raise RuntimeError()

                e1, e2 = eids
                i1 = searchsorted(eid_keys, e1)
                i2 = searchsorted(eid_keys, e2)

                if i1 == 1: # tet
                    it1 = (e1-1) * 4
                    it2 = (e1-1) * 4 + 4
                    faces1_sort = tri_faces_sort[it1:it2, :]
                    faces1_unsorted = tri_faces[it1:it2, :]

                    #print("faces1 = \n", faces1_sort, '\n')

                    # figure out irow; 3 for the test case
                    face = array(face, dtype='int32')

                    #print('face  = %s' % face)
                    #print('face3 = %s' % faces1_sort[3, :])

                    if allclose(face, faces1_sort[0, :]):
                        n1 = 0
                    elif allclose(face, faces1_sort[1, :]):
                        n1 = 1
                    elif allclose(face, faces1_sort[2, :]):
                        n1 = 2
                    elif allclose(face, faces1_sort[3, :]):
                        n1 = 3
                    else:
                        raise RuntimeError('cant find face=%s in faces for eid1=%s' % (face, e1))

                    if allclose(face, faces1_unsorted[n1, :]):
                        owner = e1
                        neighbor = e2
                    else:
                        owner = e2
                        neighbor = e1
                    face_new = faces1_unsorted[n1, :]

                elif i1 == 2:  # CHEXA
                    iq1 = iq_start[2]
                    iq2 = iq1 + 6

                elif i1 == 3:  # CPENTA5
                    #e1_new = e1 - eid_keys[2]
                    iq1 = iq_start[3]
                    iq2 = iq1 + 1
                    it1 = it_start[3]
                    it2 = it1 + 4
                elif i1 == 4:  # CPENTA6
                    iq1 = iq_start[4]
                    iq2 = iq1 + 3
                    it1 = it_start[4]
                    it2 = it1 + 2
                else:
                    raise NotImplementedError('This is a %s and is not supported' % type_mapper[i1])

                # do we need to check this???
                if 0:
                    if i2 == 1: # tet

                        it1 = it_start_keys[i1]
                        it2 = it1 + 4
                        faces2 = tri_faces_sort[it1:it2, :]
                        #print('face=%s eids=%s' % (face, eids))
                        #print("faces2 = \n", faces2)
                        # spits out 3
                    else:
                        asdf
                #type1 = type_mapper[i1]
                #type2 = type_mapper[i2]
                #if type1:
            faces_file.write(')\n')
        return

Example 37

Project: pyNastran
Source File: op2_results.py
View license
def get_nodal_averaged_stress(model, eid_to_nid_map, isubcase, options=None):
    """
    Supports:
    - plateStress
    - solidStress
    - compositePlateStress (NA)

    options = {
        'mode': 'derive/avg',  # derive/avg, avg/derive
        'layers' : 'max',      # max, min, avg
       #'ilayers' : None,      # None or [1, 2, ... N]
        'location' : 'node',   # node, centroid
    }

    TODO: this isn't part of OP2() because it's not done
    TODO: doesn't support transient, frequency, real/imaginary data
    TODO: add 'sum', 'difference' for 'layers'?
    TODO: hasn't been tested
    """
    raise NotImplementedError()
    assert options['mode'] in ['derive/avg', 'avg/derive'], options['mode']
    assert options['layers'] in ['max', 'min', 'avg'], options['layers']
    assert options['location'] in ['node', 'centroid'], options['location']
    #assert options['mode'] in ['derive/avg', 'avg/derive'], options['mode']
    #assert options['mode'] in ['derive/avg', 'avg/derive'], options['mode']

    layer_map = {
        'max': amax,
        'min': amin,
        'avg': mean,
        #'sum': sum,
    }
    mode = options['mode']
    layer_func = layer_map[options['layers']]
    location = options['location']

    results = {
        'x': defaultdict(list),
        'y': defaultdict(list),
        'z': defaultdict(list),
        'xy': defaultdict(list),
        'yz': defaultdict(list),
        'xz': defaultdict(list),
        'maxP': defaultdict(list),
        'minP': defaultdict(list),
        'vonMises': defaultdict(list),  # 3D von mises
        'vonMises2D': defaultdict(list),  # 3D von mises
    }
    if isubcase in model.solidStress:
        case = model.solidStress[isubcase]
        if case.is_von_mises():
            vmWord = 'vonMises'
        else:
            vmWord = 'maxShear'
        assert vmWord == 'vonMises', vmWord

        if location == 'node':  # derive/avg
            for eid in case.ovmShear:
                node_ids = eid_to_nid_map[eid]
                for nid in node_ids:
                    results['x'   ][nid].append(case.oxx[eid][nid])
                    results['y'   ][nid].append(case.oyy[eid][nid])
                    results['z'   ][nid].append(case.ozz[eid][nid])
                    results['xy'  ][nid].append(case.txy[eid][nid])
                    results['yz'  ][nid].append(case.tyz[eid][nid])
                    results['xz'  ][nid].append(case.txz[eid][nid])
                    results['maxP'][nid].append(case.o1[eid][nid])
                    results['minP'][nid].append(case.o3[eid][nid])
                    results['vonMises'][nid].append(case.ovmShear[eid][nid])
        elif location == 'centroid':
            for eid in case.ovmShear:  # derive/avg
                node_ids = eid_to_nid_map[eid]
                for nid in node_ids:
                    results['x'   ][nid].append(case.oxx[eid]['CENTER'])
                    results['y'   ][nid].append(case.oyy[eid]['CENTER'])
                    results['z'   ][nid].append(case.ozz[eid]['CENTER'])
                    results['xy'  ][nid].append(case.txy[eid]['CENTER'])
                    results['yz'  ][nid].append(case.tyz[eid]['CENTER'])
                    results['xz'  ][nid].append(case.txz[eid]['CENTER'])
                    results['maxP'][nid].append(case.o1[eid]['CENTER'])
                    results['minP'][nid].append(case.o3[eid]['CENTER'])
                    results['vonMises'][nid].append(case.ovmShear[eid]['CENTER'])
        else:
            raise RuntimeError('location=%r' % location)

    if isubcase in model.plateStress:
        case = model.plateStress[isubcase]
        if case.nonlinear_factor is not None: # transient
            return
        if case.is_von_mises():
            vmWord = 'vonMises'
        else:
            vmWord = 'maxShear'

        assert vmWord == 'vonMises', vmWord
        if location == 'node':
            for eid in case.ovmShear:
                node_ids = eid_to_nid_map[eid]
                eType = case.eType[eid]
                if eType in ['CQUAD4', 'CQUAD8']:
                    #cen = 'CEN/%s' % eType[-1]
                    assert len(node_ids[:4]) == 4, len(node_ids[:4])
                    if node_ids[0] in case.oxx[eid]:
                        # bilinear
                        for nid in node_ids[:4]:
                            results['x'   ][nid].append(layer_func(case.oxx[eid][nid]))
                            results['y'   ][nid].append(layer_func(case.oyy[eid][nid]))
                            results['xy'  ][nid].append(layer_func(case.txy[eid][nid]))
                            results['maxP'][nid].append(layer_func(case.majorP[eid][nid]))
                            results['minP'][nid].append(layer_func(case.minorP[eid][nid]))
                            results['vonMises'][nid].append(layer_func(case.ovmShear[eid][nid]))
                    else:
                        #cen = 'CEN/%s' % eType[-1]
                        cen = 0
                        for nid in node_ids[:4]:
                            results['x'   ][nid].append(layer_func(case.oxx[eid][cen]))
                            results['y'   ][nid].append(layer_func(case.oyy[eid][cen]))
                            results['xy'  ][nid].append(layer_func(case.txy[eid][cen]))
                            results['maxP'][nid].append(layer_func(case.majorP[eid][cen]))
                            results['minP'][nid].append(layer_func(case.minorP[eid][cen]))
                            results['vonMises'][nid].append(layer_func(case.ovmShear[eid][cen]))
                elif eType in ['CTRIA3', 'CTRIA6']:
                    #cen = 'CEN/%s' % eType[-1]
                    cen = 0
                    assert len(node_ids[:3]) == 3, len(node_ids[:3])
                    for nid in node_ids[:3]:
                        results['x'   ][nid].append(layer_func(case.oxx[eid][cen]))
                        results['y'   ][nid].append(layer_func(case.oyy[eid][cen]))
                        results['xy'  ][nid].append(layer_func(case.txy[eid][cen]))
                        results['maxP'][nid].append(layer_func(case.majorP[eid][cen]))
                        results['minP'][nid].append(layer_func(case.minorP[eid][cen]))
                        results['vonMises'][nid].append(layer_func(case.ovmShear[eid][cen]))
                else:
                    raise NotImplementedError(eType)
        elif location == 'centroid':
            for eid in case.ovmShear:
                node_ids = eid_to_nid_map[eid]
                eType = case.eType[eid]
                if eType in ['CQUAD4', 'CQUAD8', 'CTRIA3', 'CTRIA6']:
                    #cen = 'CEN/%s' % eType[-1]
                    # cen  0
                    pass
                else:
                    raise NotImplementedError(eType)
                for nid in node_ids:
                    results['x'   ][nid].append(layer_func(case.oxx[eid][cen]))
                    results['y'   ][nid].append(layer_func(case.oyy[eid][cen]))
                    results['xy'  ][nid].append(layer_func(case.txy[eid][cen]))
                    results['maxP'][nid].append(layer_func(case.majorP[eid][cen]))
                    results['minP'][nid].append(layer_func(case.minorP[eid][cen]))
                    results['vonMises'][nid].append(layer_func(case.ovmShear[eid][cen]))
        else:
            raise RuntimeError('location=%r' % location)

    if isubcase in model.compositePlateStress:
        case = model.compositePlateStress[isubcase]
        if case.nonlinear_factor is not None: # transient
            return
        if case.is_von_mises():
            vmWord = 'vonMises'
        else:
            vmWord = 'maxShear'

        assert vmWord == 'vonMises', vmWord
        if location == 'node':
            for eid in case.ovmShear:
                node_ids = eid_to_nid_map[eid]
                eType = case.eType[eid]
                if eType in ['CQUAD4', 'CQUAD8']:
                    assert len(node_ids[:4]) == 4, len(node_ids[:4])
                    for nid in node_ids[:4]:
                        results['x'   ][nid].append(layer_func(case.o11[eid]))
                        results['y'   ][nid].append(layer_func(case.o22[eid]))
                        results['xy'  ][nid].append(layer_func(case.t12[eid]))
                        results['maxP'][nid].append(layer_func(case.majorP[eid]))
                        results['minP'][nid].append(layer_func(case.minorP[eid]))
                        results['vonMises'][nid].append(layer_func(case.ovmShear[eid]))
                elif eType in ['CTRIA3', 'CTRIA6']:
                    cen = 'CEN/%s' % eType[-1]
                    assert len(node_ids[:3]) == 3, len(node_ids[:3])
                    for nid in node_ids[:3]:
                        results['x'   ][nid].append(layer_func(case.o11[eid]))
                        results['y'   ][nid].append(layer_func(case.o22[eid]))
                        results['xy'  ][nid].append(layer_func(case.t12[eid]))
                        results['maxP'][nid].append(layer_func(case.majorP[eid]))
                        results['minP'][nid].append(layer_func(case.minorP[eid]))
                        results['vonMises'][nid].append(layer_func(case.ovmShear[eid]))
                else:
                    raise NotImplementedError(eType)
        elif location == 'centroid':
            for eid in case.ovmShear:
                node_ids = eid_to_nid_map[eid]
                eType = case.eType[eid]
                for nid in node_ids:
                    results['x'   ][nid].append(layer_func(case.o11[eid]))
                    results['y'   ][nid].append(layer_func(case.o22[eid]))
                    results['xy'  ][nid].append(layer_func(case.t12[eid]))
                    results['maxP'][nid].append(layer_func(case.majorP[eid]))
                    results['minP'][nid].append(layer_func(case.minorP[eid]))
                    results['vonMises'][nid].append(layer_func(case.ovmShear[eid]))
        else:
            raise RuntimeError('location=%r' % location)

    if mode == 'derive/avg':
        for result_name, result in iteritems(results):
            for nid, datai in iteritems(result):
                results[result_name][nid] = mean(datai)
    elif mode == 'avg/derive':
        for result_name in ['x', 'y', 'z', 'xy', 'yz', 'xz']:
            for nid, datai in iteritems(results[result_name]):
                results[result_name][nid] = mean(datai)

        for nid in results['maxP']:
            oxx = results['x'][nid]
            oyy = results['y'][nid]
            ozz = results['z'][nid]

            txy = results['xy'][nid]
            tyz = results['yz'][nid]
            txz = results['xz'][nid]

            if not isinstance(ozz, float):
                ozz = 0.
            if not isinstance(txy, float):
                txy = 0.
            if not isinstance(tyz, float):
                tyz = 0.
            if not isinstance(txz, float):
                txz = 0.

            # 3D
            A = array([
                [oxx, txy, txz],
                [ 0., oyy, tyz],
                [ 0.,  0., ozz],
            ])
            eigs = eigvalsh(A, UPLO='U')
            maxP = eigs.max()
            minP = eigs.min()
            results['maxP'][nid] = maxP
            results['minP'][nid] = minP

            # 2D
            A2 = array([
                [oxx, txy],
                [ 0., oyy],
            ])
            eigs2 = eigvalsh(A2, UPLO='U')
            #maxP2 = eigs2.max()
            #minP2 = eigs2.min()

            results['vonMises'][nid] = vonMises3D(*eigs)
            results['vonMises2D'][nid] = vonMises2D(*eigs2)
    else:
        raise RuntimeError('mode=%r' % mode)
    return results

Example 38

Project: regulations-scraper
Source File: aggregates.py
View license
def reducefn(key, documents):
    from collections import defaultdict
    import datetime

    def min_date(*args):
        a = [arg for arg in args if arg is not None]
        if not a:
            return None
        else:
            return min(a)

    def max_date(*args):
        a = [arg for arg in args if arg is not None]
        if not a:
            return None
        else:
            return max(a)
    
    # Mongo loses track of the tuple-vs-list distinction, so fix that if necessary so dict() doesn't break
    def tf_dict(l):
        if not l:
            return {}
        try:
            return dict(l)
        except TypeError:
            return dict([(tuple(k) if type(k) == list else k, v) for k, v in l])

    ### COLLECTION: dockets ###
    if key[0] == 'dockets':
        out = {
            'count': 0,
            'expanded_comment_count': 0,
            'type_breakdown': defaultdict(int),
            'doc_info': {
                'fr_docs': [],
                'supporting_material': [],
                'other': []
            },
            'weeks': defaultdict(int),
            'date_range': [None, None],
            'text_entities': defaultdict(int),
            'submitter_entities': defaultdict(int)
        }
        if documents:
            out['date_range'] = documents[0]['date_range']

        for value in documents:
            out['count'] += value['count']
            out['expanded_comment_count'] += value.get('expanded_comment_count', 0)
            
            for doc_type, count in value['type_breakdown'].iteritems():
                out['type_breakdown'][doc_type] += count
            
            for doc_type in ['fr_docs', 'supporting_material', 'other']:
                out['doc_info'][doc_type].extend(value['doc_info'][doc_type])
            
            for week, count in tf_dict(value['weeks']).iteritems():
                out['weeks'][week] += count

            for entity, count in value['text_entities'].iteritems():
                out['text_entities'][entity] += count

            for entity, count in value['submitter_entities'].iteritems():
                out['submitter_entities'][entity] += count

            out['date_range'][0] = min_date(out['date_range'][0], value['date_range'][0])
            out['date_range'][1] = max_date(out['date_range'][1], value['date_range'][1])

        out['doc_info']['fr_docs'] = sorted(out['doc_info']['fr_docs'], key=lambda x: x['date'], reverse=True)
        out['doc_info']['supporting_material'] = sorted(out['doc_info']['supporting_material'], key=lambda x: x['date'], reverse=True)[:3]
        out['doc_info']['other'] = sorted(out['doc_info']['other'], key=lambda x: x['date'], reverse=True)[:3]

        out['weeks'] = sorted(out['weeks'].items(), key=lambda x: x[0][0] if x[0] else datetime.date.min.isoformat())
        return out

    ### COLLECTION: agencies ###
    if key[0] == 'agencies':
        out = {
            'count': 0,
            'expanded_comment_count': 0,
            'type_breakdown': defaultdict(int),
            'months': defaultdict(int),
            'date_range': [None, None],
            'text_entities': defaultdict(int),
            'submitter_entities': defaultdict(int)
        }
        if documents:
            out['date_range'] = documents[0]['date_range']

        for value in documents:
            out['count'] += value['count']
            out['expanded_comment_count'] += value.get('expanded_comment_count', 0)
            
            for doc_type, count in value['type_breakdown'].iteritems():
                out['type_breakdown'][doc_type] += count
                        
            for month, count in tf_dict(value['months']).iteritems():
                out['months'][month] += count

            for entity, count in value['text_entities'].iteritems():
                out['text_entities'][entity] += count

            for entity, count in value['submitter_entities'].iteritems():
                out['submitter_entities'][entity] += count

            out['date_range'][0] = min_date(out['date_range'][0], value['date_range'][0])
            out['date_range'][1] = max_date(out['date_range'][1], value['date_range'][1])

        out['months'] = sorted(out['months'].items(), key=lambda x: x[0] if x[0] else datetime.date.min.isoformat())
        return out

    ### COLLECTION: docs ###
    if key[0] == 'docs':
        out = {
            'count': 0,
            'expanded_comment_count': 0,
            'weeks': defaultdict(int),
            'date_range': [None, None],
            'text_entities': defaultdict(int),
            'submitter_entities': defaultdict(int),
            'recent_comments': []
        }
        if documents:
            out['date_range'] = documents[0]['date_range']

        for value in documents:
            out['count'] += value['count']
            out['expanded_comment_count'] += value.get('expanded_comment_count', 0)
            
            for week, count in tf_dict(value['weeks']).iteritems():
                out['weeks'][week] += count

            for entity, count in value['text_entities'].iteritems():
                out['text_entities'][entity] += count

            for entity, count in value['submitter_entities'].iteritems():
                out['submitter_entities'][entity] += count

            out['recent_comments'].extend(value['recent_comments'])

            out['date_range'][0] = min_date(out['date_range'][0], value['date_range'][0])
            out['date_range'][1] = max_date(out['date_range'][1], value['date_range'][1])

        out['recent_comments'] = sorted(out['recent_comments'], key=lambda x: x['date'], reverse=True)[:5]
        out['weeks'] = sorted(out['weeks'].items(), key=lambda x: x[0][0] if x[0] else datetime.date.min.isoformat())
        return out

    ### COLLECTION: entities ###
    if key[0] == 'entities':
        out = {
            'text_mentions': {
                'count': 0,
                'agencies': defaultdict(int),
                'dockets': defaultdict(int),
                'months': defaultdict(int),
                'agencies_by_month': defaultdict(lambda: defaultdict(int)),
                'date_range': [None, None]
            },
            'submitter_mentions': {
                'count': 0,
                'agencies': defaultdict(int),
                'dockets': defaultdict(int),
                'months': defaultdict(int),
                'agencies_by_month': defaultdict(lambda: defaultdict(int)),
                'date_range': [None, None],
                'recent_comments': []
            }
        }
        for value in documents:
            for mention_type in ['text_mentions', 'submitter_mentions']:
                out[mention_type]['count'] += value[mention_type]['count']
                for agency, count in value[mention_type]['agencies'].iteritems():
                    if value[mention_type]['agencies'][agency]:
                        out[mention_type]['agencies'][agency] += value[mention_type]['agencies'][agency]
                for docket, count in value[mention_type]['dockets'].iteritems():
                    if value[mention_type]['dockets'][docket]:
                        out[mention_type]['dockets'][docket] += value[mention_type]['dockets'][docket]
                months_dict = tf_dict(value[mention_type]['months'])
                for month, count in months_dict.iteritems():
                    if months_dict[month]:
                        out[mention_type]['months'][month] += months_dict[month]
                for agency, agency_months in value[mention_type]['agencies_by_month'].items():
                    agency_months_dict = tf_dict(value[mention_type]['agencies_by_month'][agency])
                    for month, count in agency_months_dict.iteritems():
                        if agency_months_dict[month]:
                            out[mention_type]['agencies_by_month'][agency][month] += agency_months_dict[month]
                out[mention_type]['date_range'][0] = min_date(out[mention_type]['date_range'][0], value[mention_type]['date_range'][0])
                out[mention_type]['date_range'][1] = max_date(out[mention_type]['date_range'][1], value[mention_type]['date_range'][1])

            out['submitter_mentions']['recent_comments'].extend(value['submitter_mentions']['recent_comments'])

        for mention_type in ['text_mentions', 'submitter_mentions']:
            out[mention_type]['months'] = sorted(out[mention_type]['months'].items(), key=lambda x: x[0] if x[0] else datetime.date.min.isoformat())
            for agency in out[mention_type]['agencies_by_month'].keys():
                out[mention_type]['agencies_by_month'][agency] = sorted(out[mention_type]['agencies_by_month'][agency].items(), key=lambda x: x[0] if x[0] else datetime.date.min.isoformat())
            # hack to make this defaultdict picklable
            out[mention_type]['agencies_by_month'] = tf_dict(out[mention_type]['agencies_by_month'])

        out['submitter_mentions']['recent_comments'] = sorted(out['submitter_mentions']['recent_comments'], key=lambda x: x['date'], reverse=True)[:5]
        out['count'] = out['text_mentions']['count'] + out['submitter_mentions']['count']
        
        return out

Example 39

View license
def parse_year(year, doctype):
    fr_docs = []
    files = defaultdict(dict)

    rows = year('td[rowspan] table[cellpadding="4"] tr')
    for row in rows.items():
        if len(row.find('b.blue')):
            # this is an FR document, not a header or whatever
            cells = row.find('td')
            doc_link = cells.eq(0).find('a[href]')
            details = cells.eq(2)

            detail_titles = details.find('b i')

            comment_link = details.find('i a')

            all_links = details.find('a')

            doc = {
                'id': doc_link.html(),
                'date': cells.eq(1).html(),
                'title': details('b.blue').text(),
                'doctype': doctype,
                'details': {},
                'attachments': [],
                'file_info': OrderedDict()
            }

            for t in detail_titles.items():
                label = t.html()
                if not label:
                    continue
                label = label.strip().rstrip(':')
                
                if label == "See also":
                    links = [pq(x) for x in until_break_or_end(t.parent()) if getattr(x, 'tag', None) == 'a']
                    see_also = []
                    for l in links:
                        if l.text().lower() in ('comments', 'comments received'):
                            # this is a link to a comments page
                            purl = l.attr('href')
                            number = get_file_number(purl)
                            if number:
                                if number not in doc['file_info']:
                                    doc['file_info'][number] = {}
                                doc['file_info'][number]['url'] = canonicalize_url(purl)
                        else:
                            see_also.append({'label': l.text(), 'url': canonicalize_url(l.attr('href'))})
                    doc['details'][label] = see_also
                elif label == "Additional Materials":
                    materials = []
                    for el in t.parent().nextAll():
                        tag = getattr(el, 'tag', None)
                        if tag in ('i', 'b'):
                            # we're done
                            break
                        elif tag == 'a':
                            mlink = pq(el)
                            mlabel = mlink.text()

                            if mlabel == "Federal Register version":
                                # we're done again
                                break

                            materials.append({
                                'url': canonicalize_url(mlink.attr('href')),
                                'label': mlabel
                            })
                    doc['details'][label] = materials
                else:
                    bold = t.parent()
                    if bold.hasClass('blue'):
                        # this isn't a detail label at all, but rather part of the title; skip it
                        continue
                    
                    next = next_node(bold)[0]
                    if hasattr(next, "strip"):
                        text = next.strip()

                        if label.startswith("File No"):
                            text = next_node(t.parent())[0].strip()
                            numbers = re.findall(r"((S\d+-)?\d+-\d+)", text)
                            for number in numbers:
                                doc['file_info'][number[0]] = {}
                        else:
                            doc['details'][label] = text
                    else:
                        # this is a weird one that we're not going to try and handle
                        pass

            if len(comment_link):
                for cl in comment_link:
                    pcl = pq(cl)
                    if pcl.html() == "are available":
                        purl = pcl.attr('href')
                        # find the file number in the URL
                        number = get_file_number(purl)
                        if number:
                            if number not in doc['file_info']:
                                doc['file_info'][number] = {}
                            doc['file_info'][number]['url'] = canonicalize_url(purl)

            for link in all_links:
                plink = pq(link)
                link_label = plink.text().strip()
                if link_label == "HTML":
                    # found a new-style one
                    # walk backwards until we get to an italic
                    i = prev_x_or_first(plink, set(['b', 'i']))
                    # now walk forwards until the end
                    line = pq([i] + [el for el in until_break_or_end(i) if str(el).strip()])
                    
                    # sanity check
                    if line[0].text().rstrip(":") == "Federal Register":
                        # we're good

                        fr_cite = pq(line[1]).text()
                        if "FR" in fr_cite:
                            # the second block generally should have the FR number
                            doc['details']['Federal Register Citation'] = re.sub(r"(^[\s\(]+)|([\s\):]+$)", "", fr_cite)
                        else:
                            fr_cite = None

                        # the FR number is in the HTML URL
                        fr_match = re.findall(r".*federalregister.gov.*/(\d{4}-\d+)/.+", plink.attr('href'))
                        if fr_match:
                            doc['details']['Federal Register Number'] = fr_match[0]
                        
                        attachment = {
                            'title': ('Federal Register (%s)' % doc['details']['Federal Register Citation']) if fr_cite else "Federal Register version",
                            'views': []
                        }
                        fr_links = [tag for tag in line if getattr(tag, 'tag', None) == 'a']
                        for fl in fr_links:
                            pfl = pq(fl)
                            ltype = pfl.text().strip().lower()
                            attachment['views'].append({
                                'url': canonicalize_url(pfl.attr('href')),
                                'type': LINK_TYPES[ltype] if ltype in LINK_TYPES else pfl.attr('href').split('.')[-1]
                            })

                        doc['attachments'].append(attachment)
                    else:
                        print line[0].text()
                        assert False, "What strange sorcery is this? Expected FR link"
                elif link_label in ("Federal Register version", "Federal Register PDF"):
                    # found an old-style one
                    doc['attachments'].append({
                        'title': 'Federal Register version',
                        'views': [{'url': canonicalize_url(plink.attr('href')), 'type': 'pdf'}]
                    })

            if len(doc_link):
                doc['url'] = canonicalize_url(doc_link.attr('href'))
            else:
                # there's some weird old stuff that has things split into multiple pages
                attachment = {
                    'title': 'Document pages',
                    'views': []
                }
                for bold in details.find('b'):
                    pbold = pq(bold)
                    if pbold.text().strip() == "File names:":
                        all_names = until_break_or_end(pbold)
                        for el in all_names:
                            if getattr(el, 'tag', None) == 'a':
                                url = pq(el).attr('href')
                                attachment['views'].append({
                                    'url': canonicalize_url(url),
                                    'type': url.split(".")[-1]
                                })
                doc['attachments'].append(attachment)
                doc['id'] = cells.eq(0).text()


            file_list = []
            for key, value in doc['file_info'].iteritems():
                value['id'] = key
                file_list.append(value)

                files[key].update(value)

            doc['file_info'] = file_list

            print "Parsed %s %s..." % (doc['doctype'], doc['id'])
            fr_docs.append(doc)

    return {'fr_docs': fr_docs, 'files': files}

Example 40

Project: sympy
Source File: test_residue.py
View license
def test_residue():
    assert n_order(2, 13) == 12
    assert [n_order(a, 7) for a in range(1, 7)] == \
           [1, 3, 6, 3, 6, 2]
    assert n_order(5, 17) == 16
    assert n_order(17, 11) == n_order(6, 11)
    assert n_order(101, 119) == 6
    assert n_order(11, (10**50 + 151)**2) == 10000000000000000000000000000000000000000000000030100000000000000000000000000000000000000000000022650
    raises(ValueError, lambda: n_order(6, 9))

    assert is_primitive_root(2, 7) is False
    assert is_primitive_root(3, 8) is False
    assert is_primitive_root(11, 14) is False
    assert is_primitive_root(12, 17) == is_primitive_root(29, 17)
    raises(ValueError, lambda: is_primitive_root(3, 6))

    assert [primitive_root(i) for i in range(2, 31)] == [1, 2, 3, 2, 5, 3, \
       None, 2, 3, 2, None, 2, 3, None, None, 3, 5, 2, None, None, 7, 5, \
       None, 2, 7, 2, None, 2, None]

    for p in primerange(3, 100):
        it = _primitive_root_prime_iter(p)
        assert len(list(it)) == totient(totient(p))
    assert primitive_root(97) == 5
    assert primitive_root(97**2) == 5
    assert primitive_root(40487) == 5
    # note that primitive_root(40487) + 40487 = 40492 is a primitive root
    # of 40487**2, but it is not the smallest
    assert primitive_root(40487**2) == 10
    assert primitive_root(82) == 7
    p = 10**50 + 151
    assert primitive_root(p) == 11
    assert primitive_root(2*p) == 11
    assert primitive_root(p**2) == 11
    raises(ValueError, lambda: primitive_root(-3))

    assert is_quad_residue(3, 7) is False
    assert is_quad_residue(10, 13) is True
    assert is_quad_residue(12364, 139) == is_quad_residue(12364 % 139, 139)
    assert is_quad_residue(207, 251) is True
    assert is_quad_residue(0, 1) is True
    assert is_quad_residue(1, 1) is True
    assert is_quad_residue(0, 2) == is_quad_residue(1, 2) is True
    assert is_quad_residue(1, 4) is True
    assert is_quad_residue(2, 27) is False
    assert is_quad_residue(13122380800, 13604889600) is True
    assert [j for j in range(14) if is_quad_residue(j, 14)] == \
           [0, 1, 2, 4, 7, 8, 9, 11]
    raises(ValueError, lambda: is_quad_residue(1.1, 2))
    raises(ValueError, lambda: is_quad_residue(2, 0))


    assert quadratic_residues(12) == [0, 1, 4, 9]
    assert quadratic_residues(13) == [0, 1, 3, 4, 9, 10, 12]
    assert [len(quadratic_residues(i)) for i in range(1, 20)] == \
      [1, 2, 2, 2, 3, 4, 4, 3, 4, 6, 6, 4, 7, 8, 6, 4, 9, 8, 10]

    assert list(sqrt_mod_iter(6, 2)) == [0]
    assert sqrt_mod(3, 13) == 4
    assert sqrt_mod(3, -13) == 4
    assert sqrt_mod(6, 23) == 11
    assert sqrt_mod(345, 690) == 345

    for p in range(3, 100):
        d = defaultdict(list)
        for i in range(p):
            d[pow(i, 2, p)].append(i)
        for i in range(1, p):
            it = sqrt_mod_iter(i, p)
            v = sqrt_mod(i, p, True)
            if v:
                v = sorted(v)
                assert d[i] == v
            else:
                assert not d[i]

    assert sqrt_mod(9, 27, True) == [3, 6, 12, 15, 21, 24]
    assert sqrt_mod(9, 81, True) == [3, 24, 30, 51, 57, 78]
    assert sqrt_mod(9, 3**5, True) == [3, 78, 84, 159, 165, 240]
    assert sqrt_mod(81, 3**4, True) == [0, 9, 18, 27, 36, 45, 54, 63, 72]
    assert sqrt_mod(81, 3**5, True) == [9, 18, 36, 45, 63, 72, 90, 99, 117,\
            126, 144, 153, 171, 180, 198, 207, 225, 234]
    assert sqrt_mod(81, 3**6, True) == [9, 72, 90, 153, 171, 234, 252, 315,\
            333, 396, 414, 477, 495, 558, 576, 639, 657, 720]
    assert sqrt_mod(81, 3**7, True) == [9, 234, 252, 477, 495, 720, 738, 963,\
            981, 1206, 1224, 1449, 1467, 1692, 1710, 1935, 1953, 2178]

    for a, p in [(26214400, 32768000000), (26214400, 16384000000),
        (262144, 1048576), (87169610025, 163443018796875),
        (22315420166400, 167365651248000000)]:
        assert pow(sqrt_mod(a, p), 2, p) == a

    n = 70
    a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+2)
    it = sqrt_mod_iter(a, p)
    for i in range(10):
        assert pow(next(it), 2, p) == a
    a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+3)
    it = sqrt_mod_iter(a, p)
    for i in range(2):
        assert pow(next(it), 2, p) == a
    n = 100
    a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+1)
    it = sqrt_mod_iter(a, p)
    for i in range(2):
        assert pow(next(it), 2, p) == a

    assert type(next(sqrt_mod_iter(9, 27))) is int
    assert type(next(sqrt_mod_iter(9, 27, ZZ))) is type(ZZ(1))
    assert type(next(sqrt_mod_iter(1, 7, ZZ))) is type(ZZ(1))

    assert is_nthpow_residue(2, 1, 5)

    #issue 10816
    assert is_nthpow_residue(1, 0, 1) is False
    assert is_nthpow_residue(1, 0, 2) is True
    assert is_nthpow_residue(3, 0, 2) is False
    assert is_nthpow_residue(0, 1, 8) is True
    assert is_nthpow_residue(2, 3, 2) is False
    assert is_nthpow_residue(2, 3, 9) is False
    assert is_nthpow_residue(3, 5, 30) is True
    assert is_nthpow_residue(21, 11, 20) is True
    assert is_nthpow_residue(7, 10, 20) is False
    assert is_nthpow_residue(5, 10, 20) is True
    assert is_nthpow_residue(3, 10, 48) is False
    assert is_nthpow_residue(1, 10, 40) is True
    assert is_nthpow_residue(3, 10, 24) is False
    assert is_nthpow_residue(1, 10, 24) is True
    assert is_nthpow_residue(3, 10, 24) is False
    assert is_nthpow_residue(2, 10, 48) is False
    assert is_nthpow_residue(81, 3, 972) is False
    assert is_nthpow_residue(243, 5, 5103) is True
    assert is_nthpow_residue(243, 3, 1240029) is False
    x = set([pow(i, 56, 1024) for i in range(1024)])
    assert set([a for a in range(1024) if is_nthpow_residue(a, 56, 1024)]) == x
    x = set([ pow(i, 256, 2048) for i in range(2048)])
    assert set([a for a in range(2048) if is_nthpow_residue(a, 256, 2048)]) == x
    x = set([ pow(i, 11, 324000) for i in range(1000)])
    assert [ is_nthpow_residue(a, 11, 324000) for a in x]
    x = set([ pow(i, 17, 22217575536) for i in range(1000)])
    assert [ is_nthpow_residue(a, 17, 22217575536) for a in x]
    assert is_nthpow_residue(676, 3, 5364)
    assert is_nthpow_residue(9, 12, 36)
    assert is_nthpow_residue(32, 10, 41)
    assert is_nthpow_residue(4, 2, 64)
    assert is_nthpow_residue(31, 4, 41)
    assert not is_nthpow_residue(2, 2, 5)
    assert is_nthpow_residue(8547, 12, 10007)
    assert nthroot_mod(1801, 11, 2663) == 44
    for a, q, p in [(51922, 2, 203017), (43, 3, 109), (1801, 11, 2663),
          (26118163, 1303, 33333347), (1499, 7, 2663), (595, 6, 2663),
          (1714, 12, 2663), (28477, 9, 33343)]:
        r = nthroot_mod(a, q, p)
        assert pow(r, q, p) == a
    assert nthroot_mod(11, 3, 109) is None
    raises(NotImplementedError, lambda: nthroot_mod(16, 5, 36))
    raises(NotImplementedError, lambda: nthroot_mod(9, 16, 36))

    for p in primerange(5, 100):
        qv = range(3, p, 4)
        for q in qv:
            d = defaultdict(list)
            for i in range(p):
                d[pow(i, q, p)].append(i)
            for a in range(1, p - 1):
                res = nthroot_mod(a, q, p, True)
                if d[a]:
                    assert d[a] == res
                else:
                    assert res is None

    assert legendre_symbol(5, 11) == 1
    assert legendre_symbol(25, 41) == 1
    assert legendre_symbol(67, 101) == -1
    assert legendre_symbol(0, 13) == 0
    assert legendre_symbol(9, 3) == 0
    raises(ValueError, lambda: legendre_symbol(2, 4))

    assert jacobi_symbol(25, 41) == 1
    assert jacobi_symbol(-23, 83) == -1
    assert jacobi_symbol(3, 9) == 0
    assert jacobi_symbol(42, 97) == -1
    assert jacobi_symbol(3, 5) == -1
    assert jacobi_symbol(7, 9) == 1
    assert jacobi_symbol(0, 3) == 0
    assert jacobi_symbol(0, 1) == 1
    assert jacobi_symbol(2, 1) == 1
    assert jacobi_symbol(1, 3) == 1
    raises(ValueError, lambda: jacobi_symbol(3, 8))

    assert mobius(13*7) == 1
    assert mobius(1) == 1
    assert mobius(13*7*5) == -1
    assert mobius(13**2) == 0
    raises(ValueError, lambda: mobius(-3))

    p = Symbol('p', integer=True, positive=True, prime=True)
    x = Symbol('x', positive=True)
    i = Symbol('i', integer=True)
    assert mobius(p) == -1
    raises(TypeError, lambda: mobius(x))
    raises(ValueError, lambda: mobius(i))

    assert _discrete_log_trial_mul(587, 2**7, 2) == 7
    assert _discrete_log_trial_mul(941, 7**18, 7) == 18
    assert _discrete_log_trial_mul(389, 3**81, 3) == 81
    assert _discrete_log_trial_mul(191, 19**123, 19) == 123
    assert _discrete_log_shanks_steps(442879, 7**2, 7) == 2
    assert _discrete_log_shanks_steps(874323, 5**19, 5) == 19
    assert _discrete_log_shanks_steps(6876342, 7**71, 7) == 71
    assert _discrete_log_shanks_steps(2456747, 3**321, 3) == 321
    assert _discrete_log_pollard_rho(6013199, 2**6, 2, rseed=0) == 6
    assert _discrete_log_pollard_rho(6138719, 2**19, 2, rseed=0) == 19
    assert _discrete_log_pollard_rho(36721943, 2**40, 2, rseed=0) == 40
    assert _discrete_log_pollard_rho(24567899, 3**333, 3, rseed=0) == 333
    assert _discrete_log_pohlig_hellman(98376431, 11**9, 11) == 9
    assert _discrete_log_pohlig_hellman(78723213, 11**31, 11) == 31
    assert _discrete_log_pohlig_hellman(32942478, 11**98, 11) == 98
    assert _discrete_log_pohlig_hellman(14789363, 11**444, 11) == 444
    assert discrete_log(587, 2**9, 2) == 9
    assert discrete_log(2456747, 3**51, 3) == 51
    assert discrete_log(32942478, 11**127, 11) == 127
    assert discrete_log(432751500361, 7**324, 7) == 324

Example 41

Project: sympy
Source File: powsimp.py
View license
def powsimp(expr, deep=False, combine='all', force=False, measure=count_ops):
    """
    reduces expression by combining powers with similar bases and exponents.

    Notes
    =====

    If deep is True then powsimp() will also simplify arguments of
    functions. By default deep is set to False.

    If force is True then bases will be combined without checking for
    assumptions, e.g. sqrt(x)*sqrt(y) -> sqrt(x*y) which is not true
    if x and y are both negative.

    You can make powsimp() only combine bases or only combine exponents by
    changing combine='base' or combine='exp'.  By default, combine='all',
    which does both.  combine='base' will only combine::

         a   a          a                          2x      x
        x * y  =>  (x*y)   as well as things like 2   =>  4

    and combine='exp' will only combine
    ::

         a   b      (a + b)
        x * x  =>  x

    combine='exp' will strictly only combine exponents in the way that used
    to be automatic.  Also use deep=True if you need the old behavior.

    When combine='all', 'exp' is evaluated first.  Consider the first
    example below for when there could be an ambiguity relating to this.
    This is done so things like the second example can be completely
    combined.  If you want 'base' combined first, do something like
    powsimp(powsimp(expr, combine='base'), combine='exp').

    Examples
    ========

    >>> from sympy import powsimp, exp, log, symbols
    >>> from sympy.abc import x, y, z, n
    >>> powsimp(x**y*x**z*y**z, combine='all')
    x**(y + z)*y**z
    >>> powsimp(x**y*x**z*y**z, combine='exp')
    x**(y + z)*y**z
    >>> powsimp(x**y*x**z*y**z, combine='base', force=True)
    x**y*(x*y)**z

    >>> powsimp(x**z*x**y*n**z*n**y, combine='all', force=True)
    (n*x)**(y + z)
    >>> powsimp(x**z*x**y*n**z*n**y, combine='exp')
    n**(y + z)*x**(y + z)
    >>> powsimp(x**z*x**y*n**z*n**y, combine='base', force=True)
    (n*x)**y*(n*x)**z

    >>> x, y = symbols('x y', positive=True)
    >>> powsimp(log(exp(x)*exp(y)))
    log(exp(x)*exp(y))
    >>> powsimp(log(exp(x)*exp(y)), deep=True)
    x + y

    Radicals with Mul bases will be combined if combine='exp'

    >>> from sympy import sqrt, Mul
    >>> x, y = symbols('x y')

    Two radicals are automatically joined through Mul:

    >>> a=sqrt(x*sqrt(y))
    >>> a*a**3 == a**4
    True

    But if an integer power of that radical has been
    autoexpanded then Mul does not join the resulting factors:

    >>> a**4 # auto expands to a Mul, no longer a Pow
    x**2*y
    >>> _*a # so Mul doesn't combine them
    x**2*y*sqrt(x*sqrt(y))
    >>> powsimp(_) # but powsimp will
    (x*sqrt(y))**(5/2)
    >>> powsimp(x*y*a) # but won't when doing so would violate assumptions
    x*y*sqrt(x*sqrt(y))

    """
    from sympy.matrices.expressions.matexpr import MatrixSymbol

    def recurse(arg, **kwargs):
        _deep = kwargs.get('deep', deep)
        _combine = kwargs.get('combine', combine)
        _force = kwargs.get('force', force)
        _measure = kwargs.get('measure', measure)
        return powsimp(arg, _deep, _combine, _force, _measure)

    expr = sympify(expr)

    if (not isinstance(expr, Basic) or isinstance(expr, MatrixSymbol) or (
            expr.is_Atom or expr in (exp_polar(0), exp_polar(1)))):
        return expr

    if deep or expr.is_Add or expr.is_Mul and _y not in expr.args:
        expr = expr.func(*[recurse(w) for w in expr.args])

    if expr.is_Pow:
        return recurse(expr*_y, deep=False)/_y

    if not expr.is_Mul:
        return expr

    # handle the Mul
    if combine in ('exp', 'all'):
        # Collect base/exp data, while maintaining order in the
        # non-commutative parts of the product
        c_powers = defaultdict(list)
        nc_part = []
        newexpr = []
        coeff = S.One
        for term in expr.args:
            if term.is_Rational:
                coeff *= term
                continue
            if term.is_Pow:
                term = _denest_pow(term)
            if term.is_commutative:
                b, e = term.as_base_exp()
                if deep:
                    b, e = [recurse(i) for i in [b, e]]
                if b.is_Pow or b.func is exp:
                    # don't let smthg like sqrt(x**a) split into x**a, 1/2
                    # or else it will be joined as x**(a/2) later
                    b, e = b**e, S.One
                c_powers[b].append(e)
            else:
                # This is the logic that combines exponents for equal,
                # but non-commutative bases: A**x*A**y == A**(x+y).
                if nc_part:
                    b1, e1 = nc_part[-1].as_base_exp()
                    b2, e2 = term.as_base_exp()
                    if (b1 == b2 and
                            e1.is_commutative and e2.is_commutative):
                        nc_part[-1] = Pow(b1, Add(e1, e2))
                        continue
                nc_part.append(term)

        # add up exponents of common bases
        for b, e in ordered(iter(c_powers.items())):
            # allow 2**x/4 -> 2**(x - 2); don't do this when b and e are
            # Numbers since autoevaluation will undo it, e.g.
            # 2**(1/3)/4 -> 2**(1/3 - 2) -> 2**(1/3)/4
            if (b and b.is_Number and not all(ei.is_Number for ei in e) and \
                    coeff is not S.One and
                    b not in (S.One, S.NegativeOne)):
                m = multiplicity(abs(b), abs(coeff))
                if m:
                    e.append(m)
                    coeff /= b**m
            c_powers[b] = Add(*e)
        if coeff is not S.One:
            if coeff in c_powers:
                c_powers[coeff] += S.One
            else:
                c_powers[coeff] = S.One

        # convert to plain dictionary
        c_powers = dict(c_powers)

        # check for base and inverted base pairs
        be = list(c_powers.items())
        skip = set()  # skip if we already saw them
        for b, e in be:
            if b in skip:
                continue
            bpos = b.is_positive or b.is_polar
            if bpos:
                binv = 1/b
                if b != binv and binv in c_powers:
                    if b.as_numer_denom()[0] is S.One:
                        c_powers.pop(b)
                        c_powers[binv] -= e
                    else:
                        skip.add(binv)
                        e = c_powers.pop(binv)
                        c_powers[b] -= e

        # check for base and negated base pairs
        be = list(c_powers.items())
        _n = S.NegativeOne
        for i, (b, e) in enumerate(be):
            if ((-b).is_Symbol or b.is_Add) and -b in c_powers:
                if (b.is_positive in (0, 1) or e.is_integer):
                    c_powers[-b] += c_powers.pop(b)
                    if _n in c_powers:
                        c_powers[_n] += e
                    else:
                        c_powers[_n] = e

        # filter c_powers and convert to a list
        c_powers = [(b, e) for b, e in c_powers.items() if e]

        # ==============================================================
        # check for Mul bases of Rational powers that can be combined with
        # separated bases, e.g. x*sqrt(x*y)*sqrt(x*sqrt(x*y)) ->
        # (x*sqrt(x*y))**(3/2)
        # ---------------- helper functions

        def ratq(x):
            '''Return Rational part of x's exponent as it appears in the bkey.
            '''
            return bkey(x)[0][1]

        def bkey(b, e=None):
            '''Return (b**s, c.q), c.p where e -> c*s. If e is not given then
            it will be taken by using as_base_exp() on the input b.
            e.g.
                x**3/2 -> (x, 2), 3
                x**y -> (x**y, 1), 1
                x**(2*y/3) -> (x**y, 3), 2
                exp(x/2) -> (exp(a), 2), 1

            '''
            if e is not None:  # coming from c_powers or from below
                if e.is_Integer:
                    return (b, S.One), e
                elif e.is_Rational:
                    return (b, Integer(e.q)), Integer(e.p)
                else:
                    c, m = e.as_coeff_Mul(rational=True)
                    if c is not S.One:
                        if m.is_integer:
                            return (b, Integer(c.q)), m*Integer(c.p)
                        return (b**m, Integer(c.q)), Integer(c.p)
                    else:
                        return (b**e, S.One), S.One
            else:
                return bkey(*b.as_base_exp())

        def update(b):
            '''Decide what to do with base, b. If its exponent is now an
            integer multiple of the Rational denominator, then remove it
            and put the factors of its base in the common_b dictionary or
            update the existing bases if necessary. If it has been zeroed
            out, simply remove the base.
            '''
            newe, r = divmod(common_b[b], b[1])
            if not r:
                common_b.pop(b)
                if newe:
                    for m in Mul.make_args(b[0]**newe):
                        b, e = bkey(m)
                        if b not in common_b:
                            common_b[b] = 0
                        common_b[b] += e
                        if b[1] != 1:
                            bases.append(b)
        # ---------------- end of helper functions

        # assemble a dictionary of the factors having a Rational power
        common_b = {}
        done = []
        bases = []
        for b, e in c_powers:
            b, e = bkey(b, e)
            if b in common_b.keys():
                common_b[b] = common_b[b] + e
            else:
                common_b[b] = e
            if b[1] != 1 and b[0].is_Mul:
                bases.append(b)
        c_powers = [(b, e) for b, e in common_b.items() if e]
        bases.sort(key=default_sort_key)  # this makes tie-breaking canonical
        bases.sort(key=measure, reverse=True)  # handle longest first
        for base in bases:
            if base not in common_b:  # it may have been removed already
                continue
            b, exponent = base
            last = False  # True when no factor of base is a radical
            qlcm = 1  # the lcm of the radical denominators
            while True:
                bstart = b
                qstart = qlcm

                bb = []  # list of factors
                ee = []  # (factor's expo. and it's current value in common_b)
                for bi in Mul.make_args(b):
                    bib, bie = bkey(bi)
                    if bib not in common_b or common_b[bib] < bie:
                        ee = bb = []  # failed
                        break
                    ee.append([bie, common_b[bib]])
                    bb.append(bib)
                if ee:
                    # find the number of extractions possible
                    # e.g. [(1, 2), (2, 2)] -> min(2/1, 2/2) -> 1
                    min1 = ee[0][1]/ee[0][0]
                    for i in range(len(ee)):
                        rat = ee[i][1]/ee[i][0]
                        if rat < 1:
                            break
                        min1 = min(min1, rat)
                    else:
                        # update base factor counts
                        # e.g. if ee = [(2, 5), (3, 6)] then min1 = 2
                        # and the new base counts will be 5-2*2 and 6-2*3
                        for i in range(len(bb)):
                            common_b[bb[i]] -= min1*ee[i][0]
                            update(bb[i])
                        # update the count of the base
                        # e.g. x**2*y*sqrt(x*sqrt(y)) the count of x*sqrt(y)
                        # will increase by 4 to give bkey (x*sqrt(y), 2, 5)
                        common_b[base] += min1*qstart*exponent
                if (last  # no more radicals in base
                    or len(common_b) == 1  # nothing left to join with
                    or all(k[1] == 1 for k in common_b)  # no rad's in common_b
                        ):
                    break
                # see what we can exponentiate base by to remove any radicals
                # so we know what to search for
                # e.g. if base were x**(1/2)*y**(1/3) then we should
                # exponentiate by 6 and look for powers of x and y in the ratio
                # of 2 to 3
                qlcm = lcm([ratq(bi) for bi in Mul.make_args(bstart)])
                if qlcm == 1:
                    break  # we are done
                b = bstart**qlcm
                qlcm *= qstart
                if all(ratq(bi) == 1 for bi in Mul.make_args(b)):
                    last = True  # we are going to be done after this next pass
            # this base no longer can find anything to join with and
            # since it was longer than any other we are done with it
            b, q = base
            done.append((b, common_b.pop(base)*Rational(1, q)))

        # update c_powers and get ready to continue with powsimp
        c_powers = done
        # there may be terms still in common_b that were bases that were
        # identified as needing processing, so remove those, too
        for (b, q), e in common_b.items():
            if (b.is_Pow or b.func is exp) and \
                    q is not S.One and not b.exp.is_Rational:
                b, be = b.as_base_exp()
                b = b**(be/q)
            else:
                b = root(b, q)
            c_powers.append((b, e))
        check = len(c_powers)
        c_powers = dict(c_powers)
        assert len(c_powers) == check  # there should have been no duplicates
        # ==============================================================

        # rebuild the expression
        newexpr = expr.func(*(newexpr + [Pow(b, e) for b, e in c_powers.items()]))
        if combine == 'exp':
            return expr.func(newexpr, expr.func(*nc_part))
        else:
            return recurse(expr.func(*nc_part), combine='base') * \
                recurse(newexpr, combine='base')

    elif combine == 'base':

        # Build c_powers and nc_part.  These must both be lists not
        # dicts because exp's are not combined.
        c_powers = []
        nc_part = []
        for term in expr.args:
            if term.is_commutative:
                c_powers.append(list(term.as_base_exp()))
            else:
                # This is the logic that combines bases that are
                # different and non-commutative, but with equal and
                # commutative exponents: A**x*B**x == (A*B)**x.
                if nc_part:
                    b1, e1 = nc_part[-1].as_base_exp()
                    b2, e2 = term.as_base_exp()
                    if (e1 == e2 and e2.is_commutative):
                        nc_part[-1] = Pow(b1*b2, e1)
                        continue
                nc_part.append(term)

        # Pull out numerical coefficients from exponent if assumptions allow
        # e.g., 2**(2*x) => 4**x
        for i in range(len(c_powers)):
            b, e = c_powers[i]
            if not (all(x.is_nonnegative for x in b.as_numer_denom()) or e.is_integer or force or b.is_polar):
                continue
            exp_c, exp_t = e.as_coeff_Mul(rational=True)
            if exp_c is not S.One and exp_t is not S.One:
                c_powers[i] = [Pow(b, exp_c), exp_t]

        # Combine bases whenever they have the same exponent and
        # assumptions allow
        # first gather the potential bases under the common exponent
        c_exp = defaultdict(list)
        for b, e in c_powers:
            if deep:
                e = recurse(e)
            c_exp[e].append(b)
        del c_powers

        # Merge back in the results of the above to form a new product
        c_powers = defaultdict(list)
        for e in c_exp:
            bases = c_exp[e]

            # calculate the new base for e

            if len(bases) == 1:
                new_base = bases[0]
            elif e.is_integer or force:
                new_base = expr.func(*bases)
            else:
                # see which ones can be joined
                unk = []
                nonneg = []
                neg = []
                for bi in bases:
                    if bi.is_negative:
                        neg.append(bi)
                    elif bi.is_nonnegative:
                        nonneg.append(bi)
                    elif bi.is_polar:
                        nonneg.append(
                            bi)  # polar can be treated like non-negative
                    else:
                        unk.append(bi)
                if len(unk) == 1 and not neg or len(neg) == 1 and not unk:
                    # a single neg or a single unk can join the rest
                    nonneg.extend(unk + neg)
                    unk = neg = []
                elif neg:
                    # their negative signs cancel in groups of 2*q if we know
                    # that e = p/q else we have to treat them as unknown
                    israt = False
                    if e.is_Rational:
                        israt = True
                    else:
                        p, d = e.as_numer_denom()
                        if p.is_integer and d.is_integer:
                            israt = True
                    if israt:
                        neg = [-w for w in neg]
                        unk.extend([S.NegativeOne]*len(neg))
                    else:
                        unk.extend(neg)
                        neg = []
                    del israt

                # these shouldn't be joined
                for b in unk:
                    c_powers[b].append(e)
                # here is a new joined base
                new_base = expr.func(*(nonneg + neg))
                # if there are positive parts they will just get separated
                # again unless some change is made

                def _terms(e):
                    # return the number of terms of this expression
                    # when multiplied out -- assuming no joining of terms
                    if e.is_Add:
                        return sum([_terms(ai) for ai in e.args])
                    if e.is_Mul:
                        return prod([_terms(mi) for mi in e.args])
                    return 1
                xnew_base = expand_mul(new_base, deep=False)
                if len(Add.make_args(xnew_base)) < _terms(new_base):
                    new_base = factor_terms(xnew_base)

            c_powers[new_base].append(e)

        # break out the powers from c_powers now
        c_part = [Pow(b, ei) for b, e in c_powers.items() for ei in e]

        # we're done
        return expr.func(*(c_part + nc_part))

    else:
        raise ValueError("combine must be one of ('all', 'exp', 'base').")

Example 42

Project: sympy
Source File: radsimp.py
View license
def collect(expr, syms, func=None, evaluate=None, exact=False, distribute_order_term=True):
    """
    Collect additive terms of an expression.

    This function collects additive terms of an expression with respect
    to a list of expression up to powers with rational exponents. By the
    term symbol here are meant arbitrary expressions, which can contain
    powers, products, sums etc. In other words symbol is a pattern which
    will be searched for in the expression's terms.

    The input expression is not expanded by :func:`collect`, so user is
    expected to provide an expression is an appropriate form. This makes
    :func:`collect` more predictable as there is no magic happening behind the
    scenes. However, it is important to note, that powers of products are
    converted to products of powers using the :func:`expand_power_base`
    function.

    There are two possible types of output. First, if ``evaluate`` flag is
    set, this function will return an expression with collected terms or
    else it will return a dictionary with expressions up to rational powers
    as keys and collected coefficients as values.

    Examples
    ========

    >>> from sympy import S, collect, expand, factor, Wild
    >>> from sympy.abc import a, b, c, x, y, z

    This function can collect symbolic coefficients in polynomials or
    rational expressions. It will manage to find all integer or rational
    powers of collection variable::

        >>> collect(a*x**2 + b*x**2 + a*x - b*x + c, x)
        c + x**2*(a + b) + x*(a - b)

    The same result can be achieved in dictionary form::

        >>> d = collect(a*x**2 + b*x**2 + a*x - b*x + c, x, evaluate=False)
        >>> d[x**2]
        a + b
        >>> d[x]
        a - b
        >>> d[S.One]
        c

    You can also work with multivariate polynomials. However, remember that
    this function is greedy so it will care only about a single symbol at time,
    in specification order::

        >>> collect(x**2 + y*x**2 + x*y + y + a*y, [x, y])
        x**2*(y + 1) + x*y + y*(a + 1)

    Also more complicated expressions can be used as patterns::

        >>> from sympy import sin, log
        >>> collect(a*sin(2*x) + b*sin(2*x), sin(2*x))
        (a + b)*sin(2*x)

        >>> collect(a*x*log(x) + b*(x*log(x)), x*log(x))
        x*(a + b)*log(x)

    You can use wildcards in the pattern::

        >>> w = Wild('w1')
        >>> collect(a*x**y - b*x**y, w**y)
        x**y*(a - b)

    It is also possible to work with symbolic powers, although it has more
    complicated behavior, because in this case power's base and symbolic part
    of the exponent are treated as a single symbol::

        >>> collect(a*x**c + b*x**c, x)
        a*x**c + b*x**c
        >>> collect(a*x**c + b*x**c, x**c)
        x**c*(a + b)

    However if you incorporate rationals to the exponents, then you will get
    well known behavior::

        >>> collect(a*x**(2*c) + b*x**(2*c), x**c)
        x**(2*c)*(a + b)

    Note also that all previously stated facts about :func:`collect` function
    apply to the exponential function, so you can get::

        >>> from sympy import exp
        >>> collect(a*exp(2*x) + b*exp(2*x), exp(x))
        (a + b)*exp(2*x)

    If you are interested only in collecting specific powers of some symbols
    then set ``exact`` flag in arguments::

        >>> collect(a*x**7 + b*x**7, x, exact=True)
        a*x**7 + b*x**7
        >>> collect(a*x**7 + b*x**7, x**7, exact=True)
        x**7*(a + b)

    You can also apply this function to differential equations, where
    derivatives of arbitrary order can be collected. Note that if you
    collect with respect to a function or a derivative of a function, all
    derivatives of that function will also be collected. Use
    ``exact=True`` to prevent this from happening::

        >>> from sympy import Derivative as D, collect, Function
        >>> f = Function('f') (x)

        >>> collect(a*D(f,x) + b*D(f,x), D(f,x))
        (a + b)*Derivative(f(x), x)

        >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), f)
        (a + b)*Derivative(f(x), x, x)

        >>> collect(a*D(D(f,x),x) + b*D(D(f,x),x), D(f,x), exact=True)
        a*Derivative(f(x), x, x) + b*Derivative(f(x), x, x)

        >>> collect(a*D(f,x) + b*D(f,x) + a*f + b*f, f)
        (a + b)*f(x) + (a + b)*Derivative(f(x), x)

    Or you can even match both derivative order and exponent at the same time::

        >>> collect(a*D(D(f,x),x)**2 + b*D(D(f,x),x)**2, D(f,x))
        (a + b)*Derivative(f(x), x, x)**2

    Finally, you can apply a function to each of the collected coefficients.
    For example you can factorize symbolic coefficients of polynomial::

        >>> f = expand((x + a + 1)**3)

        >>> collect(f, x, factor)
        x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + (a + 1)**3

    .. note:: Arguments are expected to be in expanded form, so you might have
              to call :func:`expand` prior to calling this function.

    See Also
    ========
    collect_const, collect_sqrt, rcollect
    """
    if evaluate is None:
        evaluate = global_evaluate[0]

    def make_expression(terms):
        product = []

        for term, rat, sym, deriv in terms:
            if deriv is not None:
                var, order = deriv

                while order > 0:
                    term, order = Derivative(term, var), order - 1

            if sym is None:
                if rat is S.One:
                    product.append(term)
                else:
                    product.append(Pow(term, rat))
            else:
                product.append(Pow(term, rat*sym))

        return Mul(*product)

    def parse_derivative(deriv):
        # scan derivatives tower in the input expression and return
        # underlying function and maximal differentiation order
        expr, sym, order = deriv.expr, deriv.variables[0], 1

        for s in deriv.variables[1:]:
            if s == sym:
                order += 1
            else:
                raise NotImplementedError(
                    'Improve MV Derivative support in collect')

        while isinstance(expr, Derivative):
            s0 = expr.variables[0]

            for s in expr.variables:
                if s != s0:
                    raise NotImplementedError(
                        'Improve MV Derivative support in collect')

            if s0 == sym:
                expr, order = expr.expr, order + len(expr.variables)
            else:
                break

        return expr, (sym, Rational(order))

    def parse_term(expr):
        """Parses expression expr and outputs tuple (sexpr, rat_expo,
        sym_expo, deriv)
        where:
         - sexpr is the base expression
         - rat_expo is the rational exponent that sexpr is raised to
         - sym_expo is the symbolic exponent that sexpr is raised to
         - deriv contains the derivatives the the expression

         for example, the output of x would be (x, 1, None, None)
         the output of 2**x would be (2, 1, x, None)
        """
        rat_expo, sym_expo = S.One, None
        sexpr, deriv = expr, None

        if expr.is_Pow:
            if isinstance(expr.base, Derivative):
                sexpr, deriv = parse_derivative(expr.base)
            else:
                sexpr = expr.base

            if expr.exp.is_Number:
                rat_expo = expr.exp
            else:
                coeff, tail = expr.exp.as_coeff_Mul()

                if coeff.is_Number:
                    rat_expo, sym_expo = coeff, tail
                else:
                    sym_expo = expr.exp
        elif expr.func is exp:
            arg = expr.args[0]
            if arg.is_Rational:
                sexpr, rat_expo = S.Exp1, arg
            elif arg.is_Mul:
                coeff, tail = arg.as_coeff_Mul(rational=True)
                sexpr, rat_expo = exp(tail), coeff
        elif isinstance(expr, Derivative):
            sexpr, deriv = parse_derivative(expr)

        return sexpr, rat_expo, sym_expo, deriv

    def parse_expression(terms, pattern):
        """Parse terms searching for a pattern.
        terms is a list of tuples as returned by parse_terms;
        pattern is an expression treated as a product of factors
        """
        pattern = Mul.make_args(pattern)

        if len(terms) < len(pattern):
            # pattern is longer than matched product
            # so no chance for positive parsing result
            return None
        else:
            pattern = [parse_term(elem) for elem in pattern]

            terms = terms[:]  # need a copy
            elems, common_expo, has_deriv = [], None, False

            for elem, e_rat, e_sym, e_ord in pattern:

                if elem.is_Number and e_rat == 1 and e_sym is None:
                    # a constant is a match for everything
                    continue

                for j in range(len(terms)):
                    if terms[j] is None:
                        continue

                    term, t_rat, t_sym, t_ord = terms[j]

                    # keeping track of whether one of the terms had
                    # a derivative or not as this will require rebuilding
                    # the expression later
                    if t_ord is not None:
                        has_deriv = True

                    if (term.match(elem) is not None and
                            (t_sym == e_sym or t_sym is not None and
                            e_sym is not None and
                            t_sym.match(e_sym) is not None)):
                        if exact is False:
                            # we don't have to be exact so find common exponent
                            # for both expression's term and pattern's element
                            expo = t_rat / e_rat

                            if common_expo is None:
                                # first time
                                common_expo = expo
                            else:
                                # common exponent was negotiated before so
                                # there is no chance for a pattern match unless
                                # common and current exponents are equal
                                if common_expo != expo:
                                    common_expo = 1
                        else:
                            # we ought to be exact so all fields of
                            # interest must match in every details
                            if e_rat != t_rat or e_ord != t_ord:
                                continue

                        # found common term so remove it from the expression
                        # and try to match next element in the pattern
                        elems.append(terms[j])
                        terms[j] = None

                        break

                else:
                    # pattern element not found
                    return None

            return [_f for _f in terms if _f], elems, common_expo, has_deriv

    if evaluate:
        if expr.is_Mul:
            return expr.func(*[
                collect(term, syms, func, True, exact, distribute_order_term)
                for term in expr.args])
        elif expr.is_Pow:
            b = collect(
                expr.base, syms, func, True, exact, distribute_order_term)
            return Pow(b, expr.exp)

    if iterable(syms):
        syms = [expand_power_base(i, deep=False) for i in syms]
    else:
        syms = [expand_power_base(syms, deep=False)]

    expr = sympify(expr)
    order_term = None

    if distribute_order_term:
        order_term = expr.getO()

        if order_term is not None:
            if order_term.has(*syms):
                order_term = None
            else:
                expr = expr.removeO()

    summa = [expand_power_base(i, deep=False) for i in Add.make_args(expr)]

    collected, disliked = defaultdict(list), S.Zero
    for product in summa:
        terms = [parse_term(i) for i in Mul.make_args(product)]

        for symbol in syms:
            if SYMPY_DEBUG:
                print("DEBUG: parsing of expression %s with symbol %s " % (
                    str(terms), str(symbol))
                )

            result = parse_expression(terms, symbol)

            if SYMPY_DEBUG:
                print("DEBUG: returned %s" % str(result))

            if result is not None:
                terms, elems, common_expo, has_deriv = result

                # when there was derivative in current pattern we
                # will need to rebuild its expression from scratch
                if not has_deriv:
                    index = 1
                    for elem in elems:
                        e = elem[1]
                        if elem[2] is not None:
                            e *= elem[2]
                        index *= Pow(elem[0], e)
                else:
                    index = make_expression(elems)
                terms = expand_power_base(make_expression(terms), deep=False)
                index = expand_power_base(index, deep=False)
                collected[index].append(terms)
                break
        else:
            # none of the patterns matched
            disliked += product
    # add terms now for each key
    collected = {k: Add(*v) for k, v in collected.items()}

    if disliked is not S.Zero:
        collected[S.One] = disliked

    if order_term is not None:
        for key, val in collected.items():
            collected[key] = val + order_term

    if func is not None:
        collected = dict(
            [(key, func(val)) for key, val in collected.items()])

    if evaluate:
        return Add(*[key*val for key, val in collected.items()])
    else:
        return collected

Example 43

Project: sympy
Source File: radsimp.py
View license
def radsimp(expr, symbolic=True, max_terms=4):
    """
    Rationalize the denominator by removing square roots.

    Note: the expression returned from radsimp must be used with caution
    since if the denominator contains symbols, it will be possible to make
    substitutions that violate the assumptions of the simplification process:
    that for a denominator matching a + b*sqrt(c), a != +/-b*sqrt(c). (If
    there are no symbols, this assumptions is made valid by collecting terms
    of sqrt(c) so the match variable ``a`` does not contain ``sqrt(c)``.) If
    you do not want the simplification to occur for symbolic denominators, set
    ``symbolic`` to False.

    If there are more than ``max_terms`` radical terms then the expression is
    returned unchanged.

    Examples
    ========

    >>> from sympy import radsimp, sqrt, Symbol, denom, pprint, I
    >>> from sympy import factor_terms, fraction, signsimp
    >>> from sympy.simplify.radsimp import collect_sqrt
    >>> from sympy.abc import a, b, c

    >>> radsimp(1/(I + 1))
    (1 - I)/2
    >>> radsimp(1/(2 + sqrt(2)))
    (-sqrt(2) + 2)/2
    >>> x,y = map(Symbol, 'xy')
    >>> e = ((2 + 2*sqrt(2))*x + (2 + sqrt(8))*y)/(2 + sqrt(2))
    >>> radsimp(e)
    sqrt(2)*(x + y)

    No simplification beyond removal of the gcd is done. One might
    want to polish the result a little, however, by collecting
    square root terms:

    >>> r2 = sqrt(2)
    >>> r5 = sqrt(5)
    >>> ans = radsimp(1/(y*r2 + x*r2 + a*r5 + b*r5)); pprint(ans)
        ___       ___       ___       ___
      \/ 5 *a + \/ 5 *b - \/ 2 *x - \/ 2 *y
    ------------------------------------------
       2               2      2              2
    5*a  + 10*a*b + 5*b  - 2*x  - 4*x*y - 2*y

    >>> n, d = fraction(ans)
    >>> pprint(factor_terms(signsimp(collect_sqrt(n))/d, radical=True))
            ___             ___
          \/ 5 *(a + b) - \/ 2 *(x + y)
    ------------------------------------------
       2               2      2              2
    5*a  + 10*a*b + 5*b  - 2*x  - 4*x*y - 2*y

    If radicals in the denominator cannot be removed or there is no denominator,
    the original expression will be returned.

    >>> radsimp(sqrt(2)*x + sqrt(2))
    sqrt(2)*x + sqrt(2)

    Results with symbols will not always be valid for all substitutions:

    >>> eq = 1/(a + b*sqrt(c))
    >>> eq.subs(a, b*sqrt(c))
    1/(2*b*sqrt(c))
    >>> radsimp(eq).subs(a, b*sqrt(c))
    nan

    If symbolic=False, symbolic denominators will not be transformed (but
    numeric denominators will still be processed):

    >>> radsimp(eq, symbolic=False)
    1/(a + b*sqrt(c))

    """
    from sympy.simplify.simplify import signsimp

    syms = symbols("a:d A:D")
    def _num(rterms):
        # return the multiplier that will simplify the expression described
        # by rterms [(sqrt arg, coeff), ... ]
        a, b, c, d, A, B, C, D = syms
        if len(rterms) == 2:
            reps = dict(list(zip([A, a, B, b], [j for i in rterms for j in i])))
            return (
            sqrt(A)*a - sqrt(B)*b).xreplace(reps)
        if len(rterms) == 3:
            reps = dict(list(zip([A, a, B, b, C, c], [j for i in rterms for j in i])))
            return (
            (sqrt(A)*a + sqrt(B)*b - sqrt(C)*c)*(2*sqrt(A)*sqrt(B)*a*b - A*a**2 -
            B*b**2 + C*c**2)).xreplace(reps)
        elif len(rterms) == 4:
            reps = dict(list(zip([A, a, B, b, C, c, D, d], [j for i in rterms for j in i])))
            return ((sqrt(A)*a + sqrt(B)*b - sqrt(C)*c - sqrt(D)*d)*(2*sqrt(A)*sqrt(B)*a*b
                - A*a**2 - B*b**2 - 2*sqrt(C)*sqrt(D)*c*d + C*c**2 +
                D*d**2)*(-8*sqrt(A)*sqrt(B)*sqrt(C)*sqrt(D)*a*b*c*d + A**2*a**4 -
                2*A*B*a**2*b**2 - 2*A*C*a**2*c**2 - 2*A*D*a**2*d**2 + B**2*b**4 -
                2*B*C*b**2*c**2 - 2*B*D*b**2*d**2 + C**2*c**4 - 2*C*D*c**2*d**2 +
                D**2*d**4)).xreplace(reps)
        elif len(rterms) == 1:
            return sqrt(rterms[0][0])
        else:
            raise NotImplementedError

    def ispow2(d, log2=False):
        if not d.is_Pow:
            return False
        e = d.exp
        if e.is_Rational and e.q == 2 or symbolic and fraction(e)[1] == 2:
            return True
        if log2:
            q = 1
            if e.is_Rational:
                q = e.q
            elif symbolic:
                d = fraction(e)[1]
                if d.is_Integer:
                    q = d
            if q != 1 and log(q, 2).is_Integer:
                return True
        return False

    def handle(expr):
        # Handle first reduces to the case
        # expr = 1/d, where d is an add, or d is base**p/2.
        # We do this by recursively calling handle on each piece.
        from sympy.simplify.simplify import nsimplify

        n, d = fraction(expr)

        if expr.is_Atom or (d.is_Atom and n.is_Atom):
            return expr
        elif not n.is_Atom:
            n = n.func(*[handle(a) for a in n.args])
            return _unevaluated_Mul(n, handle(1/d))
        elif n is not S.One:
            return _unevaluated_Mul(n, handle(1/d))
        elif d.is_Mul:
            return _unevaluated_Mul(*[handle(1/d) for d in d.args])

        # By this step, expr is 1/d, and d is not a mul.
        if not symbolic and d.free_symbols:
            return expr

        if ispow2(d):
            d2 = sqrtdenest(sqrt(d.base))**fraction(d.exp)[0]
            if d2 != d:
                return handle(1/d2)
        elif d.is_Pow and (d.exp.is_integer or d.base.is_positive):
            # (1/d**i) = (1/d)**i
            return handle(1/d.base)**d.exp

        if not (d.is_Add or ispow2(d)):
            return 1/d.func(*[handle(a) for a in d.args])

        # handle 1/d treating d as an Add (though it may not be)

        keep = True  # keep changes that are made

        # flatten it and collect radicals after checking for special
        # conditions
        d = _mexpand(d)

        # did it change?
        if d.is_Atom:
            return 1/d

        # is it a number that might be handled easily?
        if d.is_number:
            _d = nsimplify(d)
            if _d.is_Number and _d.equals(d):
                return 1/_d

        while True:
            # collect similar terms
            collected = defaultdict(list)
            for m in Add.make_args(d):  # d might have become non-Add
                p2 = []
                other = []
                for i in Mul.make_args(m):
                    if ispow2(i, log2=True):
                        p2.append(i.base if i.exp is S.Half else i.base**(2*i.exp))
                    elif i is S.ImaginaryUnit:
                        p2.append(S.NegativeOne)
                    else:
                        other.append(i)
                collected[tuple(ordered(p2))].append(Mul(*other))
            rterms = list(ordered(list(collected.items())))
            rterms = [(Mul(*i), Add(*j)) for i, j in rterms]
            nrad = len(rterms) - (1 if rterms[0][0] is S.One else 0)
            if nrad < 1:
                break
            elif nrad > max_terms:
                # there may have been invalid operations leading to this point
                # so don't keep changes, e.g. this expression is troublesome
                # in collecting terms so as not to raise the issue of 2834:
                # r = sqrt(sqrt(5) + 5)
                # eq = 1/(sqrt(5)*r + 2*sqrt(5)*sqrt(-sqrt(5) + 5) + 5*r)
                keep = False
                break
            if len(rterms) > 4:
                # in general, only 4 terms can be removed with repeated squaring
                # but other considerations can guide selection of radical terms
                # so that radicals are removed
                if all([x.is_Integer and (y**2).is_Rational for x, y in rterms]):
                    nd, d = rad_rationalize(S.One, Add._from_args(
                        [sqrt(x)*y for x, y in rterms]))
                    n *= nd
                else:
                    # is there anything else that might be attempted?
                    keep = False
                break
            from sympy.simplify.powsimp import powsimp, powdenest

            num = powsimp(_num(rterms))
            n *= num
            d *= num
            d = powdenest(_mexpand(d), force=symbolic)
            if d.is_Atom:
                break

        if not keep:
            return expr
        return _unevaluated_Mul(n, 1/d)

    coeff, expr = expr.as_coeff_Add()
    expr = expr.normal()
    old = fraction(expr)
    n, d = fraction(handle(expr))
    if old != (n, d):
        if not d.is_Atom:
            was = (n, d)
            n = signsimp(n, evaluate=False)
            d = signsimp(d, evaluate=False)
            u = Factors(_unevaluated_Mul(n, 1/d))
            u = _unevaluated_Mul(*[k**v for k, v in u.factors.items()])
            n, d = fraction(u)
            if old == (n, d):
                n, d = was
        n = expand_mul(n)
        if d.is_Number or d.is_Add:
            n2, d2 = fraction(gcd_terms(_unevaluated_Mul(n, 1/d)))
            if d2.is_Number or (d2.count_ops() <= d.count_ops()):
                n, d = [signsimp(i) for i in (n2, d2)]
                if n.is_Mul and n.args[0].is_Number:
                    n = n.func(*n.args)

    return coeff + _unevaluated_Mul(n, 1/d)

Example 44

Project: jcvi
Source File: blast.py
View license
def covfilter(args):
    """
    %prog covfilter blastfile fastafile

    Fastafile is used to get the sizes of the queries. Two filters can be
    applied, the id% and cov%.
    """
    from jcvi.algorithms.supermap import supermap
    from jcvi.utils.range import range_union

    allowed_iterby = ("query", "query_sbjct")

    p = OptionParser(covfilter.__doc__)
    p.set_align(pctid=95, pctcov=50)
    p.add_option("--scov", default=False, action="store_true",
            help="Subject coverage instead of query [default: %default]")
    p.add_option("--supermap", action="store_true",
            help="Use supermap instead of union")
    p.add_option("--ids", dest="ids", default=None,
            help="Print out the ids that satisfy [default: %default]")
    p.add_option("--list", dest="list", default=False, action="store_true",
            help="List the id% and cov% per gene [default: %default]")
    p.add_option("--iterby", dest="iterby", default="query", choices=allowed_iterby,
            help="Choose how to iterate through BLAST [default: %default]")
    p.set_outfile(outfile=None)

    opts, args = p.parse_args(args)

    if len(args) != 2:
        sys.exit(not p.print_help())

    blastfile, fastafile = args
    pctid = opts.pctid
    pctcov = opts.pctcov
    union = not opts.supermap
    scov = opts.scov
    sz = Sizes(fastafile)
    sizes = sz.mapping
    iterby = opts.iterby
    qspair = iterby == "query_sbjct"

    if not union:
        querysupermap = blastfile + ".query.supermap"
        if not op.exists(querysupermap):
            supermap(blastfile, filter="query")

        blastfile = querysupermap

    assert op.exists(blastfile)

    covered = 0
    mismatches = 0
    gaps = 0
    alignlen = 0
    queries = set()
    valid = set()
    blast = BlastSlow(blastfile)
    iterator = blast.iter_hits_pair if qspair else blast.iter_hits

    covidstore = {}
    for query, blines in iterator():
        blines = list(blines)
        queries.add(query)

        # per gene report
        this_covered = 0
        this_alignlen = 0
        this_mismatches = 0
        this_gaps = 0
        this_identity = 0

        ranges = []
        for b in blines:
            if scov:
                s, start, stop = b.subject, b.sstart, b.sstop
            else:
                s, start, stop = b.query, b.qstart, b.qstop
            cov_id = s

            if b.pctid < pctid:
                continue

            if start > stop:
                start, stop = stop, start
            this_covered += stop - start + 1
            this_alignlen += b.hitlen
            this_mismatches += b.nmismatch
            this_gaps += b.ngaps
            ranges.append(("1", start, stop))

        if ranges:
            this_identity = 100. - (this_mismatches + this_gaps) * 100. / this_alignlen

        if union:
            this_covered = range_union(ranges)

        this_coverage = this_covered * 100. / sizes[cov_id]
        covidstore[query] = (this_identity, this_coverage)
        if this_identity >= pctid and this_coverage >= pctcov:
            valid.add(query)

        covered += this_covered
        mismatches += this_mismatches
        gaps += this_gaps
        alignlen += this_alignlen

    if opts.list:
        if qspair:
            allpairs = defaultdict(list)
            for (q, s) in covidstore:
                allpairs[q].append((q, s))
                allpairs[s].append((q, s))

            for id, size in sz.iter_sizes():
                if id not in allpairs:
                    print "\t".join((id, "na", "0", "0"))
                else:
                    for qs in allpairs[id]:
                        this_identity, this_coverage = covidstore[qs]
                        print "{0}\t{1:.1f}\t{2:.1f}".format("\t".join(qs), this_identity, this_coverage)
        else:
            for query, size in sz.iter_sizes():
                this_identity, this_coverage = covidstore.get(query, (0, 0))
                print "{0}\t{1:.1f}\t{2:.1f}".format(query, this_identity, this_coverage)

    mapped_count = len(queries)
    valid_count = len(valid)
    cutoff_message = "(id={0.pctid}% cov={0.pctcov}%)".format(opts)

    m = "Identity: {0} mismatches, {1} gaps, {2} alignlen\n".\
            format(mismatches, gaps, alignlen)
    total = len(sizes.keys())
    m += "Total mapped: {0} ({1:.1f}% of {2})\n".\
            format(mapped_count, mapped_count * 100. / total, total)
    m += "Total valid {0}: {1} ({2:.1f}% of {3})\n".\
            format(cutoff_message, valid_count, valid_count * 100. / total, total)
    m += "Average id = {0:.2f}%\n".\
            format(100 - (mismatches + gaps) * 100. / alignlen)

    queries_combined = sz.totalsize
    m += "Coverage: {0} covered, {1} total\n".\
            format(covered, queries_combined)
    m += "Average coverage = {0:.2f}%".\
            format(covered * 100. / queries_combined)

    logfile = blastfile + ".covfilter.log"
    fw = open(logfile, "w")
    for f in (sys.stderr, fw):
        print >> f, m
    fw.close()

    if opts.ids:
        filename = opts.ids
        fw = must_open(filename, "w")
        for id in valid:
            print >> fw, id
        logging.debug("Queries beyond cutoffs {0} written to `{1}`.".\
                format(cutoff_message, filename))

    outfile = opts.outfile
    if not outfile:
        return

    fw = must_open(outfile, "w")
    blast = Blast(blastfile)
    for b in blast:
        query = (b.query, b.subject) if qspair else b.query
        if query in valid:
            print >> fw, b

Example 45

Project: beets
Source File: bluelet.py
View license
def run(root_coro):
    """Schedules a coroutine, running it to completion. This
    encapsulates the Bluelet scheduler, which the root coroutine can
    add to by spawning new coroutines.
    """
    # The "threads" dictionary keeps track of all the currently-
    # executing and suspended coroutines. It maps coroutines to their
    # currently "blocking" event. The event value may be SUSPENDED if
    # the coroutine is waiting on some other condition: namely, a
    # delegated coroutine or a joined coroutine. In this case, the
    # coroutine should *also* appear as a value in one of the below
    # dictionaries `delegators` or `joiners`.
    threads = {root_coro: ValueEvent(None)}

    # Maps child coroutines to delegating parents.
    delegators = {}

    # Maps child coroutines to joining (exit-waiting) parents.
    joiners = collections.defaultdict(list)

    def complete_thread(coro, return_value):
        """Remove a coroutine from the scheduling pool, awaking
        delegators and joiners as necessary and returning the specified
        value to any delegating parent.
        """
        del threads[coro]

        # Resume delegator.
        if coro in delegators:
            threads[delegators[coro]] = ValueEvent(return_value)
            del delegators[coro]

        # Resume joiners.
        if coro in joiners:
            for parent in joiners[coro]:
                threads[parent] = ValueEvent(None)
            del joiners[coro]

    def advance_thread(coro, value, is_exc=False):
        """After an event is fired, run a given coroutine associated with
        it in the threads dict until it yields again. If the coroutine
        exits, then the thread is removed from the pool. If the coroutine
        raises an exception, it is reraised in a ThreadException. If
        is_exc is True, then the value must be an exc_info tuple and the
        exception is thrown into the coroutine.
        """
        try:
            if is_exc:
                next_event = coro.throw(*value)
            else:
                next_event = coro.send(value)
        except StopIteration:
            # Thread is done.
            complete_thread(coro, None)
        except:
            # Thread raised some other exception.
            del threads[coro]
            raise ThreadException(coro, sys.exc_info())
        else:
            if isinstance(next_event, types.GeneratorType):
                # Automatically invoke sub-coroutines. (Shorthand for
                # explicit bluelet.call().)
                next_event = DelegationEvent(next_event)
            threads[coro] = next_event

    def kill_thread(coro):
        """Unschedule this thread and its (recursive) delegates.
        """
        # Collect all coroutines in the delegation stack.
        coros = [coro]
        while isinstance(threads[coro], Delegated):
            coro = threads[coro].child
            coros.append(coro)

        # Complete each coroutine from the top to the bottom of the
        # stack.
        for coro in reversed(coros):
            complete_thread(coro, None)

    # Continue advancing threads until root thread exits.
    exit_te = None
    while threads:
        try:
            # Look for events that can be run immediately. Continue
            # running immediate events until nothing is ready.
            while True:
                have_ready = False
                for coro, event in list(threads.items()):
                    if isinstance(event, SpawnEvent):
                        threads[event.spawned] = ValueEvent(None)  # Spawn.
                        advance_thread(coro, None)
                        have_ready = True
                    elif isinstance(event, ValueEvent):
                        advance_thread(coro, event.value)
                        have_ready = True
                    elif isinstance(event, ExceptionEvent):
                        advance_thread(coro, event.exc_info, True)
                        have_ready = True
                    elif isinstance(event, DelegationEvent):
                        threads[coro] = Delegated(event.spawned)  # Suspend.
                        threads[event.spawned] = ValueEvent(None)  # Spawn.
                        delegators[event.spawned] = coro
                        have_ready = True
                    elif isinstance(event, ReturnEvent):
                        # Thread is done.
                        complete_thread(coro, event.value)
                        have_ready = True
                    elif isinstance(event, JoinEvent):
                        threads[coro] = SUSPENDED  # Suspend.
                        joiners[event.child].append(coro)
                        have_ready = True
                    elif isinstance(event, KillEvent):
                        threads[coro] = ValueEvent(None)
                        kill_thread(event.child)
                        have_ready = True

                # Only start the select when nothing else is ready.
                if not have_ready:
                    break

            # Wait and fire.
            event2coro = dict((v, k) for k, v in threads.items())
            for event in _event_select(threads.values()):
                # Run the IO operation, but catch socket errors.
                try:
                    value = event.fire()
                except socket.error as exc:
                    if isinstance(exc.args, tuple) and \
                            exc.args[0] == errno.EPIPE:
                        # Broken pipe. Remote host disconnected.
                        pass
                    else:
                        traceback.print_exc()
                    # Abort the coroutine.
                    threads[event2coro[event]] = ReturnEvent(None)
                else:
                    advance_thread(event2coro[event], value)

        except ThreadException as te:
            # Exception raised from inside a thread.
            event = ExceptionEvent(te.exc_info)
            if te.coro in delegators:
                # The thread is a delegate. Raise exception in its
                # delegator.
                threads[delegators[te.coro]] = event
                del delegators[te.coro]
            else:
                # The thread is root-level. Raise in client code.
                exit_te = te
                break

        except:
            # For instance, KeyboardInterrupt during select(). Raise
            # into root thread and terminate others.
            threads = {root_coro: ExceptionEvent(sys.exc_info())}

    # If any threads still remain, kill them.
    for coro in threads:
        coro.close()

    # If we're exiting with an exception, raise it in the client.
    if exit_te:
        exit_te.reraise()

Example 46

Project: daywatch
Source File: views.py
View license
@login_required
@catch_error
@permission_required
@log_activity
def history_comparison_div(request):
    context = get_status(request)
    form = HistoryComparisonForm(user=request.user, data=request.GET)

    if form.is_valid():
        request.session['form_session'] = form.cleaned_data
        style = get_style()

        period = form.cleaned_data['period']
        concept = form.cleaned_data['concept']

        today = date.today()

        if period == 'last_3_m':
            months = 3
        elif period == 'last_4_m':
            months = 4
        elif period == 'last_6_m':
            months = 6

        player_ids = form.cleaned_data['players']
        player_ids = [int(p_id) for p_id in player_ids]

        country = form.cleaned_data['country']
        context['local_currency'] = CURRENCY_DICT[country]
        context['use_local_currency'] = country in LOCAL_CURRENCY_COUNTRIES

        if form.cleaned_data['all_categories']:
            categories = Category.objects.all()
            category_ids = []
            for category in categories:
                if category.name != 'root':
                    category_ids.append(int(category.id))
        else:
            category_ids = form.cleaned_data['categories']
            category_ids = [int(c_id) for c_id in category_ids]

        histo = defaultdict()
        player_list = []
        total_sales_player = defaultdict()
        company_names = defaultdict()

        total_sales_period = {}

        for p_id in player_ids:
            company_name = DayWatchSite.objects.get(id=int(p_id)).name
            company_names[int(p_id)] = company_name
            player_list.append((int(p_id), company_name))
            total_sales_player[company_name] = 0
            histo[company_name] = defaultdict()

        category_choices = []
        category_legends = []
        for c_id, name in CATEGORY_CHOICES:
            if c_id in category_ids:
                category_choices.append((c_id, name))
                category_legends.append(name)

        company_choices = []
        for c_id, name in COMPANY_CHOICES:
            if c_id in player_ids:
                company_choices.append((c_id, name))

        if concept == 'sales':
            if context['use_local_currency']:
                context['total_title'] = _('Total Sales')
                context['trend_y_legend'] = context['local_currency'] + ' %.0f'
            else:
                context['total_title'] = _('Total Sales U$S')
                context['trend_y_legend'] = 'U$S %.0f'
        elif concept == 'deals':
            context['total_title'] = _('Total # Deals Offered')
            context['trend_y_legend'] = '%d'
        elif concept == 'coupons_sold':
            context['trend_title'] = _('# Coupons Sold')
            context['total_title'] = _('Total # Coupons Sold')
            context['trend_y_legend'] = '%d'

        for company_id, company_name in player_list:
            total_sales_period[company_name] = []
            category_values = {}

            for category_id, category_name in category_choices:
                category_values[category_name] = []

            for i in range(months - 1, -1, -1):
                start_month = today.month - i
                if start_month <= 0:
                    start_month += 12
                    start_year = today.year - 1
                else:
                    start_year = today.year
                start_date_range = datetime(start_year, start_month, 1)
                (_, last) = calendar.monthrange(start_year, start_month)
                end_date_range = datetime(start_year, start_month, last)

                items = DayWatchItem.objects.filter(country=country)
                if concept == 'sales':
                    context['trend_title'] = _('Sales U$S')
                    context['trend_y_legend'] = _('U$S %.0f')

                    items = items.filter(
                        company__id=company_id,
                        category__id__in=category_ids,
                        start_date_time__gte=start_date_range,
                        start_date_time__lte=end_date_range,
                        total_sales_usd__gt=0
                    )
                    items = items.values('category__name').annotate(
                        subtotal=Sum('total_sales_usd')
                    )

                elif concept == 'deals':
                    context['trend_title'] = _('# Deals Offered')
                    context['trend_y_legend'] = '%d'

                    items = items.filter(
                        company__id=company_id,
                        category__id__in=category_ids,
                        start_date_time__gte=start_date_range,
                        start_date_time__lte=end_date_range
                    )
                    items = items.values('category__name').annotate(
                        subtotal=Count('offer_id')
                    )

                elif concept == 'coupons_sold':
                    context['trend_title'] = _('# Coupons Sold')
                    context['trend_y_legend'] = '%d'

                    items = items.filter(
                        company__id=company_id,
                        category__id__in=category_ids,
                        start_date_time__gte=start_date_range,
                        start_date_time__lte=end_date_range,
                        sold_count__gt=0
                    )
                    items = items.values('category__name').annotate(
                        subtotal=Sum('sold_count')
                    )

                total_category = {}
                for _, category_name in category_choices:
                    total_category[category_name] = 0

                for item in items:
                    total_category[deal['category__name']] += deal['subtotal']
                for items in items:
                    total_category[deal['category__name']] += deal['subtotal']

                for _, category_name in category_choices:
                    category_values[category_name].append(
                        total_category[category_name]
                    )

            for _, category_name in category_choices:
                total_sales_period[company_name].append(
                    category_values[category_name]
                )

        interval_values = []
        graphs = []
        for company_name in total_sales_period:
            max_val = 0
            for array in total_sales_period[company_name]:
                for value in array:
                    if value > max_val:
                        max_val = value
            graphs.append({
                'company_name': company_name,
                'interval_value': get_interval(max_val),
                'arrays': total_sales_period[company_name]
            })

        month_legends = []
        for i in range(months - 1, -1, -1):
            start_month = today.month - i
            if start_month <= 0:
                start_month += 12
            month_legends.append(calendar.month_name[start_month])

        legends = [style[category]['label'] for category in category_legends]

        context['graphs'] = graphs
        context['concept'] = concept
        context['category_legends'] = legends
        context['month_legends'] = month_legends
        context['interval_values'] = interval_values

        html = render_to_string(
            'history_comparison_div.html',
            context,
            context_instance=RequestContext(request)
            )
        return JsonResponse(result(Status.OK, data={'html': html}))
    else:
        return JsonResponse(result(Status.ERROR, 'Invalid form.'))

Example 47

Project: django-multi-gtfs
Source File: base.py
View license
    @classmethod
    def import_txt(cls, txt_file, feed, filter_func=None):
        '''Import from the GTFS text file'''

        # Setup the conversion from GTFS to Django Format
        # Conversion functions
        def no_convert(value): return value

        def date_convert(value): return datetime.strptime(value, '%Y%m%d')

        def bool_convert(value): return (value == '1')

        def char_convert(value): return (value or '')

        def null_convert(value): return (value or None)

        def point_convert(value): return (value or 0.0)

        cache = {}

        def default_convert(field):
            def get_value_or_default(value):
                if value == '' or value is None:
                    return field.get_default()
                else:
                    return value
            return get_value_or_default

        def instance_convert(field, feed, rel_name):
            def get_instance(value):
                if value.strip():
                    key1 = "{}:{}".format(field.rel.to.__name__, rel_name)
                    key2 = text_type(value)

                    # Load existing objects
                    if key1 not in cache:
                        pairs = field.rel.to.objects.filter(
                            **{field.rel.to._rel_to_feed: feed}).values_list(
                            rel_name, 'id')
                        cache[key1] = dict((text_type(x), i) for x, i in pairs)

                    # Create new?
                    if key2 not in cache[key1]:
                        kwargs = {
                            field.rel.to._rel_to_feed: feed,
                            rel_name: value}
                        cache[key1][key2] = field.rel.to.objects.create(
                            **kwargs).id
                    return cache[key1][key2]
                else:
                    return None
            return get_instance

        # Check unique fields
        column_names = [c for c, _ in cls._column_map]
        for unique_field in cls._unique_fields:
            assert unique_field in column_names, \
                '{} not in {}'.format(unique_field, column_names)

        # Map of field_name to converters from GTFS to Django format
        val_map = dict()
        name_map = dict()
        point_map = dict()
        for csv_name, field_pattern in cls._column_map:
            # Separate the local field name from foreign columns
            if '__' in field_pattern:
                field_base, rel_name = field_pattern.split('__', 1)
                field_name = field_base + '_id'
            else:
                field_name = field_base = field_pattern
            # Use the field name in the name mapping
            name_map[csv_name] = field_name

            # Is it a point field?
            point_match = re_point.match(field_name)
            if point_match:
                field = None
            else:
                field = cls._meta.get_field(field_base)

            # Pick a conversion function for the field
            if point_match:
                converter = point_convert
            elif isinstance(field, models.DateField):
                converter = date_convert
            elif isinstance(field, models.BooleanField):
                converter = bool_convert
            elif isinstance(field, models.CharField):
                converter = char_convert
            elif field.rel:
                converter = instance_convert(field, feed, rel_name)
                assert not isinstance(field, models.ManyToManyField)
            elif field.null:
                converter = null_convert
            elif field.has_default():
                converter = default_convert(field)
            else:
                converter = no_convert

            if point_match:
                index = int(point_match.group('index'))
                point_map[csv_name] = (index, converter)
            else:
                val_map[csv_name] = converter

        # Read and convert the source txt
        csv_reader = reader(txt_file)
        unique_line = dict()
        count = 0
        first = True
        extra_counts = defaultdict(int)
        new_objects = []
        for row in csv_reader:
            if first:
                # Read the columns
                columns = row
                if columns[0].startswith(CSV_BOM):
                    columns[0] = columns[0][len(CSV_BOM):]
                first = False
                continue

            if filter_func and not filter_func(zip(columns, row)):
                continue

            # Read a data row
            fields = dict()
            point_coords = [None, None]
            ukey_values = {}
            if cls._rel_to_feed == 'feed':
                fields['feed'] = feed
            for column_name, value in zip(columns, row):
                if column_name not in name_map:
                    val = null_convert(value)
                    if val is not None:
                        fields.setdefault('extra_data', {})[column_name] = val
                        extra_counts[column_name] += 1
                elif column_name in val_map:
                    fields[name_map[column_name]] = val_map[column_name](value)
                else:
                    assert column_name in point_map
                    pos, converter = point_map[column_name]
                    point_coords[pos] = converter(value)

                # Is it part of the unique key?
                if column_name in cls._unique_fields:
                    ukey_values[column_name] = value

            # Join the lat/long into a point
            if point_map:
                assert point_coords[0] and point_coords[1]
                fields['point'] = "POINT(%s)" % (' '.join(point_coords))

            # Is the item unique?
            ukey = tuple(ukey_values.get(u) for u in cls._unique_fields)
            if ukey in unique_line:
                logger.warning(
                    '%s line %d is a duplicate of line %d, not imported.',
                    cls._filename, csv_reader.line_num, unique_line[ukey])
                continue
            else:
                unique_line[ukey] = csv_reader.line_num

            # Create after accumulating a batch
            new_objects.append(cls(**fields))
            if len(new_objects) % batch_size == 0:  # pragma: no cover
                cls.objects.bulk_create(new_objects)
                count += len(new_objects)
                logger.info(
                    "Imported %d %s",
                    count, cls._meta.verbose_name_plural)
                new_objects = []

        # Create remaining objects
        if new_objects:
            cls.objects.bulk_create(new_objects)

        # Take note of extra fields
        if extra_counts:
            extra_columns = feed.meta.setdefault(
                'extra_columns', {}).setdefault(cls.__name__, [])
            for column in columns:
                if column in extra_counts and column not in extra_columns:
                    extra_columns.append(column)
            feed.save()
        return len(unique_line)

Example 48

Project: plenario
Source File: worker.py
View license
    def worker(worker_id):

        birthtime = time.time()

        log("Hello! I'm ready for anything.", worker_id)
        register_worker(birthtime, worker_id)

        with app.app_context():

            while worker_boss.do_work:
                # Report to the worker boss.
                worker_boss.workers[worker_id] = "alive"
                # Report to the status page.
                check_in(birthtime, worker_id)
                # Poll Amazon SQS for messages containing a job.
                response = job_queue.receive_messages(MessageAttributeNames=["ticket"])
                # Grab the first message as the job to be considered.
                job = response[0] if len(response) > 0 else None

                # Run checks on the validity of the message. If any of the
                # checks fail, do a tiny amount of work then skip the loop.
                if not job:
                    time.sleep(wait_interval)
                    continue
                elif not job.body == "plenario_job":
                    log("Message is not a Plenario Job. Skipping.", worker_id)
                    continue
                elif not has_valid_ticket(job):
                    log("ERROR: Job does not contain a valid ticket! Removing.", worker_id)
                    job.delete()
                    continue

                # All checks passed, this is a valid ticket.
                ticket = str(job.message_attributes["ticket"]["StringValue"])
                status = get_status(ticket)
                log("Received job with ticket {}.".format(ticket), worker_id)

                # Keep track of the last 1000 tickets we've seen, these counts
                # help to determine how many times to retry jobs if they fail
                if len(worker_boss.tickets) > 1000:
                    worker_boss.tickets = defaultdict(int)
                worker_boss.tickets[ticket] += 1

                # Now we have to determine if the job itself is a valid job to
                # perform, taking into consideration whether it has been
                # orphaned or has undergone too many retries.
                is_processing = status["status"] == "processing"
                is_orphaned = is_job_status_orphaned(status, job_timeout)

                # Check if the job was an orphan (meaning that the parent worker
                # process died and failed to complete it). If it was orphaned,
                # give it another try.
                if is_processing and is_orphaned:
                    log("Retrying orphan ticket {}".format(ticket), worker_id)
                    status["meta"]["lastDeferredTime"] = str(datetime.now())
                    set_status(ticket, status)
                # If the job isn't an orphan, check its status to make sure no
                # other worker has started work on it.
                elif status["status"] != "queued":
                    log("Job has already been started. Skipping.", worker_id)
                    continue

                register_worker_job_status(ticket, birthtime, worker_id)
                worker_boss.active_worker_count += 1
                # Once we have established that both the ticket and the job
                # are valid and able to be worked upon, give the EC2 instance
                # that contains this worker scale-in protection.
                update_instance_protection(worker_boss, autoscaling_client)
                # There is a chance that the do_work switch is False due to
                # an immenent instance termination.
                if not worker_boss.do_work:
                    deregister_worker_job_status(birthtime, worker_id)
                    worker_boss.active_worker_count -= 1
                    continue

                status["status"] = "processing"

                if "lastDeferredTime" in status["meta"]:
                    status["meta"]["lastResumeTime"] = str(datetime.now())
                else:
                    status["meta"]["startTime"] = str(datetime.now())

                if "workers" not in status["meta"]:
                    status["meta"]["workers"] = []
                status["meta"]["workers"].append(worker_id)

                set_status(ticket, status)

                try:
                    log("Starting work on ticket {}.".format(ticket), worker_id)

                    req = get_request(ticket)
                    endpoint = req['endpoint']
                    query_args = req['query']

                    worker_boss.workers[worker_id] = "working on {}".format(endpoint)

                    if endpoint in endpoint_logic:

                        # Add worker metadata
                        query_args["jobsframework_ticket"] = ticket
                        query_args["jobsframework_workerid"] = worker_id
                        query_args["jobsframework_workerbirthtime"] = birthtime

                        # Because we're getting serialized arguments from Redis,
                        # we need to convert them back into a validated form.
                        convert(query_args)
                        query_args = ValidatorProxy(query_args)

                        log("Ticket {}: endpoint {}.".format(ticket, endpoint), worker_id)

                        result = endpoint_logic[endpoint](query_args)

                        # Metacommands enable workers to modify a job's priority through a variety
                        # of methods (deferral, set_timeout, resubmission). Metacommands are recieved
                        # from work done in the endpoint logics.
                        metacommand = process_metacommands(result, job, ticket, worker_id, req, job_queue)
                        if metacommand == "STOP":
                            job.delete()
                            continue
                        if metacommand == "DEFER":
                            continue

                    elif endpoint in shape_logic:
                        convert(query_args)
                        query_args = ValidatorProxy(query_args)
                        result = shape_logic[endpoint](query_args)
                        if endpoint == 'aggregate-point-data' and query_args.data.get('data_type') != 'csv':
                            result = convert_result_geoms(result)

                    elif endpoint in etl_logic:
                        if endpoint in ('update_weather', 'update_metar'):
                            result = etl_logic[endpoint]()
                        else:
                            result = etl_logic[endpoint](query_args)

                    else:
                        raise ValueError("Attempting to send a job to an "
                                         "invalid endpoint ->> {}"
                                         .format(endpoint))

                    # By this point, we have successfully completed a task.
                    # Update the status meta information to indicate so.
                    set_ticket_success(ticket, result)
                    # Cleanup the leftover SQS message.
                    job.delete()

                    log("Finished work on ticket {}.".format(ticket), worker_id)

                except Exception as e:
                    traceback.print_exc()

                    times_ticket_was_seen = worker_boss.tickets[ticket]
                    if times_ticket_was_seen < 3:
                        set_ticket_queued(status, ticket, str(e), worker_id)
                    else:
                        error_msg = "{} errored with {}.".format(ticket, e)
                        set_ticket_error(status, ticket, error_msg, worker_id)
                        job.delete()

                finally:
                    worker_boss.workers[worker_id] = "alive"
                    worker_boss.active_worker_count -= 1
                    update_instance_protection(worker_boss, autoscaling_client)
                    deregister_worker_job_status(birthtime, worker_id)
                    time.sleep(wait_interval)

        log("Exited run loop. Goodbye!", worker_id)
        deregister_worker(worker_id)

Example 49

Project: umis
Source File: umis.py
View license
@click.command()
@click.argument('sam')
@click.argument('out')
@click.option('--genemap', required=False, default=None)
@click.option('--output_evidence_table', default=None)
@click.option('--positional', default=False, is_flag=True)
@click.option('--minevidence', required=False, default=1.0, type=float)
@click.option('--cb_histogram', default=None)
@click.option('--cb_cutoff', default=None,
              help=("Number of counts to filter cellular barcodes. Set to "
                    "'auto' to calculate a cutoff automatically."))
@click.option('--no_scale_evidence', default=False, is_flag=True)
@click.option('--subsample', required=False, default=None, type=int)
# @profile
def tagcount(sam, out, genemap, output_evidence_table, positional, minevidence,
             cb_histogram, cb_cutoff, no_scale_evidence, subsample):
    ''' Count up evidence for tagged molecules
    '''
    from pysam import AlignmentFile

    from io import StringIO
    import pandas as pd

    from utils import weigh_evidence

    logger.info('Reading optional files')

    gene_map = None
    if genemap:
        with open(genemap) as fh:
            try:
                gene_map = dict(p.strip().split() for p in fh)
            except ValueError:
                logger.error('Incorrectly formatted gene_map, need to be tsv.')
                sys.exit()

    if positional:
        tuple_template = '{0},{1},{2},{3}'
    else:
        tuple_template = '{0},{1},{3}'

    if not cb_cutoff:
        cb_cutoff = 0

    if cb_histogram and cb_cutoff == "auto":
        cb_cutoff = guess_depth_cutoff(cb_histogram)

    cb_cutoff = int(cb_cutoff)

    cb_hist = None
    filter_cb = False
    if cb_histogram:
        cb_hist = pd.read_table(cb_histogram, index_col=0, header=-1, squeeze=True)
        total_num_cbs = cb_hist.shape[0]
        cb_hist = cb_hist[cb_hist > cb_cutoff]
        logger.info('Keeping {} out of {} cellular barcodes.'.format(cb_hist.shape[0], total_num_cbs))
        filter_cb = True

    parser_re = re.compile('.*:CELL_(?P<CB>.*):UMI_(?P<MB>.*)')

    if subsample:
        logger.info('Creating reservoir of subsampled reads ({} per cell)'.format(subsample))
        start_sampling  = time.time()

        reservoir = collections.defaultdict(list)
        cb_hist_sampled = 0 * cb_hist
        cb_obs = 0 * cb_hist

        sam_mode = 'r' if sam.endswith(".sam") else 'rb'
        sam_file = AlignmentFile(sam, mode=sam_mode)
        track = sam_file.fetch(until_eof=True)
        current_read = 'none_observed_yet'
        for i, aln in enumerate(track):
            if aln.qname == current_read:
                continue

            current_read = aln.qname
            match = parser_re.match(aln.qname)
            CB = match.group('CB')

            if CB not in cb_hist.index:
                continue

            cb_obs[CB] += 1
            if len(reservoir[CB]) < subsample:
                reservoir[CB].append(i)
                cb_hist_sampled[CB] += 1
            else:
                s = pd.np.random.randint(0, cb_obs[CB])
                if s < subsample:
                    reservoir[CB][s] = i

        index_filter = set(itertools.chain.from_iterable(reservoir.values()))
        sam_file.close()
        sampling_time = time.time() - start_sampling
        logger.info('Sampling done - {:.3}s'.format(sampling_time))

    evidence = collections.defaultdict(int)

    logger.info('Tallying evidence')
    start_tally = time.time()

    sam_mode = 'r' if sam.endswith(".sam") else 'rb'
    sam_file = AlignmentFile(sam, mode=sam_mode)
    track = sam_file.fetch(until_eof=True)
    count = 0
    unmapped = 0
    kept = 0
    nomatchcb = 0
    current_read = 'none_observed_yet'
    count_this_read = True
    for i, aln in enumerate(track):
        count += 1
        if not count % 100000:
            logger.info("Processed %d alignments, kept %d." % (count, kept))
            logger.info("%d were filtered for being unmapped." % unmapped)
            if filter_cb:
                logger.info("%d were filtered for not matching known barcodes."
                            % nomatchcb)

        if aln.is_unmapped:
            unmapped += 1
            continue

        if aln.qname != current_read:
            current_read = aln.qname
            if subsample and i not in index_filter:
                count_this_read = False
                continue
            else:
                count_this_read = True
        else:
            if not count_this_read:
                continue

        match = parser_re.match(aln.qname)
        CB = match.group('CB')
        if filter_cb:
            if CB not in cb_hist.index:
                nomatchcb += 1
                continue

        MB = match.group('MB')

        txid = sam_file.getrname(aln.reference_id)
        if gene_map:
            target_name = gene_map[txid]

        else:
            target_name = txid

        e_tuple = tuple_template.format(CB, target_name, aln.pos, MB)

        # Scale evidence by number of hits
        if no_scale_evidence:
            evidence[e_tuple] += 1.0
        else:
            evidence[e_tuple] += weigh_evidence(aln.tags)
        kept += 1

    tally_time = time.time() - start_tally
    logger.info('Tally done - {:.3}s, {:,} alns/min'.format(tally_time, int(60. * count / tally_time)))
    logger.info('Collapsing evidence')

    buf = StringIO()
    for key in evidence:
        line = '{},{}\n'.format(key, evidence[key])
        buf.write(unicode(line), "utf-8")

    buf.seek(0)
    evidence_table = pd.read_csv(buf)
    evidence_query = 'evidence >= %f' % minevidence
    if positional:
        evidence_table.columns=['cell', 'gene', 'umi', 'pos', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi', 'pos'].size()

    else:
        evidence_table.columns=['cell', 'gene', 'umi', 'evidence']
        collapsed = evidence_table.query(evidence_query).groupby(['cell', 'gene'])['umi'].size()

    expanded = collapsed.unstack().T

    if gene_map:
        # This Series is just for sorting the index
        genes = pd.Series(index=set(gene_map.values()))
        genes = genes.sort_index()
        # Now genes is assigned to a DataFrame
        genes = expanded.ix[genes.index]

    else:
        genes = expanded

    genes.replace(pd.np.nan, 0, inplace=True)

    logger.info('Output results')

    if subsample:
        cb_hist_sampled.to_csv('ss_{}_'.format(subsample) + os.path.basename(cb_histogram), sep='\t')

    if output_evidence_table:
        import shutil
        buf.seek(0)
        with open(output_evidence_table, 'w') as etab_fh:
            shutil.copyfileobj(buf, etab_fh)

    genes.to_csv(out)

Example 50

Project: spline-pokedex
Source File: pokedex_conquest.py
View license
    def pokemon(self, name=None):
        try:
            c.pokemon = db.pokemon_query(name, None).one()
        except NoResultFound:
            return self._not_found()

        c.semiform_pokemon = c.pokemon
        c.pokemon = c.pokemon.species

        # This Pokémon might exist, but not appear in Conquest
        if c.pokemon.conquest_order is None:
            return self._not_found()

        ### Previous and next for the header
        c.prev_pokemon, c.next_pokemon = self._prev_next_id(
            c.pokemon, t.PokemonSpecies, 'conquest_order')

        ### Type efficacy
        c.type_efficacies = defaultdict(lambda: 100)
        for target_type in c.semiform_pokemon.types:
            for type_efficacy in target_type.target_efficacies:
                c.type_efficacies[type_efficacy.damage_type] *= \
                    type_efficacy.damage_factor

                # The defaultdict starts at 100, and every damage factor is
                # a percentage.  Dividing by 100 with every iteration turns the
                # damage factor into a decimal percentage taken of the starting
                # 100, without using floats and regardless of number of types
                c.type_efficacies[type_efficacy.damage_type] //= 100


        ### Evolution
        # Shamelessly lifted from the main controller and tweaked.
        #
        # Format is a matrix as follows:
        # [
        #   [ None, Eevee, Vaporeon, None ]
        #   [ None, None, Jolteon, None ]
        #   [ None, None, Flareon, None ]
        #   ... etc ...
        # ]
        # That is, each row is a physical row in the resulting table, and each
        # contains four elements, one per row: Baby, Base, Stage 1, Stage 2.
        # The Pokémon are actually dictionaries with 'pokemon' and 'span' keys,
        # where the span is used as the HTML cell's rowspan -- e.g., Eevee has a
        # total of seven descendents, so it would need to span 7 rows.
        c.evolution_table = []
        # Prefetch the evolution details
        family = (db.pokedex_session.query(t.PokemonSpecies)
            .filter(t.PokemonSpecies.evolution_chain_id ==
                    c.pokemon.evolution_chain_id)
            .options(
                sqla.orm.subqueryload('conquest_evolution'),
                sqla.orm.joinedload('conquest_evolution.stat'),
                sqla.orm.joinedload('conquest_evolution.kingdom'),
                sqla.orm.joinedload('conquest_evolution.gender'),
                sqla.orm.joinedload('conquest_evolution.item'),
            )
            .all())
        # Strategy: build this table going backwards.
        # Find a leaf, build the path going back up to its root.  Remember all
        # of the nodes seen along the way.  Find another leaf not seen so far.
        # Build its path backwards, sticking it to a seen node if one exists.
        # Repeat until there are no unseen nodes.
        seen_nodes = {}
        while True:
            # First, find some unseen nodes
            unseen_leaves = []
            for species in family:
                if species in seen_nodes:
                    continue

                children = []
                # A Pokémon is a leaf if it has no evolutionary children, so...
                for possible_child in family:
                    if possible_child in seen_nodes:
                        continue
                    if possible_child.parent_species == species:
                        children.append(possible_child)
                if len(children) == 0:
                    unseen_leaves.append(species)

            # If there are none, we're done!  Bail.
            # Note that it is impossible to have any unseen non-leaves if there
            # are no unseen leaves; every leaf's ancestors become seen when we
            # build a path to it.
            if len(unseen_leaves) == 0:
                break

            unseen_leaves.sort(key=lambda x: x.id)
            leaf = unseen_leaves[0]

            # root, parent_n, ... parent2, parent1, leaf
            current_path = []

            # Finally, go back up the tree to the root
            current_species = leaf
            while current_species:
                # The loop bails just after current_species is no longer the
                # root, so this will give us the root after the loop ends;
                # we need to know if it's a baby to see whether to indent the
                # entire table below
                root_pokemon = current_species

                if current_species in seen_nodes:
                    current_node = seen_nodes[current_species]
                    # Don't need to repeat this node; the first instance will
                    # have a rowspan
                    current_path.insert(0, None)
                else:
                    current_node = {
                        'species': current_species,
                        'span':    0,
                    }
                    current_path.insert(0, current_node)
                    seen_nodes[current_species] = current_node

                # This node has one more row to span: our current leaf
                current_node['span'] += 1

                current_species = current_species.parent_species

            # We want every path to have four nodes: baby, basic, stage 1 and 2.
            # Every root node is basic, unless it's defined as being a baby.
            # So first, add an empty baby node at the beginning if this is not
            # a baby.
            # We use an empty string to indicate an empty cell, as opposed to a
            # complete lack of cell due to a tall cell from an earlier row.
            if not root_pokemon.is_baby:
                current_path.insert(0, '')
            # Now pad to four if necessary.
            while len(current_path) < 4:
                current_path.append('')

            c.evolution_table.append(current_path)


        ### Stats
        # Conquest has a nonstandard stat, Range, which shouldn't be included
        # in the total, so we have to do things a bit differently.
        # XXX actually do things differently instead of just fudging the same
        #     thing to work
        c.stats = {}  # stat => { border, background, percentile }
        stat_total = 0
        total_stat_rows = db.pokedex_session.query(t.ConquestPokemonStat) \
            .filter_by(stat=c.pokemon.conquest_stats[0].stat) \
            .count()
        for pokemon_stat in c.pokemon.conquest_stats:
            stat_info = c.stats[pokemon_stat.stat.identifier] = {}

            stat_info['value'] = pokemon_stat.base_stat

            if pokemon_stat.stat.is_base:
                stat_total += pokemon_stat.base_stat

            q = db.pokedex_session.query(t.ConquestPokemonStat) \
                               .filter_by(stat=pokemon_stat.stat)
            less = q.filter(t.ConquestPokemonStat.base_stat <
                            pokemon_stat.base_stat).count()
            equal = q.filter(t.ConquestPokemonStat.base_stat ==
                             pokemon_stat.base_stat).count()
            percentile = (less + equal * 0.5) / total_stat_rows
            stat_info['percentile'] = percentile

            # Colors for the stat bars, based on percentile
            stat_info['background'] = bar_color(percentile, 0.9)
            stat_info['border'] = bar_color(percentile, 0.8)

        # Percentile for the total
        # Need to make a derived table that fakes pokemon_id, total_stats
        stat_sum_tbl = db.pokedex_session.query(
                sqla.sql.func.sum(t.ConquestPokemonStat.base_stat)
                .label('stat_total')
            ) \
            .filter(t.ConquestPokemonStat.conquest_stat_id <= 4) \
            .group_by(t.ConquestPokemonStat.pokemon_species_id) \
            .subquery()

        q = db.pokedex_session.query(stat_sum_tbl)
        less = q.filter(stat_sum_tbl.c.stat_total < stat_total).count()
        equal = q.filter(stat_sum_tbl.c.stat_total == stat_total).count()
        percentile = (less + equal * 0.5) / total_stat_rows
        c.stats['total'] = {
            'percentile': percentile,
            'value': stat_total,
            'background': bar_color(percentile, 0.9),
            'border': bar_color(percentile, 0.8),
        }

        ### Max links
        # We only want to show warriors who have a max link above a certain
        # threshold, because there are 200 warriors and most of them won't
        # have very good links.
        default_link = 70
        c.link_form = LinkThresholdForm(request.params, link=default_link)
        if request.params and c.link_form.validate():
            link_threshold = c.link_form.link.data
        else:
            link_threshold = default_link

        # However, some warriors will only be above this threshold at later
        # ranks.  In these cases, we may as well show all ranks' links.
        # No link ever goes down when a warrior ranks up, so we just need to
        # check their final rank.

        # First, craft a clause to filter out non-final warrior ranks.
        ranks_sub = sqla.orm.aliased(t.ConquestWarriorRank)
        higher_ranks_exist = (sqla.sql.exists([1])
            .where(sqla.and_(
                ranks_sub.warrior_id == t.ConquestWarriorRank.warrior_id,
                ranks_sub.rank > t.ConquestWarriorRank.rank))
        )

        # Next, find final-rank warriors with a max link high enough.
        worthy_warriors = (db.pokedex_session.query(t.ConquestWarrior.id)
            .join(t.ConquestWarriorRank)
            .filter(~higher_ranks_exist)
            .join(t.ConquestMaxLink)
            .filter(t.ConquestMaxLink.pokemon_species_id == c.pokemon.id)
            .filter(t.ConquestMaxLink.max_link >= link_threshold))

        # For Froslass and Gallade, we want to filter out male and female
        # warriors, respectively.
        # XXX Eventually we want to figure out all impossible evolutions, and
        #     show them, but sort them to the bottom and grey them out.
        if (c.pokemon.conquest_evolution is not None and
          c.pokemon.conquest_evolution.warrior_gender_id is not None):
            worthy_warriors = worthy_warriors.filter(
                t.ConquestWarrior.gender_id ==
                    c.pokemon.conquest_evolution.warrior_gender_id)

        # Finally, find ALL the max links for these warriors!
        links_q = (c.pokemon.conquest_max_links
            .join(ranks_sub)
            .filter(ranks_sub.warrior_id.in_(worthy_warriors))
            .options(
                sqla.orm.joinedload('warrior_rank'),
                sqla.orm.subqueryload('warrior_rank.stats'),
                sqla.orm.joinedload('warrior_rank.warrior'),
                sqla.orm.joinedload('warrior_rank.warrior.archetype'),
                sqla.orm.subqueryload('warrior_rank.warrior.types'),
            ))

        c.max_links = links_q.all()


        return render('/pokedex/conquest/pokemon.mako')