"""
highly based on https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/training/seq_trainer.py
"""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# import wandb

import time
from torch.utils.data import DataLoader, WeightedRandomSampler
from tqdm import tqdm
import math


def calc_kernel(X, Y, kernel_type='power', kernel_num=10, kernel_alpha=2):
    if kernel_type == 'rbf':
        first_kernel = (Y.unsqueeze(-1) - Y.unsqueeze(-2)).pow(2)
        second_kernel = (X.unsqueeze(-1) - X.unsqueeze(-2)).pow(2)
        third_kernel = (Y.unsqueeze(-1) - X.unsqueeze(-2)).pow(2)
        bandwidth_list = np.linspace(1, 10, kernel_num)
        first_items = 0
        second_items = 0
        third_items = 0
        for h in bandwidth_list:
            first_inner_distance = (-first_kernel / h).exp()
            second_inner_distance = (-second_kernel / h).exp()
            intra_distance = (-third_kernel / h).exp()
            first_items += first_inner_distance
            second_items += second_inner_distance
            third_items += intra_distance
    elif kernel_type == 'power':
        first_kernel = (Y.unsqueeze(-1) - Y.unsqueeze(-2)).abs().pow(kernel_alpha)
        second_kernel = (X.unsqueeze(-1) - X.unsqueeze(-2)).abs().pow(kernel_alpha)
        third_kernel = (Y.unsqueeze(-1) - X.unsqueeze(-2)).abs().pow(kernel_alpha)
        first_items = -first_kernel
        second_items = -second_kernel
        third_items = -third_kernel
    return first_items, second_items, third_items


