import numpy as np
import matplotlib.pyplot as plt


class DataRecord:
    def __init__(self, filename):

        self._filename = filename
        self._channel_num_list = []
        self._name_list = []
        self._data_list = []

    def init(self, channel_name='Value', channel_num=None):
        if type(channel_name) is str:

            if channel_num in self._channel_num_list:
                print("Channel {} already exists".format(channel_num))

                return None
            elif channel_num is None:
                if len(self._channel_num_list) == 0:
                    channel_num = 0
                else:
                    channel_num = np.max(self._channel_num_list) + 1
            else:
                pass

            self._channel_num_list.append(channel_num)
            self._name_list.append(channel_name)
            self._data_list.append(np.array([[0, 0]]))

            return channel_num

        else:
            print("Channel name is not string")

            return None

    def update(self, value, stamp, channel_num):
        try:
            idx = self._channel_num_list.index(channel_num)
            self._data_list[idx] = np.concatenate((self._data_list[idx], np.array([[stamp, value]])), axis=0)
        except ValueError:
            print("Fail to update channel {}, please initiate first".format(channel_num))

    def report(self):
        print("{} channels in total:".format(len(self._channel_num_list)))
        for idx, channel_num in enumerate(self._channel_num_list):
            print("Channel number: {}; Channel name: {}".format(channel_num, self._name_list[idx]))

    def set_channel(self, channel_num, new_channel_num=None, new_channel_name=None):
        try:
            idx = self._channel_num_list.index(channel_num)
            if new_channel_num is not None:
                self._channel_num_list[idx] = new_channel_num
            if new_channel_name is not None:
                self._name_list[idx] = new_channel_name

        except ValueError:
            print("No data for channel {}, please initiate first".format(channel_num))
            return None

    def load(self):
        print("Warning: Load will override existing data")
        saved_dict = np.load(self._filename, allow_pickle=True).item()

        self._channel_num_list = saved_dict['channel_num_list']
        self._name_list = saved_dict['name_list']
        self._data_list = saved_dict['data_list']

        return self

    def save(self):
        np.save(self._filename, {'channel_num_list': self._channel_num_list, 'name_list': self._name_list,
                                 'data_list': self._data_list})

    def copy(self, filename):
        np.save(filename, {'channel_num_list': self._channel_num_list, 'name_list': self._name_list,
                           'data_list': self._data_list})

    def remove(self, channel_num):
        try:
            idx = self._channel_num_list.index(channel_num)
            self._channel_num_list.pop(idx)
            self._name_list.pop(idx)
            self._data_list.pop(idx)

        except ValueError:
            print("No data for channel {}".format(channel_num))
            return None

    def merge(self, filename):
        merge_dict = np.load(filename, allow_pickle=True).item()

        merge_channel_num_list = merge_dict['channel_num_list']
        merge_name_list = merge_dict['name_list']
        merge_data_list = merge_dict['data_list']

        for channel_num in self._channel_num_list:
            if channel_num in merge_channel_num_list:
                print("Repeated channel number, cannot merge {}".format(filename))
                return

        self._channel_num_list = self._channel_num_list + merge_channel_num_list
        self._name_list = self._name_list + merge_name_list
        self._data_list = self._data_list + merge_data_list

    def average(self, channel_num_list, data_name=""):
        if (channel_num_list is None) or (len(channel_num_list) == 0):
            print("channel list is empty")
            return
        elif len(channel_num_list) == 1:
            print("channel list must contain multiple channels")
        else:
            data = self._data_list[channel_num_list[0]]
            idx_array = data[:, 0]

            array = data[:, 1]
            for count, channel_num in enumerate(channel_num_list):
                if count == 0:
                    continue
                else:
                    data_temp = self._data_list[channel_num]
                    if np.all(idx_array == data_temp[:, 0]):
                        array = np.concatenate((array, data_temp[:, 0]), axis=1)

        mean_array = np.mean(array, axis=1)
        std_array = np.std(array, axis=1)

        mean_channel_num = self.init(channel_name=data_name + "mean_array")
        std_channel_num = self.init(channel_name=data_name + "mean_array")

        mean_channel_idx = self._channel_num_list.index(mean_channel_num)
        std_channel_idx = self._channel_num_list.index(std_channel_num)

        self._data_list[mean_channel_idx] = np.concatenate((idx_array, mean_array), axis=1)
        self._data_list[std_channel_idx] = np.concatenate((idx_array, std_array), axis=1)

        return

    def visualize_plot(self, channel_nums=None, sort=True, add_label='', y_log=False, y_fun=None,
                       x_fun=None):
        if y_log:
            plt.yscale('log')
        if channel_nums is None:
            plot_list = []
            for channel_num in self._channel_num_list:
                idx = self._channel_num_list.index(channel_num)
                name = self._name_list[idx]
                data = np.copy(self._data_list[idx][1:, :])
                if x_fun is not None:
                    data[:, 0] = x_fun(data[:, 0])
                if y_fun is not None:
                    data[:, 1] = y_fun(data[:, 1])
                if sort:
                    sorted_idx = np.argsort(data[:, 0])
                    p = plt.plot(data[sorted_idx, 0], data[sorted_idx, 1], label=add_label + name)
                else:
                    p = plt.plot(data[:, 0], data[:, 1], label=add_label + name)
                plot_list.append(p)
            return plot_list
        if type(channel_nums) is not list:
            channel_nums = [channel_nums]
        plot_list = []
        for channel_num in channel_nums:
            try:
                idx = self._channel_num_list.index(channel_num)
                name = self._name_list[idx]
                data = np.copy(self._data_list[idx][1:, :])
                if x_fun is not None:
                    data[:, 0] = x_fun(data[:, 0])
                if y_fun is not None:
                    data[:, 1] = y_fun(data[:, 1])
                if sort:
                    sorted_idx = np.argsort(data[:, 0])
                    p = plt.plot(data[sorted_idx, 0], data[sorted_idx, 1], label=add_label + name)
                else:
                    p = plt.plot(data[:, 0], data[:, 1], label=add_label + name)
                plot_list.append(p)

            except ValueError:
                print("No data for channel {}, please initiate first".format(channel_num))

        return plot_list

    def visualize_scatter(self, channel_num):
        try:
            idx = self._channel_num_list.index(channel_num)
            name = self._name_list[idx]
            data = self._data_list[idx][1:, :]

            sc = plt.scatter(data[:, 0], data[:, 1], label=name)
            return sc

        except ValueError:
            print("No data for channel {}, please initiate first".format(channel_num))
            return None

    def visualize_errorbar(self, channel_num, err_channel_num):
        try:
            idx = self._channel_num_list.index(channel_num)
            name = self._name_list[idx]
            data = self._data_list[idx][1:, :]
            try:
                err_idx = self._channel_num_list.index(err_channel_num)
                err_data = self._data_list[err_idx][1:, :]
                if np.all(err_data[:, 0] == data[:, 0]):
                    pass
                else:
                    print("Time stamp does not match")
                    err_data = np.zeros(np.shape(data[:, 1]))
            except ValueError:
                print("Could not find channel {} for error data".format(err_channel_num))
                err_data = np.zeros(np.shape(data))
            eb = plt.errorbar(data[:, 0], data[:, 1], err_data[:, 1], label=name)
            return eb

        except ValueError:
            print("No data for channel {}, please initiate first".format(channel_num))
            return None
