import os
import numpy as np
from config import parse_args, save_module
# from model import learning_model
from data_loader import OfflineMetaRLAlgorithm
import torch
from numpy.random import default_rng
from DNN import EnsembleTransition
from torch.utils.data import TensorDataset, DataLoader


def clone(dnn_from, dnn_to):
    with torch.no_grad():
        if True:
            params_model = list(dnn_from.parameters())
            params_copy = list(dnn_to.parameters())
            # params_copy = params_model
            for p_ind in range(len(params_model)):
                params_copy[p_ind].copy_(params_model[p_ind].clone())


def metaRL_algo(_args, iterator_, out_dnn, in_dnn_list, save_folder, test_dnn_list, test_iterator):
    alpha = _args.alpha
    iter_nums = _args.training_iteration_number
    task_num = _args.training_task_number
    inner_iter_num = _args.task_training_iterations
    task_idx = np.arange(len(iterator_))
    check_file = ''
    test_check_file = ''

    training_loss_record = []
    training_loss_record_idx = []
    for _i in range(iter_nums):
        print('Iteration: ' + str(_i))
        permuted_task_idx = np.random.permutation(task_idx)[:task_num]
        sum_phi = []
        training_loss = 0
        for _j in permuted_task_idx:
            # Perform inner loops
            clone(out_dnn, in_dnn_list[_j])
            for _k in range(inner_iter_num):
                task_obs, task_act, next_task_obs, reward = iterator_[_j].get_next_batch()
                cost_ = in_dnn_list[_j].inner_training(task_obs, task_act, next_task_obs, reward, out_dnn)
                # total_cost += cost_
                print('Outer iteration: ' + str(_i) + ' Task: ' + str(_j) + ' Task iteration: ' + str(_k)
                      + ' Training cost: ' + str(cost_))
            with torch.no_grad():
                if _j == permuted_task_idx[0]:
                    sum_phi = list(in_dnn_list[_j].parameters())
                elif _j == permuted_task_idx[-1]:
                    phi_avg = []
                    for ind in range(len(sum_phi)):
                        phi_avg.append(sum_phi[ind] / len(permuted_task_idx))
                else:
                    current_phi = list(in_dnn_list[_j].parameters())
                    for ind in range(len(current_phi)):
                        sum_phi[ind] += current_phi[ind]
            training_loss += cost_
        # outer update
        with torch.no_grad():
            theta_params = list(out_dnn.parameters())
            for ind in range(len(theta_params)):
                new_vals = (1-alpha)*theta_params[ind] + alpha*phi_avg[ind]
                theta_params[ind].copy_(new_vals)

        training_loss_record.append(training_loss/task_num)
        training_loss_record_idx.append(_i)
        if _i % _args.save_interval == 0:
            print('Outer iteration: ' + str(_i) + ' Average loss over tasks: ' + str(training_loss))
            save_module(_args, save_folder, _i, [training_loss_record[_i]], ['training_loss'])
            prev_check_file = check_file
            check_file = save_folder + '/checkpoint_Iter_' + str(_i)
            torch.save({
                'training_task_models': [in_dnn_list[aa].state_dict() for aa in range(len(in_dnn_list))],
                'meta_model': out_dnn.state_dict()
            }, check_file)
            # try:
            #     os.remove(prev_check_file)
            # except FileNotFoundError:
            #     pass
        if True:  # True whenever there is testing tasks
            if _i % _args.meta_test_interval == 0:
                avg_train_loss, avg_test_loss = meta_test(_args, test_dnn_list, out_dnn, test_iterator)
                print('Outer iteration: ' + str(_i) + ' Average meta-training loss over tasks: ' + str(avg_train_loss)
                      + ' Average meta-testing loss over tasks: ' + str(avg_test_loss))
                save_module(_args, save_folder, _i, [avg_train_loss], ['meta_train_loss'])
                save_module(_args, save_folder, _i, [avg_test_loss], ['meta_test_loss'])
                prev_testing_check_file = test_check_file
                test_check_file = save_folder + '/test_checkpoint_Iter_' + str(_i)
                torch.save({
                    'testing_trained_task_models': [test_dnn_list[aa].state_dict() for aa in range(len(test_dnn_list))]
                }, test_check_file)
                # try:
                #     os.remove(prev_testing_check_file)
                # except FileNotFoundError:
                #     pass

    print('Training is completed.')


