/* Copyright (c) 2005 Stanford University and Christopher Bruns
 * 
 * Permission is hereby granted, free of charge, to any person obtaining
 * a copy of this software and associated documentation files (the
 * "Software"), to deal in the Software without restriction, including 
 * without limitation the rights to use, copy, modify, merge, publish, 
 * distribute, sublicense, and/or sell copies of the Software, and to
 * permit persons to whom the Software is furnished to do so, subject
 * to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included 
 * in all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 */

/*
 * Created on Feb 6, 2006
 * Original author: Christopher Bruns
 */
package org.simtk.isimsu;

import java.awt.*;
import java.awt.event.*;
import javax.swing.*;
import javax.swing.table.*;
import java.io.*;
import java.util.*;
import java.util.regex.*;
import java.text.*;

/**
 *  
  * @author Christopher Bruns
  * 
  * Displays counts of individual ion types after a run of ISIM.
 */
public class IonCountDialog extends JDialog implements ActionListener {
    JTable ionTable = new JTable();

    JButton dismissButton = new JButton("Close");
    JButton saveResultsButton = new JButton("Save ion counts...");
    JButton showDetailsButton = new JButton("Simulation details...");

    ISIMWrapper isimWrapper;
    JPanel tablePanel = new JPanel();

    protected JLabel moleculeChargeLabel = new JLabel("Macromolecule charge = (unknown)");
    protected JLabel totalChargeLabel = new JLabel("Total charge = (unknown)");

