# Code backbone: Decision Transformer https://github.com/kzl/decision-transformer/
# Decision Transformer License: https://github.com/kzl/decision-transformer/blob/master/LICENSE.md

import numpy as np
import torch
import torch.nn.functional as F
import time
from wandb import env
#from .prompt_utils import flatten_prompt
import copy

def check_avoid_success_state(states, avoid_boxes, avoid_mask, obs_start_index = 0, prompt_dim=3):
    """
    states: B x seq_length x dim
    avoid_boxes: B x num_avoid x 2*dim
    avoid_mask: B x num_avoid (ignored for now because the padding is done with the initial avoid box in the prompt, not affecting the boolean final result)
    """
    states = states[:,:, obs_start_index : obs_start_index + prompt_dim]
    avoid_boxes_upper, avoid_boxes_lower = avoid_boxes[:, :, prompt_dim:], avoid_boxes[:, :, :prompt_dim]
    a = (avoid_boxes_lower.unsqueeze(1) <= states.unsqueeze(2)) & (avoid_boxes_upper.unsqueeze(1) >= states.unsqueeze(2))
    result = ~a.any(dim=2).all(dim=2)
    return result # B x seq_length

class PromptSequenceTrainer:

    def __init__(self, model, optimizer, batch_size,
                 scheduler=None, eval_fns=None, get_prompt=None, get_avoid_prompt=None, get_prompt_batch=None):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        #self.get_batch = get_batch
        if model.discrete_action:
            self.loss_fn = F.cross_entropy
        else:
            self.loss_fn = lambda a_hat, a: torch.mean((a_hat - a) ** 2)
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()
        self.get_prompt = get_prompt
        self.get_avoid_prompt = get_avoid_prompt
        #self.prompt = self.get_prompt() # sample prompt data when initialization
        self.get_prompt_batch = get_prompt_batch

        self.start_time = time.time()


    def train(self, num_steps, no_prompt=False):

        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step(no_prompt)
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        return logs


    def train_step(self, no_prompt=False):
        success_list = None
        if self.get_avoid_prompt:
            prompt_batch_tuple = self.get_prompt_batch()
            prompt, avoid_prompt, batch = prompt_batch_tuple[0], prompt_batch_tuple[1], prompt_batch_tuple[2]
            if len(prompt_batch_tuple) == 4:
                success_list = prompt_batch_tuple[3]
        else:
            prompt, batch = self.get_prompt_batch()
        states, actions, rewards, dones, rtg, timesteps, attention_mask = batch
        action_target = torch.clone(actions)
        if no_prompt:
            raise NotImplementedError
            #state_preds, action_preds, reward_preds = self.model.forward(
            #    states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask, prompt=None
            #)
        else:
            if self.get_avoid_prompt:
                if success_list is not None:
                    state_preds, action_preds = self.model.forward(states, actions, timesteps, attention_mask=attention_mask, prompt=prompt, avoid_prompt=avoid_prompt, success_list=success_list)
                else:
                    state_preds, action_preds = self.model.forward(states, actions, timesteps, attention_mask=attention_mask, prompt=prompt, avoid_prompt=avoid_prompt)
            else:
                state_preds, action_preds = self.model.forward(
                    states, actions, timesteps, attention_mask=attention_mask, prompt=prompt
                )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        if self.model.discrete_action:
            action_target = action_target.reshape(-1)[attention_mask.reshape(-1) > 0]
        else:
            action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        #print(action_preds.shape, action_target.shape, attention_mask.shape)
        loss = self.loss_fn(action_preds, action_target)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()

        #with torch.no_grad():
        #    self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
        self.diagnostics['training/action_error'] = loss.detach().cpu().item()

        return loss.detach().cpu().item()


    def eval_iteration_multienv(self, trajectories_list, eval_episodes, info, variant, 
                                env_list, iter_num=0, print_logs=False, no_prompt=False, group='test'):
        #print('evaluate at tasks: ', env_name_list)
        logs = dict()
        print('start evaluating at {} tasks: {}'.format(len(env_list), group))
        self.model.eval()

        eval_start = time.time()
        eval_prompt_lens = [i for i in range(1, variant['max_prompt_len']+1)]
        # evaluate for different prompt length, average over envs
        for prompt_len in eval_prompt_lens:
            #self.prompt = self.get_prompt(trajectories_list[env_id], variant['max_prompt_len'], 
            #                            prompt_length=1, device=info[env_id]['device'])
            if self.get_avoid_prompt:
                outputs = eval_episodes(info[0], variant, env_list, self.model, prompt_len, 
                                    self.get_prompt, trajectories_list, self.get_avoid_prompt)
            else:
                outputs = eval_episodes(info[0], variant, env_list, self.model, prompt_len, 
                                    self.get_prompt, trajectories_list)
            for k, v in outputs.items():
                logs[f'{group}-evaluation/{k}'] = v

        logs['time/evaluation'] = time.time() - eval_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

 
    def save_model(self, env_name, postfix, folder):
        model_name = '/prompt_model_' + env_name + postfix
        torch.save(self.model.state_dict(),folder+model_name)  # model save
        print('model saved to ', folder+model_name)

