# %% Part 0 import package and Global Parameters
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Normal
from torch.optim.lr_scheduler import CosineAnnealingLR
import wandb

import numpy as np
import random
import copy
import math
from loguru import logger
import itertools
import einops
from einops.layers.torch import Rearrange

from diffusion_predictor.Predictor_model import Diffusion
from diffusion_predictor.render_img import MuJoCoRenderer

from tqdm import tqdm

LOG_SIG_MAX = 2
LOG_SIG_MIN = -20
epsilon = 1e-6

import time

# %% Part 1 Global Function Definition
def setup_seed(seed=1024):  # After doing this, the Training results will always be the same for the same seed
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    logger.info(f"Seed {seed} has been set for all modules!")


# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)


def soft_update(target, source, tau):  # Target will be updated but Source will not change
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


def hard_update(target, source):  # Target will be updated but Source will not change
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)


# %% Part 2 Network Definition
class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Downsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.Conv1d(dim, dim, 3, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Upsample1d(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)

    def forward(self, x):
        return self.conv(x)


class Conv1dBlock(nn.Module):
    '''
        Conv1d --> GroupNorm --> Mish
    '''

    def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
        super().__init__()

        self.block = nn.Sequential(
            nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
            Rearrange('batch channels horizon -> batch channels 1 horizon'),
            nn.GroupNorm(n_groups, out_channels),
            Rearrange('batch channels 1 horizon -> batch channels horizon'),
            nn.Mish(),
        )

    def forward(self, x):
        return self.block(x)


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1, dim, 1))
        self.b = nn.Parameter(torch.zeros(1, dim, 1))

    def forward(self, x):
        var = torch.var(x, dim=1, unbiased=False, keepdim=True)
        mean = torch.mean(x, dim=1, keepdim=True)
        return (x - mean) / (var + self.eps).sqrt() * self.g + self.b


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = LayerNorm(dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv1d(hidden_dim, dim, 1)

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(lambda t: einops.rearrange(t, 'b (h c) d -> b h c d', h=self.heads), qkv)
        q = q * self.scale

        k = k.softmax(dim=-1)
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = einops.rearrange(out, 'b h c d -> b (h c) d')
        return self.to_out(out)


class MLP(nn.Module):
    def __init__(self, state_dim, action_dim, device, t_dim=32, embed_dim=64):
        super(MLP, self).__init__()
        self.device = device
        self.t_dim = t_dim
        self.embed_dim = embed_dim

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 2),
            nn.Mish(),
            nn.Linear(t_dim * 2, 2 * t_dim),
        )

        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, embed_dim),
            nn.Mish(),
            nn.Linear(embed_dim, embed_dim)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, embed_dim),
            nn.Mish(),
            nn.Linear(embed_dim, embed_dim)
        )

        input_dim = 2 * t_dim + 3 * embed_dim

        self.mid_layer = nn.Sequential(nn.Linear(input_dim, 256),
                                       nn.Mish(),
                                       nn.Linear(256, 256),
                                       nn.Mish(),
                                       nn.Linear(256, 256),
                                       nn.Mish())
        self.dropout = nn.Dropout(0.1)
        self.final_layer = nn.Linear(256, state_dim)

    def forward(self, noise_state, time, action, state):
        t = self.time_mlp(time)
        x = torch.cat([t, self.action_encoder(action), self.state_encoder(state), self.state_encoder(noise_state)],
                      dim=1)
        x = self.mid_layer(x)
        x = self.dropout(x)
        return self.final_layer(x)


class ResidualTemporalBlock(nn.Module):

    def __init__(self, inp_channels, out_channels, embed_dim, horizon, kernel_size=5):
        super().__init__()

        self.blocks = nn.ModuleList([
            Conv1dBlock(inp_channels, out_channels, kernel_size),
            Conv1dBlock(out_channels, out_channels, kernel_size),
        ])

        self.time_mlp = nn.Sequential(
            nn.Mish(),
            nn.Linear(embed_dim, out_channels),
            Rearrange('batch t -> batch t 1'),
        )

        self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \
            if inp_channels != out_channels else nn.Identity()

    def forward(self, x, t):
        '''
            x : [ batch_size x inp_channels x horizon ]
            t : [ batch_size x embed_dim ]
            returns:
            out : [ batch_size x out_channels x horizon ]
        '''
        out = self.blocks[0](x) + self.time_mlp(t)
        out = self.blocks[1](out)
        return out + self.residual_conv(x)


