import argparse
import ruamel_yaml as yaml
from pathlib import Path
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import StudentT
from torch import einsum
from einops import repeat, rearrange
import warnings
warnings.filterwarnings('ignore')


def custom_nll_loss(y_hat, y):
    if y_hat.dtype == torch.float16:
        y_hat = y_hat.to(dtype=torch.float32)
    y = y.to(dtype=torch.float32)

    mu, sigma, nu = y_hat[..., 0], y_hat[..., 1], y_hat[..., 2]
    
    dist = StudentT(df=nu, loc=mu, scale=sigma)

    pdf = dist.log_prob(y)

    nll = -torch.mean(pdf)
    return nll


def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
    

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


class EmbedToLatents(nn.Module):
    def __init__(self, dim, dim_latents):
        super().__init__()
        self.to_latents = nn.Linear(dim, dim_latents, bias=False)

    def forward(self, x):
        latents = self.to_latents(x)
        return F.normalize(latents, dim=-1)


class RotaryEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, max_seq_len, *, device):
        seq = torch.arange(max_seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = einsum("i , j -> i j", seq, self.inv_freq)
        return torch.cat((freqs, freqs), dim=-1)


def rotate_half(x):
    x = rearrange(x, "... (j d) -> ... j d", j=2)
    x1, x2 = x.unbind(dim=-2)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(pos, t):
    return (t * pos.cos()) + (rotate_half(t) * pos.sin())


class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x


class ParallelTransformerBlock(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, ff_mult=4, dropout_rate=None):
        super().__init__()
        self.norm = LayerNorm(dim)

        attn_inner_dim = dim_head * heads
        ff_inner_dim = dim * ff_mult
        self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

        self.heads = heads
        self.scale = dim_head**-0.5
        self.rotary_emb = RotaryEmbedding(dim_head)

        self.dropout_rate = dropout_rate
        if self.dropout_rate is not None:
            self.dropout_layer_attn = nn.Dropout(self.dropout_rate)
            self.dropout_layer_ffn = nn.Dropout(self.dropout_rate)

        self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
        self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)

        self.ff_out = nn.Sequential(
            SwiGLU(),
            nn.Linear(ff_inner_dim, dim, bias=False)
        )

        self.mask = None
        self.pos_emb = None

    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n].to(device)

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.mask = mask
        return mask

    def get_rotary_embedding(self, n, device):
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n].to(device)

        pos_emb = self.rotary_emb(n, device=device)
        self.pos_emb = pos_emb
        return pos_emb

    def forward(self, x, attn_mask=None):
        n, device, h = x.shape[1], x.device, self.heads

        x = self.norm(x)

        # attention queries, keys, values, and feedforward inner
        q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)

        # split heads
        q = rearrange(q, "b n (h d) -> b h n d", h=h)

        # rotary embeddings
        positions = self.get_rotary_embedding(n, device)
        q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))

        # scale
        q = q * self.scale

        # similarity
        sim = einsum("b h i d, b j d -> b h i j", q, k)

        # causal mask
        causal_mask = self.get_mask(n, device)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        # extra attention mask - for masking out attention from text CLS token to padding
        if exists(attn_mask):
            attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')
            sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

        # attention
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
        if self.dropout_rate is not None:
            attn = self.dropout_layer_attn(attn)

        # aggregate values
        out = einsum("b h i j, b j d -> b h i d", attn, v)

        # merge heads
        out = rearrange(out, "b h n d -> b n (h d)")
        if self.dropout_rate is not None:
            return self.attn_out(out) + self.dropout_layer_ffn(self.ff_out(ff))
        else:
            return self.attn_out(out) + self.ff_out(ff)


