"""
highly based on https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/training/seq_trainer.py
"""

import numpy as np
import torch
# import wandb

import time
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm
# import faiss
import numpy as np
import math


def calc_kernel(X, Y, kernel_type='power', kernel_num=10, kernel_alpha=2):
    if kernel_type == 'rbf':
        first_kernel = (Y.unsqueeze(-1) - Y.unsqueeze(-2)).pow(2)
        second_kernel = (X.unsqueeze(-1) - X.unsqueeze(-2)).pow(2)
        third_kernel = (Y.unsqueeze(-1) - X.unsqueeze(-2)).pow(2)
        bandwidth_list = np.linspace(1, 10, kernel_num)
        first_items = 0
        second_items = 0
        third_items = 0
        for h in bandwidth_list:
            first_inner_distance = (-first_kernel / h).exp()
            second_inner_distance = (-second_kernel / h).exp()
            intra_distance = (-third_kernel / h).exp()
            first_items += first_inner_distance
            second_items += second_inner_distance
            third_items += intra_distance
    elif kernel_type == 'power':
        first_kernel = (Y.unsqueeze(-1) - Y.unsqueeze(-2)).abs().pow(kernel_alpha)
        second_kernel = (X.unsqueeze(-1) - X.unsqueeze(-2)).abs().pow(kernel_alpha)
        third_kernel = (Y.unsqueeze(-1) - X.unsqueeze(-2)).abs().pow(kernel_alpha)
        first_items = -first_kernel
        second_items = -second_kernel
        third_items = -third_kernel
    return first_items, second_items, third_items


