import os
import numpy as np
import pickle
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

from utils.cityflow import CityFlowEnv
from utils.motsc.motsc import MOTSCFakeEnv
from agents.maxpressure_agent import MaxPressureAgent

class tester():
    def __init__(self, dic_agent_conf, dic_path, dic_traffic_env_conf, real_env=None, fake_env=None, agents=None, round=None):
        self.dic_agent_conf = dic_agent_conf
        self.dic_path = dic_path
        self.dic_traffic_env_conf = dic_traffic_env_conf
        self.round = round
        self.watching_intersections = []
        self.real_env = real_env if real_env is not None else CityFlowEnv(dic_path["PATH_TO_WORK_DIRECTORY"], dic_traffic_env_conf)
        self.fake_env = fake_env if fake_env is not None else MOTSCFakeEnv(os.path.join(self.dic_path["PATH_TO_OFFLINE_DATA"], "data"), dic_traffic_env_conf)
        if self.round is not None:
            self.save_path = os.path.join(os.getcwd(), self.dic_path["PATH_TO_WORK_DIRECTORY"], "test_round", "round_{}".format(self.round))
        else:
            self.save_path = os.path.join(os.getcwd(), self.dic_path["PATH_TO_WORK_DIRECTORY"])
        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        if agents is None:
            self.agents = []
            for _ in range(dic_traffic_env_conf["NUM_INTERSECTIONS"]):
                self.agents.append(MaxPressureAgent(dic_agent_conf, dic_traffic_env_conf, dic_path))
        else:
            self.agents = agents
            
    def _save_data(self, obj, filename):
        with open(os.path.join(self.save_path, filename), "wb") as f:
            pickle.dump(obj, f)
            
    def _load_data(self, filename):
        with open(os.path.join(self.save_path, filename), "rb") as f:
            obj = pickle.load(f)
        return obj
    
    def _linear_fit(self, x, y):
        x = np.array(x).reshape((len(x), 1))
        y = np.array(y).reshape((len(x), 1))
        linear_model = LinearRegression()
        linear_model.fit(x, y)
        return linear_model.coef_[0][0], linear_model.intercept_[0], linear_model.score(x, y)
    
    def test(self):
        test_start_step = 5
        
        real_env_state = self.real_env.reset()
        self.fake_env.reset()
        fake_env_state = real_env_state
        
        real_env_data_list = []
        fake_env_data_list = []
        
        for step in range(int(self.dic_traffic_env_conf["TEST_COUNTS"] / self.dic_traffic_env_conf["MIN_ACTION_TIME"])):
            action_list = []
            if len(self.agents) == 1:
                one_state = real_env_state
                action = self.agents[0].choose_action(step, one_state)
                action_list = action
            else:
                for i in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"]):
                    # use real env state to make choice
                    one_state = real_env_state[i]
                    action = self.agents[i].choose_action(0, one_state)
                    action_list.append(action)

            real_env_next_state, real_env_reward = self.real_env.step(action_list)
            if step < test_start_step:
                real_env_state = real_env_next_state
                continue
            elif step == test_start_step:
                self.fake_env.set_state(real_env_state)
                fake_env_state = real_env_state

            fake_env_next_state, fake_env_reward = self.fake_env.step(action_list)
            
            real_data = {"state" : np.array(real_env_state),
                         "action" : np.array(action_list),
                         "next_state" : np.array(real_env_next_state),
                         "reward" : np.array(real_env_reward),
                         "step" : step * self.dic_traffic_env_conf["MIN_ACTION_TIME"]}
            fake_data = {"state" : np.array(fake_env_state),
                         "action" : np.array(action_list),
                         "next_state" : np.array(fake_env_next_state),
                         "reward" : np.array(fake_env_reward),
                         "step" : step * self.dic_traffic_env_conf["MIN_ACTION_TIME"]}
            real_env_data_list.append(real_data)
            fake_env_data_list.append(fake_data)
            
            real_env_state = real_env_next_state
            fake_env_state = fake_env_next_state
            
        
        self._save_data(real_env_data_list, "real_env_data_iter.pkl")
        self._save_data(fake_env_data_list, "fake_env_data_iter.pkl")
        
    def test_single_step(self):
        real_env_state = self.real_env.reset()
        fake_env_state = real_env_state
        self.fake_env.reset()
        
        real_env_data_list = []
        fake_env_data_list = []
        
        for step in range(int(self.dic_traffic_env_conf["TEST_COUNTS"] / self.dic_traffic_env_conf["MIN_ACTION_TIME"])):
            action_list = []
            if len(self.agents) == 1:
                one_state = real_env_state
                action = self.agents[0].choose_action(step, one_state)
                action_list = action
            else:
                for i in range(self.dic_traffic_env_conf["NUM_INTERSECTIONS"]):
                    # use real env state to make choice
                    one_state = real_env_state[i]
                    action = self.agents[i].choose_action(0, one_state)
                    action_list.append(action)

                
            real_env_next_state, real_env_reward = self.real_env.step(action_list)
            real_data = {"state" : np.array(real_env_state),
                         "action" : np.array(action_list),
                         "next_state" : np.array(real_env_next_state),
                         "reward" : np.array(real_env_reward),
                         "step" : step * self.dic_traffic_env_conf["MIN_ACTION_TIME"]}
            fake_env_next_state, fake_env_reward = self.fake_env.test_single_step(real_data)
            fake_data = {"state" : np.array(fake_env_state),
                         "action" : np.array(action_list),
                         "next_state" : np.array(fake_env_next_state),
                         "reward" : np.array(fake_env_reward),
                         "step" : step * self.dic_traffic_env_conf["MIN_ACTION_TIME"]}
            
            real_env_data_list.append(real_data)
            fake_env_data_list.append(fake_data)
            
            real_env_state = real_env_next_state
            fake_env_state = real_env_state
            
        self._save_data(real_env_data_list, "real_env_data_step.pkl")
        self._save_data(fake_env_data_list, "fake_env_data_step.pkl")
            
    # waiting for update
    def draw_scatter(self, method, catagory, watching_intersections, feature=None):
        plt.close()
        real_env_list_all = self._load_data("real_env_data_{}.pkl".format(method))
        fake_env_list_all = self._load_data("fake_env_data_{}.pkl".format(method))
        
        for inter in watching_intersections:
            if catagory == "next_state":
                real_env_list = [np.mean(one_iter[catagory][inter][feature]) for one_iter in real_env_list_all]
                fake_env_list = [np.mean(one_iter[catagory][inter][feature]) for one_iter in fake_env_list_all]
            else:
                real_env_list = [one_iter[catagory][inter] for one_iter in real_env_list_all]
                fake_env_list = [one_iter[catagory][inter] for one_iter in fake_env_list_all]
            plt.xlabel("real {} {}".format(catagory, feature))
            plt.ylabel("predict {} {}".format(catagory, feature))
            k, b, score = self._linear_fit(real_env_list, fake_env_list)
            plt.plot(real_env_list, [k * x + b for x in real_env_list])
            plt.scatter(real_env_list, fake_env_list)
            plt.text(1.0, 0.1, "score : {}".format(score))
            fig_name = "{}_real_fake_{}_scatter_inter_{}.jpg".format(method, catagory, inter)
            plt.savefig(os.path.join(self.save_path, fig_name))
            plt.close()
    
    #waiting for update
    def draw_line(self, method, catagory, watching_intersections, feature=None):
        plt.close()
        real_env_list_all = self._load_data("real_env_data_{}.pkl".format(method))
        fake_env_list_all = self._load_data("fake_env_data_{}.pkl".format(method))
        
        for inter in watching_intersections:
            if catagory == "next_state":
                real_env_list = [np.mean(one_iter[catagory][inter][feature]) for one_iter in real_env_list_all]
                fake_env_list = [np.mean(one_iter[catagory][inter][feature]) for one_iter in fake_env_list_all]
            else:
                real_env_list = [one_iter[catagory][inter] for one_iter in real_env_list_all]
                fake_env_list = [one_iter[catagory][inter] for one_iter in fake_env_list_all]

            plt.xlabel("stem num")
            plt.ylabel("{} : {}".format(catagory, feature))
            plt.plot(np.linspace(0, len(real_env_list) * self.dic_traffic_env_conf["MIN_ACTION_TIME"], len(real_env_list)), real_env_list, label="real")
            plt.plot(np.linspace(0, len(fake_env_list) * self.dic_traffic_env_conf["MIN_ACTION_TIME"], len(fake_env_list)), fake_env_list, label="fake")
            plt.legend()
            fig_name = "{}_real_fake_{}_line_inter_{}.jpg".format(method, catagory, inter)
            plt.savefig(os.path.join(self.save_path, fig_name))
            plt.close()
        