import numpy as np
import torch
import numpy as np
import random
import torch.nn.functional as F
import time
import sys 
MAX_EPISODE_LEN = 1000
class Cross_attetion_Trainer:

    def __init__(self, model, optimizer,optimizer_pre, batch_size,get_batch,get_batch_new,device,loss_fn,num_subgoal,scheduler=None, eval_fns=None):
        self.model = model
        self.optimizer = optimizer
        self.optimizer_pre = optimizer_pre
        self.batch_size = batch_size
        self.num_subgoal = num_subgoal
        self.get_batch=get_batch
        self.get_batch_new = get_batch_new
        self.loss_fn = loss_fn
        self.device =device
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()
        self.start_time = time.time()

    def train_iteration(self,num_steps,iter_num=0, print_logs=False):
        self.enable_dt_training()
        for param in self.model.mlp.parameters():
            param.requires_grad = False
        train_losses = []
        logs = dict()
        train_start = time.time()
        self.model.train()
        #logs['time/training'] = time.time() - train_start
        # time1 =time.time()
        for j in range(num_steps):
            # torch.cuda.empty_cache()
            states, actions, rewards, dones, rtg, timesteps,attention_mask,target_rtg,subgoal=self.get_batch(self.batch_size,self.num_subgoal)
            train_loss = self.train_step_stochastic(states, actions, rewards, dones, rtg, timesteps,attention_mask,target_rtg,subgoal)
            train_losses.append(train_loss)
            if self.scheduler is not None:
                    self.scheduler.step()
            logs['time/training'] = time.time() - train_start
        # time2 =time.time()
        # print('datatrain time:')
        # print(time2-time1)
        eval_start = time.time()
        # print('evaluating')
        # time1 =time.time()
        self.model.eval()
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model)
            print(outputs)
            for k, v in outputs.items():
                logs[f'evaluation/{k}'] = v
        # time2 =time.time()
        # print('datatrain time:')
        # print(time2-time1)
        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs
    def train_step(self,dataloader):

        states, actions, rewards, dones, attention_mask,states_plan,returns = self.get_batch(self.batch_size)

        state_target, action_target, reward_target = torch.clone(states), torch.clone(actions), torch.clone(rewards)

        state_preds, action_preds, reward_preds = self.model.forward(
            states, actions, rewards, states_plan,masks=None, attention_mask=attention_mask, target_return=returns
        )

        # note: currently indexing & masking is not fully correct
        loss = self.loss_fn(
            state_preds, action_preds, reward_preds,
            state_target[:,1:], action_target, reward_target[:,1:],
        )
        sys.exit()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.detach().cpu().item()
    def contrastive_loss(self, z1, z2, temperature):
        N = z1.size(0)
        z = torch.cat([z1, z2], dim=0)  # 2N x D
        sim = torch.mm(z, z.T) / temperature  # 2N x 2N 相似度矩阵

        # 对角线 mask
        mask = torch.eye(2 * N, dtype=torch.bool).to(z.device)
        sim = sim.masked_fill(mask, -1e9)  # 避免自身比对

        labels = torch.cat([torch.arange(N) + N, torch.arange(N)], dim=0).to(z.device)
        loss = F.cross_entropy(sim, labels)
        return loss
    def train_pre(self,s,target_rtg,subgoal,tgt_mask,all_rtg,temperature):
        subgoal = torch.clone(subgoal)
        pre_subgoal,z1,z2 = self.model.forward_pre(
            s,
            target_rtg,
            subgoal,
            tgt_mask,
            all_rtg
        )
        c_loss = self.contrastive_loss(z1,z2,temperature)
        subgoal_loss = self.loss_decoder(subgoal,pre_subgoal)       
        loss_all   = c_loss + subgoal_loss
        self.optimizer_pre.zero_grad()
        loss_all.backward()
        self.optimizer_pre.step()
        return loss_all.detach().cpu().item()
    def train_pre_e(self,s,target_rtg,subgoal,tgt_mask,all_rtg,temperature):
        subgoal = torch.clone(subgoal)
        pre_subgoal= self.model.forward_pre_e(
            s,
            target_rtg,
            subgoal,
            tgt_mask,
            all_rtg
        )
        print(f'subgoal:{pre_subgoal}')
        return   
    def loss_decoder(self,subgoal,pre_subgoal):
        subgoal_loss = torch.mean((subgoal -pre_subgoal)** 2)
        return subgoal_loss
    def train_encoder_decoder_e(self,num_steps,temprature):
        self.enable_subgoal_training()
        subgoal_losses = []
        logs = dict()
        for j in range(num_steps):
            s,target_rtg,subgoal,tgt_mask,all_rtg,tra=self.get_batch_new_evalu(self.batch_size,self.num_subgoal)
            print(f'tra:{tra}')
            self.train_pre_e(s,target_rtg,subgoal,tgt_mask,all_rtg,temprature)
        return logs 
    def train_encoder_decoder(self,num_steps,temprature):
        self.enable_subgoal_training()
        subgoal_losses = []
        logs = dict()
        for j in range(num_steps):
            s,target_rtg,subgoal,tgt_mask,all_rtg=self.get_batch_new(self.batch_size,self.num_subgoal)
            subgoal_loss = self.train_pre(s,target_rtg,subgoal,tgt_mask,all_rtg,temprature)
            subgoal_losses.append(subgoal_loss)
        logs['training/z_loss_mean'] = np.mean(subgoal_losses)
        logs['training/z_loss_std'] = np.std(subgoal_losses)
        return logs
    def enable_subgoal_training(self):
        for p in self.model.mlp.parameters():
            p.requires_grad = True
        for p in self.model.decoder.parameters():
            p.requires_grad = True
        for p in self.model.transformer.parameters():
            p.requires_grad = False

    def enable_dt_training(self):
        for p in self.model.mlp.parameters():
            p.requires_grad = False
        for p in self.model.decoder.parameters():
            p.requires_grad = False
        for p in self.model.transformer.parameters():
            p.requires_grad = True
    def train_step_stochastic(self, states, actions, rewards, dones, rtg, timesteps,attention_mask,target_rtg,subgoal):
        action_target = torch.clone(actions)
        action_preds = self.model.forward(
            states,
            actions,
            rewards,
            dones,
            rtg,
            timesteps,
            attention_mask,
            target_rtg,
            subgoal
        )
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        total_loss = self.loss_fn(
            None, action_preds, None,
            None, action_target, None,
        )
        # time1 =time.time()
        self.optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()
        with torch.no_grad():
            self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
        return total_loss.detach().cpu().item()

    import torch
    import torch.nn.functional as F
    import torch.distributions as dist


