"""
A class for converting a PySB model to a set of ordinary differential
equations for integration in MATLAB.

Note that for use in MATLAB, the name of the ``.m`` file must match the name of
the exported MATLAB class (e.g., ``robertson.m`` for the example below).

For information on how to use the model exporters, see the documentation
for :py:mod:`pysb.export`.

Output for the Robertson example model
======================================

Information on the form and usage of the generated MATLAB class is contained in
the documentation for the MATLAB model, as shown in the following example for
``pysb.examples.robertson``::

    classdef robertson
        % A simple three-species chemical kinetics system known as "Robertson's
        % example", as presented in:
        % 
        % H. H. Robertson, The solution of a set of reaction rate equations, in Numerical
        % Analysis: An Introduction, J. Walsh, ed., Academic Press, 1966, pp. 178-182.
        % 
        % A class implementing the ordinary differential equations
        % for the robertson model.
        %
        % Save as robertson.m.
        %
        % Generated by pysb.export.matlab.MatlabExporter.
        %
        % Properties
        % ----------
        % observables : struct
        %     A struct containing the names of the observables from the
        %     PySB model as field names. Each field in the struct
        %     maps the observable name to a matrix with two rows:
        %     the first row specifies the indices of the species
        %     associated with the observable, and the second row
        %     specifies the coefficients associated with the species.
        %     For any given timecourse of model species resulting from
        %     integration, the timecourse for an observable can be
        %     retrieved using the get_observable method, described
        %     below.
        %
        % parameters : struct
        %     A struct containing the names of the parameters from the
        %     PySB model as field names. The nominal values are set by
        %     the constructor and their values can be overriden
        %     explicitly once an instance has been created.
        %
        % Methods
        % -------
        % robertson.odes(tspan, y0)
        %     The right-hand side function for the ODEs of the model,
        %     for use with MATLAB ODE solvers (see Examples).
        %
        % robertson.get_initial_values()
        %     Returns a vector of initial values for all species,
        %     specified in the order that they occur in the original
        %     PySB model (i.e., in the order found in model.species).
        %     Non-zero initial conditions are specified using the
        %     named parameters included as properties of the instance.
        %     Hence initial conditions other than the defaults can be
        %     used by assigning a value to the named parameter and then
        %     calling this method. The vector returned by the method
        %     is used for integration by passing it to the MATLAB
        %     solver as the y0 argument.
        %
        % robertson.get_observables(y)
        %     Given a matrix of timecourses for all model species
        %     (i.e., resulting from an integration of the model),
        %     get the trajectories corresponding to the observables.
        %     Timecourses are returned as a struct which can be
        %     indexed by observable name.
        %
        % Examples
        % --------
        % Example integration using default initial and parameter
        % values:
        %
        % >> m = robertson();
        % >> tspan = [0 100];
        % >> [t y] = ode15s(@m.odes, tspan, m.get_initial_values());
        %
        % Retrieving the observables:
        %
        % >> y_obs = m.get_observables(y)
        %
        properties
            observables
            parameters
        end

        methods
            function self = robertson()
                % Assign default parameter values
                self.parameters = struct( ...
                    'k1', 0.040000000000000001, ...
                    'k2', 30000000, ...
                    'k3', 10000, ...
                    'A_0', 1, ...
                    'B_0', 0, ...
                    'C_0', 0);

                % Define species indices (first row) and coefficients
                % (second row) of named observables
                self.observables = struct( ...
                    'A_total', [1; 1], ...
                    'B_total', [2; 1], ...
                    'C_total', [3; 1]);
            end

            function initial_values = get_initial_values(self)
                % Return the vector of initial conditions for all
                % species based on the values of the parameters
                % as currently defined in the instance.

                initial_values = zeros(1,3);
                initial_values(1) = self.parameters.A_0; % A()
                initial_values(2) = self.parameters.B_0; % B()
                initial_values(3) = self.parameters.C_0; % C()
            end

            function y = odes(self, tspan, y0)
                % Right hand side function for the ODEs

                % Shorthand for the struct of model parameters
                p = self.parameters;

                % A();
                y(1,1) = -p.k1*y0(1) + p.k3*y0(2)*y0(3);
                % B();
                y(2,1) = p.k1*y0(1) - p.k2*power(y0(2), 2) - p.k3*y0(2)*y0(3);
                % C();
                y(3,1) = p.k2*power(y0(2), 2);
            end

            function y_obs = get_observables(self, y)
                % Retrieve the trajectories for the model observables
                % from a matrix of the trajectories of all model
                % species.

                % Initialize the struct of observable timecourses
                % that we will return
                y_obs = struct();

                % Iterate over the observables;
                observable_names = fieldnames(self.observables);
                for i = 1:numel(observable_names)
                    obs_matrix = self.observables.(observable_names{i});
                    species = obs_matrix(1, :);
                    coefficients = obs_matrix(2, :);
                    y_obs.(observable_names{i}) = ...
                                    y(:, species) * coefficients';
                end
            end
        end
    end
"""