class PromptSequenceReachAvoidTrainer:

    def __init__(self, model, optimizer, batch_size,
                 scheduler=None, 
                 eval_fns=None, get_prompt=None, 
                 get_avoid_prompt=None, 
                 get_prompt_batch=None, 
                 alpha1=1, 
                 alpha2=1, 
                 buffer_size=0.03, 
                 obs_start_index=0,
                 prompt_dim=3):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        #self.get_batch = get_batch
        if model.discrete_action:
            self.loss_fn = F.cross_entropy
        else:
            self.loss_fn = lambda a_hat, a: torch.mean((a_hat - a) ** 2)
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()
        self.get_prompt = get_prompt
        self.get_avoid_prompt = get_avoid_prompt
        #self.prompt = self.get_prompt() # sample prompt data when initialization
        self.get_prompt_batch = get_prompt_batch
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.buffer_size = buffer_size
        self.obs_start_index = obs_start_index
        self.prompt_dim = prompt_dim

        self.start_time = time.time()


    def train(self, num_steps, no_prompt=False):

        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step(no_prompt)
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)
        # logs['training/lr'] = self.scheduler.get_last_lr()[0]
        logs['training/lr'] = self.scheduler.get_lr()[0]

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        return logs

    def train_step(self, no_prompt=False):
        success_list = None
        if self.get_avoid_prompt:
            prompt_batch_tuple = self.get_prompt_batch()
            prompt, avoid_prompt, batch = prompt_batch_tuple[0], prompt_batch_tuple[1], prompt_batch_tuple[2]
            if len(prompt_batch_tuple) == 4:
                success_list = prompt_batch_tuple[3]
        else:
            prompt, batch = self.get_prompt_batch()
        states, actions, rewards, dones, rtg, timesteps, attention_mask = batch
        action_target = torch.clone(actions)
        if no_prompt:
            raise NotImplementedError
            #state_preds, action_preds, reward_preds = self.model.forward(
            #    states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask, prompt=None
            #)
        else:
            if self.get_avoid_prompt:
                if success_list is not None:
                    state_preds, action_preds, success_preds = self.model.forward(states, actions, timesteps, attention_mask=attention_mask, prompt=prompt, avoid_prompt=avoid_prompt, success_list=success_list)
                else:
                    state_preds, action_preds = self.model.forward(states, actions, timesteps, attention_mask=attention_mask, prompt=prompt, avoid_prompt=avoid_prompt)
            else:
                state_preds, action_preds = self.model.forward(
                    states, actions, timesteps, attention_mask=attention_mask, prompt=prompt
                )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        if self.model.discrete_action:
            action_target = action_target.reshape(-1)[attention_mask.reshape(-1) > 0]
        else:
            action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        action_loss = self.loss_fn(action_preds, action_target)
        #print(action_preds.shape, action_target.shape, attention_mask.shape)
        avoid_prompt_states, avoid_prompt_attention_mask = avoid_prompt
        # avoid_boxes = avoid_prompt_states[avoid_prompt_attention_mask > 0] # mask avoid state fillers
        success_target = check_avoid_success_state(states, avoid_prompt_states, avoid_prompt_attention_mask, self.obs_start_index, prompt_dim = self.prompt_dim) # B x seq_length true/false
        
        batch_success_rate = success_target.all(axis=1).float().mean()
        match = (success_target.all(axis=1) == success_list).all() #debug

        success_target = torch.stack([success_target.float(), (~success_target).float()], axis=2) # B x seq_length x 2 one-hot
        # success_target = torch.tensor(np.array([np.array([1, 0]) if s else np.array([0, 1]) for s in success_list]), device=success_preds.device)
        success_target = success_target.reshape(-1, 2)[attention_mask.reshape(-1) > 0]
        success_preds = success_preds.reshape(-1, 2)[attention_mask.reshape(-1) > 0]
        
        # success_loss = F.cross_entropy(success_preds, success_target) 

        # loss = action_loss + success_loss
        # loss = success_loss

        success_accuracy = (success_preds.argmax(axis=1) == success_target.argmax(axis=1)).float().mean()
        neg_indices = (success_target[:, 1] == 1).bool()
        success_target_neg = success_target[neg_indices]
        success_preds_neg = success_preds[neg_indices]
        success_accuracy_neg = (success_preds_neg.argmax(axis=1) == success_target_neg.argmax(axis=1)).float().mean()
        pos_indices = (success_target[:, 0] == 1).bool()
        success_target_pos = success_target[pos_indices]
        success_preds_pos = success_preds[pos_indices]
        success_accuracy_pos = (success_preds_pos.argmax(axis=1) == success_target_pos.argmax(axis=1)).float().mean()
        prop_pos = success_target_pos.shape[0] / success_target.shape[0]

        success_loss_neg = F.cross_entropy(success_preds_neg, success_target_neg) 
        success_loss_pos = F.cross_entropy(success_preds_pos, success_target_pos) 

        success_loss = success_loss_pos + self.alpha2 * success_loss_neg
        loss = action_loss + self.alpha1 * success_loss

        # print("debug 1")
        self.optimizer.zero_grad()
        # print("debug 2")
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()

        #with torch.no_grad():
        #    self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
        self.diagnostics['training/total_error'] = loss.detach().cpu().item()
        self.diagnostics['training/action_error'] = action_loss.detach().cpu().item()
        self.diagnostics['training/success_error'] = success_loss.detach().cpu().item()
        self.diagnostics['training/success_accuracy'] = success_accuracy.detach().cpu().item()
        self.diagnostics['training/success_accuracy_neg'] = success_accuracy_neg.detach().cpu().item()
        self.diagnostics['training/success_accuracy_pos'] = success_accuracy_pos.detach().cpu().item()
        self.diagnostics['training/prop_pos'] = prop_pos
        self.diagnostics['training/batch_success_rate'] = batch_success_rate
        self.diagnostics['training/match'] = int(match.detach().cpu().item())

        return loss.detach().cpu().item()


    def eval_iteration_multienv(self, trajectories_list, eval_episodes, info, variant, 
                                env_list, iter_num=0, print_logs=False, no_prompt=False, group='test', recording_prefix=None):
        #print('evaluate at tasks: ', env_name_list)
        logs = dict()
        print('start evaluating at {} tasks: {}'.format(len(env_list), group))
        self.model.eval()

        eval_start = time.time()
        eval_prompt_lens = [1]
        # evaluate for different prompt length, average over envs
        for prompt_len in eval_prompt_lens:
            #self.prompt = self.get_prompt(trajectories_list[env_id], variant['max_prompt_len'], 
            #                            prompt_length=1, device=info[env_id]['device'])
            if self.get_avoid_prompt:
                outputs = eval_episodes(info[0], variant, env_list, self.model, prompt_len, 
                                    self.get_prompt, trajectories_list, self.get_avoid_prompt, self.buffer_size, prompt_dim = self.prompt_dim, name_prefix = recording_prefix)
            else:
                outputs = eval_episodes(info[0], variant, env_list, self.model, prompt_len, 
                                    self.get_prompt, trajectories_list)
            for k, v in outputs.items():
                logs[f'{group}-evaluation/{k}'] = v

        logs['time/evaluation'] = time.time() - eval_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

 
    def save_model(self, env_name, postfix, folder):
        model_name = '/prompt_model_' + env_name + postfix
        torch.save(self.model.state_dict(),folder+model_name)  # model save
        print('model saved to ', folder+model_name)


