import functools
import json
import logging
from io import BytesIO

import numpy as np
from google.protobuf.json_format import MessageToJson

from .summary import _clean_tag

try:
    import comet_ml
    comet_installed = True
    from PIL import Image
except ImportError:
    comet_installed = False
logger = logging.getLogger(__name__)


class CometLogger:
    def __init__(self, comet_config={"disabled": True}):
        global comet_installed
        self._logging = None
        self._comet_config = comet_config
        if comet_config["disabled"] is True:
            self._logging = False
        elif comet_config["disabled"] is False and comet_installed is False:
            raise Exception("Comet and/or Python Image Library not installed. Run 'pip install comet-ml pillow'")

    def _requiresComet(method):
        @functools.wraps(method)
        def wrapper(*args, **kwargs):
            self = args[0]
            global comet_installed
            if self._logging is None and comet_installed:
                self._logging = False
                try:
                    if 'api_key' not in self._comet_config:
                        comet_ml.init()
                    if comet_ml.get_global_experiment() is not None:
                        logger.warning("You have already created a comet \
                                        experiment manually, which might \
                                        cause clashes")
                    self._experiment = comet_ml.Experiment(**self._comet_config)
                    self._logging = True
                    self._experiment.log_other("Created from", "tensorboardX")
                except Exception as e:
                    logger.warning(e)

            if self._logging is True:
                return method(*args, **kwargs)
        return wrapper

    @_requiresComet
    def end(self):
        """Ends an experiment."""
        self._experiment.end()
        comet_ml.config.experiment = None

    @_requiresComet
    def log_metric(self, tag, display_name, value, step=None, epoch=None,
                   include_context=True):
        """Logs a general metric (i.e accuracy, f1)..

        Args:
            tag: String - Data identifier
            display_name: The title of the plot. If empty string is passed,
              `tag` will be used.
            value: Float/Integer/Boolean/String
            step: Optional. Used as the X axis when plotting on comet.ml
            epoch: Optional. Used as the X axis when plotting on comet.ml
            include_context: Optional. If set to True (the default),
                the current context will be logged along the metric.
        """
        name = _clean_tag(tag) if display_name == "" else display_name
        self._experiment.log_metric(name, value, step, epoch,
                                    include_context)

    @_requiresComet
    def log_metrics(self, dic, prefix=None, step=None, epoch=None):
        """Logs a key,value dictionary of metrics.

        Args:
            dic: key,value dictionary of metrics
            prefix: prefix added to metric name
            step: Optional. Used as the X axis when plotting on comet.ml
            epoch: Optional. Used as the X axis when plotting on comet.ml
        """
        self._experiment.log_metrics(dic, prefix, step, epoch)

    @_requiresComet
    def log_parameters(self, parameters, prefix=None, step=None):
        """Logs a dictionary (or dictionary-like object) of multiple parameters.

        Args:
            parameters: key,value dictionary of parameters
            prefix: prefix added to metric name
            step: Optional. Used as the X axis when plotting on comet.ml
        """
        self._experiment.log_parameters(parameters, prefix, step)

    @_requiresComet
    def log_audio(self, audio_data, sample_rate=None, file_name=None,
                  metadata=None, overwrite=False, copy_to_tmp=True,
                  step=None):
        """Logs the audio Asset determined by audio data.

        Args:
        audio_data: String or a numpy array - either the file path
            of the file you want to log, or a numpy array given to
            scipy.io.wavfile.write for wav conversion.
        sample_rate: Integer - Optional. The sampling rate given to
            scipy.io.wavfile.write for creating the wav file.
        file_name: String - Optional. A custom file name to be displayed.
            If not provided, the filename from the audio_data argument
            will be used.
        metadata: Some additional data to attach to the the audio asset.
            Must be a JSON-encodable dict.
        overwrite: if True will overwrite all existing assets with the same name.
        copy_to_tmp: If audio_data is a numpy array, then this flag
            determines if the WAV file is first copied to a temporary
            file before upload. If copy_to_tmp is False, then it is sent
            directly to the cloud.
        step: Optional. Used to associate the audio asset to a specific step.
        """
        self._experiment.log_audio(audio_data, sample_rate, file_name,
                                   metadata, overwrite, copy_to_tmp,
                                   step)

    @_requiresComet
    def log_text(self, text, step=None, metadata=None):
        """Logs the text. These strings appear on the Text Tab in the Comet UI.

        Args:
        text: string to be stored
        step: Optional. Used to associate the asset to a specific step.
        metadata: Some additional data to attach to the the text. Must
            be a JSON-encodable dict.
        """
        self._experiment.log_text(text, step, metadata)

    @_requiresComet
    def log_histogram(self, values, name=None, step=None, epoch=None,
                      metadata=None, **kwargs):
        """Logs a histogram of values for a 3D chart as an asset for
           this experiment. Calling this method multiple times with the
           same name and incremented steps will add additional histograms
           to the 3D chart on Comet.ml.

        Args:
        values: a list, tuple, array (any shape) to summarize, or a
            Histogram object
        name: str (optional), name of summary
        step: Optional. Used as the Z axis when plotting on Comet.ml.
        epoch: Optional. Used as the Z axis when plotting on Comet.ml.
        metadata: Optional: Used for items like prefix for histogram name.
        kwargs: Optional. Additional keyword arguments for histogram.
        """
        self._experiment.log_histogram_3d(values, name, step,
                                          epoch, metadata,
                                          **kwargs)

    @_requiresComet
    def log_histogram_raw(self, tag, summary, step=None):
        """Log Raw Histogram Data to Comet as an Asset.

        Args:
            tag: Name given to the logged asset
            summary: TensorboardX Summary protocol buffer with histogram data
            step: The Global Step for this experiment run. Defaults to None.
        """

        histogram_proto = summary.value[0].histo
        histogram_raw_data = MessageToJson(histogram_proto)
        histogram_raw_data['name'] = tag

        self.log_asset_data(data=histogram_raw_data, name=tag, step=step)

    @_requiresComet
    def log_curve(self, name, x, y, overwrite=False, step=None):
        """Log timeseries data.

        Args:
        name: (str) name of data
        x: array of x-axis values
        y: array of y-axis values
        overwrite: (optional, bool) if True, overwrite previous log
        step: (optional, int) the step value
        """
        self._experiment.log_curve(name, x.tolist(), y.tolist(), overwrite, step)

    @_requiresComet
    def log_image_encoded(self, encoded_image_string, tag, step=None):
        """Logs the image. Images are displayed on the Graphics tab on Comet.ml.

        Args:
        encoded_image_string: Required. An encoded image string
        tag: String - Data identifier
        step: Optional. Used to associate the image asset to a specific step.
        """
        buff = BytesIO(encoded_image_string)
        image_pil = Image.open(buff)
        name = _clean_tag(tag)
        self._experiment.log_image(image_pil, name, step=step)

    @_requiresComet
    def log_asset(self, file_data, file_name=None, overwrite=False,
                  copy_to_tmp=True, step=None, metadata=None):
        """Logs the Asset determined by file_data.

        Args:
        file_data: String or File-like - either the file path of the
            file you want to log, or a file-like asset.
        file_name: String - Optional. A custom file name to be displayed.
            If not provided the filename from the file_data argument will be used.
        overwrite: if True will overwrite all existing assets with
            the same name.
        copy_to_tmp: If file_data is a file-like object, then this flag
            determines if the file is first copied to a temporary file
            before upload. If copy_to_tmp is False, then it is sent
            directly to the cloud.
        step: Optional. Used to associate the asset to a specific step.
        """
        self._experiment.log_asset(file_data, file_name, overwrite,
                                   copy_to_tmp, step, metadata)

    @_requiresComet
    def log_asset_data(self, data, name=None, overwrite=False, step=None,
                       metadata=None, epoch=None):
        """Logs the data given (str, binary, or JSON).

        Args:
        data: data to be saved as asset
        name: String, optional. A custom file name to be displayed If
            not provided the filename from the temporary saved file
            will be used.
        overwrite: Boolean, optional. Default False. If True will
            overwrite all existing assets with the same name.
        step: Optional. Used to associate the asset to a specific step.
        epoch: Optional. Used to associate the asset to a specific epoch.
        metadata: Optional. Some additional data to attach to the
            asset data. Must be a JSON-encodable dict.
        """
        self._experiment.log_asset_data(data, name, overwrite, step,
                                        metadata, epoch)

    @_requiresComet
    def log_embedding(self, vectors, labels, image_data=None,
                      image_preprocess_function=None, image_transparent_color=None,
                      image_background_color_function=None, title="Comet Embedding",
                      template_filename=None,
                      group=None):
        """Log a multi-dimensional dataset and metadata for viewing
           with Comet's Embedding Projector (experimental).

        Args:
        vectors: the tensors to visualize in 3D
        labels: labels for each tensor
        image_data: (optional) list of arrays or Images
        image_preprocess_function: (optional) if image_data is an array,
            apply this function to each element first
        image_transparent_color: a (red, green, blue) tuple
        image_background_color_function: a function that takes an
            index, and returns a (red, green, blue) color tuple
        title: (optional) name of tensor
        template_filename: (optional) name of template JSON file
        """
        image_size = None
        if labels is None:
            return
        if image_data is not None:
            image_data = image_data.cpu().detach().numpy()
            image_size = image_data.shape[1:]
            if image_size[0] == 1:
                image_size = image_size[1:]
        if type(labels) == list:
            labels = np.array(labels)
        else:
            labels = labels.cpu().detach().numpy()
        self._experiment.log_embedding(vectors, labels, image_data,
                                       image_size, image_preprocess_function,
                                       image_transparent_color,
                                       image_background_color_function,
                                       title, template_filename,
                                       group)

    @_requiresComet
    def log_mesh(self, tag, vertices, colors, faces, config_dict, step, walltime):
        """Logs a mesh as an asset

        Args:
        tag: Data identifier
        vertices: List of the 3D coordinates of vertices.
        colors: Colors for each vertex
        faces: Indices of vertices within each triangle.
        config_dict: Dictionary with ThreeJS classes names and configuration.
        step: step value to record
        walltime: Optional override default walltime (time.time())
            seconds after epoch of event
        """
        mesh_json = {}
        mesh_json['tag'] = tag
        mesh_json['vertices'] = vertices.tolist()
        mesh_json['colors'] = colors.tolist()
        mesh_json['faces'] = faces.tolist()
        mesh_json['config_dict'] = config_dict
        mesh_json['walltime'] = walltime
        mesh_json['asset_type'] = 'mesh'
        mesh_json = json.dumps(mesh_json)
        self.log_asset_data(mesh_json, tag, step=step)

    @_requiresComet
    def log_raw_figure(self, tag, asset_type, step=None, **kwargs):
        """Logs a histogram as an asset.

        Args:
        tag: Data identifier
        asset_type: List of the 3D coordinates of vertices.
        step: step value to record
        """
        file_json = kwargs
        file_json['asset_type'] = asset_type
        self.log_asset_data(file_json, tag, step=step)

    @_requiresComet
    def log_pr_data(self, tag, summary, num_thresholds, step=None):
        """Logs a Precision-Recall Curve Data as an asset.

        Args:
        tag: An identifier for the PR curve
        summary: TensorboardX Summary protocol buffer.
        step: step value to record
        """
        tensor_proto = summary.value[0].tensor
        shape = [d.size for d in tensor_proto.tensor_shape.dim]

        values = np.fromiter(tensor_proto.float_val, dtype=np.float32).reshape(shape)
        thresholds = [1.0 / num_thresholds * i for i in range(num_thresholds)]
        tp, fp, tn, fn, precision, recall = map(lambda x: x.flatten().tolist(), np.vsplit(values, values.shape[0]))

        pr_data = {
            'TP': tp,
            'FP': fp,
            'TN': tn,
            'FN': fn,
            'precision': precision,
            'recall': recall,
            'thresholds': thresholds,
            'name': tag,
        }

        self.log_asset_data(pr_data, name=tag, step=step)

    @_requiresComet
    def log_pr_raw_data(self, tag, true_positive_counts,
                        false_positive_counts, true_negative_counts,
                        false_negative_counts, precision, recall,
                        num_thresholds, weights, step=None):
        """Logs a Precision-Recall Curve Data as an asset.

        Args:
        tag: An identifier for the PR curve
        summary: TensorboardX Summary protocol buffer.
        step: step value to record
        """
        thresholds = [1.0 / num_thresholds * i for i in range(num_thresholds)]
        tp, fp, tn, fn, precision, recall = map(lambda x: x.flatten().tolist(), [
            true_positive_counts,
            false_positive_counts,
            true_negative_counts,
            false_negative_counts,
            precision,
            recall])

        pr_data = {
            'TP': tp,
            'FP': fp,
            'TN': tn,
            'FN': fn,
            'precision': precision,
            'recall': recall,
            'thresholds': thresholds,
            'weights': weights,
            'name': tag,
        }

        self.log_asset_data(pr_data, name=tag, step=step)
