#!/usr/bin/env python

import numpy as np

from tqdm import tqdm

from state import State

class System:
    '''
    Combine classifier and response function in setting.

    Used to store computed values of simulation with specified resolution.
    '''
    def __init__(self, name, setting, classifier_func, response_func, res, resp_kwargs={}, cls_kwargs={}):

        self.name = name # for plotting
        self.setting = setting
        self.classifier_func = classifier_func
        self.response_func = response_func
        self.res = res
        self.resp_kwargs = resp_kwargs
        self.cls_kwargs = cls_kwargs

    def calculate(self):
        '''
        Assumes two groups
        '''
        res = self.res

        self.x = np.linspace(0.01, 0.99, res) # s_1
        self.y = np.linspace(0.01, 0.99, res) # s_2

        self.xx, self.yy = np.meshgrid(self.x, self.y)

        # Velocity vectors
        self.Vx = np.zeros((res, res)) # x-component (group 1)
        self.Vy = np.zeros((res, res)) # y-component (group 2)

        # Acceptance rate for each (of two) groups
        self.A1 = np.zeros((res, res))
        self.A2 = np.zeros((res, res))

        # Outcomes (false negative rates, false positive rates)
        self.fpr1 = np.zeros((res, res))
        self.fpr2 = np.zeros((res, res))
        self.fnr1 = np.zeros((res, res))
        self.fnr2 = np.zeros((res, res))

        # Average fitness of each group
        self.f1 = np.zeros((res, res))
        self.f2 = np.zeros((res, res))

        # https://eli.thegreenplace.net/2014/meshgrids-and-disambiguating-rows-and-columns-from-cartesian-coordinates/
        # Arrays indexed by row = y, column = x.

        for ix in tqdm(range(res)):
            for iy in range(res):

                state = State(self.setting.mu, np.array([self.x[ix], self.y[iy]]))

                # classifier
                phi = self.classifier_func(self.setting, state, **self.cls_kwargs)

                self.A1[iy,ix], self.A2[iy,ix] = self.setting.beta(phi, state)
                self.fpr1[iy,ix], self.fpr2[iy,ix] = self.setting.fpr(phi, state)
                self.fnr1[iy,ix], self.fnr2[iy,ix] = self.setting.fnr(phi, state)
                self.f1[iy,ix], self.f2[iy,ix] = self.setting.avg_fitness(phi, state)

                # response
                vel = self.response_func(self.setting, state, phi, **self.resp_kwargs)
                self.Vx[iy,ix], self.Vy[iy,ix] = vel.sg

        # normalize fitness over entire state space
        max_f = np.max((np.max(self.f1), np.max(self.f2)))
        min_f = np.min((np.min(self.f1), np.min(self.f2)))
        self.f1 = (self.f1 - min_f) / (max_f - min_f)
        self.f2 = (self.f2 - min_f) / (max_f - min_f)

        # # normalize acceptance rates
        # self.A1 = self.A1 / (varphi * 2)
        # self.A2 = self.A2 / (varphi * 2)