def meta_test(args_, all_test_models, init_model, test_iterator):
    avg_task_training_loss = 0
    avg_task_testing_loss = 0
    for ii in range(len(all_test_models)):
        total_task_training_loss = 0
        total_task_testing_loss = 0
        for t_iter in range(args_.meta_test_number):
            # Create an identical model to initial model
            clone(init_model, all_test_models[ii])
            # Perform gradient (or OMD) iterations
            for i in range(args_.test_task_iter_num):
                task_obs, task_act, next_task_obs, reward = test_iterator[ii].get_next_test_tr_batch()
                cost_ = all_test_models[ii].inner_training(task_obs, task_act, next_task_obs, reward, init_model)
                print('Task: ' + str(ii) + ' Replication: ' + str(t_iter) + ' Iteration: ' + str(i)
                      + ' Testing training cost: ' + str(cost_))
            task_obs, task_act, next_task_obs, reward = test_iterator[ii].get_next_test_batch()
            test_loss = all_test_models[ii].test_training(task_obs, task_act, next_task_obs, reward,
                                                          init_model, training=False)
            print('Task: ' + str(ii) + ' Replication: ' + str(t_iter) + ' Testing training cost: ' + str(test_loss))

            total_task_training_loss += cost_
            total_task_testing_loss += test_loss
        avg_task_training_loss += total_task_training_loss / args_.meta_test_number
        avg_task_testing_loss += total_task_testing_loss / args_.meta_test_number
    return avg_task_training_loss / len(all_test_models), avg_task_testing_loss / len(all_test_models)


def get_infinite_batches(iterator_obj):
    while True:
        for i, (samples_) in enumerate(iterator_obj):
            yield samples_


class pipeClass(object):
    def __init__(self, arg_, buff, task_ind):
        self.task_batch_size = arg_.task_batch_size
        task_temp = buff[task_ind]

        tensor_obs = torch.Tensor(task_temp._observations)
        tensor_act = torch.Tensor(task_temp._actions)
        tensor_next_obs = torch.Tensor(task_temp._next_obs)
        self.task_idx = task_ind
        tensor_reward = torch.Tensor(task_temp._rewards)

        my_dataset = TensorDataset(tensor_obs, tensor_act, tensor_next_obs, tensor_reward)
        self.train_ds = get_infinite_batches(torch.utils.data.DataLoader(my_dataset,
                                                                         batch_size=self.task_batch_size,
                                                                         shuffle=True,
                                                                         pin_memory=True))

    def get_next_batch(self):
        next_pair = self.train_ds.__next__()
        return next_pair


