from __future__ import absolute_import
import sys
import types

import numpy as np
import xarray as xr

from .. import util
from ..dimension import Dimension
from ..ndmapping import NdMapping, item_check, sorted_context
from ..element import Element
from .grid import GridInterface
from .interface import Interface

class XArrayInterface(GridInterface):

    types = (xr.Dataset if xr else None,)

    datatype = 'xarray'

    def dimension_type(cls, dataset, dim):
        name = dataset.get_dimension(dim).name

    def dtype(cls, dataset, dim):
        name = dataset.get_dimension(dim).name

    def init(cls, eltype, data, kdims, vdims):
        element_params = eltype.params()
        kdim_param = element_params['kdims']
        vdim_param = element_params['vdims']

        if isinstance (data, xr.DataArray):
                vdim = Dimension(
            elif vdims:
                vdim = vdims[0]
            elif len(vdim_param.default) == 1:
                vdim = vdim_param.default[0]
            vdims = [vdim]
            kdims = [Dimension(d) for d in data.dims[::-1]]
            data = xr.Dataset({ data})
        elif not isinstance(data, xr.Dataset):
            if kdims is None:
                kdims = kdim_param.default
            if vdims is None:
                vdims = vdim_param.default
            kdims = [kd if isinstance(kd, Dimension) else Dimension(kd)
                     for kd in kdims]
            vdims = [vd if isinstance(vd, Dimension) else Dimension(vd)
                     for vd in vdims]
            if isinstance(data, tuple):
                data = { vals for d, vals in zip(kdims + vdims, data)}
            if not isinstance(data, dict):
                raise TypeError('XArrayInterface could not interpret data type')
            coords = [(, data[]) for kd in kdims][::-1]
            arrays = {}
            for vdim in vdims:
                arr = data[]
                if not isinstance(arr, xr.DataArray):
                    arr = xr.DataArray(arr, coords=coords)
                arrays[] = arr
                data = xr.Dataset(arrays)
            if vdims is None:
                vdims = list(data.data_vars.keys())
            if kdims is None:
                kdims = [name for name in data.dims
                         if isinstance(data[name].data, np.ndarray)]

        if not isinstance(data, xr.Dataset):
            raise TypeError('Data must be be an xarray Dataset type.')
        return data, {'kdims': kdims, 'vdims': vdims}, {}

    def range(cls, dataset, dimension):
        dim = dataset.get_dimension(dimension).name
        if dim in
            data =[dim]
            dmin, dmax = data.min().data, data.max().data
            dmin = dmin if np.isscalar(dmin) else dmin.item()
            dmax = dmax if np.isscalar(dmax) else dmax.item()
            return dmin, dmax
            return np.NaN, np.NaN

    def groupby(cls, dataset, dimensions, container_type, group_type, **kwargs):
        index_dims = [dataset.get_dimension(d) for d in dimensions]
        element_dims = [kdim for kdim in dataset.kdims
                        if kdim not in index_dims]

        group_kwargs = {}
        if group_type != 'raw' and issubclass(group_type, Element):
            group_kwargs = dict(util.get_param_values(dataset),

        # XArray 0.7.2 does not support multi-dimensional groupby
        # Replace custom implementation when 
        # is merged.
        if len(dimensions) == 1:
            data = [(k, group_type(v, **group_kwargs)) for k, v in
            unique_iters = [cls.values(dataset, d, False) for d in dimensions]
            indexes = zip(*[vals.flat for vals in util.cartesian_product(unique_iters)])
            data = [(k, group_type(**dict(zip(dimensions, k))),
                    for k in indexes]

        if issubclass(container_type, NdMapping):
            with item_check(False), sorted_context(False):
                return container_type(data, kdims=index_dims)
            return container_type(data)

    def coords(cls, dataset, dim, ordered=False, expanded=False):
        if expanded:
            return util.expand_grid_coords(dataset, dim)
        data = np.atleast_1d([dim].data)
        if ordered and data.shape and np.all(data[1:] < data[:-1]):
            data = data[::-1]
        return data

    def values(cls, dataset, dim, expanded=True, flat=True):
        data =[dim].data
        if dim in dataset.vdims:
            coord_dims =[dim].dims
            data = cls.canonicalize(dataset, data, coord_dims=coord_dims)
            return data.T.flatten() if flat else data
        elif expanded:
            data = cls.coords(dataset, dim, expanded=True)
            return data.flatten() if flat else data
            return cls.coords(dataset, dim, ordered=True)

    def aggregate(cls, dataset, dimensions, function, **kwargs):
        if len(dimensions) > 1:
            raise NotImplementedError('Multi-dimensional aggregation not '
                                      'supported as of xarray <=0.7.2.')
        elif not dimensions:

    def unpack_scalar(cls, dataset, data):
        Given a dataset object and data in the appropriate format for
        the interface, return a simple scalar.
        if (len(data.data_vars) == 1 and
            len(data[dataset.vdims[0].name].shape) == 0):
            return data[dataset.vdims[0].name].item()
        return data

    def concat(cls, dataset_objs):
        #cast_objs = cls.cast(dataset_objs)
        # Reimplement concat to automatically add dimensions
        # once multi-dimensional concat has been added to xarray.
        return xr.concat([ for col in dataset_objs], dim='concat_dim')

    def redim(cls, dataset, dimensions):
        renames = {k: for k, v in dimensions.items()}

    def reindex(cls, dataset, kdims=None, vdims=None):

    def sort(cls, dataset, by=[]):
        return dataset

    def select(cls, dataset, selection_mask=None, **selection):
        validated = {}
        for k, v in selection.items():
            if isinstance(v, slice):
                v = (v.start, v.stop)
            if isinstance(v, set):
                validated[k] = list(v)
            elif isinstance(v, tuple):
                upper = None if v[1] is None else v[1]-sys.float_info.epsilon*10
                validated[k] = slice(v[0], upper)
            elif isinstance(v, types.FunctionType):
                validated[k] = v(dataset[k])
                validated[k] = v
        data =**validated)

        # Restore constant dimensions
        dropped = { np.atleast_1d(data[])
                   for d in dataset.kdims
                   if not data[].data.shape}
        if dropped:
            data = data.assign_coords(**dropped)

        indexed = cls.indexed(dataset, selection)
        if (indexed and len(data.data_vars) == 1 and
            len(data[dataset.vdims[0].name].shape) == 0):
            return data[dataset.vdims[0].name].item()
        return data

    def length(cls, dataset):
        return np.product(dataset[dataset.vdims[0].name].shape)
    def dframe(cls, dataset, dimensions):
        if dimensions:
            return dataset.reindex(columns=dimensions)

    def sample(cls, columns, samples=[]):
        raise NotImplementedError

    def add_dimension(cls, dataset, dimension, dim_pos, values, vdim):
        if not vdim:
            raise Exception("Cannot add key dimension to a dense representation.")
        dim = if isinstance(dimension, Dimension) else dimension
        arr = xr.DataArray(values,, name=dim,
        return**{dim: arr})