import copy
import cPickle
import itertools
#import logging
import operator
import matplotlib
import networkx
import numpy
import pylab
import sys
class CommunityDictionary(dict):
    pass

def find_communities(graph):
    '''Find the unconnected communities.

    Arguments:
    graph -- graph

    Return:
    communities -- a list-of-lists containing the nodes in each community
    '''
    unclassified_nodes = set(graph.nodes())
    communities = []
    while unclassified_nodes:
        next_node = unclassified_nodes.pop()
        next_community = networkx.node_connected_component(graph, next_node)
        communities.append(next_community)
        unclassified_nodes.difference_update(next_community)
    return communities

def _community_mapper(graph, communities):
    '''Return a dictionary of (node, community_key) pairs.

    Requires that communities is a list-of-lists.
    '''
    community_mapper = dict([(node, [node in com for com in 
                                     communities].index(True))
                             for node in graph.nodes()])
    return community_mapper

def _compute_edge_matrix(graph, communities):
    '''Compute the number of edges within and between communities.
    '''
    e = numpy.zeros((len(communities), len(communities)))
    community_mapper = _community_mapper(graph, communities)
    for (node_i, node_j) in graph.edges():
        com_i = community_mapper[node_i]
        com_j = community_mapper[node_j]
        e[com_i, com_j] += 1
        if com_i != com_j:
            e[com_j, com_i] += 1
    return e

def _compute_weighted_edge_matrix(graph, communities):
    '''Compute the number of edges within and between communities.
    '''
    e = numpy.zeros((len(communities), len(communities)))
    community_mapper = _community_mapper(graph, communities)
    for (node_i, node_j) in graph.edges():
        com_i = community_mapper[node_i]
        com_j = community_mapper[node_j]
        e[com_i, com_j] += graph.get_edge_data(node_i, node_j)['weight']
        if com_i != com_j:
            e[com_j, com_i] += graph.get_edge_data(node_i, node_j)['weight']
    return e

def _compute_weighted_betweenness_edge_matrix(graph, communities):
    '''Compute the sum of the weighted betweenness of the edges connecting communities.
    '''
    e = numpy.zeros((len(communities), len(communities)))
    edges = networkx.edge_betweenness(graph, normalized=True, weight='weight')
    community_mapper = _community_mapper(graph, communities)
    for ((node_i, node_j), betweenness) in edges.iteritems():
        com_i = community_mapper[node_i]
        com_j = community_mapper[node_j]
        e[com_i, com_j] += betweenness
        if com_i != com_j:
            e[com_j, com_i] += betweenness
    return e

def compute_modularity(graph, communities):
    '''Compute the modularity of a graph given a community structure.
    '''
    e = _compute_edge_matrix(graph, communities)
    e /= graph.number_of_edges()
    Q = numpy.trace(e) - numpy.dot(e, e).sum()
    return Q

def Girvan_Newman_algorithm(graph, num_iterations, community_graph=None,
                           target_num_communities=None, logfile=None):
    '''Remove num_iterations edges.

    Arguments:
    graph -- fine-grained graph
    num_iterations -- number of algorithm iterations to run
    community_graph -- community graph to start with, allowing algorithm to be restarted
    target_num_communities -- terminated algorithm as soon as target number of
    communities is reached

    Returns:
    max_community_graph -- community graph with the highest modularity
    last_community_graph -- community graph from the last iteration of the algorithm
    modularity -- list of modularity values at each iterations of the algorithm 
    '''
    if community_graph is None:
        community_graph = graph.copy()
    # Set this in case no communities are found
    max_community_graph = community_graph
    modularity = []
    for idx in range(num_iterations):
        print 'Girvan-Newman iteration %d of %d' % (idx+1,num_iterations)
        if(logfile != None):
            logfile.write('Girvan-Newman iteration %d of %d \n' % (idx+1,num_iterations))
        sys.stdout.flush()
        betweenness = networkx.edge_betweenness(community_graph,weight='weight')
        (max_edge, max_betweenness) = max(betweenness.iteritems(), key=operator.itemgetter(1))
        community_graph.remove_edge(*max_edge)
        communities = find_communities(community_graph)
        current_modularity = compute_modularity(graph, communities)
        if modularity and current_modularity > max(modularity):
            max_community_graph = community_graph.copy()
        modularity.append(current_modularity)
        if target_num_communities:
            if networkx.number_connected_components(community_graph)  ==\
               target_num_communities:
                return (max_community_graph, community_graph, modularity)
    
    return (max_community_graph, community_graph, modularity)