class TemporalUnet(nn.Module):
    def __init__(self, state_dim, action_dim, device, cond_dim=8,
                 embed_dim=256, dim_mults=(2, 4), attention=False):
        super(TemporalUnet, self).__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.device = device
        self.cond_dim = cond_dim
        horizon = self.cond_dim
        self.embed_dim = embed_dim

        self.state_encoder = nn.Sequential(
            nn.Linear(state_dim, 2 * embed_dim),
            nn.Mish(),
            nn.Linear(2 * embed_dim, embed_dim // 2)
        )

        self.action_encoder = nn.Sequential(
            nn.Linear(action_dim, 2 * embed_dim),
            nn.Mish(),
            nn.Linear(2 * embed_dim, embed_dim // 2)
        )

        dims = [embed_dim // 2, *map(lambda m: embed_dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        logger.info(f'Models Channel dimensions: {in_out}')

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(embed_dim),
            nn.Linear(embed_dim, embed_dim * 2),
            nn.Mish(),
            nn.Linear(embed_dim * 2, embed_dim),
        )

        time_dim = embed_dim
        horizon_history = []
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            horizon_history.append(horizon)
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(nn.ModuleList([
                ResidualTemporalBlock(dim_in, dim_out, kernel_size=3, embed_dim=time_dim, horizon=horizon),
                ResidualTemporalBlock(dim_out, dim_out, kernel_size=3, embed_dim=time_dim, horizon=horizon),
                Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(),
                Downsample1d(dim_out) if not is_last else nn.Identity()
            ]))

            if not is_last:
                horizon = horizon // 2

        mid_dim = dims[-1]
        self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)
        self.mid_attn = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity()
        self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim=time_dim, horizon=horizon)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(nn.ModuleList([
                ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim=time_dim, horizon=horizon),
                ResidualTemporalBlock(dim_in, dim_in, embed_dim=time_dim, horizon=horizon),
                Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(),
                Upsample1d(dim_in) if not is_last and horizon_history[-(ind + 1)] != horizon_history[-(ind + 2)]
                else nn.Identity()
            ]))

            if not is_last:
                horizon = horizon_history[-(ind + 2)]

        self.final_conv = nn.Sequential(
            Conv1dBlock(2 * embed_dim, 2 * embed_dim, kernel_size=3),
            nn.Conv1d(2 * embed_dim, embed_dim // 4, 1),
        )

        out_horizon = horizon

        self.mid_layer = nn.Sequential(nn.Linear( 8 * embed_dim // 2, 512),
                                       nn.Mish(),
                                       nn.Linear(512, 512),
                                       nn.Mish(),
                                       nn.Linear(512, 512),
                                       nn.Mish())

        self.final_layer = torch.nn.Linear(512, self.state_dim * (2 * self.cond_dim + 1))

    def forward(self, x, time, action, state_condition, mask=None):
        '''
            x : [ batch x horizon x transition ]
            To may understanding, this unet forwarding can be simply seen as 
                'some kinds of attn(x, t)'
        '''
        batch_size, horizon = x.shape[0], x.shape[1]

        x = self.state_encoder(x)
        x = einops.rearrange(x, 'b h t -> b t h')
        t = self.time_mlp(time)
        h = []
        for resnet, resnet2, attn, downsample in self.downs:
            x = resnet(x, t)
            x = resnet2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)
        for resnet, resnet2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = resnet(x, t)
            x = resnet2(x, t)
            x = attn(x)
            x = upsample(x)
        x = self.final_conv(x)
        x = einops.rearrange(x, 'b t h -> b h t')

        info = x.reshape(batch_size, -1)
        output = self.mid_layer(torch.cat([info,  t], dim=1))
        output = self.final_layer(output)
        output = output.reshape(output.shape[0], 2 * self.cond_dim + 1, self.state_dim)

        return output

class Diffusion_instance(object):
    def __init__(self, state_dim, action_dim, device, config, df_mod="naive", log_writer=False):
        self.model = TemporalUnet(state_dim=state_dim, action_dim=action_dim, device=device,
                                  cond_dim=config['condition_length'], embed_dim=config['embed_dim']).to(device)
        self.predictor = Diffusion(config=config, state_dim=state_dim, action_dim=action_dim, model=self.model, df_mod=df_mod,
                                   beta_schedule=config['beta_schedule'], beta_mode=config["beta_training_mode"],
                                   n_timesteps=config['T'], predict_epsilon=config['predict_epsilon']).to(device)
        self.predictor_optimizer = torch.optim.Adam(self.predictor.parameters(), lr=config['lr'])
        self.lr_decay = config['lr_decay']
        self.grad_norm = config['gn']
        self.n_timestep = config['T']
        self.step = 0
        self.step_start_ema = config['step_start_ema']
        self.ema = EMA(config['ema_decay'])
        self.ema_model = copy.deepcopy(self.predictor)
        self.update_ema_every = config['update_ema_every']
        if self.lr_decay:
            self.predictor_lr_scheduler = CosineAnnealingLR(self.predictor_optimizer,
                                                            T_max=config['max_timestep'], eta_min=0.)
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.discount = config['gamma']
        self.tau = config['tau']
        self.eta = config['eta']  # q_learning weight
        self.device = device
        self.max_q_backup = config['max_q_backup']
        self.NonM_step = config['non_markovian_step']
        self.condition_step = config['condition_length']
        self.buffer_sample_length = self.NonM_step + self.condition_step
        self.T_scheme = config['T-scheme']
        self.config = config

    def step_ema(self):
        if self.step < self.step_start_ema:
            return
        self.ema.update_model_average(self.ema_model, self.predictor)

    def filt(self, sn, mask, topk):
        detect_denoise_step = self.config["detect_denoise_steps"]
        with torch.no_grad():
            noise_p = self.predictor.model(
                sn,
                detect_denoise_step * torch.ones((sn.shape[0],),device=self.device).long(),
                None,
                None,
                mask)
            all_mos = (noise_p ** 2).mean(-1)
            ther = torch.quantile(all_mos, 1 - topk)
        # Delete 90% data
        detected = all_mos < ther
        return detected

    def train(self, replay_buffer, iterations, batch_size, pretrained_detector=None, topk=1.0):
        for itr in tqdm(range(iterations)):
            s, sn, a, r, d, rtg, timesteps, mask = replay_buffer.sample(batch_size, self.condition_step)
            if pretrained_detector == None:
                filterd_mask = self.filt(sn, mask, topk).unsqueeze(-1)
            else:
                filterd_mask = pretrained_detector.filt(sn, mask, topk).unsqueeze(-1)
            true_mask = (torch.mean((s - sn) ** 2, dim=-1) <= 1e-4).float()
            total_loss = self.predictor.loss(sn, None, None, filterd_mask, None, weights=1.0)
            self.predictor_optimizer.zero_grad()
            total_loss.backward()
            if self.grad_norm > 0:
                nn.utils.clip_grad_norm_(self.predictor.parameters(), max_norm=self.grad_norm, norm_type=2)
            self.predictor_optimizer.step()
            if self.step % self.update_ema_every == 0:
                self.step_ema()
            self.step += 1
        # if self.lr_decay:
        #     self.predictor_lr_scheduler.step()
        return None

    def denoise_state(self, noise_next_state, timestep):
        # Core: Method 1 is to Average 50 results

        repeat_times = self.config["repeat_times"]

        input_dim = noise_next_state.shape[0]
        # assert input_dim != 1
        if input_dim == 1:
            if timestep >= 1:
                noise_next_state = self.ema_model.recover(noise_next_state, timestep - 1)
            noise_next_state_rpt = torch.repeat_interleave(
                    noise_next_state, repeats=repeat_times, dim=0
                )
            with torch.no_grad():
                return_state = self.ema_model(
                    noise_next_state_rpt,
                    None,
                    None,
                    timestep).reshape(repeat_times, input_dim, self.condition_step * 2 + 1, self.state_dim)
            final_state = torch.mean(return_state, dim=0)
            return final_state
        else:
            if timestep >= 1:
                noise_next_state = self.ema_model.recover(noise_next_state, timestep - 1)
            noise_state_rpt = torch.repeat_interleave(
                        noise_next_state.unsqueeze(0), repeats=repeat_times, dim=0
                    ).reshape(repeat_times * input_dim, self.condition_step * 2 + 1, self.state_dim)
            with torch.no_grad():
                return_state = self.ema_model(
                    noise_state_rpt,
                    None,
                    None,
                    timestep).reshape(repeat_times, input_dim, self.condition_step * 2 + 1, self.state_dim)
            final_state = torch.mean(return_state, dim=0)
            return final_state

    def save_model(self, file_name):
        logger.info('Saving models to {}'.format(file_name))
        torch.save({'actor_state_dict': self.predictor.state_dict(),
                    'ema_state_dict': self.ema_model.state_dict(),
                    'actor_optimizer_state_dict': self.predictor_optimizer.state_dict()}, file_name)

    def save_checkpoint(self, file_name):
        logger.info('Saving Checkpoint model to {}'.format(file_name))
        torch.save({'ema_state_dict': self.ema_model.state_dict()}, file_name)

    def load_model(self, file_name, device_idx=0):
        logger.info(f'Loading models from {file_name}')
        if file_name is not None:
            checkpoint = torch.load(file_name, map_location=f'cuda:{device_idx}')
            self.predictor.load_state_dict(checkpoint['actor_state_dict'])
            self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
            self.predictor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])

    def load_checkpoint(self, file_name, device_idx=0):
        if file_name is not None:
            checkpoint = torch.load(file_name, map_location=f'cuda:{device_idx}')
            self.ema_model.load_state_dict(checkpoint['ema_state_dict'])
            self.predictor = copy.deepcopy(self.ema_model)