Révision ae097639 modules/screening.py

b/modules/screening.py
112 112
        return np.array([np.average(array[:, i]) for i in range(num_vects)])
113 113

  
114 114

  
115
def get_atom_coords(atoms: ase.Atoms, ctrs_list=None):
116
    """Gets the coordinates of the specified indices from a ase.Atoms object.
115
def get_atom_coords(atoms: ase.Atoms, center=None):
116
    """Gets the coordinates of the specified center for an ase.Atoms object.
117 117

  
118
    Given an ase.Atoms object and a list of atom indices specified in ctrs_list
119
    it gets the coordinates of the specified atoms. If the element in the
120
    ctrs_list is not an index but yet a list of indices, it computes the
118
    If center is not an index but a list of indices, it computes the
121 119
    element-wise mean of the coordinates of the atoms specified in the inner
122 120
    list.
123 121
    @param atoms: ase.Atoms object for which to obtain the coordinates of.
124
    @param ctrs_list: list of (indices/list of indices) of the atoms for which
125
                      the coordinates should be extracted.
122
    @param center: index/list of indices of the atoms for which the coordinates
123
                   should be extracted.
126 124
    @return: np.ndarray of atomic coordinates.
127 125
    """
128
    coords = []
129
    err = "'ctrs_list' argument must be an integer, a list of integers or a " \
130
          "list of lists of integers. Every integer must be in the range " \
131
          "[0, num_atoms)"
132
    if ctrs_list is None:
133
        ctrs_list = range(len(atoms))
134
    elif isinstance(ctrs_list, int):
135
        if ctrs_list not in range(len(atoms)):
136
            logger.error(err)
137
            raise ValueError(err)
138
        return atoms[ctrs_list].position
139
    for elem in ctrs_list:
140
        if isinstance(elem, list):
141
            coords.append(vect_avg([atoms[c].position for c in elem]))
142
        elif isinstance(elem, int):
143
            coords.append(atoms[elem].position)
144
        else:
145
            logger.error(err)
146
            raise ValueError
147
    return np.array(coords)
126
    err_msg = "Argument 'ctr' must be an integer or a list of integers. "\
127
              "Every integer must be in the range [0, num_atoms)"
128
    if center is None:
129
        center = list(range(len(atoms)))
130
    if isinstance(center, int):
131
        if center not in list(range(len(atoms))):
132
            logger.error(err_msg)
133
            raise ValueError(err_msg)
134
        return atoms[center].position
135
    elif isinstance(center, list):
136
        for elm in center:
137
            if elm not in list(range(len(atoms))):
138
                logger.error(err_msg)
139
                raise ValueError(err_msg)
140
        return vect_avg([atoms[idx].position for idx in center])
141
    else:
142
        logger.error(err_msg)
143
        raise ValueError(err_msg)
148 144

  
149 145

  
150 146
def compute_norm_vect(atoms, idxs, cell):
......
959 955
        surf.set_cell(inp_vars['pbc_cell'])
960 956

  
961 957
    surf_ads_list = []
962
    sites_coords = get_atom_coords(surf, sites)
958
    sites_coords = [get_atom_coords(surf, site) for site in sites]
963 959
    if coll_coeff is not False:
964 960
        surf_cutoffs = natural_cutoffs(surf, mult=coll_coeff)
965 961
        surf_nghbs = len(neighbor_list("i", surf, surf_cutoffs))
966 962
    else:
967 963
        surf_nghbs = 0
968 964
    for i, conf in enumerate(conf_list):
969
        molec_ctr_coords = get_atom_coords(conf, molec_ctrs)
965
        molec_ctr_coords = [get_atom_coords(conf, ctr) for ctr in molec_ctrs]
970 966
        if inp_vars['pbc_cell'] is not False:
971 967
            conf.set_pbc(True)
972 968
            conf.set_cell(inp_vars['pbc_cell'])

Formats disponibles : Unified diff