    IonCountDialog(ISIMWrapper isimWrapper) {
        super(isimWrapper); // preserves title bar icon from parent
        
        this.isimWrapper = isimWrapper;
        setTitle("Ion Counts (" + isimWrapper.programName + ")");
        
        dismissButton.addActionListener(this);
        saveResultsButton.addActionListener(this);
        showDetailsButton.addActionListener(this);        
        
        Container pane = getContentPane();
        pane.setLayout(new BoxLayout(pane, BoxLayout.Y_AXIS)); // vertical

        pane.add(new JLabel("Ion Counts:"));
        
        tablePanel.add(ionTable);
        pane.add(tablePanel);
        
        JPanel runInfoPanel = new JPanel();
        runInfoPanel.setLayout(new BoxLayout(runInfoPanel, BoxLayout.Y_AXIS));
        runInfoPanel.setBorder(BorderFactory.createEmptyBorder(12,12,12,12));
        runInfoPanel.add(moleculeChargeLabel);        
        runInfoPanel.add(totalChargeLabel);
        
        // Left justify run information
        JPanel outerRunInfoPanel = new JPanel();
        outerRunInfoPanel.setLayout(new BoxLayout(outerRunInfoPanel, BoxLayout.X_AXIS));
        outerRunInfoPanel.add(runInfoPanel);
        outerRunInfoPanel.add(Box.createHorizontalGlue());
        
        pane.add(outerRunInfoPanel);
        
        // dismiss button
        JPanel buttonPanel = new JPanel();
        buttonPanel.setBorder(BorderFactory.createEmptyBorder(10,10,10,10));
        buttonPanel.setLayout(new BoxLayout(buttonPanel, BoxLayout.X_AXIS));
        buttonPanel.add(Box.createHorizontalGlue());
        buttonPanel.add(showDetailsButton);
        buttonPanel.add(Box.createHorizontalStrut(5));
        buttonPanel.add(saveResultsButton);
        buttonPanel.add(Box.createHorizontalStrut(5));
        buttonPanel.add(dismissButton);
        pane.add(buttonPanel);

        setLocationRelativeTo(isimWrapper);
        
        pack();
    }

    
    protected double loadMoleculeCharge() throws IOException {

        // Read pqr file to get macromolecule charge
        File pqrFile = isimWrapper.apbsParameters.getPqrInputFile();
        
        BufferedReader pqrReader = new BufferedReader(new FileReader(pqrFile));

        String numberRegex = "[-0-9\\.]+";

        Pattern chargeRemarkPattern = 
            Pattern.compile("^REMARK\\s.*\\sTotal charge on\\s.*:\\s+("+numberRegex+")\\s*e");

        Pattern atomPattern = 
            Pattern.compile("^(?:ATOM  |HETATM).{47}\\d\\s+("+numberRegex+")\\s");

        String line;
        int atomCount = 0;
        double remarkCharge = Double.NaN;
        double atomCharge = 0.0;
        while ( (line = pqrReader.readLine()) != null ) {

            // 0000000000111111111122222222223333333333444444444455555555556666666666
            // 0123456789012345678901234567890123456789012345678901234567890123456789
            // REMARK   6 Total charge on this protein: -246.0000 e
            // REMARK   6 Total charge on this protein: 3.0000 e
            // ATOM      1  O5' G      96      44.945   9.378  47.182 -0.6600 1.7700
            // HETATM 4801  O   WAT   401       2.067  16.277   2.415 -0.8340 1.7682

            Matcher remarkMatcher = chargeRemarkPattern.matcher(line);
            if (remarkMatcher.find()) {
                String chargeString = remarkMatcher.group(1);
                System.out.println("Charge = #"+chargeString+"#");
                remarkCharge = new Double(chargeString);
            }
            else {
                Matcher atomMatcher = atomPattern.matcher(line);
                if (atomMatcher.find()) {
                    atomCount ++;
                    String chargeString = atomMatcher.group(1);
                    // System.out.println("#"+chargeString+"#");

                    atomCharge += new Double(chargeString);
                }
            }
        }
        
        
        System.out.println(""+atomCount+" atoms found");
        System.out.println("Charge from remark = "+remarkCharge);
        System.out.println("Charge from atoms = "+atomCharge);
        
        // Reconcile charges
        // Prefer remark, if available
        double moleculeCharge;
        if (Double.isNaN(remarkCharge)) moleculeCharge = atomCharge;
        else moleculeCharge = remarkCharge;

        return moleculeCharge;
    }
    
    
    // Populate ion counts table from files in working directory
    public void loadAllIonCounts() throws IOException {
       
        int nIons = isimWrapper.apbsParameters.getIonCount();        
        IonTableModel ionTableModel = new IonTableModel(nIons);
                
        // Ppopulate the table of ion counts
        ionTable.setModel(ionTableModel);
        
        JScrollPane scrollPane = new JScrollPane(ionTable);
        ionTable.setPreferredScrollableViewportSize(new Dimension(500, 70));
        tablePanel.removeAll();
        tablePanel.add(scrollPane);
        
        // Load ion radii and charges

        // But first sort the ions in natural order, to match the order from the run
        java.util.List<IonConcentration> sortedIons = 
            new Vector<IonConcentration>(isimWrapper.apbsParameters.getIons());
        Collections.sort(sortedIons);

        int ionIndex = 0;
        for (IonConcentration concentration : sortedIons) {
            IonSpecies ion = concentration.getIonSpecies();
            double radius = ion.getRadius();

            ionTableModel.setName(ionIndex, ionLabel(ion));
            ionTableModel.setRadius(ionIndex, radius);
            
            ionIndex ++;
        }
        
        // Load the counts from the simulation
        loadAverageIonCounts(nIons, ionTableModel);
        
        // Load the expected bulk counts for the same volume
        loadBulkIonCounts(nIons, ionTableModel);
        
        // System.out.println(ionTableModel.dumpDebug());
        
        // Charge on the macromolecule
        double moleculeCharge = loadMoleculeCharge();
        DecimalFormat chargeFormat = new DecimalFormat("0.0");
        String moleculeChargeString = chargeFormat.format(moleculeCharge);
        moleculeChargeLabel.setText("Macromolecule charge = " + moleculeChargeString + " e");
        
        // Compute charge contribution from ions
        ionIndex = 0;
        double totalIonCharge = 0.0;
        for (IonConcentration ion : sortedIons) {
            double ionCount = ionTableModel.getAverageCount(ionIndex).doubleValue();
            double ionCharge = ion.getCharge();
            
            double deltaCharge = ionCount * ionCharge;            
            totalIonCharge += deltaCharge;
            System.out.println("Ion charge = "+deltaCharge);
            
            ionIndex ++;
        }
        double totalCharge = moleculeCharge + totalIonCharge;
        String totalChargeString = chargeFormat.format(totalCharge);
        totalChargeLabel.setText("Total charge (macromolecule + ions) = " + totalChargeString + " e");
        
        ionTable.revalidate();
        ionTable.repaint();
        pack();
    }

    private String ionLabel(IonSpecies ion) {
        String ionString = "" + ion.getIonId();

        // Put charge in parentheses with a plus or minus sign
        float charge = ion.getCharge();
        DecimalFormat chargeFormat = new DecimalFormat("0");
        String chargeString = chargeFormat.format(charge);
        if (charge > 0) chargeString = "+"+chargeString;

        ionString += " (" + chargeString + ")";

        return ionString;
    }
    
