"""Module that allows for imposing a kinetically connected network
structure of weighted ensemble simulation data.
"""
# Standard Library
import gc
from collections import defaultdict
from copy import deepcopy
# Third Party Library
import networkx as nx
import numpy as np
# First Party Library
from wepy.analysis.transitions import counts_d_to_matrix, transition_counts
try:
# Third Party Library
import pandas as pd
except ModuleNotFoundError:
print("Pandas is not installe, that functionality won't work")
[docs]
class MacroStateNetworkError(Exception):
"""Errors specific to MacroStateNetwork requirements."""
pass
[docs]
class BaseMacroStateNetwork:
"""A base class for the MacroStateNetwork which doesn't contain a
WepyHDF5 object. Useful for serialization of the object and can
then be reattached later to a WepyHDF5. For this functionality see
the 'MacroStateNetwork' class.
BaseMacroStateNetwork can also be though of as just a way of
mapping macrostate properties to the underlying microstate data.
The network itself is a networkx directed graph.
Upon construction the nodes will be a value called the 'node_id'
which is the label/assignment for the node. This either comes from
an explicit labelling (the 'assignments' argument) or from the
labels/assignments from the contig tree (from the 'assg_field_key'
argument).
Nodes have the following attributes after construction:
- node_id :: Same as the actual node value
- node_idx :: An extra index that is used for 'internal' ordering
of the nodes in a consistent manner. Used for
example in any method which constructs matrices from
edges and ensures they are all the same.
- assignments :: An index trace over the contig_tree dataset used
to construct the network. This is how the
individual microstates are indexed for each node.
- num_samples :: A total of the number of microstates that a node
has. Is the length of the 'assignments' attribute.
Additionally, there are auxiliary node attributes that may be
added by various methods. All of these are prefixed with a single
underscore '_' and any user set values should avoid this.
These auxiliary attributes also make use of namespacing, where
namespaces are similar to file paths and are separated by '/'
characters.
Additionally the auxiliary groups are typically managed such that
they remain consistent across all of the nodes and have metadata
queryable from the BaseMacroStateNetwork object. In contrast user
defined node attributes are not restricted to this structure.
The auxiliary groups are:
- '_groups' :: used to mark nodes as belonging to a higher level group.
- '_observables' :: used for scalar values that are calculated
from the underlying microstate structures. As
opposed to more operational values describing
the network itself. By virtue of being scalar
these are also compatible with output to
tabular formats.
Edge values are simply 2-tuples of node_ids where the first value
is the source and the second value is the target. Edges have the
following attributes following initialization:
- 'weighted_counts' :: The weighted sum of all the transitions
for an edge. This is a floating point
number.
- 'unweighted_counts' :: The unweighted sum of all the
transitions for an edge, this is a
normal count and is a whole integer.
- 'all_transition' :: This is an array of floats of the weight
for every individual transition for an
edge. This is useful for doing more
advanced statistics for a given edge.
A network object can be used as a stateful container for
calculated values over the nodes and edges and has methods to
support this. However, there is no standard way to serialize this
data beyond the generic python techniques like pickle.
"""
ASSIGNMENTS = "assignments"
"""Key for the microstates that are assigned to a macrostate."""
def __init__(
self, contig_tree, assg_field_key=None, assignments=None, transition_lag_time=2
):
"""Create a network of macrostates from the simulation microstates
using a field in the trajectory data or precomputed assignments.
Either 'assg_field_key' or 'assignments' must be given, but not
both.
The 'transition_lag_time' is default set to 2, which is the natural connection
between microstates. The lag time can be increased to vary the
kinetic accuracy of transition probabilities generated through
Markov State Modelling.
The 'transition_lag_time' must be given as an integer greater
than 1.
Arguments
---------
contig_tree : ContigTree object
assg_field_key : str, conditionally optional on 'assignments'
The field in the WepyHDF5 dataset you want to generate macrostates for.
assignments : list of list of array_like of dim (n_traj_frames, observable_shape[0], ...),
conditionally optional on 'assg_field_key'
List of assignments for all frames in each run, where each
element of the outer list is for a run, the elements of
these lists are lists for each trajectory which are
arraylikes of shape (n_traj, observable_shape[0], ...).
See Also
"""
self._graph = nx.DiGraph()
assert not (
assg_field_key is None and assignments is None
), "either assg_field_key or assignments must be given"
assert (
assg_field_key is not None or assignments is not None
), "one of assg_field_key or assignments must be given"
self._base_contig_tree = contig_tree.base_contigtree
self._assg_field_key = assg_field_key
# initialize the groups dictionary
self._node_groups = {}
# initialize the list of the observables
self._observables = []
# initialize the list of available layouts
self._layouts = []
# initialize the lookup of the node_idxs from node_ids
self._node_idxs = {}
# initialize the reverse node lookups which is memoized if
# needed
self._node_idx_to_id_dict = None
# validate lag time input
if (transition_lag_time is not None) and (transition_lag_time < 2):
raise MacroStateNetworkError(
"transition_lag_time must be an integer value >= 2"
)
self._transition_lag_time = transition_lag_time
## Temporary variables for initialization only
# the temporary assignments dictionary
self._node_assignments = None
# and temporary raw assignments
self._assignments = None
## Code for creating nodes and edges
## Nodes
with contig_tree:
# map the keys to their lists of assignments, depending on
# whether or not we are using a field from the HDF5 traj or
# assignments provided separately
if assg_field_key is not None:
assert type(assg_field_key) == str, "assignment key must be a string"
self._key_init(contig_tree)
else:
self._assignments_init(assignments)
# once we have made the dictionary add the nodes to the network
# and reassign the assignments to the nodes
for node_idx, assg_item in enumerate(self._node_assignments.items()):
assg_key, assigs = assg_item
# count the number of samples (assigs) and use this as a field as well
num_samples = len(assigs)
# save the nodes with attributes, we save the node_id
# as the assg_key, because of certain formats only
# typing the attributes, and we want to avoid data
# loss, through these formats (which should be avoided
# as durable stores of them though)
self._graph.add_node(
assg_key,
node_id=assg_key,
node_idx=node_idx,
assignments=assigs,
num_samples=num_samples,
)
self._node_idxs[assg_key] = node_idx
## Edges
(
all_transitions_d,
weighted_counts_d,
unweighted_counts_d,
) = self._init_transition_counts(
contig_tree,
transition_lag_time,
)
# after calculating the transition counts set these as edge
# values make the edges with these attributes
for edge, all_trans in all_transitions_d.items():
weighted_counts = weighted_counts_d[edge]
unweighted_counts = unweighted_counts_d[edge]
# add the edge with all of the values
self._graph.add_edge(
*edge,
weighted_counts=weighted_counts,
unweighted_counts=unweighted_counts,
all_transitions=all_trans,
)
## Cleanup
# then get rid of the assignments dictionary, this information
# can be accessed from the network
del self._node_assignments
del self._assignments
[docs]
def _key_init(self, contig_tree):
"""Initialize the assignments structures given the field key to use.
Parameters
----------
"""
wepy_h5 = contig_tree.wepy_h5
# blank assignments
assignments = [
[[] for traj_idx in range(wepy_h5.num_run_trajs(run_idx))]
for run_idx in wepy_h5.run_idxs
]
test_field = wepy_h5.get_traj_field(
wepy_h5.run_idxs[0],
wepy_h5.run_traj_idxs(0)[0],
self.assg_field_key,
)
# WARN: assg_field shapes can come wrapped with an extra
# dimension. We handle both cases. Test the first traj and see
# how it is
unwrap = False
if len(test_field.shape) == 2 and test_field.shape[1] == 1:
# then we raise flag to unwrap them
unwrap = True
elif len(test_field.shape) == 1:
# then it is unwrapped and we don't need to do anything,
# just assert the flag to not unwrap
unwrap = False
else:
raise ValueError(
f"Wrong shape for an assignment type observable: {test_field.shape}"
)
# the raw assignments
curr_run_idx = -1
for idx_tup, fields_d in wepy_h5.iter_trajs_fields(
[self.assg_field_key], idxs=True
):
run_idx = idx_tup[0]
traj_idx = idx_tup[1]
assg_field = fields_d[self.assg_field_key]
# if we need to we unwrap the assignements scalar values
# if they need it
if unwrap:
assg_field = np.ravel(assg_field)
assignments[run_idx][traj_idx].extend(assg_field)
# then just call the assignments constructor to do it the same
# way
self._assignments_init(assignments)
[docs]
def _assignments_init(self, assignments):
"""Given the assignments structure sets up the other necessary
structures.
Parameters
----------
assignments : list of list of array_like of dim (n_traj_frames, observable_shape[0], ...),
conditionally optional on 'assg_field_key'
List of assignments for all frames in each run, where each
element of the outer list is for a run, the elements of
these lists are lists for each trajectory which are
arraylikes of shape (n_traj, observable_shape[0], ...).
"""
# set the type for the assignment field
self._assg_field_type = type(assignments[0])
# set the raw assignments to the temporary attribute
self._assignments = assignments
# this is the dictionary mapping node_id -> the (run_idx, traj_idx, cycle_idx) frames
self._node_assignments = defaultdict(list)
for run_idx, run in enumerate(assignments):
for traj_idx, traj in enumerate(run):
for frame_idx, assignment in enumerate(traj):
self._node_assignments[assignment].append(
(run_idx, traj_idx, frame_idx)
)
[docs]
def _init_transition_counts(
self,
contig_tree,
transition_lag_time,
):
"""Given the lag time get the transitions between microstates for the
network using the sliding windows algorithm.
This will create a directed edge between nodes that had at
least one transition, no matter the weight.
See the main class docstring for a description of the fields.
contig_tree should be unopened.
"""
# now count the transitions between the states and set those
# as the edges between nodes
# first get the sliding window transitions from the contig
# tree, once we set edges for a tree we don't really want to
# have multiple sets of transitions on the same network so we
# don't provide the method to add different assignments
# get the weights for the walkers so we can compute
# the weighted transition counts
with contig_tree:
weights = [[] for run_idx in contig_tree.wepy_h5.run_idxs]
for idx_tup, traj_fields_d in contig_tree.wepy_h5.iter_trajs_fields(
["weights"], idxs=True
):
run_idx, traj_idx = idx_tup
weights[run_idx].append(np.ravel(traj_fields_d["weights"]))
# get the transitions as trace idxs
trace_transitions = []
for window in contig_tree.sliding_windows(transition_lag_time):
trace_transition = [window[0], window[-1]]
# convert the window trace on the contig to a trace
# over the runs
trace_transitions.append(trace_transition)
# ALERT: I'm not sure this is going to work out since this is
# potentially a lot of data and might make the object too
# large, lets just be aware and maybe we'll have to not do
# this if things are out of control.
## transition distributions
# get an array of all of the transition weights so we can do
# stats on them later.
all_transitions_d = defaultdict(list)
for trace_transition in trace_transitions:
# get the node ids of the edge using the assignments
start = trace_transition[0]
end = trace_transition[-1]
# get the assignments for the transition
start_assignment = self._assignments[start[0]][start[1]][start[2]]
end_assignment = self._assignments[end[0]][end[1]][end[2]]
edge_id = (start_assignment, end_assignment)
# get the weight of the walker that transitioned, this
# uses the trace idxs for the individual walkers
weight = weights[start[0]][start[1]][start[2]]
# append this transition weight to the list for it, but
# according to the node_ids, in edge_id
all_transitions_d[edge_id].append(weight)
# convert the lists in the transition dictionary to numpy arrays
all_transitions_d = {
edge: np.array(transitions_l)
for edge, transitions_l in all_transitions_d.items()
}
gc.collect()
## sum of weighted counts
# then get the weighted counts for those edges
weighted_counts_d = transition_counts(
self._assignments,
trace_transitions,
weights=weights,
)
## Sum of unweighted counts
# also get unweighted counts
unweighted_counts_d = transition_counts(
self._assignments,
trace_transitions,
weights=None,
)
return all_transitions_d, weighted_counts_d, unweighted_counts_d
# DEBUG: remove this, but account for the 'Weight' field when
# doing gexf stuff elsewhere
# # then we also want to get the transition probabilities so
# # we get the counts matrix and compute the probabilities
# # we first have to replace the keys of the counts of the
# # node_ids with the node_idxs
# node_id_to_idx_dict = self.node_id_to_idx_dict()
# self._countsmat = counts_d_to_matrix(
# {(node_id_to_idx_dict[edge[0]],
# node_id_to_idx_dict[edge[1]]) : counts
# for edge, counts in counts_d.items()})
# self._probmat = normalize_counts(self._countsmat)
# # then we add these attributes to the edges in the network
# node_idx_to_id_dict = self.node_id_to_idx_dict()
# for i_id, j_id in self._graph.edges:
# # i and j are the node idxs so we need to get the
# # actual node_ids of them
# i_idx = node_idx_to_id_dict[i_id]
# j_idx = node_idx_to_id_dict[j_id]
# # convert to a normal float and set it as an explicitly named attribute
# self._graph.edges[i_id, j_id]['transition_probability'] = \
# float(self._probmat[i_idx, j_idx])
# # we also set the general purpose default weight of
# # the edge to be this.
# self._graph.edges[i_id, j_id]['Weight'] = \
# float(self._probmat[i_idx, j_idx])
[docs]
def node_id_to_idx(self, assg_key):
"""Convert a node_id (which is the assignment value) to a canonical index.
Parameters
----------
assg_key : node_id
Returns
-------
node_idx : int
"""
return self.node_id_to_idx_dict()[assg_key]
[docs]
def node_idx_to_id(self, node_idx):
"""Convert a node index to its node id.
Parameters
----------
node_idx : int
Returns
-------
node_id : node_id
"""
return self.node_idx_to_id_dict()[node_idx]
[docs]
def node_id_to_idx_dict(self):
"""Generate a full mapping of node_ids to node_idxs."""
return self._node_idxs
[docs]
def node_idx_to_id_dict(self):
"""Generate a full mapping of node_idxs to node_ids."""
if self._node_idx_to_id_dict is None:
rev = {node_idx: node_id for node_id, node_idx in self._node_idxs.items()}
self._node_idx_to_id_dict = rev
else:
rev = self._node_idx_to_id_dict
# just reverse the dictionary and return
return rev
@property
def graph(self):
"""The networkx.DiGraph of the macrostate network."""
return self._graph
@property
def num_states(self):
"""The number of states in the network."""
return len(self.graph)
@property
def node_ids(self):
"""A list of the node_ids."""
return list(self.graph.nodes)
@property
def contig_tree(self):
"""The underlying ContigTree"""
return self._base_contig_tree
@property
def assg_field_key(self):
"""The string key of the field used to make macro states from the WepyHDF5 dataset.
Raises
------
MacroStateNetworkError
If this wasn't used to construct the MacroStateNetwork.
"""
if self._assg_field_key is None:
raise MacroStateNetworkError("Assignments were manually defined, no key.")
else:
return self._assg_field_key
### Node attributes & methods
[docs]
def get_node_attributes(self, node_id):
"""Returns the node attributes of the macrostate.
Parameters
----------
node_id : node_id
Returns
-------
macrostate_attrs : dict
"""
return self.graph.nodes[node_id]
[docs]
def get_node_attribute(self, node_id, attribute_key):
"""Return the value for a specific node and attribute.
Parameters
----------
node_id : node_id
attribute_key : str
Returns
-------
node_attribute
"""
return self.get_node_attributes(node_id)[attribute_key]
[docs]
def get_nodes_attribute(self, attribute_key):
"""Get a dictionary mapping nodes to a specific attribute."""
nodes_attr = {}
for node_id in self.graph.nodes:
nodes_attr[node_id] = self.graph.nodes[node_id][attribute_key]
return nodes_attr
[docs]
def node_assignments(self, node_id):
"""Return the microstates assigned to this macrostate as a run trace.
Parameters
----------
node_id : node_id
Returns
-------
node_assignments : list of tuples of ints (run_idx, traj_idx, cycle_idx)
Run trace of the nodes assigned to this macrostate.
"""
return self.get_node_attribute(node_id, self.ASSIGNMENTS)
[docs]
def set_nodes_attribute(self, key, values_dict):
"""Set node attributes for the key and values for each node.
Parameters
----------
key : str
values_dict : dict of node_id: values
"""
for node_id, value in values_dict.items():
self.graph.nodes[node_id][key] = value
@property
def node_groups(self):
return self._node_groups
[docs]
def set_node_group(self, group_name, node_ids):
# push these values to the nodes themselves, overwriting if
# necessary
self._set_group_nodes_attribute(group_name, node_ids)
# then update the group mapping with this
self._node_groups[group_name] = node_ids
[docs]
def _set_group_nodes_attribute(self, group_name, group_node_ids):
# the key for the attribute of the group goes in a little
# namespace prefixed with _group
group_key = "_groups/{}".format(group_name)
# make the mapping
values_map = {
node_id: True if node_id in group_node_ids else False
for node_id in self.graph.nodes
}
# then set them
self.set_nodes_attribute(group_key, values_map)
@property
def observables(self):
"""The list of available observables."""
return self._observables
[docs]
def node_observables(self, node_id):
"""Dictionary of observables for each node_id."""
node_obs = {}
for obs_name in self.observables:
obs_key = "_observables/{}".format(obs_name)
node_obs[obs_name] = self.get_nodes_attributes(node_id, obs_key)
return node_obs
[docs]
def set_nodes_observable(self, observable_name, node_values):
# the key for the attribute of the observable goes in a little
# namespace prefixed with _observable
observable_key = "_observables/{}".format(observable_name)
self.set_nodes_attribute(observable_key, node_values)
# then add to the list of available observables
self._observables.append(observable_name)
### Edge methods
[docs]
def get_edge_attributes(self, edge_id):
"""Returns the edge attributes of the macrostate.
Parameters
----------
edge_id : edge_id
Returns
-------
edge_attrs : dict
"""
return self.graph.edges[edge_id]
[docs]
def get_edge_attribute(self, edge_id, attribute_key):
"""Return the value for a specific edge and attribute.
Parameters
----------
edge_id : edge_id
attribute_key : str
Returns
-------
edge_attribute
"""
return self.get_edge_attributes(edge_id)[attribute_key]
[docs]
def get_edges_attribute(self, attribute_key):
"""Get a dictionary mapping edges to a specific attribute."""
edges_attr = {}
for edge_id in self.graph.edges:
edges_attr[edge_id] = self.graph.edges[edge_id][attribute_key]
return edges_attr
### Layout stuff
@property
def layouts(self):
return self._layouts
[docs]
def node_layouts(self, node_id):
"""Dictionary of layouts for each node_id."""
node_layouts = {}
for layout_name in self.layouts:
layout_key = "_layouts/{}".format(layout_name)
node_layouts[obs_name] = self.get_nodes_attributes(node_id, layout_key)
return node_layouts
[docs]
def set_nodes_layout(self, layout_name, node_values):
# the key for the attribute of the observable goes in a little
# namespace prefixed with _observable
layout_key = "_layouts/{}".format(layout_name)
self.set_nodes_attribute(layout_key, node_values)
# then add to the list of available observables
if layout_name not in self._layouts:
self._layouts.append(layout_name)
[docs]
def write_gexf(
self,
filepath,
exclude_node_fields=None,
exclude_edge_fields=None,
layout=None,
):
"""Writes a graph file in the gexf format of the network.
Parameters
----------
filepath : str
"""
layout_key = None
if layout is not None:
layout_key = "_layouts/{}".format(layout)
if layout not in self.layouts:
raise ValueError("Layout not found, use None for no layout")
### filter the node and edge attributes
# to do this we need to get rid of the assignments in the
# nodes though since this is not really supported or good to
# store in a gexf file which is more for visualization as an
# XML format, so we copy and modify then write the copy
gexf_graph = deepcopy(self._graph)
## Nodes
if exclude_node_fields is None:
exclude_node_fields = [self.ASSIGNMENTS]
else:
exclude_node_fields.append(self.ASSIGNMENTS)
exclude_node_fields = list(set(exclude_node_fields))
# exclude the layouts, we will set the viz manually for the layout
exclude_node_fields.extend(
["_layouts/{}".format(layout_name) for layout_name in self.layouts]
)
for node in gexf_graph:
# remove requested fields
for field in exclude_node_fields:
del gexf_graph.nodes[node][field]
# also remove the fields which are not valid gexf types
fields = list(gexf_graph.nodes[node].keys())
for field in fields:
if (
type(gexf_graph.nodes[node][field])
not in nx.readwrite.gexf.GEXF.xml_type
):
del gexf_graph.nodes[node][field]
if layout_key is not None:
# set the layout as viz attributes to this
gexf_graph.nodes[node]["viz"] = self._graph.nodes[node][layout_key]
## Edges
if exclude_edge_fields is None:
exclude_edge_fields = ["all_transitions"]
else:
exclude_edge_fields.append("all_transitions")
exclude_edge_fields = list(set(exclude_edge_fields))
# TODO: viz and layouts not supported for edges currently
#
# exclude the layouts, we will set the viz manually for the layout
# exclude_edge_fields.extend(['_layouts/{}'.format(layout_name)
# for layout_name in self.layouts])
for edge in gexf_graph.edges:
# remove requested fields
for field in exclude_edge_fields:
del gexf_graph.edges[edge][field]
# also remove the fields which are not valid gexf types
fields = list(gexf_graph.edges[edge].keys())
for field in fields:
if (
type(gexf_graph.edges[edge][field])
not in nx.readwrite.gexf.GEXF.xml_type
):
del gexf_graph.edges[edge][field]
# TODO,SNIPPET: we don't support layouts for the edges,
# but maybe we could
# if layout_key is not None:
# # set the layout as viz attributes to this
# gexf_graph.nodes[node]['viz'] = self._graph.nodes[node][layout_key]
# then write this filtered gexf to file
nx.write_gexf(gexf_graph, filepath)
[docs]
def nodes_to_records(
self,
extra_attributes=("_observables/total_weight",),
):
if extra_attributes is None:
extra_attributes = []
# keys which always go into the records
keys = [
"num_samples",
"node_idx",
]
# add all the groups to the keys
keys.extend(["_groups/{}".format(key) for key in self.node_groups.keys()])
# add the observables
keys.extend(["_observables/{}".format(obs) for obs in self.observables])
recs = []
for node_id in self.graph.nodes:
rec = {"node_id": node_id}
# the keys which are always there
for key in keys:
rec[key] = self.get_node_attribute(node_id, key)
# the user defined ones
for extra_key in extra_attributes:
rec[key] = self.get_node_attribute(node_id, extra_key)
recs.append(rec)
return recs
[docs]
def nodes_to_dataframe(
self,
extra_attributes=("_observables/total_weight",),
):
"""Make a dataframe of the nodes and their attributes.
Not all attributes will be added as they are not relevant to a
table style representation anyhow.
The columns will be:
- node_id
- node_idx
- num samples
- groups (as booleans) which is anything in the '_groups' namespace
- observables : anything in the '_observables' namespace and
will assume to be scalars
And anything in the 'extra_attributes' argument.
"""
# TODO: set the column order
# col_order = []
return pd.DataFrame(self.nodes_to_records(extra_attributes=extra_attributes))
[docs]
def edges_to_records(
self,
extra_attributes=None,
):
"""Make a dataframe of the nodes and their attributes.
Not all attributes will be added as they are not relevant to a
table style representation anyhow.
The columns will be:
- edge_id
- source
- target
- weighted_counts
- unweighted_counts
"""
if extra_attributes is None:
extra_attributes = []
keys = [
"weighted_counts",
"unweighted_counts",
]
recs = []
for edge_id in self.graph.edges:
rec = {
"edge_id": edge_id,
"source": edge_id[0],
"target": edge_id[1],
}
for key in keys:
rec[key] = self.graph.edges[edge_id][key]
# the user defined ones
for extra_key in extra_attributes:
rec[key] = self.get_node_attribute(node_id, extra_key)
recs.append(rec)
return recs
[docs]
def edges_to_dataframe(
self,
extra_attributes=None,
):
"""Make a dataframe of the nodes and their attributes.
Not all attributes will be added as they are not relevant to a
table style representation anyhow.
The columns will be:
- edge_id
- source
- target
- weighted_counts
- unweighted_counts
"""
return pd.DataFrame(self.edges_to_records(extra_attributes=extra_attributes))
[docs]
def node_map(self, func, map_func=map):
"""Map a function over the nodes.
The function should take as its first argument a node_id and
the second argument a dictionary of the node attributes. This
will not give access to the underlying trajectory data in the
HDF5, to do this use the 'node_fields_map' function.
Extra args not supported use 'functools.partial' to make
functions with arguments for all data.
Parameters
----------
func : callable
The function to map over the nodes.
map_func : callable
The mapping function, implementing the `map` interface
Returns
-------
node_values : dict of node_id : values
The mapping of node_ids to the values computed by the mapped func.
"""
# wrap the function so that we can pass through the node_id
def func_wrapper(args):
node_id, node_attrs = args
return node_id, func(node_attrs)
# zip the node_ids with the node attributes as an iterator
node_attr_it = (
(node_id, {**self.get_node_attributes(node_id), "node_id": node_id})
for node_id in self.graph.nodes
)
return {
node_id: value for node_id, value in map_func(func_wrapper, node_attr_it)
}
[docs]
def edge_attribute_to_matrix(
self,
attribute_key,
fill_value=np.nan,
):
"""Convert scalar edge attributes to an assymetric matrix.
This will always return matrices of size (num_nodes,
num_nodes).
Additionally, matrices for the same network will always have
the same indexing, which is according to the 'node_idx'
attribute of each node.
For example if you have a matrix like:
>>> msn = MacroStateNetwork(...)
>>> mat = msn.edge_attribute_to_matrix('unweighted_counts')
Then, for example, the node with node_id of '10' having a
'node_idx' of 0 will always be the first element for each
dimension. Using this example the self edge '10'->'10' can be
accessed from the matrix like:
>>> mat[0,0]
For another node ('node_id' '25') having 'node_idx' 4, we can
get the edge from '10'->'25' like:
>>> mat[0,4]
This is because 'node_id' does not necessarily have to be an
integer, and even if they are integers they don't necessarily
have to be a contiguous range from 0 to N.
To get the 'node_id' for a 'node_idx' use the method
'node_idx_to_id'.
>>> msn.node_idx_to_id(0)
=== 10
Parameters
----------
attribute_key : str
The key of the edge attribute the matrix should be made of.
fill_value : Any
The value to put in the array for non-existent edges. Must
be a numpy dtype compatible with the dtype of the
attribute value.
Returns
-------
edge_matrix : numpy.ndarray
Assymetric matrix of dim (n_macrostates,
n_macrostates). The 0-th axis corresponds to the 'source'
node and the 1-st axis corresponds to the 'target' nodes,
i.e. the dimensions mean: (source, target).
"""
# get the datatype of the attribute and validate it will fit in an array
test_edge_id = list(self.graph.edges.keys())[0]
test_attr_value = self.get_edge_attribute(
test_edge_id,
attribute_key,
)
# duck type check
dt = np.dtype(type(test_attr_value))
# TODO: test that its a numerical type
# get the dtype so we can make the matrix
# assert hasattr(test_attr_value, 'dtype')
# do "duck type" test, if the construction fails it was no good!
# allocate the matrix and initialize to zero for each element
mat = np.full(
(self.num_states, self.num_states),
fill_value,
dtype=dt,
)
# get a dictionary of (node_id, node_id) -> value
edges_attr_d = self.get_edges_attribute(attribute_key)
# make a dictionary of the edge (source, target) mapped to the
# scalar values
# the mapping id->idx
node_id_to_idx_dict = self.node_id_to_idx_dict()
# convert node_ids to node_idxs
edges_idx_attr_d = {}
for edge, value in edges_attr_d.items():
idx_edge = (node_id_to_idx_dict[edge[0]], node_id_to_idx_dict[edge[1]])
edges_idx_attr_d[idx_edge] = value
# assign to the array
for trans, value in edges_idx_attr_d.items():
source = trans[0]
target = trans[1]
mat[source, target] = value
return mat
[docs]
class MacroStateNetwork:
"""Provides an abstraction over weighted ensemble data in the form of
a kinetically connected network.
The MacroStateNetwork refers to any grouping of the so called
"micro" states that were observed during simulation,
i.e. trajectory frames, and not necessarily in the usual sense
used in statistical mechanics. Although it is the perfect vehicle
for working with such macrostates.
Because walker trajectories in weighted ensemble there is a
natural way to generate the edges between the macrostate nodes in
the network. These edges are determined automatically and a lag
time can also be specified, which is useful in the creation of
Markov State Models.
This class provides transparent access to an underlying 'WepyHDF5'
dataset. If you wish to have a simple serializable network that
does not reference see the 'BaseMacroStateNetwork' class, which
you can construct standalone or access the instance attached as
the 'base_network' attribute of an object of this class.
For a description of all of the default node and edge attributes
which are set after construction see the docstring for the
'BaseMacroStateNetwork' class docstring.
Warnings
--------
This class is not serializable as it references a 'WepyHDF5'
object. Either construct a 'BaseMacroStateNetwork' or use the
attached instance in the 'base_network' attribute.
"""
def __init__(
self,
contig_tree,
base_network=None,
assg_field_key=None,
assignments=None,
transition_lag_time=2,
):
"""For documentation of the following arguments see the constructor
docstring of the 'BaseMacroStateNetwork' class:
- contig_tree
- assg_field_key
- assignments
- transition_lag_time
The other arguments are documented here. This is primarily
optional 'base_network' argument. This is a
'BaseMacroStateNetwork' instance, which allows you to
associate it with a 'WepyHDF5' dataset for access to the
microstate data etc.
Parameters
----------
base_network : BaseMacroStateNetwork object
An already constructed network, which will avoid
recomputing all in-memory network values again for this
object.
"""
self.closed = True
self._contig_tree = contig_tree
self._wepy_h5 = self._contig_tree.wepy_h5
# if we pass a base network use that one instead of building
# one manually
if base_network is not None:
assert isinstance(base_network, BaseMacroStateNetwork)
self._set_base_network_to_self(base_network)
else:
new_network = BaseMacroStateNetwork(
contig_tree,
assg_field_key=assg_field_key,
assignments=assignments,
transition_lag_time=transition_lag_time,
)
self._set_base_network_to_self(new_network)
[docs]
def _set_base_network_to_self(self, base_network):
self._base_network = base_network
# then make references to this for the attributes we need
# attributes
self._graph = self._base_network._graph
self._assg_field_key = self._base_network._assg_field_key
self._node_idxs = self._base_network._node_idxs
self._node_idx_to_id_dict = self._base_network._node_idx_to_id_dict
self._transition_lag_time = self._base_network._transition_lag_time
# DEBUG: remove once tested
# self._probmat = self._base_network._probmat
# self._countsmat = self._base_network._countsmat
# functions
self.node_id_to_idx = self._base_network.node_id_to_idx
self.node_idx_to_id = self._base_network.node_idx_to_id
self.node_id_to_idx_dict = self._base_network.node_id_to_idx_dict
self.node_idx_to_id_dict = self._base_network.node_idx_to_id_dict
self.get_node_attributes = self._base_network.get_node_attributes
self.get_node_attribute = self._base_network.get_node_attribute
self.get_nodes_attribute = self._base_network.get_nodes_attribute
self.node_assignments = self._base_network.node_assignments
self.set_nodes_attribute = self._base_network.set_nodes_attribute
self.get_edge_attributes = self._base_network.get_edge_attributes
self.get_edge_attribute = self._base_network.get_edge_attribute
self.get_edges_attribute = self._base_network.get_edges_attribute
self.node_groups = self._base_network.node_groups
self.set_node_group = self._base_network.set_node_group
self._set_group_nodes_attribute = self._base_network._set_group_nodes_attribute
self.observables = self._base_network.observables
self.node_observables = self._base_network.node_observables
self.set_nodes_observable = self._base_network.set_nodes_observable
self.nodes_to_records = self._base_network.nodes_to_records
self.nodes_to_dataframe = self._base_network.nodes_to_dataframe
self.edges_to_records = self._base_network.edges_to_records
self.edges_to_dataframe = self._base_network.edges_to_dataframe
self.node_map = self._base_network.node_map
self.edge_attribute_to_matrix = self._base_network.edge_attribute_to_matrix
self.write_gexf = self._base_network.write_gexf
[docs]
def open(self, mode=None):
if self.closed:
self.wepy_h5.open(mode=mode)
self.closed = False
else:
raise IOError("This file is already open")
[docs]
def close(self):
self.wepy_h5.close()
self.closed = True
def __enter__(self):
self.wepy_h5.__enter__()
self.closed = False
return self
def __exit__(self, exc_type, exc_value, exc_tb):
self.wepy_h5.__exit__(exc_type, exc_value, exc_tb)
self.close()
# from the Base class
@property
def graph(self):
"""The networkx.DiGraph of the macrostate network."""
return self._graph
@property
def num_states(self):
"""The number of states in the network."""
return len(self.graph)
@property
def node_ids(self):
"""A list of the node_ids."""
return list(self.graph.nodes)
@property
def assg_field_key(self):
"""The string key of the field used to make macro states from the WepyHDF5 dataset.
Raises
------
MacroStateNetworkError
If this wasn't used to construct the MacroStateNetwork.
"""
if self._assg_field_key is None:
raise MacroStateNetworkError("Assignments were manually defined, no key.")
else:
return self._assg_field_key
# @property
# def countsmat(self):
# """Return the transition counts matrix of the network.
# Raises
# ------
# MacroStateNetworkError
# If no lag time was given.
# """
# if self._countsmat is None:
# raise MacroStateNetworkError("transition counts matrix not calculated")
# else:
# return self._countsmat
# @property
# def probmat(self):
# """Return the transition probability matrix of the network.
# Raises
# ------
# MacroStateNetworkError
# If no lag time was given.
# """
# if self._probmat is None:
# raise MacroStateNetworkError("transition probability matrix not set")
# else:
# return self._probmat
# unique to the HDF5 holding one
@property
def base_network(self):
return self._base_network
@property
def wepy_h5(self):
"""The WepyHDF5 source object for which the contig tree is being constructed."""
return self._wepy_h5
[docs]
def state_to_mdtraj(self, node_id, alt_rep=None):
"""Generate an mdtraj.Trajectory object from a macrostate.
By default uses the "main_rep" in the WepyHDF5
object. Alternative representations of the topology can be
specified.
Parameters
----------
node_id : node_id
alt_rep : str
(Default value = None)
Returns
-------
traj : mdtraj.Trajectory
"""
return self.wepy_h5.trace_to_mdtraj(
self.base_network.node_assignments(node_id), alt_rep=alt_rep
)
[docs]
def state_to_traj_fields(self, node_id, alt_rep=None):
return self.states_to_traj_fields([node_id], alt_rep=alt_rep)
[docs]
def states_to_traj_fields(self, node_ids, alt_rep=None):
node_assignments = []
for node_id in node_ids:
node_assignments.extend(self.base_network.node_assignments(node_id))
# get the right fields
rep_path = self.wepy_h5._choose_rep_path(alt_rep)
fields = [rep_path, "box_vectors"]
return self.wepy_h5.get_trace_fields(node_assignments, fields)
[docs]
def get_node_fields(self, node_id, fields):
"""Return the trajectory fields for all the microstates in the
specified macrostate.
Parameters
----------
node_id : node_id
fields : list of str
Field name to retrieve.
Returns
-------
fields : dict of str: array_like
A dictionary mapping the names of the fields to an array of the field.
Like fields of a trace.
"""
node_trace = self.base_network.node_assignments(node_id)
# use the node_trace to get the weights from the HDF5
fields_d = self.wepy_h5.get_trace_fields(node_trace, fields)
return fields_d
[docs]
def iter_nodes_fields(self, fields):
"""Iterate over all nodes and return the field values for all the
microstates for each.
Parameters
----------
fields : list of str
Returns
-------
nodes_fields : dict of node_id: (dict of field: array_like)
A dictionary with an entry for each node.
Each node has it's own dictionary of node fields for each microstate.
"""
nodes_d = {}
for node_id in self.graph.nodes:
fields_d = self.base_network.get_node_fields(node_id, fields)
nodes_d[node_id] = fields_d
return nodes_d
[docs]
def microstate_weights(self):
"""Returns the weights of each microstate on the basis of macrostates.
Returns
-------
microstate_weights : dict of node_id: ndarray
"""
node_weights = {}
for node_id in self.graph.nodes:
# get the trace of the frames in the node
node_trace = self.base_network.node_assignments(node_id)
# use the node_trace to get the weights from the HDF5
trace_weights = self.wepy_h5.get_trace_fields(node_trace, ["weights"])[
"weights"
]
node_weights[node_id] = trace_weights
return node_weights
[docs]
def macrostate_weights(self):
"""Compute the total weight of each macrostate.
Returns
-------
macrostate_weights : dict of node_id: float
"""
macrostate_weights = {}
microstate_weights = self.microstate_weights()
for node_id, weights in microstate_weights.items():
macrostate_weights[node_id] = float(sum(weights)[0])
return macrostate_weights
[docs]
def set_macrostate_weights(self):
"""Compute the macrostate weights and set them as node attributes
'total_weight'."""
self.base_network.set_nodes_observable(
"total_weight",
self.macrostate_weights(),
)
[docs]
def node_fields_map(self, func, fields, map_func=map):
"""Map a function over the nodes and microstate fields.
The function should take as its arguments:
1. node_id
2. dictionary of all the node attributes
3. fields dictionary mapping traj field names. (The output of
`MacroStateNetwork.get_node_fields`)
This *will* give access to the underlying trajectory data in
the HDF5 which can be requested with the `fields`
argument. The behaviour is very similar to the
`WepyHDF5.compute_observable` function with the added input
data to the mapped function being all of the macrostate node
attributes.
Extra args not supported use 'functools.partial' to make
functions with arguments for all data.
Parameters
----------
func : callable
The function to map over the nodes.
fields : iterable of str
The microstate (trajectory) fields to provide to the mapped function.
map_func : callable
The mapping function, implementing the `map` interface
Returns
-------
node_values : dict of node_id : values
The mapping of node_ids to the values computed by the mapped func.
Returns
-------
node_values : dict of node_id : values
Dictionary mapping nodes to the computed values from the
mapped function.
"""
# wrap the function so that we can pass through the node_id
def func_wrapper(args):
node_id, node_attrs, node_fields = args
# evaluate the wrapped function
result = func(
node_id,
node_attrs,
node_fields,
)
return node_id, result
# zip the node_ids with the node attributes as an iterator
node_attr_fields_it = (
(
node_id,
{**self.get_node_attributes(node_id), "node_id": node_id},
self.get_node_fields(node_id, fields),
)
for node_id in self.graph.nodes
)
# map the inputs to the wrapped function and return as a
# dictionary for the nodes
return {
node_id: value
for node_id, value in map_func(func_wrapper, node_attr_fields_it)
}