class PromptSequenceTrainerLossDebug:

    def __init__(self, model, optimizer, batch_size,
                 scheduler=None, eval_fns=None, get_prompt=None, get_avoid_prompt=None, get_prompt_batch=None, alpha2=1):
        self.model = model
        self.optimizer = optimizer
        self.batch_size = batch_size
        #self.get_batch = get_batch
        if model.discrete_action:
            self.loss_fn = F.cross_entropy
        else:
            self.loss_fn = lambda a_hat, a: torch.mean((a_hat - a) ** 2)
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()
        self.get_prompt = get_prompt
        self.get_avoid_prompt = get_avoid_prompt
        #self.prompt = self.get_prompt() # sample prompt data when initialization
        self.get_prompt_batch = get_prompt_batch
        self.alpha2 = alpha2

        self.start_time = time.time()


    def train(self, num_steps, no_prompt=False):

        train_losses = []
        logs = dict()

        train_start = time.time()

        self.model.train()
        for _ in range(num_steps):
            train_loss = self.train_step(no_prompt)
            train_losses.append(train_loss)
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)
        # logs['training/lr'] = self.scheduler.get_last_lr()[0]
        logs['training/lr'] = self.scheduler.get_lr()[0]

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        return logs

    def train_step(self, no_prompt=False):
        success_list = None
        if self.get_avoid_prompt:
            prompt_batch_tuple = self.get_prompt_batch()
            prompt, avoid_prompt, batch = prompt_batch_tuple[0], prompt_batch_tuple[1], prompt_batch_tuple[2]
            if len(prompt_batch_tuple) == 4:
                success_list = prompt_batch_tuple[3]
        else:
            prompt, batch = self.get_prompt_batch()
        states, actions, rewards, dones, rtg, timesteps, attention_mask = batch
        action_target = torch.clone(actions)
        if no_prompt:
            raise NotImplementedError
            #state_preds, action_preds, reward_preds = self.model.forward(
            #    states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask, prompt=None
            #)
        else:
            if self.get_avoid_prompt:
                if success_list is not None:
                    state_preds, action_preds, success_preds = self.model.forward(states, actions, timesteps, attention_mask=attention_mask, prompt=prompt, avoid_prompt=avoid_prompt, success_list=success_list)
                else:
                    state_preds, action_preds = self.model.forward(states, actions, timesteps, attention_mask=attention_mask, prompt=prompt, avoid_prompt=avoid_prompt)
            else:
                state_preds, action_preds = self.model.forward(
                    states, actions, timesteps, attention_mask=attention_mask, prompt=prompt
                )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        if self.model.discrete_action:
            action_target = action_target.reshape(-1)[attention_mask.reshape(-1) > 0]
        else:
            action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]

        action_loss = self.loss_fn(action_preds, action_target)
        #print(action_preds.shape, action_target.shape, attention_mask.shape)
        avoid_prompt_states, avoid_prompt_attention_mask = avoid_prompt
        # avoid_boxes = avoid_prompt_states[avoid_prompt_attention_mask > 0] # mask avoid state fillers
        success_target = check_avoid_success_state(states, avoid_prompt_states, avoid_prompt_attention_mask, prompt_dim = self.prompt_dim) # B x seq_length true/false
        
        batch_success_rate = success_target.all(axis=1).float().mean()
        match = (success_target.all(axis=1) == success_list).all() #debug

        success_target = torch.stack([success_target.float(), (~success_target).float()], axis=2) # B x seq_length x 2 one-hot
        # success_target = torch.tensor(np.array([np.array([1, 0]) if s else np.array([0, 1]) for s in success_list]), device=success_preds.device)
        success_target = success_target.reshape(-1, 2)[attention_mask.reshape(-1) > 0]
        success_preds = success_preds.reshape(-1, 2)[attention_mask.reshape(-1) > 0]
        
        # success_loss = F.cross_entropy(success_preds, success_target) 

        # loss = action_loss + success_loss
        # loss = success_loss

        success_accuracy = (success_preds.argmax(axis=1) == success_target.argmax(axis=1)).float().mean()
        neg_indices = (success_target[:, 1] == 1).bool()
        success_target_neg = success_target[neg_indices]
        success_preds_neg = success_preds[neg_indices]
        success_accuracy_neg = (success_preds_neg.argmax(axis=1) == success_target_neg.argmax(axis=1)).float().mean()
        pos_indices = (success_target[:, 0] == 1).bool()
        success_target_pos = success_target[pos_indices]
        success_preds_pos = success_preds[pos_indices]
        success_accuracy_pos = (success_preds_pos.argmax(axis=1) == success_target_pos.argmax(axis=1)).float().mean()
        prop_pos = success_target_pos.shape[0] / success_target.shape[0]

        success_loss_neg = F.cross_entropy(success_preds_neg, success_target_neg) 
        success_loss_pos = F.cross_entropy(success_preds_pos, success_target_pos) 

        success_loss = success_loss_pos + self.alpha2 * success_loss_neg
        loss = success_loss

        print("debug 1")
        self.optimizer.zero_grad()
        print("debug 2")
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()

        #with torch.no_grad():
        #    self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
        self.diagnostics['training/total_error'] = loss.detach().cpu().item()
        self.diagnostics['training/action_error'] = action_loss.detach().cpu().item()
        self.diagnostics['training/success_error'] = success_loss.detach().cpu().item()
        self.diagnostics['training/success_accuracy'] = success_accuracy.detach().cpu().item()
        self.diagnostics['training/success_accuracy_neg'] = success_accuracy_neg.detach().cpu().item()
        self.diagnostics['training/success_accuracy_pos'] = success_accuracy_pos.detach().cpu().item()
        self.diagnostics['training/prop_pos'] = prop_pos
        self.diagnostics['training/batch_success_rate'] = batch_success_rate
        self.diagnostics['training/match'] = int(match.detach().cpu().item())

        return loss.detach().cpu().item()


    def eval_iteration_multienv(self, trajectories_list, eval_episodes, info, variant, 
                                env_list, iter_num=0, print_logs=False, no_prompt=False, group='test', recording_prefix=None):
        #print('evaluate at tasks: ', env_name_list)
        logs = dict()
        print('start evaluating at {} tasks: {}'.format(len(env_list), group))
        self.model.eval()

        eval_start = time.time()
        eval_prompt_lens = [i for i in range(1, variant['max_prompt_len']+1)]
        # evaluate for different prompt length, average over envs
        for prompt_len in eval_prompt_lens:
            #self.prompt = self.get_prompt(trajectories_list[env_id], variant['max_prompt_len'], 
            #                            prompt_length=1, device=info[env_id]['device'])
            if self.get_avoid_prompt:
                outputs = eval_episodes(info[0], variant, env_list, self.model, prompt_len, 
                                    self.get_prompt, trajectories_list, self.get_avoid_prompt, name_prefix=recording_prefix)
            else:
                outputs = eval_episodes(info[0], variant, env_list, self.model, prompt_len, 
                                    self.get_prompt, trajectories_list)
            for k, v in outputs.items():
                logs[f'{group}-evaluation/{k}'] = v

        logs['time/evaluation'] = time.time() - eval_start

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

 
    def save_model(self, env_name, postfix, folder):
        model_name = '/prompt_model_' + env_name + postfix
        torch.save(self.model.state_dict(),folder+model_name)  # model save
        print('model saved to ', folder+model_name)