import matplotlib.pyplot as pp
from scipy.stats import gaussian_kde
import numpy as np


def plotpmf_hist(datas, n_bins=10, labels=None):
    '''Plot PMF. Each entry in datas (datas[i]) is a seperate data set
    Use a bin counting method, with n_bins. Labels should be the same length as datas
    '''
    
    pp.title('Histogram PMF: %d bins' % n_bins)
    if labels is None:
        labels = [str(i) for i in range(len(datas))]
    colors = pp.cm.jet(np.linspace(0, 0.9, len(datas)))
    
        
    for i, data in enumerate(datas):
        logq = np.log(1.0 / (n_bins + 1))
        hist, bin_edges = np.histogram(data, bins=n_bins)
        width = bin_edges[1] - bin_edges[0]
        for j in range(n_bins):
            pp.plot([bin_edges[j], bin_edges[j+1]], -np.log([hist[j], hist[j]]) + logq, color=colors[i])
            try:
                pp.plot([bin_edges[j+1], bin_edges[j+1]], -np.log([hist[j], hist[j+1]]) + logq, color=colors[i])
            except:
                pass
        pp.plot([0], color=colors[i], label=labels[i])
    pp.legend()

def plotpmf_gkde(datas, labels=None):
    '''Plot PMF. Each entry in datas (datas[i]) is a seperate data set
    Use a gaussian kernel density estimator. Labels should be the same length as datas
    '''
    
    pp.title('Gaussian Kernel Density Estimator PMF')
    if labels is None:
        labels = [str(i) for i in range(len(datas))]
        
    for i, data in enumerate(datas):
        gkde = gaussian_kde(data)
        min, max = np.min(data), np.max(data)
        x = np.linspace(min - (max-min)/10.0, max + (max-min)/10.0, 1000)
        pp.plot(x, -np.log(gkde(x)), label=labels[i])
    pp.legend()