    // Load ion counts from simulation from the file NUMBERS
    private void loadBulkIonCounts(int nIons, IonTableModel ionTableModel) throws IOException {
        double bulkIonCounts[] = new double[nIons];
        
        // Read NUMBERS file to get observed ion counts
        File directory = isimWrapper.workingDirectory;

        // File isimLogFile = new File(directory, "ISIM.log");
        // Capitalization in Linux is "isim.LOG"
        // Hopefully windows and mac are either the same or case insensitive
        File isimLogFile = new File(directory, "isim.LOG");
        
        BufferedReader isimLogReader = new BufferedReader(new FileReader(isimLogFile));

        // Look for lines in log file with expected bulk ion count, like:
        // "Bulk expectation number of Mg2plus ions is 3."
        Pattern bulkIonCountRegex = Pattern.compile("^Bulk expectation number of (\\S+) ions is ([-+0-9\\.]+)");

        String line;
        int ionIndex = 0;
        while ( (line = isimLogReader.readLine()) != null ) {
            // Find expected ion counts
            Matcher matcher = bulkIonCountRegex.matcher(line);
            if (matcher.matches()) {
                // Parse bulk ion count values
                String ionName = matcher.group(1);
                Double ionCount = new Double(matcher.group(2));
                
                bulkIonCounts[ionIndex] = ionCount.doubleValue();

                // Populate table
                // ionTableModel.setName(ionIndex, ionName);
                // ionTableModel.setValueAt(ionName, ionIndex, 0); // Ion Name
                
                // Ion Count
                ionTableModel.setBulkCount(ionIndex, ionCount);
                // ionTableModel.setValueAt(ionCountNumber, ionIndex, 3);
                
                ionIndex ++;
            }
        }
        
        isimLogReader.close();

    }

    // Load ion counts from simulation from the file NUMBERS
    private void loadAverageIonCounts(int nIons, IonTableModel ionTableModel) throws IOException {
        
        
        long ionCountSum[] = new long[nIons];
        long stepCount[] = new long[nIons];
        Vector ionCounts[] = new Vector[nIons];

        // Read NUMBERS file to get observed ion counts
        File directory = isimWrapper.workingDirectory;
        File numbersFile = new File(directory, "NUMBERS");
        
        BufferedReader numbersReader = new BufferedReader(new FileReader(numbersFile));

        String line;
        while ( (line = numbersReader.readLine()) != null ) {
            StringTokenizer st = new StringTokenizer(line);

            // If the line is empty, go to the next line
            try {st.nextToken();} // step number string
            catch (NoSuchElementException exc) {continue;}

            
            // Read one ion count from each column
            int ionIndex = 0;
            while (st.hasMoreTokens()) {
                // Initialize ionCounts "just in time"
                if (ionCounts[ionIndex] == null) {
                    ionCounts[ionIndex] = new Vector();
                }
                
                Integer count = new Integer(st.nextToken());
                ionCounts[ionIndex].add(count);
                ionCountSum[ionIndex] += count.longValue();
                stepCount[ionIndex] += 1;
                ionIndex ++;
            }
            if (ionIndex != nIons) throw new IOException("Unexpected number of columns in NUMBERS file: " + line);
        }
        
        numbersReader.close();

        // Compute average count for each ion
        double averageCount[] = new double[nIons];
        for (int ionIndex = 0; ionIndex < nIons; ionIndex ++) {
            averageCount[ionIndex] = 
                ((double)ionCountSum[ionIndex])/((double)stepCount[ionIndex]);

            ionTableModel.setAverageCount(ionIndex, new Double(averageCount[ionIndex]));
        }
        
        // Compute standard deviation for each ion count
        double stdDevCount[] = new double[nIons];
        for (int ionIndex = 0; ionIndex < nIons; ionIndex ++) {
            if (stepCount[ionIndex] == 0) continue; // ignore ions with no observations

            // For numerical stability, divide by the number of counts in two steps
            // Is this premature optimization?
            double sqrtStepCount = Math.sqrt((double)stepCount[ionIndex]);
            
            double variance = 0;
            
            Iterator stepIterator = ionCounts[ionIndex].iterator();
            while (stepIterator.hasNext()) {
                Integer count = (Integer) stepIterator.next();
                double difference = count.doubleValue() - averageCount[ionIndex];
                variance += ((difference * difference) / sqrtStepCount);
            }
            
            variance = variance/sqrtStepCount; // finalize average variance
            stdDevCount[ionIndex] = Math.sqrt(variance);
            
            ionTableModel.setDeviation(ionIndex, new Double(stdDevCount[ionIndex]));
        }
        
    }
    