class Trainer:
    def __init__(self, model, optimizer, batch_size, dataset, writer, config, scheduler=None, eval_fns=None, mmd_optimizer=None):
        self.model = model
        self.optimizer = optimizer
        self.mmd_optimizer = mmd_optimizer
        self.batch_size = batch_size
        self.dataset = dataset
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()
        self.writer = writer
        self.model_type = config['model_type']
        self.reward_scale = config['reward_scale']
        self.config = config
        self.norm_states = self.dataset.states / np.linalg.norm(self.dataset.states, axis=1, keepdims=True)
        
        self.train_count = 0
        self.max_count = self.config['num_steps_per_iter'] * self.config['max_iters']
        self.weight = 1.0
        self.final_weight = 0.0
        self.init_weight = 1.0

        self.start_time = time.time()

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):
        train_losses = []
        logs = dict()

        train_start = time.time()
        sampler = WeightedRandomSampler(self.dataset.p_sample, num_samples=num_steps*self.batch_size, replacement=True)
        dataloader = DataLoader(self.dataset, sampler=sampler, batch_size=self.batch_size)

        self.model.train()
        for  states, actions, rewards, dones, rtg, timesteps, attention_mask in tqdm(dataloader):
        # for  states, actions, rewards, dones, rtg, timesteps, attention_mask in dataloader:
            if self.config['is_stitch']:
                train_loss, mmd_loss = self.train_step( states, actions, rewards, dones, rtg, timesteps, attention_mask)
            else:
                train_loss = self.train_step( states, actions, rewards, dones, rtg, timesteps, attention_mask)
            train_losses.append(train_loss)
            if self.writer is not None:
                self.writer.add_scalar('train_loss', train_loss, self.train_count)
                if self.config['is_stitch']:
                    self.writer.add_scalar('mmd_loss', mmd_loss, self.train_count)
            self.train_count += 1
            # print(f"Step: {self.train_count} Action loss: {train_loss}")
            if self.config['is_stitch']:
                if self.train_count % self.config['update_freq'] == 0:
                    self.model.target_mmdnet.load_state_dict(self.model.mmdnet.state_dict())
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        eval_start = time.time()

        self.model.eval()
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model)
            for k, v in outputs.items():
                logs[f'evaluation/{k}'] = v

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs
    
    def train_step(self, states, actions, rewards, dones, rtg, timesteps, attention_mask):


        rewards_target, action_target, rtg_target = torch.clone(rewards), torch.clone(actions), torch.clone(rtg)
        

        state_preds, action_preds, return_preds, reward_preds = self.model.forward(
            states, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask,
        )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)
        action_target = action_target.reshape(-1, act_dim)

        if self.model_type in ['dt', 'bc']:
            loss = (action_preds - action_target) ** 2
            if self.config['is_stitch']:
                cumsum_rewards = torch.cumsum(rewards, dim=1)
                tmp_return = torch.cat([rtg[:, 0].repeat(1, return_preds[1].size(-1)).unsqueeze(1), return_preds[1][:, :-1]], dim=1)
                start_return_est = torch.sort(tmp_return.flatten(0, 1).view(*return_preds[0].size()), dim=-1, descending=True)[0][:, :, :int(self.config['percentile'])].mean(dim=-1).unsqueeze(-1)
                B, T = rewards.shape[:2]
                act_dim = action_target.shape[-1]

                alive_mask = (attention_mask > 0)  # bool [B, T]

                max_rounds = int(self.config.get('stitch_rounds', 2))
                total_q_loss = 0.0
                total_count = 0

                for _ in range(max_rounds):
                    if not alive_mask.any():
                        break
                    
                    return_max_full = start_return_est + cumsum_rewards  # [B, T, 1]
                    masked_return = return_max_full.squeeze(-1).clone()  # [B, T]
                    masked_return[~alive_mask] = -999

                    start_max, step_max = masked_return.max(dim=1)             # [B], [B]
                    if torch.isinf(start_max).all():
                        break
                    
                    start_max_col = start_max.unsqueeze(1)               # [B,1]
                    max_return_seq = start_max_col.unsqueeze(-1) - cumsum_rewards  # [B,T,1]
                    max_return_seq = torch.cat([start_max_col.unsqueeze(-1), max_return_seq[:, :-1]], dim=1)  # [B,T,1]
                    max_return_seq = max_return_seq.detach()

                    t_idx = torch.arange(T, device=rewards.device).unsqueeze(0)  # [1,T]
                    selected_mask_round = (t_idx <= step_max.unsqueeze(1)) & alive_mask  # [B,T] bool

                    _, q_action_preds, _, _ = self.model.forward(
                        states, actions, rewards, max_return_seq, timesteps, attention_mask=attention_mask
                    )
                    q_action_preds = q_action_preds.reshape(-1, act_dim)            # [B*T, A]
                    q_loss_all = (q_action_preds - action_target) ** 2              # [B*T, A]
                    pick = (selected_mask_round & (attention_mask > 0)).reshape(-1) # [B*T]
                    if pick.any():
                        q_loss_round = q_loss_all[pick].mean()
                        total_q_loss = total_q_loss + q_loss_round * (self.config.get('lambda', 0.7) ** total_count)
                        total_count += 1

                    alive_mask[selected_mask_round] = False

                if total_count > 0:
                    q_loss = total_q_loss
                else:
                    _, q_action_preds, _, _ = self.model.forward(
                        states, actions, rewards, start_return_est, timesteps, attention_mask=attention_mask
                    )
                    q_action_preds = q_action_preds.reshape(-1, act_dim)
                    q_loss = ((q_action_preds - action_target) ** 2)[(attention_mask.reshape(-1) > 0)].mean()
                
                return_target = (1 - dones).unsqueeze(-1) * return_preds[1] + rewards
                first_item, second_item, third_item = calc_kernel(return_preds[0][:, :-1], return_target[:, 1:].detach(), self.config['kernel_type'], self.config['kernel_num'], self. config['kernel_alpha'])
                first_item = (first_item.sum(-1).sum(-1) / (self.config['particle_num'] ** 2))
                second_item = (second_item.sum(-1).sum(-1) / (self.config['particle_num'] ** 2))
                third_item = (third_item.sum(-1).sum(-1) / (self.config['particle_num'] ** 2))
                mmd_loss = ((first_item + second_item - 2 * third_item)[:, 1:].reshape(-1)[attention_mask[:, 1:-1].reshape(-1) > 0]).mean()
                
                self.mmd_optimizer.zero_grad()
                mmd_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.get_mmd_parameters(), .25)
                self.mmd_optimizer.step()
                
                ratio = min(self.train_count / self.max_count, 1.0)
                cosine_decay = 0.5 * (1 + math.cos(math.pi * ratio))
                self.weight = self.final_weight + (self.init_weight - self.final_weight) * cosine_decay
                loss = q_loss
            else:
                loss =  loss[attention_mask.reshape(-1) > 0]
                loss = torch.mean(loss)
        elif self.model_type in ['mgdt']:
            if self.config['sample_return'] == True:
                eps = torch.randn_like(return_preds[1])
                return_preds_tmp = return_preds[0] + eps * torch.exp(0.5 * return_preds[1])
                return_preds = return_preds_tmp
            return_preds = return_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            return_target = rtg_target[:,:-1].reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            reward_preds = reward_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            reward_target = rewards_target.reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            loss = torch.mean((action_preds - action_target) ** 2) \
                + torch.mean((return_preds - return_target) ** 2) \
                + torch.mean((reward_preds - reward_target) ** 2)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.get_decision_transformer_parameters(), .25)
        self.optimizer.step()

        with torch.no_grad():
            self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
            if self.config['is_stitch']:
                self.diagnostics['training/mmd_loss'] = mmd_loss.detach().cpu().item()

        if self.config['is_stitch']:
            return q_loss.detach().cpu().item(), mmd_loss.detach().cpu().item()
        else:
            return loss.detach().cpu().item()

    def debug_param_update(self, tag=""):
        with torch.no_grad():
            for i, (name, param) in enumerate(self.model.named_parameters()):
                if param.requires_grad and ("proj" in name or "linear" in name):
                    print(f"[{tag}] {name} mean: {param.data.mean().item():.6f}, "
                        f"norm: {param.data.norm().item():.6f}")
                    if i >= 2:
                        break
