from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from time import localtime, strftime

from dotmap import DotMap
from scipy.io import savemat, loadmat
from tqdm import trange
import torch
from Agent_dtwil import Agent
from DotmapUtils import get_required_argument
from dtaidistance import dtw_ndim
import matplotlib.pyplot as plt
import numpy as np
import random



TORCH_DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
class MBExperiment_dtwil:
    def __init__(self, params):
        """Initializes class instance.

        Argument:
            params (DotMap): A DotMap containing the following:
                .sim_cfg:
                    .env (gym.env): Environment for this experiment
                    .task_hor (int): Task horizon
                    .stochastic (bool): (optional) If True, agent adds noise to its actions.
                        Must provide noise_std (see below). Defaults to False.
                    .noise_std (float): for stochastic agents, noise of the form N(0, noise_std^2I)
                        will be added.

                .exp_cfg:
                    .ntrain_iters (int): Number of training iterations to be performed.
                    .nrollouts_per_iter (int): (optional) Number of rollouts done between training
                        iterations. Defaults to 1.
                    .ninit_rollouts (int): (optional) Number of initial rollouts. Defaults to 1.
                    .policy (controller): Policy that will be trained.

                .log_cfg:
                    .logdir (str): Parent of directory path where experiment data will be saved.
                        Experiment will be saved in logdir/<date+time of experiment start>
                    .nrecord (int): (optional) Number of rollouts to record for every iteration.
                        Defaults to 0.
                    .neval (int): (optional) Number of rollouts for performance evaluation.
                        Defaults to 1.
        """

        # Assert True arguments that we currently do not support
        assert params.sim_cfg.get("stochastic", False) == False

        self.env = get_required_argument(params.sim_cfg, "env", "Must provide environment.")
        self.task_hor = get_required_argument(params.sim_cfg, "task_hor", "Must provide task horizon.")
        self.agent = Agent(DotMap(env=self.env, noisy_actions=False))

        self.ntrain_iters = get_required_argument(
            params.exp_cfg, "ntrain_iters", "Must provide number of training iterations."
        )
        self.nrollouts_per_iter = params.exp_cfg.get("nrollouts_per_iter", 1)
        self.ninit_rollouts = params.exp_cfg.get("ninit_rollouts", 0)
        self.policy = get_required_argument(params.exp_cfg, "policy", "Must provide a policy.")

        self.logdir = os.path.join(
            get_required_argument(params.log_cfg, "logdir", "Must provide log parent directory."),
            strftime("%Y-%m-%d--%H:%M:%S", localtime())
        )
        self.nrecord = params.log_cfg.get("nrecord", 0)
        self.neval = params.log_cfg.get("neval", 1)

    def run_experiment(self):
        """Perform experiment.
        """
        os.makedirs(self.logdir, exist_ok=True)

        traj_obs, traj_acs, traj_rets, traj_rews, traj_dtws = [], [], [], [], []
        if self.env.spec.id=='MBRLHalfCheetah-v0':
            expert = torch.load('expert_datasets/halfcheetah_5trajs_processed.pt') #halfcheetah_18dim_expert3.pth halfcheetah_5trajs_processed.pt
            expert_next_obs = expert['next_obs'].to(torch.float64)
            expert_obs = expert['obs'].to(torch.float64)
            expert_actions = expert['actions'].to(torch.float64)
            expert_seed = [None]*expert_actions.shape[0]


        elif self.env.spec.id=='maze2d-umaze-blackroom-v0':
            expert = torch.load('expert_datasets/v2_maze2d-umaze-blackroom-v0_100.pt')
            expert_next_obs = expert['next_obs'].to(torch.float64)
            expert_obs = expert['obs'].to(torch.float64)
            expert_actions = expert['actions'].to(torch.float64)
            expert_seed = [None]*expert_actions.shape[0]
            
        elif self.env.spec.id=='maze2d-medium-v1':
            expert = torch.load('expert_datasets/maze2d_100.pt')
            expert_next_obs = expert['next_obs'].to(torch.float64)
            expert_obs = expert['obs'].to(torch.float64)
            expert_actions = expert['actions'].to(torch.float64)
            expert_seed = [None]*expert_actions.shape[0]

        elif self.env.spec.id=='MBRLHopper-v0':
            expert = torch.load('expert_datasets/hopper_5000_valuedice.pth')
            expert_next_obs = expert['next_obs'].to(torch.float64)
            expert_obs = expert['obs'].to(torch.float64)
            expert_actions = expert['actions'].to(torch.float64)
            expert_seed = [None]*expert_actions.shape[0]

            
        
        expert_s0 = []
        expert_s1 = []
        expert_a = []
        if self.env.spec.id == 'maze2d-umaze-blackroom-v0':
            start_idx = 0
            for i in range(len(expert_actions)-1):
                if not torch.all(expert_obs[i+1].eq(expert_next_obs[i])):
                    expert_s0.append(expert_obs[start_idx:i+1])
                    expert_s1.append(expert_next_obs[start_idx:i+1])
                    expert_a.append(expert_actions[start_idx:i+1])
                    start_idx = i+1
                elif i == len(expert_actions)-2:
                    expert_s0.append(expert_obs[start_idx:])
                    expert_s1.append(expert_next_obs[start_idx:])
                    expert_a.append(expert_actions[start_idx:])
        elif self.env.spec.id == 'maze2d-medium-v1':
            start_idx = 0
            for i in range(len(expert_actions)-1):
                if not torch.all(expert_obs[i+1].eq(expert_next_obs[i])):
                    expert_s0.append(expert_obs[start_idx:i+1])
                    expert_s1.append(expert_next_obs[start_idx:i+1])
                    expert_a.append(expert_actions[start_idx:i+1])
                    start_idx = i+1
                elif i == len(expert_actions)-2:
                    expert_s0.append(expert_obs[start_idx:])
                    expert_s1.append(expert_next_obs[start_idx:])
                    expert_a.append(expert_actions[start_idx:])
        elif self.env.spec.id == 'MBRLHalfCheetah-v0':
            for i in range(5):
                expert_s0.append(expert_obs[i*1000:i*1000+1000])
                expert_s1.append(expert_next_obs[i*1000:i*1000+1000])
                expert_a.append(expert_actions[i*1000:i*1000+1000])

        elif self.env.spec.id == 'MBRLHopper-v0':
            for i in range(5):
                expert_s0.append(expert_obs[i*1000:i*1000+1000])
                expert_s1.append(expert_next_obs[i*1000:i*1000+1000])
                expert_a.append(expert_actions[i*1000:i*1000+1000])
                
                
        self.policy.expert_max = expert_obs.max(dim=0)[0].to(TORCH_DEVICE)
        self.policy.expert_min = expert_obs.min(dim=0)[0].to(TORCH_DEVICE)
        self.policy.expert_maxmin_gap = (self.policy.expert_max-self.policy.expert_min+1e-8).to(TORCH_DEVICE)
        
        samples = []
        for i in range(len(expert_s0)):
            samples.append({'obs':torch.cat((expert_s0[i][:,:],expert_s1[i][-1:,:]),dim=0).cpu().numpy(),
            'ac':expert_a[i]})
        self.policy.train(
                [sample["obs"] for sample in samples],
                [sample["ac"] for sample in samples],
                None
            )
        
        # Training loop
        for i in trange(self.ntrain_iters):
            
            print("####################################################################")
            print("Starting training iteration %d." % (i + 1))

            iter_dir = os.path.join(self.logdir, "train_iter%d" % (i + 1))
            os.makedirs(iter_dir, exist_ok=True)
            samples = []
            expert_idx = i % len(expert_s0) # random.randint(0,len(expert_s0)-1)
            align_target_s0 = expert_s0[expert_idx:expert_idx+1]
            align_target_s1 = expert_s1[expert_idx:expert_idx+1]
            expert_traj = align_target_s0[0]
            expert_traj = torch.cat((expert_traj, align_target_s1[0][-1].unsqueeze(dim=0)))
            self.policy.expert_traj = expert_traj
            self.policy.expert_actions = expert_a[expert_idx]
            for j in range(max(self.neval, self.nrollouts_per_iter)):
                samples.append(
                    self.agent.sample(
                        self.task_hor, self.policy, expert_traj, expert_seed[expert_idx], os.path.join(self.logdir, self.env.spec.id + '_train_'+str(i)+'.mp4')
                    )
                )
            
            print("dtw_distance:", [sample["dtw_distance"] for sample in samples[:self.neval]])
            print("Rewards obtained:", [sample["reward_sum"] for sample in samples[:self.neval]])
            traj_obs.extend([sample["obs"] for sample in samples[:self.nrollouts_per_iter]])
            traj_acs.extend([sample["ac"] for sample in samples[:self.nrollouts_per_iter]])
            traj_rets.extend([sample["reward_sum"] for sample in samples[:self.neval]])
            traj_rews.extend([sample["rewards"] for sample in samples[:self.nrollouts_per_iter]])
            traj_dtws.extend([sample["dtw_distance"] for sample in samples[:self.neval]])
            print('avg return:', np.array(traj_rets).mean())
            
            samples = samples[:self.nrollouts_per_iter]

            self.policy.dump_logs(self.logdir, iter_dir)
            savemat(
                os.path.join(self.logdir, "logs.mat"),
                {
                    "observations": traj_obs,
                    "actions": traj_acs,
                    "returns": traj_rets,
                    "rewards": traj_rews,
                    "dtws": traj_dtws,
                }
            )
            # Delete iteration directory if not used
            if len(os.listdir(iter_dir)) == 0:
                os.rmdir(iter_dir)

            # print('training sample shape:' ,samples[0]["obs"].shape, samples[0]["ac"].shape)
            if i < self.ntrain_iters - 1:
                self.policy.train(
                    [sample["obs"] for sample in samples],
                    [sample["ac"] for sample in samples],
                    [sample["rewards"] for sample in samples]
                )
        
        log = loadmat(os.path.join(self.logdir, "logs.mat"))
        saved_traj = {}
        obs_traj = []
        next_obs_traj = []
        act_traj=[]
        best_rew = [0]*len(expert_s0)
        best_dis = [10000]*len(expert_s0)
        best_idx = [0]*len(expert_s0)
        best_done_idx = []
        
        file_names = ['dtws','returns', 'observations', 'actions', 'rewards']
        for file_name in file_names:
            if len(log[file_name]) == 1:
                log[file_name] = log[file_name][0]
        
        for idx, dtw in enumerate(log['dtws']):
            if best_dis[idx%len(expert_s0)] > dtw:
                best_rew[idx%len(expert_s0)] = log['returns'][idx]
                best_idx[idx%len(expert_s0)] = idx
                best_dis[idx%len(expert_s0)] = dtw
        for idx in best_idx:
            obs_traj.extend(torch.from_numpy(log['observations'][idx][0:-1]))
            next_obs_traj.extend(torch.from_numpy(log['observations'][idx][1:]))
            act_traj.extend(torch.from_numpy(log['actions'][idx][0:]))
            best_done_idx.append(log['observations'][idx].shape[0]-1)
        obs_traj=torch.stack(obs_traj)
        next_obs_traj=torch.stack(next_obs_traj)
        act_traj=torch.stack(act_traj)
        done_traj = torch.zeros(obs_traj.shape[0])
        idx = 0
        for i in best_done_idx:
            done_traj[idx+i-1] = 1
            idx+=i
        saved_traj['obs'] = obs_traj
        saved_traj['next_obs'] = next_obs_traj
        saved_traj['actions'] = act_traj
        saved_traj['done'] = done_traj
        torch.save(saved_traj, os.path.join(self.logdir, self.env.spec.id + '_' + self.policy.constraint_type)+'.pth')