class SubTrajectory(torch.utils.data.Dataset):
    def __init__(
        self,
        trajectories,
        sampling_ind,
        transform=None,
    ):

        super(SubTrajectory, self).__init__()
        self.sampling_ind = sampling_ind
        self.trajs = trajectories
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        traj = self.trajs[self.sampling_ind[index]]
        if self.transform:
            return self.transform(traj)
        else:
            return traj

    def __len__(self):
        return len(self.sampling_ind)


class TransformSamplingSubTraj:
    def __init__(
        self,
        max_len,
        state_dim,
        act_dim,
        state_mean,
        state_std,
        reward_scale,
        action_range,
        max_future_len
    ):
        super().__init__()
        self.max_len = max_len
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.state_mean = state_mean
        self.state_std = state_std
        self.reward_scale = reward_scale
        self.max_future_len = max_future_len

        # For some datasets there are actions with values 1.0/-1.0 which is problematic
        # for the SquahsedNormal distribution. The inversed tanh transformation will
        # produce NAN when computing the log-likelihood. We clamp them to be within
        # the user defined action range.
        self.action_range = action_range

    def __call__(self, traj, end_idx=None):
        si = random.randint(0, traj['rewards'].shape[0] - self.max_len)
        # get sequences from dataset
        s = traj["observations"][si:si + self.max_len].reshape(1, -1, self.state_dim)
        a = traj["actions"][si:si + self.max_len].reshape(1, -1, self.state_dim)
        r = traj["rewards"][si:si + self.max_len].reshape(-1, 1)
        if "terminals" in traj:
            d = traj["terminals"][si:si + self.max_len].reshape(-1,1)  # .reshape(-1)
        else:
            d = traj["dones"][si:si + self.max_len].reshape(-1,1)

        # get the total length of a trajectory
        tlen = s.shape[0]

        timesteps = np.arange(si, si+self.max_len)  # .reshape(-1)
        ordering = np.arange(tlen)
        ordering[timesteps >= MAX_EPISODE_LEN] = -1
        ordering[ordering == -1] = ordering.max()
        timesteps[timesteps >= MAX_EPISODE_LEN] = MAX_EPISODE_LEN - 1  # padding cutoff

        rtg = discount_cumsum(traj["rewards"][si:], gamma=1.0)[: tlen + 1].reshape(
            -1, 1
        )
        if rtg.shape[0] <= tlen:
            rtg = np.concatenate([rtg, np.zeros((1, 1))])

        # padding and state + reward normalization

        s = np.concatenate([np.zeros((self.max_len - tlen, self.state_dim)), s])
        s = (s - self.state_mean) / self.state_std

        a = np.concatenate([np.zeros((self.max_len - tlen, self.act_dim)), a])
        r = np.concatenate([np.zeros((self.max_len - tlen, 1)), r])
        d = np.concatenate([np.ones((self.max_len - tlen)) * 2, d])
        rtg = (
            np.concatenate([np.zeros((self.max_len - tlen, 1)), rtg])
            * self.reward_scale
        )
        timesteps = np.concatenate([np.zeros((self.max_len - tlen)), timesteps])
        ordering = np.concatenate([np.zeros((self.max_len - tlen)), ordering])
        padding_mask = np.concatenate([np.zeros(self.max_len - tlen), np.ones(tlen)])

        s= torch.from_numpy(s).to(dtype=torch.float32)
        a = torch.from_numpy(a).to(dtype=torch.float32).clamp(*self.action_range)
        r = torch.from_numpy(r).to(dtype=torch.float32)
        d = torch.from_numpy(d).to(dtype=torch.long)
        rtg = torch.from_numpy(rtg).to(dtype=torch.float32)
        timesteps = torch.from_numpy(timesteps).to(dtype=torch.long)
        ordering = torch.from_numpy(ordering).to(dtype=torch.long)
        padding_mask = torch.from_numpy(padding_mask)

        fss = traj["observations"][si+1:].reshape(-1, self.state_dim)
        if "terminals" in traj:
            fdd = traj["terminals"][si+1:]  # .reshape(-1)
        else:
            fdd = traj["dones"][si+1:]  # .reshape(-1)
        print(fss)
        # get the total length of a trajectory
        ftlen = fss.shape[0]

        ftimesteps = np.arange(si+1, self.max_future_len)  # .reshape(-1)
        fordering = np.arange(ftlen)
        if ftlen:
            fordering[ftimesteps >= MAX_EPISODE_LEN] = -1
            fordering[fordering == -1] = fordering.max()
        ftimesteps[ftimesteps >= MAX_EPISODE_LEN] = MAX_EPISODE_LEN - 1  # padding cutoff

        frtg = discount_cumsum(traj["rewards"][end_idx:], gamma=1.0)[: ftlen + 1].reshape(
            -1, 1
        )
        if frtg.shape[0] <= ftlen:
            frtg = np.concatenate([frtg, np.zeros((1, 1))])

        # padding and state + reward normalization
        fss = np.concatenate([np.zeros((self.max_future_len - ftlen, self.state_dim)), fss])
        fss = (fss - self.state_mean) / self.state_std
        fdd = np.concatenate([np.ones((self.max_future_len - ftlen)) * 2, fdd])
        frtg = (
            np.concatenate([np.zeros((self.max_future_len - ftlen, 1)), frtg])
            * self.reward_scale
        )
        ftimesteps = np.concatenate([np.zeros((self.max_future_len - ftlen)), ftimesteps])
        fordering = np.concatenate([np.zeros((self.max_future_len - ftlen)), fordering])
        fpadding_mask = np.concatenate([np.zeros(self.max_future_len - ftlen), np.ones(ftlen)])

        fss = torch.from_numpy(fss).to(dtype=torch.float32)
        fdd = torch.from_numpy(fdd).to(dtype=torch.long)
        frtg = torch.from_numpy(frtg).to(dtype=torch.float32)
        ftimesteps = torch.from_numpy(ftimesteps).to(dtype=torch.long)
        fordering = torch.from_numpy(fordering).to(dtype=torch.long)
        fpadding_mask = torch.from_numpy(fpadding_mask)

        return (
            s, a, r, d, rtg, timesteps, ordering, padding_mask,
            fss, fdd, frtg, ftimesteps, fordering, fpadding_mask
        )

def discount_cumsum(x, gamma):
    if x.size == 0:
        return np.array([[] for _ in range(x.shape[0])])
    ret = np.zeros_like(x)
    ret[-1] = x[-1]
    for t in reversed(range(x.shape[0] - 1)):
        ret[t] = x[t] + gamma * ret[t + 1]
    return ret