class Trainer:
    def __init__(self, model, optimizer, batch_size, dataset, writer, config, scheduler=None, eval_fns=None, mmd_optimizer=None):
        self.model = model
        self.optimizer = optimizer
        self.mmd_optimizer = mmd_optimizer
        self.batch_size = batch_size
        self.dataset = dataset
        self.scheduler = scheduler
        self.eval_fns = [] if eval_fns is None else eval_fns
        self.diagnostics = dict()
        self.writer = writer
        self.model_type = config['model_type']
        self.config = config
        
        self.train_count = 0

        self.start_time = time.time()

    def train_iteration(self, num_steps, iter_num=0, print_logs=False):
        train_losses = []
        logs = dict()

        train_start = time.time()
        sampler = WeightedRandomSampler(self.dataset.p_sample, num_samples=num_steps*self.batch_size, replacement=True)
        dataloader = DataLoader(self.dataset, sampler=sampler, batch_size=self.batch_size)

        self.model.train()
        for images, missions, mission_masks, actions, rewards, dones, rtg, timesteps, attention_mask in tqdm(dataloader):
            if self.config['is_stitch']:
                train_loss, mmd_loss = self.train_step(images, missions, mission_masks, actions, rewards, dones, rtg, timesteps, attention_mask)
            else:
                train_loss = self.train_step(images, missions, mission_masks, actions, rewards, dones, rtg, timesteps, attention_mask)
            train_losses.append(train_loss)
            if self.writer is not None:
                self.writer.add_scalar('train_loss', train_loss, self.train_count)
                if self.config['is_stitch']:
                    self.writer.add_scalar('mmd_loss', mmd_loss, self.train_count)
            self.train_count += 1
            if self.config['is_stitch']:
                if self.train_count % self.config['update_freq'] == 0:
                    self.model.target_mmdnet.load_state_dict(self.model.mmdnet.state_dict())
            if self.scheduler is not None:
                self.scheduler.step()

        logs['time/training'] = time.time() - train_start

        eval_start = time.time()

        self.model.eval()
        for eval_fn in self.eval_fns:
            outputs = eval_fn(self.model)
            for k, v in outputs.items():
                logs[f'evaluation/{k}'] = v

        logs['time/total'] = time.time() - self.start_time
        logs['time/evaluation'] = time.time() - eval_start
        logs['training/train_loss_mean'] = np.mean(train_losses)
        logs['training/train_loss_std'] = np.std(train_losses)

        for k in self.diagnostics:
            logs[k] = self.diagnostics[k]

        if print_logs:
            print('=' * 80)
            print(f'Iteration {iter_num}')
            for k, v in logs.items():
                print(f'{k}: {v}')

        return logs

    def train_step(self, images, missions, mission_masks, actions, rewards, dones, rtg, timesteps, attention_mask):
        rewards_target, action_target, rtg_target = torch.clone(rewards), torch.clone(actions), torch.clone(rtg)

        state_preds, action_preds, return_preds, reward_preds = self.model.forward(
            images, missions, mission_masks, actions, rewards, rtg[:,:-1], timesteps, attention_mask=attention_mask,
        )

        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        

        if self.model_type in ['dt', 'bc']:
            if self.config['is_stitch']:
                cumsum_rewards = torch.cumsum(rewards, dim=1)
                tmp_return = torch.cat([rtg[:, 0].repeat(1, return_preds[1].size(-1)).unsqueeze(1), return_preds[1][:, :-1]], dim=1)
                start_return_est = torch.sort(tmp_return.flatten(0, 1).view(*return_preds[0].size()), dim=-1, descending=True)[0][:, :, :int(self.config['percentile'])].mean(dim=-1).unsqueeze(-1)
                B, T = rewards.shape[:2]
                act_dim = action_target.shape[-1]

                alive_mask = (attention_mask > 0)  # bool [B, T]

                max_rounds = int(self.config.get('stitch_rounds', 2))
                total_q_loss = 0.0
                total_count = 0

                for _ in range(max_rounds):
                    if not alive_mask.any():
                        break
                    
                    return_max_full = start_return_est + cumsum_rewards  # [B, T, 1]
                    masked_return = return_max_full.squeeze(-1).clone()  # [B, T]
                    masked_return[~alive_mask] = -999

                    start_max, step_max = masked_return.max(dim=1)             # [B], [B]
                    if torch.isinf(start_max).all():
                        break

                    start_max_col = start_max.unsqueeze(1)               # [B,1]
                    max_return_seq = start_max_col.unsqueeze(-1) - cumsum_rewards  # [B,T,1]
                    max_return_seq = torch.cat([start_max_col.unsqueeze(-1), max_return_seq[:, :-1]], dim=1)  # [B,T,1]
                    max_return_seq = max_return_seq.detach()

                    t_idx = torch.arange(T, device=rewards.device).unsqueeze(0)  # [1,T]
                    selected_mask_round = (t_idx <= step_max.unsqueeze(1)) & alive_mask  # [B,T] bool

                    _, q_action_preds, _, _ = self.model.forward(
                        images, missions, mission_masks, actions, rewards, max_return_seq, timesteps, attention_mask=attention_mask,
                    )
                    q_action_preds = q_action_preds.reshape(-1, act_dim)            # [B*T, A]
                    pick = (selected_mask_round & (attention_mask > 0)).reshape(-1) # [B*T]
                    if pick.any():
                        action_target_tmp = action_target.reshape(-1, act_dim)[pick.reshape(-1) > 0]
                        q_action_preds = q_action_preds[pick.reshape(-1) > 0]
                        q_loss_round = F.cross_entropy(q_action_preds, action_target_tmp.max(-1)[1])
                        total_q_loss = total_q_loss + q_loss_round * (self.config.get('lambda', 0.7) ** total_count)
                        total_count += 1

                    alive_mask[selected_mask_round] = False

                if total_count > 0:
                    q_loss = total_q_loss
                else:
                    _, q_action_preds, _, _ = self.model.forward(
                        images, missions, mission_masks, actions, rewards, start_return_est, timesteps, attention_mask=attention_mask,
                    )
                    action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
                    q_action_preds = q_action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
                    q_loss = F.cross_entropy(action_preds, action_target.max(-1)[1])

                return_target = (1 - dones).unsqueeze(-1) * return_preds[1] + rewards
                first_item, second_item, third_item = calc_kernel(return_preds[0][:, :-1], return_target[:, 1:].detach(), self.config['kernel_type'], self.config['kernel_num'], self. config['kernel_alpha'])
                first_item = (first_item.sum(-1).sum(-1) / (self.config['particle_num'] ** 2))
                second_item = (second_item.sum(-1).sum(-1) / (self.config['particle_num'] ** 2))
                third_item = (third_item.sum(-1).sum(-1) / (self.config['particle_num'] ** 2))
                mmd_loss = ((first_item + second_item - 2 * third_item)[:, 1:].reshape(-1)[attention_mask[:, 1:-1].reshape(-1) > 0]).mean()

                self.mmd_optimizer.zero_grad()
                mmd_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.get_mmd_parameters(), .25)
                self.mmd_optimizer.step()

                loss = q_loss
            else:
                action_target_tmp = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
                loss = F.cross_entropy(action_preds, action_target_tmp.max(-1)[1])
        elif self.model_type in ['mgdt']:
            if self.config['sample_return'] == True:
                eps = torch.randn_like(return_preds[1])
                return_preds_tmp = return_preds[0] + eps * torch.exp(0.5 * return_preds[1])
                return_preds = return_preds_tmp
            return_preds = return_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            return_target = rtg_target[:,:-1].reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            reward_preds = reward_preds.reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            reward_target = rewards_target.reshape(-1, 1)[attention_mask.reshape(-1) > 0]
            loss = F.cross_entropy(action_preds, action_target.max(-1)[1]) \
                + torch.mean((return_preds - return_target) ** 2) \
                + torch.mean((reward_preds - reward_target) ** 2)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), .25)
        self.optimizer.step()

        with torch.no_grad():
            action_target = action_target.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
            self.diagnostics['training/action_error'] = torch.mean((action_preds-action_target)**2).detach().cpu().item()
            if self.config['is_stitch']:
                self.diagnostics['training/mmd_loss'] = mmd_loss.detach().cpu().item()

        if self.config['is_stitch']:
            return q_loss.detach().cpu().item(), mmd_loss.detach().cpu().item()
        else:
            return loss.detach().cpu().item()