class CrossAttention(nn.Module):
    def __init__(
        self,
        dim,
        *,
        context_dim=None,
        dim_head=64,
        heads=8,
        parallel_ff=False,
        ff_mult=4,
        dropout_rate=None,
        norm_context=False,
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        inner_dim = heads * dim_head
        context_dim = default(context_dim, dim)

        self.norm = LayerNorm(dim)
        self.context_norm = LayerNorm(context_dim) if norm_context else nn.Identity()
        
        self.dropout_rate = dropout_rate
        if self.dropout_rate is not None:
            self.dropout_layer_attn = nn.Dropout(self.dropout_rate)

        self.to_q = nn.Linear(dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, dim_head * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, dim, bias=False)

        # whether to have parallel feedforward
        ff_inner_dim = ff_mult * dim

        if self.dropout_rate is None:
            self.ff = nn.Sequential(
                nn.Linear(dim, ff_inner_dim * 2, bias=False),
                SwiGLU(),
                nn.Linear(ff_inner_dim, dim, bias=False),
            ) if parallel_ff else None
        else:
            self.ff = nn.Sequential(
                nn.Linear(dim, ff_inner_dim * 2, bias=False),
                SwiGLU(),
                nn.Linear(ff_inner_dim, dim, bias=False),
                nn.Dropout(self.dropout_rate),
            ) if parallel_ff else None

        self.mask = None

    def get_mask(self, n, device):
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n].to(device)

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.mask = mask
        return mask

    def forward(self, x, context, attn_mask=None):
        n, device = x.shape[1], x.device

        x = self.norm(x)
        context = self.context_norm(context)

        # get queries
        q = self.to_q(x)
        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        # scale
        q = q * self.scale

        # get key / values
        k, v = self.to_kv(context).chunk(2, dim=-1)

        # query / key similarity
        sim = einsum('b h i d, b j d -> b h i j', q, k)

        # attention
        sim = sim - sim.amax(dim=-1, keepdim=True)
        attn = sim.softmax(dim=-1)
        if self.dropout_rate is not None:
            attn = self.dropout_layer_attn(attn)

        # aggregate
        out = einsum('b h i j, b j d -> b h i d', attn, v)

        # merge and combine heads
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)

        # add parallel feedforward (for multimodal layers)
        if exists(self.ff):
            out = out + self.ff(x)

        return out