class pipeTestClass(object):
    def __init__(self, arg_, buff, task_ind):
        self.task_batch_size = arg_.task_batch_size
        task_temp = buff[task_ind]

        tr_idx = np.arange(len(task_temp._observations))
        permuted_idx = np.random.permutation(tr_idx)
        permuted_train_idx = permuted_idx[: arg_.task_test_train_sample_number]
        permuted_test_idx = permuted_idx[arg_.task_test_train_sample_number :]

        tensor_test_obs_tr = torch.Tensor(task_temp._observations[permuted_train_idx])
        tensor_test_act_tr = torch.Tensor(task_temp._actions[permuted_train_idx])
        tensor_test_next_obs_tr = torch.Tensor(task_temp._next_obs[permuted_train_idx])
        tensor_test_reward_tr = torch.Tensor(task_temp._rewards[permuted_train_idx])

        tensor_test_obs = torch.Tensor(task_temp._observations[permuted_test_idx])
        tensor_test_act = torch.Tensor(task_temp._actions[permuted_test_idx])
        tensor_test_next_obs = torch.Tensor(task_temp._next_obs[permuted_test_idx])
        tensor_test_reward = torch.Tensor(task_temp._rewards[permuted_test_idx])

        my_test_tr_dataset = TensorDataset(tensor_test_obs_tr, tensor_test_act_tr, tensor_test_next_obs_tr,
                                           tensor_test_reward_tr)
        self.test_tr_ds = get_infinite_batches(torch.utils.data.DataLoader(my_test_tr_dataset,
                                                                           batch_size=self.task_batch_size,
                                                                           shuffle=True,
                                                                           pin_memory=True))
        my_test_dataset = TensorDataset(tensor_test_obs, tensor_test_act, tensor_test_next_obs, tensor_test_reward)
        self.test_ds = get_infinite_batches(torch.utils.data.DataLoader(my_test_dataset,
                                                                        batch_size=self.task_batch_size,
                                                                        shuffle=True,
                                                                        pin_memory=True))

    def get_next_test_tr_batch(self):
        next_pair = self.test_tr_ds.__next__()
        return next_pair

    def get_next_test_batch(self):
        next_pair = self.test_ds.__next__()
        return next_pair


def main(args, sim_folder_):

    rng = default_rng()
    tasks = [x for x in range(2)]  # MODIFY
    train_tasks = rng.choice(len(tasks), size=args.train_task_number, replace=False)
    eval_tasks = np.setxor1d(range(len(tasks)), train_tasks)
    # print(tasks)
    # print(train_tasks)
    # print(eval_tasks)
    # eval_tasks = set(range(len(tasks))).difference(train_tasks)
    dimensions = [args.observation_dims, args.action_dims]

    loader_obj = OfflineMetaRLAlgorithm(args, dimensions, train_tasks, eval_tasks, eval_deterministic=True,
                                        render=False, render_eval_paths=False, plotter=None,
                                        data_dir=args.data_dir)
    # loader_obj.init_buffer()

    task_pipeline = []
    task_test_pipeline = []
    for i in range(len(loader_obj.train_buffer.task_buffers)):
        task_pipeline.append(pipeClass(args, loader_obj.train_buffer.task_buffers, loader_obj.train_tasks[i]))
    for i in range(len(loader_obj.eval_buffer.task_buffers)):
        task_test_pipeline.append(pipeTestClass(args, loader_obj.eval_buffer.task_buffers, loader_obj.eval_tasks[i]))
    print('Trajectories are loaded. Creating Neural Networks...')

    hidden_features = args.hidden_feature_number
    hidden_layers = args.hidden_layer_number
    outer_dnn = EnsembleTransition(args, dimensions[0], dimensions[1], hidden_features, hidden_layers,
                                   device=torch.device(args.device), cuda_index=args.cuda_index)
    inner_dnns = []
    for i in range(len(loader_obj.train_buffer.task_buffers)):
        inner_dnns.append(EnsembleTransition(args, dimensions[0], dimensions[1], hidden_features, hidden_layers,
                                             device=torch.device(args.device)))
    test_dnns = []
    for i in range(len(loader_obj.eval_buffer.task_buffers)):
        test_dnns.append(EnsembleTransition(args, dimensions[0], dimensions[1], hidden_features, hidden_layers,
                                            device=torch.device(args.device)))
    print('Neural Networks are created.')

    metaRL_algo(args, task_pipeline, outer_dnn, inner_dnns, sim_folder_, test_dnns, task_test_pipeline)


if __name__ == '__main__':
    argument = parse_args()
    print('Is cuda enabled: ' + str(argument.cuda))
    trial = 0
    sim_folder = argument.save_folder + '_' + str(trial)
    while os.path.exists(sim_folder):
        trial += 1
        sim_folder = argument.save_folder + '_' + str(trial)
    os.mkdir(sim_folder)
    setup_file = sim_folder + '/setup.txt'
    f = open(setup_file, "w+")
    f.write(str(argument))
    f.close()

    main(argument, sim_folder)
    print('Finished')
