import numpy as np
import Model_Based.utils as utils
from Model_Based.parser import Parser
from Model_Based.config import Config
from Model_Based.model import TrueModel_land, EstimatedModel_land, TrueModel_air, EstimatedModel_air
from Model_Based.controller import Controller_land, Controller_air
from time import time
import torch
import scipy.io as sp
from torch import tensor

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

class Solver:
    def __init__(self, config):
        self.config = config


        if self.config.domain == 'land':
            # Initialize the controller
            self.controller = Controller_land(config, hidden=self.config.hidden_nodes)

            # Initialize the model
            if self.config.true_model:
                self.model = TrueModel_land()
            else:
                self.model = EstimatedModel_land(config, hidden=self.config.hidden_nodes)

        elif self.config.domain == 'air':
            # Initialize the controller
            self.controller = Controller_air(config, hidden=self.config.hidden_nodes)

            # Initialize the model
            if self.config.true_model:
                self.model = TrueModel_air()
            else:
                self.model = EstimatedModel_air(config, hidden=self.config.hidden_nodes)

        else:
            raise ValueError('Unknown domain')

    def matlab_saver(self, model_point, controller_point):
        if not self.config.true_model:
            # Save dynamics model in matlab format
            # Note that models quality is independent of the controllers quality
            path_string = "{}model_{}_{}".format(self.config.paths['ckpt'], self.config.domain, model_point)
            self.model.load(path_string)

            all_weights = {}
            for key, val in self.model.state_dict().items():
                key = key.replace(".", "_")
                all_weights[key] =  val.data.numpy()

            path_string = "{}model_matlab_{}_{}".format(self.config.paths['ckpt'], self.config.domain, model_point)
            sp.savemat(path_string, all_weights)


        # Save controller in matlab format
        # Note that the quality of the controller depends on model's quality as well
        path_string = "{}controller_{}_{}_{}".format(self.config.paths['ckpt'], self.config.domain, model_point, controller_point)
        self.controller.load(path_string)

        all_weights = {}
        for key, val in self.controller.state_dict().items():
            key = key.replace(".", "_")
            all_weights[key] =  val.data.numpy()

        path_string = "{}controller_matlab_{}_{}_{}".format(self.config.paths['ckpt'], self.config.domain, model_point, controller_point)
        sp.savemat(path_string, all_weights)


    def train(self, model_point, controller_points):
        # Train the model if system dynamics need to be estimated
        if not self.config.true_model:
            # We are using an estimated model,
            # hence we need to train the model to learn system dynamics
            self.train_model(model_point)

        # Train the controller using model predictive control methods
        self.train_controller(model_point, controller_points)
        # self.model.load()


    def train_model(self, model_point):
        # Here we will learn the model for system dynamics

        # Load the dataset
        dataset = utils.DataBuffer()
        dataset.load(self.config.paths['dataset'], self.config.domain)

        # Temporary variables
        if dataset.__len__() < model_point:
            raise ValueError("Dataset size smaller than needed")

        num_train = model_point
        indices = list(range(num_train))
        split = int(np.floor(self.config.valid_fraction * num_train))
        np.random.shuffle(indices)

        # Randomly split the indices for training and validating
        train_idx, valid_idx = indices[split:], indices[:split]

        # Use the indices to select the samples for training and validating
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)

        # Efficient parallel multi-threaded data loader for the Neural Network
        train_iter = DataLoader(dataset=dataset,
                                batch_size=self.config.batch_size,
                                num_workers=4,
                                sampler=train_sampler)

        # Efficient parallel multi-threaded data loader for the Neural Network
        valid_iter = DataLoader(dataset=dataset,
                                batch_size=self.config.batch_size,
                                num_workers=4,
                                sampler=valid_sampler)

        print("Model training phase started...")

        train_losses = []
        final_train_errors = []

        valid_losses = []

        time_ckpt = time()
        for counter in range(self.config.max_model_iterations+1):
            final_valid_errors = []
            final_valid_inputs = []

            batch_losses = []
            # Compute training loss and update model parameters
            for batch_id, batch_data in enumerate(train_iter):
                x, y = batch_data                # x = action, y = displacement
                y_pred = self.model.forward(x)   # predict the next state using the learned model
                errors = (y - y_pred)**2
                mean_error = torch.mean(errors)

                self.model.update(mean_error)           # Minimize the loss
                batch_losses.append(mean_error.item())  # Keep track of the mean loss batch

                if counter == self.config.max_model_iterations -1:
                    final_train_errors.extend(errors.tolist())

            train_losses.append(np.mean(batch_losses))  # mean in loss in one epoch

            # Compute validation loss periodically after a fixed interval
            if counter % self.config.save_after == 0 or counter == self.config.max_model_iterations -1:
            # if True:
                batch_losses = []
                for batch_id, batch_data in enumerate(valid_iter):
                    x, y = batch_data
                    y_pred = self.model.forward(x)   # predict the next state using the learned model
                    errors = (y - y_pred)**2
                    mean_error = torch.mean(errors)
                    batch_losses.append(mean_error.item())      # mean loss in a batch

                    # if counter == self.config.max_model_iterations - 1:
                    final_valid_errors.extend(errors.tolist())
                    final_valid_inputs.extend(x.tolist())

                valid_losses.append(np.mean(batch_losses))      # mean loss in one epoch

                print("Epoch {} | train loss:: {} | val_loss:: {} | time:: {:.3f}".format\
                      (counter, np.mean(train_losses[-10:]), np.mean(valid_losses[-10:]), time() - time_ckpt))

                path_string = "{}model_{}_{}_".format(self.config.paths['results'], self.config.domain, model_point)
                np.save(path_string + "train_loss", train_losses)
                np.save(path_string + "valid_loss", valid_losses)
                np.save(path_string + "final_train_error", final_train_errors)
                np.save(path_string + "final_valid_error", final_valid_errors)
                np.save(path_string + "final_valid_inputs", final_valid_inputs)

                path_string = "{}model_{}_{}".format(self.config.paths['ckpt'], self.config.domain, model_point)
                self.model.save(path_string)
                time_ckpt = time()

            # Terminate model training once validation loss starts increasing.
            if len(train_losses) >= 30 and np.mean(valid_losses[-20:]) > np.mean(valid_losses[-30:-10]):
                print("Converged...")
                break


    def train_controller(self, model_point, controller_points):
        # Here we will learn a model predictive controller
        if self.config.max_episodes < max(controller_points):
            raise ValueError("Increase maximum episode limit")

        # ckpt = self.config.save_after
        ckpt = controller_points.pop(0)
        returns, rm, t0 = [], 0, time()

        for episode in range(1, self.config.max_episodes+1):
            total_return = 0
            self.model.reset_trajectory()
            state = self.model.start_state()
            state_dim = state.shape[-1]

            for _ in range(self.config.max_steps):
                next_action = self.controller.get_action(state)  # Use the controller to get action for the state
                next_state, reward, info = self.model.get_next_state_and_reward(state, next_action)  # Use the model to obtain the next state
                total_return = total_return + reward      # Accumulate the rewards

                # Add noise to state transition to make training robust
                noise = torch.randn(state_dim) * self.config.noise_var
                next_state = next_state + noise

                # Update state and action for next time step
                state = next_state
                # action = next_action

                # Check whether the trajectory should be stopped or not
                if self.model.check_termination(next_state):
                    break

            # Maximize total return = minimize negative of total return
            self.controller.update(- total_return)

            # Track the progress of training
            rm = 0.99*rm + 0.01*total_return.item()

            # self.model.render()

            # Create checkpoint
            # if episode%ckpt == 0 or episode == self.config.max_episodes-1:
            returns.append(total_return.item())
            if episode%ckpt == 0:
                print("{} :: Rewards {:.3f} :: Time {:.3f} :: Grads : {}".
                      format(episode, rm, (time() - t0)/ckpt, self.controller.get_grads()))

                path_string = "_{}_{}_{}".format(self.config.domain, model_point, ckpt)
                np.save(config.paths['results'] + "returns" + path_string, returns)
                self.controller.save(self.config.paths['ckpt']+'controller' + path_string)
                utils.save_plots(returns, self.config, path_string)
                self.matlab_saver(model_point, ckpt)

                t0 = time()

                # Display the trajectory
                # self.model.render()

                # Do monte-carlo evaluation of the controller to figure out ground truth safe/unsafe states
                self.eval(1000, model_point, ckpt)

                if len(controller_points) > 0:
                    ckpt = controller_points.pop(0)
                else:
                    break


    def eval(self, count, model_point, controller_point):
        eval_trajectories = []
        safe_starts = []

        if self.config.domain == 'land':
            true_model = TrueModel_land()
        elif self.config.domain == 'air':
            true_model = TrueModel_air()
        else:
            raise ValueError("Unknown domain")

        for idx in range(count):
            true_model.reset_trajectory()
            flag = False
            # state = tensor([0.0, -0.3 ])
            state = true_model.start_state()
            start_state = state.tolist()

            trajectory = [state.tolist()]
            for _ in range(self.config.max_steps):
                next_action = self.controller.get_action(state)  # Use the controller to get action for the state
                next_state, reward, info = true_model.get_next_state_and_reward(state, next_action)  # Use the model to obtain the next state

                # Update state and action for next time step
                state = next_state

                # track the trajectory
                trajectory.append(state.tolist())

                if info["unsafe"]:
                    flag = True

                # Check whether the trajectory should be stopped or not
                if true_model.check_termination(next_state):
                    break

            true_model.render()

            eval_trajectories.append(trajectory)
            start_state.append(int(flag))
            safe_starts.append(start_state)

        path_string = "_{}_{}_{}".format(self.config.domain, model_point, controller_point)
        np.save(config.paths['results'] + "eval_trajectories" + path_string, eval_trajectories)
        np.save(config.paths['results'] + "safe_starts" + path_string, safe_starts)
        # print(eval_trajectories)



if __name__== "__main__":
    t = time()
    args = Parser().get_parser().parse_args()

    config = Config(args)

    model_points = [600, 1000]
    controller_points = [100, 300, 600]
    # model_points = [300, 600, 1000]
    # controller_points = [100, 300, 600]

    for model_point in model_points:
            # Reinitialize solver for each setting
            solver = Solver(config=config)

            # Training mode
            solver.train(model_point, controller_points.copy())
            #
            # # Evaluation mode
            # solver.eval(1000, model_point, controller_points[0])

            # Save stuff in matlab readbale format
            # solver.matlab_saver(model_point, controller_point)

    print("Total time taken: ", time() - t)