    /**
     * Respond to pressing the "dismiss" button
     * @param event
     */
    public void actionPerformed(ActionEvent event) {

        if (event.getSource() == dismissButton) {
            setVisible(false);
        }

        else if (event.getSource() == showDetailsButton) {
            SimulationDetailsDialog details = new SimulationDetailsDialog(isimWrapper);
            details.setVisible(true);
        }

        else if (event.getSource() == saveResultsButton) {

            JFileChooser ionCountSaveFileChooser = new JFileChooser();
            ionCountSaveFileChooser.addChoosableFileFilter(new CsvFilter());
            ionCountSaveFileChooser.setFileSelectionMode(JFileChooser.FILES_AND_DIRECTORIES);
            try {
                File defaultFile = new File(new File("ionCounts.csv").getCanonicalPath());
                ionCountSaveFileChooser.setSelectedFile(defaultFile);
            } catch (IOException exc) {}

            int result = ionCountSaveFileChooser.showSaveDialog(isimWrapper);

            CSV_CHOICE: switch (result) {
            case JFileChooser.APPROVE_OPTION:
                // Approve (Open or Save) was clicked
                File selFile = ionCountSaveFileChooser.getSelectedFile();
                if (selFile.exists()) {
                    Object[] options = {"Overwrite", "Cancel"};
                    int response = JOptionPane.showOptionDialog (isimWrapper,
                            "Overwrite existing file?\n"+selFile,
                            "Confirm Overwrite",
                            JOptionPane.OK_CANCEL_OPTION,
                            JOptionPane.QUESTION_MESSAGE,
                            null,
                            options,
                            1); // 1 is Cancel
                    if (response == 1) { // Cancel
                        break CSV_CHOICE;
                    }
                  }

                try {
                    PrintWriter csvOut =
                      new PrintWriter (new BufferedWriter (new FileWriter (selFile)));
                    
                    // Write comma-separated table data to the file
                    
                    // Header labels
                    for (int col = 0; col < ionTable.getColumnCount(); col ++) {
                        if (col > 0) csvOut.print(",");
                        csvOut.print(ionTable.getColumnName(col));
                    }
                    csvOut.print("\n");

                    // Table data
                    for (int row = 0; row < ionTable.getRowCount(); row ++) {
                        for (int col = 0; col < ionTable.getColumnCount(); col++) {
                            if (col > 0) csvOut.print(",");
                            csvOut.print(ionTable.getValueAt(row, col));
                        }
                        csvOut.print("\n");
                    }

                    csvOut.flush ();
                    csvOut.close ();
                 }
                 catch (IOException e) {
                    JOptionPane.showMessageDialog(isimWrapper, 
                            "Error: Save file failed:\n"+e, 
                            "Error: Save File Failed",
                            JOptionPane.WARNING_MESSAGE
                            );
                 }

                break;
            case JFileChooser.CANCEL_OPTION:
                // Cancel or the close-dialog icon was clicked
                break;
            case JFileChooser.ERROR_OPTION:
                // The selection process did not complete successfully
                break;
            }
            
            ionCountSaveFileChooser.setVisible(false);
        }

    }
    
    class CsvFilter extends javax.swing.filechooser.FileFilter {
        public boolean accept(File file) {
            String filename = file.getName();
            return filename.endsWith(".csv");
        }
        public String getDescription() {
            return "Comma separated values (.csv) for Excel";
        }
    }    

    class IonTableModel extends AbstractTableModel {
        String columnNames[] = {"Ion", "Ion Radius (\u212B)", "Total Ions", "Background Ions", "Excess Ions"};
        int columnCount = columnNames.length;
        private java.util.List<IonCountStats> rows = new Vector<IonCountStats>();

        public IonTableModel(int rowCount) {
            for (int i = 0; i < rowCount; i++) {
                IonCountStats row = new IonCountStats();
                rows.add(row);
            }
            fireTableStructureChanged();
        }
        
        public int getRowCount() {
            return rows.size();
        }
        public int getColumnCount() {
            return columnCount;
        }
        public Object getValueAt(int i, int j) {
            IonCountStats row = (IonCountStats) rows.get(i);
            switch(j) {
                case 0: return row.getName();
                case 1: return row.getRadius();
                case 2: return row.getAverageCount();
                // case 2: return row.getAverageCountDeviation();
                case 3: return row.getBulkCount();
                case 4: return row.getExcessCount();
                default: return null;
            }
        }

        public void setValueAt(Object o, int i, int j) {
            switch(j) {
                case 0: { setName(i, o.toString()); break; }
                case 1: { setRadius(i, (Number) o); break; }
                case 2: { setAverageCount(i, (Number) o); break; }
                // case 2: { row.setAverageCountDeviation((Number) o); break; }
                case 3: { setBulkCount(i, (Number) o); break; }
            }
        }

        public void setName(int i, String name) {
            IonCountStats row = (IonCountStats) rows.get(i);
            row.setName(name);
            fireTableDataChanged();
        }
        