class DualForecaster(nn.Module):
    def __init__(
        self,
        *,
        dim,
        ts_token_size,
        ts_output_size,
        unimodal_depth,
        multimodal_depth,
        dim_latents=None,
        text_dim=None,
        num_text_queries=None,
        dim_head=None,
        heads=None,
        ff_mult=None,
        dropout_rate_fcst=None,
        dropout_rate_cons=None,
        textAug=True,
        addfuture=True,
        text_encoder=None,
        text_encoder_type=None,
        text_encoder_frozen_flag=False,
        forecast_loss_weight=None,
        contrastive_loss_weight=None,
        pad_id=None,
        num_vars=None,
        temperature=None,
    ):
        super(DualForecaster, self).__init__()
        self.dim = dim
        self.num_vars = num_vars
        self.ts_token_size = ts_token_size
        self.ts_output_size = ts_output_size
        self.unimodal_depth = unimodal_depth
        self.multimodal_depth = multimodal_depth
        self.dim_latents = dim_latents
        self.text_dim = text_dim
        self.num_text_queries = num_text_queries
        self.dim_head = dim_head
        self.heads = heads
        self.ff_mult = ff_mult

        self.dropout_rate_fcst = dropout_rate_fcst
        self.dropout_rate_cons = dropout_rate_cons

        self.pad_id = pad_id
        self.forecast_loss_weight = forecast_loss_weight
        self.contrastive_loss_weight = contrastive_loss_weight

        self.temperature = temperature

        # time series token embeddings
        self.ts_token_emb = nn.Linear(self.ts_token_size, self.dim)
        self.ts_cls_token = nn.Parameter(torch.randn(self.dim))

        self.textAug = textAug
        self.addfuture = addfuture
        if self.textAug:

            # text encoder
            self.text_encoder = text_encoder
            self.text_encoder_type = text_encoder_type
            self.text_encoder_frozen_flag = text_encoder_frozen_flag
            if self.text_encoder_frozen_flag:
                for param in self.text_encoder.parameters(): # Frozen
                    param.requires_grad = False

            # attention pooling for text tokens
            self.text_queries = nn.Parameter(torch.randn(self.num_text_queries + 1, self.dim)) # num text queries for multimodal, but 1 extra CLS for contrastive learning
            self.text_attn_pool = CrossAttention(dim=self.dim, context_dim=self.text_dim, dim_head=self.dim_head, heads=self.heads, norm_context=True, dropout_rate=self.dropout_rate_cons)
            self.text_attn_pool_norm = LayerNorm(self.dim)
            
            dim_latents = default(self.dim_latents, self.dim)

            if self.addfuture:

                self.ts_text_tokens_f_attn = Residual(CrossAttention(dim=self.dim, context_dim=self.dim, dim_head=self.dim_head, heads=self.heads, parallel_ff=True, norm_context=True, dropout_rate=self.dropout_rate_fcst))
                
            # to latents
            self.ts_to_latents = EmbedToLatents(self.dim, dim_latents)
            self.text_to_latents = EmbedToLatents(self.dim, dim_latents)
        
        self.ts_cls_norm = LayerNorm(self.dim)

        # contrastive learning temperature
        self.temperature = nn.Parameter(torch.Tensor([self.temperature]))

        # unimodal layers
        self.unimodal_layers = nn.ModuleList([])
        for ind in range(self.unimodal_depth):
            self.unimodal_layers.append(
                Residual(ParallelTransformerBlock(dim=self.dim, dim_head=self.dim_head, heads=self.heads, ff_mult=self.ff_mult, dropout_rate=self.dropout_rate_fcst)),
            )

        # multimodal layers
        self.multimodal_layers = nn.ModuleList([]) # w/o History-side Cross-Attention or w/o History Texts
        if self.textAug:
            for ind in range(self.multimodal_depth):
                self.multimodal_layers.append(nn.ModuleList([
                    Residual(ParallelTransformerBlock(dim=self.dim, dim_head=self.dim_head, heads=self.heads, ff_mult=self.ff_mult, dropout_rate=self.dropout_rate_fcst)),
                    Residual(CrossAttention(dim=self.dim, dim_head=self.dim_head, heads=self.heads, parallel_ff=True, ff_mult=self.ff_mult, dropout_rate=self.dropout_rate_fcst))                    
                ]))

        # Student's T-distribution
        self.final_layer_prob = nn.Linear(self.dim, self.ts_output_size * 3)

    def embed_ts(self, ts):
        batch = ts.shape[0]

        # patch embedding
        ts_tokens = self.ts_token_emb(ts) # ts_tokens: (batch_size, num_ts_tokens, dim)

        # create specific mask for ts cls token at the end
        # to prevent it from attending to padding
        cls_mask = rearrange(ts.sum(dim=-1)!=0, 'b j -> b 1 j')
        attn_mask = F.pad(cls_mask, (0, 1), value=True)
        attn_mask = repeat(attn_mask, 'b 1 j -> b i j', i=attn_mask.shape[2])

        # append ts cls tokens
        ts_cls_tokens = repeat(self.ts_cls_token, 'd -> b 1 d', b=batch)
        ts_tokens = torch.cat((ts_tokens, ts_cls_tokens), dim=-2) # ts_tokens: (batch_size, num_ts_tokens+1, dim)

        # go through unimodal layers
        for attn_ff in self.unimodal_layers:
            ts_tokens = attn_ff(ts_tokens, attn_mask=attn_mask) # ts_tokens: (batch_size, num_ts_tokens+1, dim)

        # get ts cls token
        ts_tokens, ts_cls_tokens = ts_tokens[:, :-1], ts_tokens[:, -1]
        ts_embeds = self.ts_cls_norm(ts_cls_tokens)
        
        return ts_embeds, ts_tokens
    
    def embed_text(self, text_h=None, text_tokens_h=None, text_f=None, text_tokens_f=None):
        # encode texts into embeddings
        # with the text_encoder passed in at init
        # it can also accept precomputed text tokens
        if self.addfuture:

            assert not (exists(text_h) and exists(text_tokens_h) and exists(text_f) and exists(text_tokens_f))

            if exists(text_h) and exists(text_f):
                assert exists(self.text_encoder), 'text_encoder must be passed in for automatic text encoding'
                
                if self.text_encoder_type=='Roberta':
                    text_tokens_h = self.text_encoder(text_h.input_ids, attention_mask=text_h.attention_mask, return_dict=True)
                    text_tokens_h = text_tokens_h.last_hidden_state # text_queries_h: (batch_size, num_text_tokens, dim)
                    
                    text_tokens_f = self.text_encoder(text_f.input_ids, attention_mask=text_f.attention_mask, return_dict=True)
                    text_tokens_f = text_tokens_f.last_hidden_state # text_tokens_f: (batch_size, num_text_tokens, dim)
                else:
                    text_tokens_h = self.text_encoder.get_input_embeddings()(text_h.input_ids)
                    text_tokens_f = self.text_encoder.get_input_embeddings()(text_f.input_ids)
        
                # attention pool text tokens
                text_queries = repeat(self.text_queries, 'n d -> b n d', b=text_tokens_h.shape[0])
                text_queries_h = self.text_attn_pool(text_queries, text_tokens_h) # text_queries_h: (batch_size, num_text_tokens+1, dim)
                text_queries_h = self.text_attn_pool_norm(text_queries_h)
                text_queries_f = self.text_attn_pool(text_queries, text_tokens_f) # text_queries_h: (batch_size, num_text_tokens+1, dim)
                text_queries_f = self.text_attn_pool_norm(text_queries_f)

            return text_queries_h[:, 0], text_queries_h[:, 1:], text_queries_f[:, 0], text_queries_f[:, 1:]

            # assert not (exists(text_f) and exists(text_tokens_f)) # w/o History Texts

            # if exists(text_f):
            #     assert exists(self.text_encoder), 'text_encoder must be passed in for automatic text encoding'
                                
            #     text_tokens_f = self.text_encoder(text_f.input_ids, attention_mask=text_f.attention_mask, return_dict=True)
            #     text_tokens_f = text_tokens_f.last_hidden_state # text_tokens_f: (batch_size, num_text_tokens, dim)
        
            #     # attention pool text tokens
            #     text_queries = repeat(self.text_queries, 'n d -> b n d', b=text_tokens_f.shape[0])
            #     text_queries_f = self.text_attn_pool(text_queries, text_tokens_f) # text_queries_h: (batch_size, num_text_tokens+1, dim)
            #     text_queries_f = self.text_attn_pool_norm(text_queries_f)

            # return text_queries_f[:, 0], text_queries_f[:, 1:]
        
        else:

            assert not (exists(text_h) and exists(text_tokens_h))

            if exists(text_h):
                assert exists(self.text_encoder), 'text_encoder must be passed in for automatic text encoding'
                
                if self.text_encoder_type=='Roberta':
                    text_tokens_h = self.text_encoder(text_h.input_ids, attention_mask=text_h.attention_mask, return_dict=True)
                    text_tokens_h = text_tokens_h.last_hidden_state # text_queries_h: (batch_size, num_text_tokens, dim)
                else:
                    text_tokens_h = self.text_encoder.get_input_embeddings()(text_h.input_ids)
                        
                # attention pool text tokens
                text_queries = repeat(self.text_queries, 'n d -> b n d', b=text_tokens_h.shape[0])
                text_queries_h = self.text_attn_pool(text_queries, text_tokens_h) # text_queries_h: (batch_size, num_text_tokens+1, dim)
                text_queries_h = self.text_attn_pool_norm(text_queries_h)
        
            return text_queries_h[:, 0], text_queries_h[:, 1:]
        
    def forward(
        self,
        ts,
        text_h=None,
        text_f=None,
        text_tokens_h=None,
        text_tokens_f=None,
        y=None,
        return_loss=False,
        return_embeddings=False
    ):
        batch, device = ts.shape[0], ts.device

        ts = ts.permute(0, 2, 1)
        ts = ts.unfold(dimension=-1, size=self.ts_token_size, step=self.ts_token_size)
        ts = ts.squeeze(1)

        # time series embedding
        ts_embeds, ts_tokens = self.embed_ts(ts) # ts: (batch_size, num_ts_tokens, ts_token_size)

        if self.textAug:

            # text embedding            
            if self.addfuture:
                text_embeds_h, text_tokens_h, text_embeds_f, text_tokens_f = self.embed_text(text_h=text_h.to(device), # text_h: (batch_size, num_text_tokens)
                                                                                             text_tokens_h=text_tokens_h, 
                                                                                             text_f=text_f.to(device), # text_f: (batch_size, num_text_tokens)
                                                                                             text_tokens_f=text_tokens_f)
                                                
                # text_embeds_f, text_tokens_f = self.embed_text(text_f=text_f.to(device), # text_f: (batch_size, num_text_tokens) ## w/o History Texts
                #                                                text_tokens_f=text_tokens_f)
            else:
                text_embeds_h, text_tokens_h = self.embed_text(text_h=text_h.to(device), text_tokens_h=text_tokens_h) # text_h: (batch_size, num_text_tokens)

            # go through multimodal layers
            for attn_ff, cross_attn in self.multimodal_layers: # w/o History-side Cross-Attention or w/o History Texts
                ts_tokens = attn_ff(ts_tokens)
                ts_tokens = cross_attn(ts_tokens, text_tokens_h)

            if self.addfuture:
                ts_tokens = self.ts_text_tokens_f_attn(ts_tokens, text_tokens_f) # w/o Future-side Cross-Attention
                    
        forecast_prob = self.final_layer_prob(ts_tokens[:, -1, :]) # ts_tokens: (batch_size, num_ts_tokens, dim)
        forecast_prob = forecast_prob.view(-1, self.ts_output_size, 3)
        scale = torch.exp(forecast_prob[..., 1])
        df = F.softplus(forecast_prob[..., 2]) + 1
        loc = forecast_prob[..., 0]
        forecast_prob = torch.stack((loc, scale, df), dim=-1)

        # contrastive loss function
        ce = F.cross_entropy

        # calculate forecast loss (NLL loss)
        y = rearrange(y, 'b n c -> b (n c)')
        forecast_loss = custom_nll_loss(forecast_prob, y)
        forecast_loss = forecast_loss * self.forecast_loss_weight

        if self.textAug:
            # embedding to latents
            ts_latents = self.ts_to_latents(ts_embeds)
            text_latents = self.text_to_latents(text_embeds_h)

            # return embeddings if that is what the researcher wants
            if return_embeddings and not return_loss:
                return ts_latents, text_latents, ts_tokens[:, -1, :], forecast_prob

            # calculate contrastive loss (history text vs history ts)
            sim = einsum('i d, j d -> i j', ts_latents, text_latents)
            sim = sim * self.temperature.exp()
            contrastive_labels = torch.arange(batch, device=device)

            contrastive_loss = (ce(sim, contrastive_labels) + ce(sim.t(), contrastive_labels)) * 0.5
            contrastive_loss = contrastive_loss * self.contrastive_loss_weight
            return forecast_loss, contrastive_loss, forecast_prob

            # if return_embeddings and not return_loss: # # w/o History Texts
            #     return ts_tokens[:, -1, :], forecast_prob
            # return forecast_loss, forecast_prob
        else:
            if return_embeddings and not return_loss:
                return ts_tokens[:, -1, :], forecast_prob
            return forecast_loss, forecast_prob
    

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='./configs/forecast.yaml', help='Configuration of MMTSFM')
    parser.add_argument('--output_dir', default='./output/forecast')
    parser.add_argument('--device', default='cuda')
    args = parser.parse_args()

    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))

    device = torch.device(args.device)

    model = DualForecaster(
        dim=config['dim'],
        num_vars=config['num_vars'],
        ts_token_size=config['token_size'],
        ts_output_size=config['output_size'],
        unimodal_depth=config['unimodal_depth'],
        multimodal_depth=config['multimodal_depth'],
        text_encoder_layers=config['text_encoder_layers'],
        max_len=config['max_len'],
        text_dim=config['text_dim'],
        num_text_queries=config['num_text_queries'],
        heads=config['heads'],
        dim_head=config['dim_head'],
        dropout_rate_fcst=config['dropout_rate_fcst'],
        dropout_rate_cons=config['dropout_rate_cons'],
        textAug=config['textAug_flag'],
        text_encoder_frozen_flag=config['text_encoder_frozen_flag'],
        forecast_loss_weight=config['forecast_loss_weight'],
        contrastive_loss_weight=config['contrastive_loss_weight'],
    )

    model = model.to(device)   

    ts_x = torch.rand(5, 200, 1).to(device)    # (batch, seq_len, channel)
    ts_y = torch.rand(5, 4, 1).to(device)      # (batch, pred_len, channel)
    ts_mask = torch.rand(5, 200, 1).to(device) # (batch, seq_len, channel)
    text = ('The time series exhibits noisy cosine-like seasonal patterns with a period of 120, and its oscillation amplitude gradually diminishes over time.',
            'The time series exhibits noisy cosine-like seasonal patterns with a period of 120, and its oscillation amplitude gradually diminishes over time.',
            'The time series exhibits noisy cosine-like seasonal patterns with a period of 120, and its oscillation amplitude gradually diminishes over time.',
            'The time series exhibits noisy cosine-like seasonal patterns with a period of 120, and its oscillation amplitude gradually diminishes over time.',
            'The time series exhibits noisy cosine-like seasonal patterns with a period of 120, and its oscillation amplitude gradually diminishes over time.')
    
    loss_forecast, loss_contrastive = model(ts=ts_x, y=ts_y, mask=ts_mask, text=text, return_loss=True, return_embeddings=False)
    print(loss_forecast.item(), loss_contrastive.item())