def draw_communities(graph, communities, ax=None, **kwargs):
    '''Draw the communities in different colors.
    '''
    # Default is a full figure
    if ax is None:
        ax = pylab.axes()
    community_mapper = [[node in com for com in communities].index(True)
                        for node in graph.nodes()]
    community_colors = _community_colors(communities)
    node_colors = [community_colors[com_idx] for com_idx in community_mapper]
    networkx.draw(graph, node_color=node_colors, ax=ax, **kwargs)

def coarsegrain_communities(graph, communities):
    '''Coarse grain the graph into communities.
    '''
    cg_graph = networkx.Graph()
    # Create community nodes
    if not isinstance(communities, dict):
        communities = dict(enumerate(communities))
    for key, community in communities.items():
        if community:
            cg_graph.add_node(key, size=len(community), members=community)
    # Create connecting edges
    #e = _compute_edge_matrix(graph, communities.values())
    #e = _compute_weighted_edge_matrix(graph, communities.values())
    e = _compute_weighted_betweenness_edge_matrix(graph, communities.values())
    for idx_i in range(len(communities)):
        for idx_j in range(idx_i + 1, len(communities)):
            eij = e[idx_i, idx_j]
            if eij:
                key_i = communities.keys()[idx_i]
                key_j = communities.keys()[idx_j]
                cg_graph.add_edge(key_i, key_j, {'weight':eij})
    return cg_graph

def draw_coarsegrain(graph, node_factor=5.0, edge_factor=200.0, colors=None,
                     **kwargs):
    '''Draw the coarse-grained graph.
    '''
    # Create and normalized the node sizes and edge widths
    
    node_sizes = numpy.array([graph.node[node]['size'] for node
                              in graph.nodes()])
    node_sizes *= node_factor
    edge_widths = numpy.array([graph.edge[i][j]['weight'] for (i, j)
                               in graph.edges()])
    edge_widths *= edge_factor
    if colors is None:
        colors = _community_colors(range(graph.number_of_nodes()))
        colors = [colors[com_idx] for com_idx in graph.nodes()]
    networkx.draw(graph, node_color=colors, node_size=node_sizes,
                  width=edge_widths, **kwargs)

def replace_weights_and_draw_coarsegrain(graph, mutinf_between_communities_matrix, \
                                             node_factor=5.0, edge_factor=5.0, colors=None, \
                                             **kwargs):
    '''Draw the coarse-grained graph.
    '''
    # Create and normalized the node sizes and edge widths
    node_sizes = numpy.array([graph.node[node]['size'] for node 
                              in graph.nodes()])
    for (i,j) in graph.edges():
        graph.edge[i][j]['weight'] = mutinf_between_communities_matrix[i][j]

    node_sizes *= node_factor
    edge_widths = numpy.array([graph.edge[i][j]['weight'] for (i, j)
                               in graph.edges()])
    
    node_sizes = numpy.array([graph.node[node]['size'] for node
                              in graph.nodes()])
    node_sizes *= node_factor
    edge_widths *= edge_factor
    if colors is None:
        colors = _community_colors(range(graph.number_of_nodes()))
        colors = [colors[com_idx] for com_idx in graph.nodes()]
    networkx.draw(graph, node_color=colors, node_size=node_sizes,
                  width=edge_widths, **kwargs)

def print_coarsegrain(graph, myfilename="cg_edgeweights.txt", **kwargs):
    '''Print the coarse-grained edge weights.
    '''
    myfile = open(myfilename, 'w')
    # Create and normalized the node sizes and edge widths
    #node_sizes = numpy.array([graph.node[node]['size'] for node
    #                          in graph.nodes()])
    #node_sizes *= node_factor
    for (i, j) in graph.edges():
        myfile.write(str(i)+","+str(j)+","+str(graph.edge[i][j]['weight'])+"\n")
    #close(myfile)

def _community_colors(communities, cmap=pylab.cm.gist_ncar):
    '''Returns a list of RGB tuples for each community.
    '''
    if not isinstance(communities, dict):
        communities = dict(enumerate(communities))
    colors = {}
    for idx, (key, community) in enumerate(communities.items()):
        fractional_idx = float(idx) / len(communities)
        colors[key] = cmap(fractional_idx)
    return colors

def pickle_data(graph, communities, pickle_path='communities.pickle',
                cmap=pylab.cm.gist_ncar):
    '''Save a pickle for visualization purposes.
    '''
    with open(pickle_path, 'w') as fhandle:
        data = (graph, communities, _community_colors(communities, cmap=cmap))
        cPickle.dump(data, fhandle)

def unpickle_data(pickle_path):
    '''Unpack the pickle, assumed to have been saved with pickle_data.
    '''
    with open(pickle_path, 'r') as fhandle:
        (graph, communities, colors) = cPickle.load(fhandle)
    return (graph, communities, colors)