        public void setRadius(int i, Number radius) {
            IonCountStats row = (IonCountStats) rows.get(i);
            row.setRadius(new FormattedTableNumber(radius));
            fireTableDataChanged();
        }
        
        public void setAverageCount(int i, Number count) {
            IonCountStats row = (IonCountStats) rows.get(i);
            row.setAverageCount(new FormattedTableNumber(count));
            fireTableDataChanged();
        }
        
        public Number getAverageCount(int i) {
            IonCountStats row = (IonCountStats) rows.get(i);
            return row.getAverageCount();
        }
        
        public void setDeviation(int i, Number deviation) {
            IonCountStats row = (IonCountStats) rows.get(i);
            row.setAverageCountDeviation(new FormattedTableNumber(deviation));
            fireTableDataChanged();
        }
        
        public void setBulkCount(int i, Number count) {
            IonCountStats row = (IonCountStats) rows.get(i);
            row.setBulkCount(count);
            fireTableDataChanged();
        }
        
        public String getColumnName(int j) {
            return columnNames[j];
        }

        public Class getColumnClass(int c) {
            try {
                switch(c) {
                case 0: { return Class.forName("java.lang.String"); }
                default: { return Class.forName("java.lang.Number"); }
                }
            } catch (ClassNotFoundException exc) {
                return (new Object()).getClass();
            }
        }

        public String dumpDebug() {
            StringBuffer answer = new StringBuffer("");

            // Generate column labels
            for (int col = 0; col < getColumnCount(); col ++) {
                if (col > 0) answer.append("\t"); // tab separated
                answer.append(getColumnName(col));
            }
            answer.append("\n");

            // Generate data rows
            for (int row = 0; row < getRowCount(); row ++) {
                for (int col = 0; col < getColumnCount(); col ++) {
                    if (col > 0) answer.append("\t");
                    Object value = getValueAt(row, col);
                    if (value == null) answer.append("(null)");
                    else answer.append(getValueAt(row, col).toString());
                }
                answer.append("\n");
            }
            
            return answer.toString();
        }

        static final long serialVersionUID = 01L;
    };

    /**
     *  
      * @author Christopher Bruns
      * 
      * Ion count statistics container class, to be used as the data model for the Ion table
     */
    class IonCountStats {
        private String name = "";
        private NumberWithError  averageCount = new NumberWithError();
        // private Number  averageCountDeviation = new Double(0);
        private Number  bulkCount = new Integer(0);
        private Number radius = new FormattedTableNumber(0);
        
        public IonCountStats() {
        }
        
        public String getName() { return name; }
        public Number getRadius() { return radius; }
        public Number getAverageCount() { return averageCount; }
        public Number getAverageCountDeviation() { return averageCount.getError(); }
        public Number getBulkCount() { return bulkCount; }
        public Number getExcessCount() {return new FormattedTableNumber(new Double(averageCount.doubleValue() - bulkCount.doubleValue())); }
        
        public void setRadius(Number radius) { this.radius = radius; }
        public void setName(String name) { this.name = name; }
        public void setAverageCount(Number averageCount) { this.averageCount.setValue(averageCount); }
        public void setAverageCountDeviation(Number averageCountDeviation) { this.averageCount.setError(averageCountDeviation); }
        public void setBulkCount(Number bulkCount) { this.bulkCount = bulkCount; }
    }    

    static final long serialVersionUID = 02L;
}

class NumberWithError extends Number {
    private Number number = new Double(0.0);
    private Number error = new Double(0.0);
    
    public float floatValue() {return number.floatValue();}
    public int intValue() {return number.intValue();}
    public double doubleValue() {return number.doubleValue();}
    public long longValue() {return number.longValue();}

    public void setValue(Number value) {
        number = value;
    }    
    public void setError(Number error) {
        this.error = error;
    }
    
    public Number getValue() {return number;}
    public Number getError() {return error;}
    
    public String toString() {
        String plusOrMinus = " \u00B1 ";

        // return number.toString() + " +/- " + error.toString();
        return "" + number + plusOrMinus + error;
    }
}

class FormattedTableNumber extends Number {
    DecimalFormat oneDecimalPlace = new DecimalFormat("#0.0");
    private Number number;
    public FormattedTableNumber(Number number) {
        this.number = number;
    }
    public String toString() {
        return oneDecimalPlace.format(number.doubleValue());
    }
    public double doubleValue() {return number.doubleValue();}
    public float floatValue() {return number.floatValue();}
    public long longValue() {return number.longValue();}
    public int intValue() {return number.intValue();}
    
    static final long serialVersionUID = 01L;
    
}


