import numpy as np
import torch
import time
from decision_transformer.models.s4_muj import *
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import Dataset
import logging
logger = logging.getLogger(__name__)

def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
    return discount_cumsum

class OnlineTrainer:
    def __init__(self,
                 env,
                 model,
                 model_target,
                 critic,
                 critic_target,
                 batch_size=100,
                 max_ep_len=1000,
                 scale=1000.,
                 state_mean=0.,
                 state_std=1.,
                 device='cuda',
                 eval_fns = None,
                 optimizer = None,
                 scheduler = None,
                 critic_scheduler = None,
                 critic_optimizer = None,
                 target_return=None,
                 steps_between_model_swap=2000,
                 steps_between_trains=10,
                 min_amount_to_train=5000,
                 online_model_export_freq=10000,
                 online_exploration_type="0.75,0.3,18000,60000,90000",
                 s4_load_model="none",
                 trains_per_step=4,
                 mode='normal',
                 game_name="def",
                 online_savepostifx="latest",
                 rtg_variation=0,
                 base_target_reward=3200,
                 online_soft_update=-1,
                 fine_tune_critic_steps=0,
                 online_step_partial_advance = 'none',
                 fine_tune_critic_steps_speedup = 10,
                 episodes_per_iteration=200,
                 ):
        self.env = env
        self.game_name = game_name
        self.base_target_reward = base_target_reward
        self.device = device
        self.state_dim = env.observation_space.shape[0]
        self.act_dim = env.action_space.shape[0]
        self.model = model
        self.model_target = model_target
        self.model_target.eval()
        #self.critic = critic.to(device=device)
        #self.critic_target = critic_target.to(device=device)
        self.critic = critic
        self.critic_target = critic_target
        logger.info(f"Critic type {str(type(self.critic_target))}")
        self.critic_target.eval()
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.critic_scheduler = critic_scheduler
        self.critic_optimizer = critic_optimizer

        self.batch_size = batch_size
        self.max_ep_len = max_ep_len
        self.reward_scale = scale
        self.state_mean = torch.from_numpy(state_mean).to(device=device)
        self.state_std = torch.from_numpy(state_std).to(device=device)
        self.mode = mode
        self.eval_fns = eval_fns
        self.diagnostics = dict()
        self.start_time = time.time()

        self.replay_memory_max_size = 80000
        self.episodes_per_iteration = episodes_per_iteration
        self.steps_between_model_swap = steps_between_model_swap
        self.steps_between_trains = steps_between_trains
        self.min_amount_to_train = min_amount_to_train ### 20000
        self.online_model_export_freq = online_model_export_freq
        self.gamma = 0.99
        self.online_savepostifx = online_savepostifx
        self.rtg_variation = rtg_variation
        self.tau = online_soft_update
        self.fine_tune_critic_steps = fine_tune_critic_steps
        self.fine_tune_critic_steps_speedup = fine_tune_critic_steps_speedup
        self.online_step_partial_advance = online_step_partial_advance
        self.step_partial_dat = [1, 1, 1, 1]
        if self.online_step_partial_advance != "none":
            self.step_partial_dat = [int(x) for x in self.online_step_partial_advance.split("_")]

        self.s4_load_model = s4_load_model
        # mujoco_hopp_107500_online_latest_critic.pkl
        self.all_steps = 0 if s4_load_model == "none" else int(self.s4_load_model.split("_")[-3])
        self.train_dataset = StateActionReturnDataset_Online_Qlearn(self.replay_memory_max_size, "mujoco",
                                                                    123, self.device)

        temp = online_exploration_type.split(",")
        self.start_variation= float(temp[0])
        self.end_variation = float(temp[1])
        self.explore_breakstart = int(temp[2])
        self.explore_breakend = int(temp[3])
        self.explore_breakend_simple = int(temp[4])
        self.trains_per_step = trains_per_step
        torch.autograd.set_detect_anomaly(True)
        return

    def update_module_param(self):
        if self.tau <= 0:
            self.model_target.load_state_dict(self.model.state_dict())
            self.critic_target.load_state_dict(self.critic.state_dict())
        else:
            if self.all_steps > self.fine_tune_critic_steps:
                for target_param, param in zip(self.model_target.parameters(), self.model.parameters()):
                    target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)
            for target_param, param in zip(self.critic_target.parameters(), self.critic.parameters()):
                target_param.data.copy_(target_param.data * (1.0 - self.tau) + param.data * self.tau)

    def get_variation(self):
        if self.all_steps < self.explore_breakstart:
            currprob = self.start_variation
        elif self.all_steps < self.explore_breakend:
            currprob = 1.0 * (self.end_variation - self.start_variation) / (self.explore_breakend - self.explore_breakstart) * (
                        self.all_steps - self.explore_breakstart) + self.start_variation
        elif self.all_steps < self.explore_breakend_simple:
            currprob = 1.0 * (0.005 - self.end_variation) / (self.explore_breakend_simple - self.explore_breakend) * (
                        self.all_steps - self.explore_breakend) + self.end_variation + 0.005
        else:
            currprob = 0.005
        return currprob

    def train_iteration(self, num_steps=100, episodes_to_run=200, mode="partial", iter_num=0, print_logs=True ):
        ## run to create dataset:
        print(f"LOG variation: {self.get_variation()}")
        logger.info(f"LOG variation: {self.get_variation()}")
        episodes_to_run = self.episodes_per_iteration
        if self.rtg_variation>0:
            run_target_return = self.base_target_reward * (np.ones((episodes_to_run)) + self.rtg_variation * ( 2 * np.random.random((episodes_to_run)) -1) )
        else:
            run_target_return = self.base_target_reward * np.ones((episodes_to_run))
        for i in range(episodes_to_run):
            episode_return, episode_length, data = self.run_train_episode(float(run_target_return[i]/self.reward_scale))
            logger.info(f"Iternum,innerepisode {iter_num:4} {i:4} length,rewards :: {episode_length:6d} {episode_return:.3f}")
            print(f"Average evaluation return: {episode_return}, Length: {episode_length}")

        train_losses = []
        logs = dict()
        eval_start = time.time()

        self.model.eval()
        self.model.pre_val_setup()

        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model)
            for k, v in outputs.items():
                logs[f'evaluation/{k}'] = v

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num+1}')
            for k, v in logs.items():
                print(f'{k}: {v}')
        return logs

    def run_train_episode(self, target_return):
        self.model.eval()
        self.model.pre_val_setup()
        s4_states = [r.detach() for r in self.model.get_initial_state((1), self.device)]
        self.model_target.pre_val_setup()
        state = self.env.reset()
        if self.mode == 'noise':
            state = state + np.random.normal(0, 0.1, size=self.state.shape)
        #return observations_stack

        # we keep all the histories on the device
        # note that the latest action and reward will be "padding"
        states = torch.from_numpy(state).reshape(1, self.state_dim).to(device=self.device, dtype=torch.float32)
        actions = torch.zeros((1, self.act_dim), device=self.device, dtype=torch.float32)
        rewards = torch.zeros(1, device=self.device, dtype=torch.float32)

        ep_return = target_return
        target_return = torch.tensor(ep_return, device=self.device, dtype=torch.float32).reshape(1, 1)
        timesteps = torch.tensor(0, device=self.device, dtype=torch.long).reshape(1, 1)

        sim_states = []

        episode_return, episode_length = 0, 0
        actions = torch.cat([actions, torch.zeros((1, self.act_dim), device=self.device)], dim=0)
        rewards = torch.cat([rewards, torch.zeros(1, device=self.device)])
        for t in range(self.max_ep_len):

            # add padding
            action, next_s4_states = self.model.get_action(
                (states.to(dtype=torch.float32) - self.state_mean) / self.state_std,
                actions.to(dtype=torch.float32),
                rewards.to(dtype=torch.float32),
                target_return.to(dtype=torch.float32),
                timesteps.to(dtype=torch.long),
                s4_states=s4_states
            )
            if t > 0:
                actions = torch.cat([actions, torch.zeros((1, self.act_dim), device=self.device)], dim=0)
                rewards = torch.cat([rewards, torch.zeros(1, device=self.device)])
            ### Add proc to decide the upcoming action according to training method
            if self.all_steps > self.fine_tune_critic_steps:
                action = action + self.get_variation() * torch.randn(action.shape, dtype=action.dtype, device=action.device)
            action = torch.minimum(torch.maximum(action, -torch.ones_like(action)), torch.ones_like(action))
            actions[-1] = action
            action = action.detach().cpu().numpy()

            state, reward, done, _ = self.env.step(action)

            cur_state = torch.from_numpy(state).to(device=self.device).reshape(1, self.state_dim)
            states = torch.cat([states, cur_state], dim=0)
            rewards[-1] = reward
            pred_return = target_return[0, -1] - (reward/self.reward_scale)
            target_return = torch.cat(
                [target_return, pred_return.reshape(1, 1)], dim=1)
            timesteps = torch.cat(
                [timesteps,
                 torch.ones((1, 1), device=self.device, dtype=torch.long) * (t + 1)], dim=1)

            self.train_dataset.add_observation(states[-2, ...].cpu(), actions[-2, ...].cpu(), target_return[0, -2].cpu(),
                                               [x.cpu() for x in s4_states],
                                               states[-1, ...].cpu(), actions[-1, ...].cpu(), target_return[0, -1].cpu(),
                                               [x.cpu() for x in next_s4_states], int(done))
            episode_return += reward
            episode_length += 1
            s4_states = next_s4_states

            self.all_steps += 1
            if self.all_steps % 1000 == 0:
                print(f"Steps {self.all_steps:8d} :: Var {self.get_variation():.3f}")
                logger.info(f"Steps {self.all_steps:8d} :: Var {self.get_variation():.3f}")
            if self.all_steps % self.steps_between_trains == 0 and self.all_steps >= self.min_amount_to_train and len(
                    self.train_dataset) > self.batch_size * 100:
                loss1, loss2 = 0, 0
                self.model.train()
                self.model_target.eval()
                self.critic_target.eval()
                self.critic.train()
                to_run_steps = self.trains_per_step
                if self.all_steps <= self.fine_tune_critic_steps:
                    to_run_steps = to_run_steps * self.fine_tune_critic_steps_speedup
                for jj in range(to_run_steps):
                    train_loss1, train_loss2 = self.train_step()
                    loss1 += train_loss1
                    loss2 += train_loss2
                    #logger.info(f"PASSED {jj}")
                if self.scheduler is not None:
                    self.scheduler.step()
                if self.critic_scheduler is not None:
                    self.critic_scheduler.step()
                print(f"training log: {self.all_steps:10} losses: {loss1/to_run_steps:4f} {loss2/to_run_steps:4f}")
                logger.info(f"training log: {self.all_steps:10} losses: {loss1/to_run_steps:4f} {loss2/to_run_steps:4f}")
                self.model.eval()
                self.critic.eval()

            if self.all_steps % self.steps_between_model_swap == 0 and self.all_steps >= self.min_amount_to_train:
                print(f"updated Tar_model {self.all_steps:10}")
                #logger.info(f"updated Tar_model {self.all_steps:10}")
                # self.target_model = copy.deepcopy(self.model).to(device=self.device)
                self.update_module_param()

            if self.all_steps % self.online_model_export_freq == 0 and self.all_steps >= self.min_amount_to_train:
                fileoutname1 = f"mujoco_{self.game_name}_{self.all_steps}_online_{self.online_savepostifx}_actor.pkl"
                fileoutname2 = f"mujoco_{self.game_name}_{self.all_steps}_online_{self.online_savepostifx}_critic.pkl"
                torch.save(self.model_target.state_dict(), fileoutname1)
                torch.save(self.critic_target.state_dict(), fileoutname2)
                print(f"Saved latest dict : {fileoutname1}")
                logger.info(f"Saved latest dict : {fileoutname1}")


            if done:
                break
        return episode_return, episode_length, None

    def train_step(self):
        #states, actions, rewards, dones, rtg, timesteps, attention_mask = self.train_dataset.get_batch(self.batch_size)
        loader = DataLoader(self.train_dataset, shuffle=True, pin_memory=True,
                            batch_size=self.batch_size,
                            num_workers=2)
        for zz in loader:
            break
        #out_state, out_action, out_rtg, out_s4state, out_next_state, out_next_action, out_next_rtg, out_next_s4state, rewards, dones = self.train_dataset.get_batch(self.batch_size)
        out_state, out_action, out_rtg, out_s4state, out_next_state, out_next_action, out_next_rtg, out_next_s4state, rewards, dones = [q.to(self.device) for q in zz]
        forward_s4_base = [out_s4state[:, x, :, :].squeeze(1) for x in range(self.model.s4_amount)]
        forward_s4_next = [out_next_s4state[:, x, :, :].squeeze(1) for x in range(self.model.s4_amount)]
        out_state = (out_state - self.state_mean) / self.state_std
        out_next_state = (out_next_state - self.state_mean) / self.state_std
        ## Critic update
        if False:
            logger.info("VALIDATEXX")
            logger.info(f"out_state shaped: {out_state.shape}")
            logger.info(f"out_state shaped: {out_state.dtype}")
            logger.info(f"out_action shaped: {out_action.shape}")
            logger.info(f"out_action shaped: {out_action.dtype}")
            logger.info(f"out_rtg shaped: {out_rtg.shape}")
            logger.info(f"out_rtg shaped: {out_rtg.dtype}")
            logger.info(f"forward_s4_base shaped: {forward_s4_base[0].shape}")
            logger.info(f"forward_s4_base shaped: {forward_s4_base[0].dtype}")

            logger.info(f"out_next_state shaped: {out_next_state.shape}")
            logger.info(f"out_next_state shaped: {out_next_state.dtype}")
            logger.info(f"out_next_action shaped: {out_next_action.shape}")
            logger.info(f"out_next_action shaped: {out_next_action.dtype}")
            logger.info(f"out_next_rtg shaped: {out_next_rtg.shape}")
            logger.info(f"out_next_rtg shaped: {out_next_rtg.dtype}")
            logger.info(f"forward_s4_next shaped: {forward_s4_next[0].shape}")
            logger.info(f"forward_s4_next shaped: {forward_s4_next[0].dtype}")

            logger.info(f"rewards shaped: {rewards.shape}")
            logger.info(f"rewards shaped: {rewards.dtype}")
            logger.info(f"dones shaped: {dones.shape}")
            logger.info(f"dones shaped: {dones.dtype}")
        loss1, loss2 = 0, 0
        if self.all_steps % self.step_partial_dat[3] < self.step_partial_dat[2] or self.all_steps <= self.fine_tune_critic_steps:
            self.critic_optimizer.zero_grad()
            self.optimizer.zero_grad()
            with torch.no_grad():
                actor_actions_next, _ = self.model_target.step_forward(out_next_state, out_next_action, None, out_next_rtg, None, s4_states=forward_s4_next)
            Q_critic = self.critic(out_state, out_action)
            Q_critic_next = self.critic_target(out_next_state, actor_actions_next)
            y = rewards + self.gamma * Q_critic_next * (1 - dones)
            losscritic = torch.pow(Q_critic - y, 2).mean()

            # Updatevars:
            losscritic.backward()
            torch.nn.utils.clip_grad_norm_(self.critic.parameters(), .5) # .25
            self.critic_optimizer.step()
            loss1 = losscritic.detach().cpu().item()

        if self.all_steps > self.fine_tune_critic_steps and self.all_steps % self.step_partial_dat[1] < self.step_partial_dat[0]:
            actor_actions, _ = self.model.step_forward(out_state, out_action, None, out_rtg, None,
                                                       s4_states=forward_s4_base)
            Q_critic_now = -self.critic(out_state, actor_actions).mean()
            Q_critic_now.backward()  # retain_graph=Tru
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), .5)  # .25
            self.optimizer.step()
            self.model.pre_val_setup()
            self.model_target.pre_val_setup()
            loss2 = Q_critic_now.detach().cpu().item()

        ## Actor update
        #self.critic.eval()
        #for param in self.critic.parameters():
        #    param.requires_grad = False


        # Updatevars:
        #self.optimizer.zero_grad()
        #Q_critic_now.backward() #retain_graph=True
        #torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1) # .25
        #self.optimizer.step()


        #action_target = torch.clone(actions)
        #print(f"LOGXXX states {states.shape}")
        #print(f"LOGXXX actions {actions.shape}")
        #print(f"LOGXXX rewards {rewards.shape}")
        #print(f"LOGXXX rtg {rtg.shape}")

        #with torch.no_grad():
        #    self.diagnostics['training/action_error'] = torch.mean(
        #        (action_preds - action_target) ** 2).detach().cpu().item()
        return [loss1, loss2]