def compare_communities(com_a, com_b, plot=False):
    '''Compare two sets of communties by fractional intersection.

    Arguments:
    com_a -- first list of communities, length M
    com_b -- second list of communities, length N

    Returns:
    overlap -- MxN matrix of fractional overlaps
    '''
    frac_intersection = numpy.zeros((len(com_a), len(com_b)), dtype=float)
    if not isinstance(com_a, dict):
        com_a = dict(enumerate(com_a))
    if not isinstance(com_b, dict):
        com_b = dict(enumerate(com_b))
    for idx_a, (key_a, community_a) in enumerate(com_a.items()):
        for idx_b, (key_b, community_b) in enumerate(com_b.items()):
            intersection = set(community_a) & set(community_b)
            mean_size = (len(community_a) + len(community_b)) / 2.0
            # Handle empty-to-empty comparisons
            if mean_size:
                frac_intersection[idx_a, idx_b] = len(intersection) / mean_size
    if plot:
        cmap = pylab.cm.bone_r
        ca = pylab.gca()
        pylab.pcolor(frac_intersection, cmap=cmap)
        ca.set_xticks(pylab.arange(len(com_b)) + 0.5)
        ca.set_xticklabels(com_b.keys())
        ca.set_yticks(pylab.arange(len(com_a)) + 0.5)
        ca.set_yticklabels(com_a.keys())
        ca.set_xlim(0, len(com_b))
        ca.set_ylim(0, len(com_a))
        minor_locator = matplotlib.ticker.MultipleLocator(1)
	ca.yaxis.set_minor_locator(minor_locator)
        minor_locator = matplotlib.ticker.MultipleLocator(1)
        ca.xaxis.set_minor_locator(minor_locator)
        ca.grid(which='minor')
    return frac_intersection

class BasisSet(dict):
    def __init__(self, *args):
        super(BasisSet, self).__init__(*args)
        self._common_count = 0
        self._unique_count = 0
        self._split_counts = {}

    def next_common_key(self):
        key = 'c%02d' % self._common_count
        self._split_counts[key] = 0
        self._common_count += 1
        return key
  
    def next_split_key(self, original_key):
        original_count = self._split_counts[original_key]
        key = '%s_%02d' % (original_key, original_count)
        self._split_counts[key] = 0
        self._split_counts[original_key] += 1
        return key

    def next_unique_key(self):
        key = 'u%02d' % self._unique_count
        self._split_counts[key] = 0
        self._unique_count += 1
        return key

def map_communities(community_maps, common_cutoff=0.7, 
                    similarity_cutoff=0.7, splitting_cutoff=0.7):
    '''Map communities to a common numbering scheme using fractional 
       intersection.

    Arguments:
    community_maps -- Tuple of community maps to remap
    common_cutoff -- Fractional intersection common cutoff
    similarity_cutoff -- Fractional intersection similarity cutoff
    splitting_cutoff -- Fractional intersection splitting cutoff

    Returns:
    remapped_communities -- Remapped communities, in the same order as 
    community_maps
    '''
    # Make sure that all of the community maps are composed of sets
    
    community_maps = [[set(community) for community in com_map] for com_map in community_maps]
    basis_set = BasisSet()
    # Find the common communities
    community_lengths = [len(com) for com in community_maps]
    common_intersection = numpy.zeros(community_lengths, dtype=float)
    for indices in itertools.product(*[range(com_len) for com_len
                                       in community_lengths]):
        current_maps = [community_maps[map_idx][com_idx] for (map_idx, com_idx)
                        in enumerate(indices)]
        complete_intersection = current_maps[0].intersection(*current_maps[1:])
        average_size = numpy.mean([len(community) for community in current_maps])
        common_intersection[indices] = len(complete_intersection) / average_size
    common_indices = numpy.where(common_intersection >= common_cutoff)
    # Make sure there are no duplicates!!!
    common_keys = []
    for indices in zip(*numpy.where(common_intersection >= common_cutoff)):
        if common_intersection[indices] >= common_cutoff:
            key = basis_set.next_common_key()
            current_maps = [community_maps[map_idx][com_idx] for (map_idx, com_idx)
                            in enumerate(indices)]
            basis_set[key] = current_maps[0].intersection(*current_maps[1:])
            common_keys.append(key)
    # Clear the common communities from the maps
    remaining_communities = copy.deepcopy(community_maps)
    for (map_idx, common_indices) in enumerate(common_indices):
        for index in common_indices:
            remaining_communities[map_idx][index] = None
    # Now process through the communities from largest to smallest
    remaining_communities = [community for community in 
                             itertools.chain(*remaining_communities) if community]
    remaining_communities.sort(key=operator.methodcaller('__len__'), reverse=True)
    for community in remaining_communities:
        # Prepare the intersections with the current basis_set
        basis_intersection = [(key, community.intersection(basis))
                              for (key, basis) in basis_set.items()
                              if community.intersection(basis)]
        intersection_length = lambda x: operator.methodcaller('__len__')(operator.itemgetter(1)(x))
        basis_intersection.sort(key=intersection_length, reverse=True)
        # First, check for similarity with ALL of the basis_set...
        # using just the largest intersection lead to some strange behavior
        similar = False
        for (key, intersecting_residues) in basis_intersection:
            max_size = float(len(intersecting_residues))
            average_size = (len(community) + len(basis_set[key])) / 2.0
            if max_size / average_size >= similarity_cutoff:
                similar = True
            if similar: break
        if similar: continue
        # Second, check for a split community
        split = False
        for (key, intersecting_residues) in basis_intersection:
            max_size = float(len(intersecting_residues))
            if max_size / len(community) >= splitting_cutoff:
                split = True
                split_key = basis_set.next_split_key(key)
                basis_set[split_key] = intersecting_residues
            if split: break
        if split: continue
        # The community must be unique
        key = basis_set.next_unique_key()
        basis_set[key] = community
    # And now remap the communities onto the common basis_set
    remapped_communities = []
    for community_map in community_maps:
        remapped_community = {}
        frac_intersection = compare_communities(community_map, basis_set.values())
        mapping = frac_intersection.argmax(axis=1)
        for orig_idx, community in enumerate(community_map):
            basis_idx = mapping[orig_idx]
            remapped_key = basis_set.keys()[basis_idx]
            remapped_community[remapped_key] = community
        remapped_communities.append(remapped_community)
    return (remapped_communities, basis_set)