import pysb
import pysb.bng
import sympy
import re
try:
    from cStringIO import StringIO
except ImportError:
    from io import StringIO
from pysb.export import Exporter, pad

class MatlabExporter(Exporter):
    """A class for returning the ODEs for a given PySB model for use in
    MATLAB.

    Inherits from :py:class:`pysb.export.Exporter`, which implements
    basic functionality for all exporters.
    """
    def export(self):
        """Generate a MATLAB class definition containing the ODEs for the PySB
        model associated with the exporter.

        Returns
        -------
        string
            String containing the MATLAB code for an implementation of the
            model's ODEs.
        """
        output = StringIO()
        pysb.bng.generate_equations(self.model)

        docstring = ''
        if self.docstring:
            docstring += self.docstring.replace('\n', '\n    % ')

        # Substitute underscores for any dots in the model name
        model_name = self.model.name.replace('.', '_')

        # -- Parameters and Initial conditions -------
        # Declare the list of parameters as a struct
        params_str = 'self.parameters = struct( ...\n'+' '*16
        params_str_list = []
        for i, p in enumerate(self.model.parameters):
            # Add parameter to struct along with nominal value
            cur_p_str = "'%s', %.17g" % (_fix_underscores(p.name), p.value)
            # Decide whether to continue or terminate the struct declaration:
            if i == len(self.model.parameters) - 1:
                cur_p_str += ');'    # terminate
            else:
                cur_p_str += ', ...' # continue

            params_str_list.append(cur_p_str)

        # Format and indent the params struct declaration
        params_str += ('\n'+' '*16).join(params_str_list)

        # Fill in an array of the initial conditions based on the named
        # parameter values
        initial_values_str = ('initial_values = zeros(1,%d);\n'+' '*12) % \
                             len(self.model.species)
        initial_values_str += ('\n'+' '*12).join(
                ['initial_values(%d) = self.parameters.%s; %% %s' %
                 (i+1, _fix_underscores(ic[1].name), ic[0])
                 for i, ic in enumerate(self.model.initial_conditions)])

        # -- Build observables declaration --
        observables_str = 'self.observables = struct( ...\n'+' '*16
        observables_str_list = []
        for i, obs in enumerate(self.model.observables):
            # Associate species and coefficient lists with observable names,
            # changing from zero- to one-based indexing
            cur_obs_str = "'%s', [%s; %s]" % \
                          (_fix_underscores(obs.name),
                           ' '.join([str(sp+1) for sp in obs.species]),
                           ' '.join([str(c) for c in obs.coefficients]))
            # Decide whether to continue or terminate the struct declaration:
            if i == len(self.model.observables) - 1:
                cur_obs_str += ');'    # terminate
            else:
                cur_obs_str += ', ...' # continue

            observables_str_list.append(cur_obs_str)
        # Format and indent the observables struct declaration
        observables_str += ('\n'+' '*16).join(observables_str_list)

        # -- Build ODEs -------
        # Build a stringified list of species
        species_list = ['%% %s;' % s for i, s in enumerate(self.model.species)]
        # Build the ODEs as strings from the model.odes array
        odes_list = ['y(%d,1) = %s;' % (i+1, sympy.ccode(self.model.odes[i])) 
                     for i in range(len(self.model.odes))] 
        # Zip the ODEs and species string lists and then flatten them
        # (results in the interleaving of the two lists)
        odes_species_list = [item for sublist in zip(species_list, odes_list)
                                  for item in sublist]
        # Flatten to a string and add correct indentation
        odes_str = ('\n'+' '*12).join(odes_species_list)

        # Change species names from, e.g., '__s(0)' to 'y0(1)' (note change
        # from zero-based indexing to 1-based indexing)
        odes_str = re.sub(r'__s(\d+)', \
                          lambda m: 'y0(%s)' % (int(m.group(1))+1), odes_str)
        # Change C code 'pow' function to MATLAB 'power' function
        odes_str = re.sub(r'pow\(', 'power(', odes_str)
        # Prepend 'p.' to named parameters and fix any underscores
        for i, p in enumerate(self.model.parameters):
            odes_str = re.sub(r'\b(%s)\b' % p.name,
                              'p.%s' % _fix_underscores(p.name), odes_str)

        # -- Build final output --
        output.write(pad(r"""
            classdef %(model_name)s
                %% %(docstring)s
                %% A class implementing the ordinary differential equations
                %% for the %(model_name)s model.
                %%
                %% Save as %(model_name)s.m.
                %%
                %% Generated by pysb.export.matlab.MatlabExporter.
                %%
                %% Properties
                %% ----------
                %% observables : struct
                %%     A struct containing the names of the observables from the
                %%     PySB model as field names. Each field in the struct
                %%     maps the observable name to a matrix with two rows:
                %%     the first row specifies the indices of the species
                %%     associated with the observable, and the second row
                %%     specifies the coefficients associated with the species.
                %%     For any given timecourse of model species resulting from
                %%     integration, the timecourse for an observable can be
                %%     retrieved using the get_observable method, described
                %%     below.
                %%
                %% parameters : struct
                %%     A struct containing the names of the parameters from the
                %%     PySB model as field names. The nominal values are set by
                %%     the constructor and their values can be overriden
                %%     explicitly once an instance has been created.
                %%
                %% Methods
                %% -------
                %% %(model_name)s.odes(tspan, y0)
                %%     The right-hand side function for the ODEs of the model,
                %%     for use with MATLAB ODE solvers (see Examples).
                %%
                %% %(model_name)s.get_initial_values()
                %%     Returns a vector of initial values for all species,
                %%     specified in the order that they occur in the original
                %%     PySB model (i.e., in the order found in model.species).
                %%     Non-zero initial conditions are specified using the
                %%     named parameters included as properties of the instance.
                %%     Hence initial conditions other than the defaults can be
                %%     used by assigning a value to the named parameter and then
                %%     calling this method. The vector returned by the method
                %%     is used for integration by passing it to the MATLAB
                %%     solver as the y0 argument.
                %%
                %% %(model_name)s.get_observables(y)
                %%     Given a matrix of timecourses for all model species
                %%     (i.e., resulting from an integration of the model),
                %%     get the trajectories corresponding to the observables.
                %%     Timecourses are returned as a struct which can be
                %%     indexed by observable name.
                %%
                %% Examples
                %% --------
                %% Example integration using default initial and parameter
                %% values:
                %%
                %% >> m = %(model_name)s();
                %% >> tspan = [0 100];
                %% >> [t y] = ode15s(@m.odes, tspan, m.get_initial_values());
                %%
                %% Retrieving the observables:
                %%
                %% >> y_obs = m.get_observables(y)
                %%
                properties
                    observables
                    parameters
                end

                methods
                    function self = %(model_name)s()
                        %% Assign default parameter values
                        %(params_str)s

                        %% Define species indices (first row) and coefficients
                        %% (second row) of named observables
                        %(observables_str)s
                    end

                    function initial_values = get_initial_values(self)
                        %% Return the vector of initial conditions for all
                        %% species based on the values of the parameters
                        %% as currently defined in the instance.

                        %(initial_values_str)s
                    end

                    function y = odes(self, tspan, y0)
                        %% Right hand side function for the ODEs

                        %% Shorthand for the struct of model parameters
                        p = self.parameters;

                        %(odes_str)s
                    end

                    function y_obs = get_observables(self, y)
                        %% Retrieve the trajectories for the model observables
                        %% from a matrix of the trajectories of all model
                        %% species.

                        %% Initialize the struct of observable timecourses
                        %% that we will return
                        y_obs = struct();

                        %% Iterate over the observables;
                        observable_names = fieldnames(self.observables);
                        for i = 1:numel(observable_names)
                            obs_matrix = self.observables.(observable_names{i});
                            species = obs_matrix(1, :);
                            coefficients = obs_matrix(2, :);
                            y_obs.(observable_names{i}) = ...
                                            y(:, species) * coefficients';
                        end
                    end
                end
            end
            """, 0) %
            {'docstring': docstring,
             'model_name': model_name,
             'params_str':params_str,
             'initial_values_str': initial_values_str,
             'observables_str': observables_str,
             'params_str': params_str,
             'odes_str': odes_str})

        return output.getvalue()

def _fix_underscores(name):
    if name.startswith('_'):
        return 'X' + name
    else:
        return name