class StateActionReturnDataset_Online_Qlearn(Dataset):
    def __init__(self, max_buffer_size, game_name, seed, device):
        Dataset.__init__(self)
        self.max_buffer_size = max_buffer_size
        #self.vocab_size = max(actions) + 1
        self.state_stack = []
        self.actions_stack = []
        self.rtg_stack = []
        self.s4_state_stack = []

        self.next_state_stack = []
        self.next_actions_stack = []
        self.next_rtg_stack = []
        self.s4_next_state_stack = []

        self.done_marker_stack = []
        self.rewards_stack = []
        #self.all_totrewards = []
        self.game_name = game_name
        self.seed = seed
        self.device = device

        return

    def __len__(self):
        return len(self.state_stack)

    def __getitem__(self, idx):
        return [torch.tensor(self.state_stack[idx], dtype=torch.float32).detach(),
                torch.tensor(self.actions_stack[idx], dtype=torch.float32).detach(),
                torch.tensor(self.rtg_stack[idx], dtype=torch.long).detach(),
                torch.tensor(self.s4_state_stack[idx], dtype=self.s4_state_stack[idx].dtype).detach(),
                torch.tensor(self.next_state_stack[idx], dtype=torch.float32).detach(),
                torch.tensor(self.next_actions_stack[idx], dtype=torch.float32).detach(),
                torch.tensor(self.next_rtg_stack[idx], dtype=torch.long).detach(),
                torch.tensor(self.s4_next_state_stack[idx], dtype=self.s4_state_stack[idx].dtype).detach(),
                torch.tensor(self.rewards_stack[idx], dtype=torch.long).detach(),
                torch.tensor(self.done_marker_stack[idx], dtype=torch.float).detach()]

    def clean_buffer(self, amount=500):
        for z in range(amount):
            self.state_stack.pop(0)
            self.actions_stack.pop(0)
            self.rtg_stack.pop(0)
            self.s4_state_stack.pop(0)

            self.next_state_stack.pop(0)
            self.next_actions_stack.pop(0)
            self.next_rtg_stack.pop(0)
            self.s4_next_state_stack.pop(0)

            self.rewards_stack.pop(0)
            self.done_marker_stack.pop(0)

        return

    def add_observation(self, state, action, rtg, s4_state,
                              next_state, next_action, next_rtg, s4_next_state, done):
        if len(self.state_stack) > self.max_buffer_size:
            self.clean_buffer()
        self.state_stack.append(state)
        self.actions_stack.append(action)
        self.rtg_stack.append(rtg)
        self.s4_state_stack.append(torch.cat([x.unsqueeze(0) for x in s4_state], dim=0))

        self.next_state_stack.append(next_state)
        self.next_actions_stack.append(next_action)
        self.next_rtg_stack.append(next_rtg)
        self.s4_next_state_stack.append(torch.cat([x.unsqueeze(0) for x in s4_next_state], dim=0))

        self.rewards_stack.append(rtg-next_rtg)
        self.done_marker_stack.append(done)

        return

    def get_batch(self, batch_size):
        batch_inds = np.random.choice(
            np.arange(len(self)),
            size=batch_size,
            replace=True,
        )
        flag = True
        for x in batch_inds:
            if flag:
                out_state, out_action, out_rtg, out_s4state, out_next_state, out_next_action, out_next_rtg, out_next_s4state, rewards, dones = self.__getitem__(x)
                out_state = out_state.unsqueeze(0)
                out_action = out_action.unsqueeze(0)
                out_rtg = out_rtg.unsqueeze(0)
                out_s4state = out_s4state.unsqueeze(0)

                out_next_state = out_next_state.unsqueeze(0)
                out_next_action = out_next_action.unsqueeze(0)
                out_next_rtg = out_next_rtg.unsqueeze(0)
                out_next_s4state = out_next_s4state.unsqueeze(0)

                rewards = rewards.unsqueeze(0)
                dones = dones.unsqueeze(0)
                flag = False
            else:
                out_state_n, out_action_n, out_rtg_n, out_s4state_n, out_next_state_n, out_next_action_n, out_next_rtg_n, out_next_s4state_n, rewards_n, dones_n = self.__getitem__(x)
                out_state = torch.cat([out_state, out_state_n.unsqueeze(0)], dim=0)
                out_action = torch.cat([out_action, out_action_n.unsqueeze(0)], dim=0)
                out_rtg = torch.cat([out_rtg, out_rtg_n.unsqueeze(0)], dim=0)
                out_s4state = torch.cat([out_s4state, out_s4state_n.unsqueeze(0)], dim=0)

                out_next_state = torch.cat([out_next_state, out_next_state_n.unsqueeze(0)], dim=0)
                out_next_action = torch.cat([out_next_action, out_next_action_n.unsqueeze(0)], dim=0)
                out_next_rtg = torch.cat([out_next_rtg, out_next_rtg_n.unsqueeze(0)], dim=0)
                out_next_s4state = torch.cat([out_next_s4state, out_next_s4state_n.unsqueeze(0)], dim=0)

                rewards = torch.cat([rewards, rewards_n.unsqueeze(0)], dim=0)
                dones = torch.cat([dones, dones_n.unsqueeze(0)], dim=0)
        return out_state.to(self.device), out_action.to(self.device), out_rtg.to(self.device), out_s4state.to(self.device),\
               out_next_state.to(self.device), out_next_action.to(self.device), out_next_rtg.to(self.device),\
               out_next_s4state.to(self.device), rewards.to(self.device), dones.to(self.device)