def _create_test_data(number_of_nodes=100, number_of_communities=4):
    import random
    number_of_nodes = 100
    number_of_communities = 4
    # Create the test graph
    test_graph = networkx.Graph()
    # Separate the nodes into non-overlapping, contiguous communities
    actual_communities = []
    community_size = number_of_nodes / number_of_communities
    for com_idx in range(number_of_communities):
        actual_communities.append(range(com_idx * community_size,
                                        (com_idx + 1) * community_size))
    # Add the internal community connections
    for community in actual_communities:
        test_graph.add_nodes_from(community)
        for idx_i, node_i in enumerate(community):
            for idx_j, node_j in enumerate(community[idx_i+1:]):
                test_graph.add_edge(node_i, node_j)
    # Add the community connections
    community_edges = []
    for idx_i, com_i in enumerate(actual_communities):
        for idx_j, com_j in enumerate(actual_communities[idx_i:]):
            current_edge = (random.choice(com_i), random.choice(com_j))
            test_graph.add_edge(*current_edge)
            community_edges.append(current_edge)
    return test_graph

if __name__ == "__main__":
    # Create the test graph
    test_graph = _create_test_data()
    # Show the original test graph
    pylab.figure()
    ax = pylab.gca()
    networkx.draw(test_graph, ax=ax)
    #networkx.draw(test_graph)
    pylab.title('Original network')
    # Remove the max_betweeness
    (best_community, last_community, modularity) = Girvan_Newman_algorithm(test_graph, 50)
    pylab.figure()
    ax = pylab.gca()
    networkx.draw(best_community, ax=ax)
    pylab.title('Highest modularity community network')
    # Show the modularity
    pylab.figure()
    pylab.plot(modularity)
    pylab.xlabel('Iteration number')
    pylab.ylabel('Modularity Q')
    pylab.title('Modularity')
    # Compare the best and last communities
    pylab.figure()
    communities = find_communities(best_community)
    last_communities = find_communities(last_community)
    intersection = compare_communities(communities, last_communities, plot=True)
    print "Fractional intersection of the best modularity communities versus the final community classification..."
    print intersection
    ([new_a, new_b], common_communities) = map_communities([communities, last_communities])
    # Show the community split
    pylab.figure()
    draw_communities(test_graph, new_a)
    pylab.title('Original network with community classification')
    # Create the coarsegrain
    pylab.figure()
    cg_graph = coarsegrain_communities(test_graph, new_a)
    last_cg_graph = coarsegrain_communities(test_graph, new_b)
    # Find a fixed set of positions for the common communities
    initial_pos = networkx.layout.spring_layout(last_cg_graph)
    draw_coarsegrain(cg_graph, pos=initial_pos, fixed=common_communities)
    pylab.title('Coarse-grained network for best modularity')
    pylab.figure()
    draw_coarsegrain(last_cg_graph, pos=initial_pos, fixed=common_communities)
    pylab.title('Coarse-grained network for last step')
    # Show all plots
    pylab.show()
