# -*- coding: utf-8 -*-
"""my_simple_transformer

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1nScp_X8P6NlDpPgc-BZBY5QmXgd2TAy3
"""

import numpy as np
import  os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import torch
import torch.nn as nn
import torch.nn.functional as F
#import wget
import sys
from utils.dataProcess import normalize , denormalize

import math
from dataclasses import dataclass

import os
import pickle
import sys

########################### hyperparameters
torch.manual_seed(1337)
#block_size = 256 #maximum context length to be considered as an input for prediction
batch_size = 64 #how many independent sequence s will we process in parallel
#max_iters = 5000
max_iters = 50
#eval_interval = 500
eval_interval = 3
#eval_iters =200
eval_iters =2
learning_rate = 3e-4
device = torch.device("cuda:0" if torch.cuda.is_available()  else "cpu")
#n_embed=384 #C  ---> every head is 384/6= 64 dimensional  we assumed also head_size=C= n_embed   # gpt2_large:1280
#n_head = 6  # gpt2_large:20
#n_layer = 6   # gpt2_large:36
#dropout = 0.2
#
sys.exit(3)
def new_gelu(x):
  """
  Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
  Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
  """
  return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
"""my transformer"""

# url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
# filename = wget.download(url)
#print(filename)


@dataclass
class GPTConfig:
    # block_size: int = 1024   #0.0614 seconds context//  usually be overrided by context_size*2
    # vocab_size: int = 25   # in TS it is input dimension e.g 25 features // word embedding layer was changed to simple linear layer
    # n_layer: int = 12
    # n_head: int = 12
    # n_embed: int = 768 #C #every head is 768/12= 64 dimensional  we assumed also head_size=C= n_embed
    # head_size = n_embed//n_head #C
    # dropout: float = 0.1
    #-----------------------------------

    block_size: int = 128   #0.0614 seconds context
    vocab_size: int = 25   # in TS it is input dimension e.g 25 features // token embedding layer was changed to simple linear layer
    n_layer: int = 6
    n_head: int = 6
    n_embed: int = 384 #C #every head is 768/12= 64 dimensional  we assumed also head_size=C= n_embed
    head_size = n_embed//n_head #C = 64
    dropout: float = 0.1
    #-----------------------------------

    # #0.318M params
    # block_size: int = 8   # usually be overrided by context_size*2   = 256
    # vocab_size: int = 25
    # n_layer: int = 6
    # n_head: int = 6
    # n_embed: int = 96 #C #every head is 768/12= 64 dimensional  we assumed also head_size=C= n_embed
    # head_size = n_embed//n_head #C
    # dropout: float = 0.1
    #---------------------------------
    # #0.318M params
    # block_size: int = 8   # usually be overrided by context_size*2   = 256
    # vocab_size: int = 25
    # n_layer: int = 6
    # n_head: int = 4
    # n_embed: int = 64 #C #every head is 768/12= 64 dimensional  we assumed also head_size=C= n_embed
    # head_size = n_embed//n_head #C
    # dropout: float = 0.1
    #----------------------
    # #0.212M params
    # block_size: int = 8   # usually be overrided by context_size*2   = 256
    # vocab_size: int = 25
    # n_layer: int = 4
    # n_head: int = 4
    # n_embed: int = 64 #C #every head is 768/12= 64 dimensional  we assumed also head_size=C= n_embed
    # head_size = n_embed//n_head #C
    # dropout: float = 0.1

#--------


class head(nn.Module):

  #def __init__(self, head_size):
  def __init__(self, config):
    super().__init__()
    assert config.n_embed % config.n_head == 0
    self.key    = nn.Linear(config.n_embed,config.head_size, bias=False)
    self.query  = nn.Linear(config.n_embed,config.head_size, bias=False)
    self.value  = nn.Linear(config.n_embed,config.head_size , bias=False)
    #self.register_buffer('tril',torch.tril(torch.ones(config.block_size,config.block_size)))
    self.tril_prime =  torch.tril(torch.ones(config.block_size,config.block_size))
    self.dropout = nn.Dropout(config.dropout)

  def forward(self,x):
    B,T,C = x.shape
    # KEY: what information I contain
    k = self.key(x)   #(B,T,C)  --->  (B,T,head_size)
    # query: what information I am looking for
    q = self.query(x) #(B,T,C)  --->  (B,T,head_size)
    #value: what info I will communicate in case of a match
    v = self.value(x) #(B,T,C)  --->  (B,T,head_size)  #
    # q*k    #ruye head_size=n_embed bayad dot beshan

    wei = torch.einsum('bij,bkj->bik',q,k) * C**-0.5   # (B,T,head_size=C) * (B,T,head_size=C).T  ---> (B,T,T)
    #print('wei.shape = ',wei.shape)
    #print("self.tril.shape = ",self.tril.shape)
    #wei = wei.masked_fill(self.tril==0 , float('-inf')) # (B,T,T)  #in generation at the start, T=1 but tril is 8*8 and wei is 1*1 but trill is 8*8 --->error in masking
    with torch.no_grad():
      #wei = wei.masked_fill(self.tril[:T,:T]==0 , float('-inf')) # (B,T,T)  #in generation at the start, T=1
      wei = wei.masked_fill(self.tril_prime.detach().to(device)[:T,:T]==0 , float('-inf')) # (B,T,T)  #in generation at the start, T=1
    #print(self.tril_prime)

    wei = F.softmax(wei,dim=-1)       # (B,T,T)
    wei = self.dropout(wei)
    #Weighted aggregation:  # (B,T,T) * (B,T,head_size=C) ---> (B,T,head_size=C)
    out = torch.matmul(wei,v)  # (B,T,C)  ----> Then we apply another linear layer to map it to (B,T,vocab_size) ---> then we pick sample based on probs --> (B,T)
    return out

class MultiHeadAttention(nn.Module):    #communication
  #def __init__(self,num_heads,head_size):
  def __init__(self, config):
    super().__init__()
    self.n_head  = config.n_head #added recently
    self.n_embed = config.n_embed

    self.heads = nn.ModuleList([head(config) for _ in range(self.n_head)])  # defining lists of heads  #(B,T,C//4 + C//4 + C//4 +C//4) = (B,T,C)
    self.proj = nn.Linear(self.n_embed,self.n_embed) #(B,T,C) --> (B,T,C)
    self.dropout = nn.Dropout(config.dropout)
  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads] , dim=-1)
    out = self.dropout(self.proj(out))
    return out

class FeedForward(nn.Module):      ##computations

  #def __init__(self, n_embed):
  def __init__(self, config):
    super().__init__()
    self.c_fc = nn.Linear(config.n_embed, 4 * config.n_embed)
    self.c_proj = nn.Linear(4 * config.n_embed, config.n_embed)
    self.dropout = nn.Dropout(config.dropout)


    # self.net = nn.Sequential(
    #   nn.Linear(n_embed,4 * n_embed),
    #   nn.ReLU(),
    #   nn.Linear(4 * n_embed, n_embed),
    #   nn.Dropout(dropout)
    # )


  def forward(self,x):
    x = self.c_fc(x)
    x = new_gelu(x)
    x = self.c_proj(x)
    x = self.dropout(x)
    return x
    #return self.net(x)

class Block(nn.Module):
  #def __init__(self,n_embed,n_head):
  def __init__(self, config):
    super().__init__()
    #head_size = n_embed//n_head
    #self.sa = MultiHeadAttention(num_heads=n_head,head_size=head_size)
    self.sa = MultiHeadAttention(config)
    #self.feedforward = FeedForward(n_embed=n_embed)
    self.feedforward = FeedForward(config)
    self.ln1 = nn.LayerNorm(config.n_embed) #n_embed=32  normalization is done on the Time dimension
    self.ln2 = nn.LayerNorm(config.n_embed)

  def forward(self, x):
    x = x + self.sa(self.ln1(x))
    out = x + self.feedforward(self.ln2(x))
    return out

# class LayerNorm(nn.Module):
#   def __init__(self,dim, eps=1e-5, momentum=0.1):
#     super().__init__()
#     self.eps = eps
#     self.gamma = torch.ones(dim)
#     self.beta = torch.zeros(dim)
#
#   def forward(self,x):
#     #forward pass
#     xmean = x.mean(dim=1,keepdim=True)
#     xvar  = x.var(dim=1, keepdim=True)
#     xhat  = (x - xmean) / torch.sqrt(xvar + self.eps)
#     self.out = self.gamma * xhat + self.beta
#     return self.out
#
#   def parameters(self):
#     return [self.gamma, self.beta]





# a simple bigram model
class TransformerModel(nn.Module):
  #def __init__(self):
  def __init__(self,config):
    super().__init__()
    assert config.vocab_size is not None
    assert config.block_size is not None
    self.config = config
    #vocab_size is the maximum number that we may see in the tokens #embeding layre is like one_hot + linear, where its first arg is one_hot max number and the 2nd arg is the output of linear layer

    #block_size (e.g 80 is the maximum number that we can pass to the below layer  ---> so in the generation phase we should only pass the last block size samples (e.g 8) to it
    #self.position_embedding_table = nn.Embedding(config.block_size,config.n_embed) #(B,T)--->(B,T,C)
    #self.sa_head = head(head_size=n_embed)  # (B,T,C) ---> (B,T,head_size=C)
    #self.sa_head = MultiHeadAttention(num_heads=4, head_size=n_embed//4 )  # 4head, whith head_size of 32//4=8
    #self.feedforwad = FeedForward(n_embed)
    self.transformer = nn.ModuleDict(dict(
      # token_embedding = nn.Embedding(config.vocab_size,config.n_embed), # (B,T) ---> (B,T,C=32)
      # position_embedding_table = nn.Embedding(config.block_size,config.n_embed), #(B,T)--->(B,T,C)  -->should remain the same in time-series forecasstig
      token_embedding=nn.Linear(config.vocab_size, config.n_embed),    #(B,T,#f)--->(B,T,C)
      position_embedding_table=nn.Embedding(config.block_size, config.n_embed),  # (B,T)--->(B,T,C)
      drop=nn.Dropout(config.dropout),
      h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
      ln_f=nn.LayerNorm(config.n_embed), #final layer norm
    ))

    #self.blocks = nn.Sequential(*[Block(n_embed,n_head=n_head) for _ in range (n_layer)])
    # self.ln_f = nn.LayerNorm(config.n_embed) #final layer norm

    self.lm_head = nn.Linear(config.n_embed,config.vocab_size)# PLEASE LOOK AT "bias" OF THIS LAYER LATER  # (B,T,C) ---> (B,T,vocab_size)
    #self.transformer.token_embedding.weight = self.lm_head.weight # "UserWarning: functional_call was passed multiple values for tied weights. # https://paperswithcode.com/method/weight-tying

    # report number of parameters
    # n_params = sum(p.numel() for p in self.parameters())
    # print("number of parameters: %.2fM" % (n_params / 1e6,))

  def forward(self,inp,target=None):
    #device = inp.device
    #print("device:",device)
    #print('forward is called')
    B,T, f =inp.shape #32,8  # T=8 in training however when we do generate() it increases up pto 500
    #print(B,T)
    assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
    pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)  # shape (1, T)

    # forward the GPT model itself
    pos_emb = self.transformer.position_embedding_table(pos)  #self.position_embedding_table(torch.arange(T,device=device))  # (T,C)    #in generate() T is being increases however self.position_embedding layer only accept embeding up to size of "block_size=8"
    #print("inp.shape",inp.shape)
    #print(self.transformer.token_embedding)
    tok_emb = self.transformer.token_embedding(inp)  # self.token_embeding(inp) # (B,T,C)
    x = self.transformer.drop(tok_emb + pos_emb)
    #x = tok_emb + pos_emb  #(B,T,C)
    #x = self.sa_head(x)  # Apply one self Attention  #(B,T,C) ---> #(B,T,head_size = C)
    for block in self.transformer.h:
      x = block(x)
    x = self.transformer.ln_f(x)  # (B,T,C)
    #x = self.feedforwad(x) # (B,T,C) -->  (B,T,C)

    if target is not None:
      # if we are given some desired targets also calculate the loss
      logits = self.lm_head(x)  # (B,T,C)  --> (B,T,vocab_size)
      B,T,C = logits.shape
      #logits=logits.view(B*T,C)
      #target=target.view(B*T)
      #loss =F.cross_entropy(logits,target) #pytorch expect (Batch,Channel,d1,d2,..)
      criterion = nn.MSELoss()
      loss = criterion(logits,target)
      return logits, loss
    else:
      #logits = self.lm_head(x[:, -1:, :])  # note: using list [-1]//or// -1: to preserve the time dim
      logits = self.lm_head(x[:, :, :])
      return logits,None

  def configure_optimizers(self, weight_decay, learning_rate, betas):
    """
    This long function is unfortunately doing something very simple and is being very defensive:
    We are separating out all parameters of the model into two buckets: those that will experience
    weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
    We are then returning the PyTorch optimizer object.
    """

    # separate out all parameters to those that will and won't experience regularizing weight decay
    decay = set()
    no_decay = set()
    whitelist_weight_modules = (torch.nn.Linear,)
    blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
    for mn, m in self.named_modules():
      for pn, p in m.named_parameters():
        fpn = '%s.%s' % (mn, pn) if mn else pn  # full param name
        # random note: because named_modules and named_parameters are recursive
        # we will see the same tensors p many many times. but doing it this way
        # allows us to know which parent module any tensor p belongs to...
        if pn.endswith('bias'):
          # all biases will not be decayed
          no_decay.add(fpn)
        elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
          # weights of whitelist modules will be weight decayed
          decay.add(fpn)
        elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
          # weights of blacklist modules will NOT be weight decayed
          no_decay.add(fpn)

    # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they
    # will appear in the no_decay and decay sets respectively after the above.
    # In addition, because named_parameters() doesn't return duplicates, it
    # will only return the first occurence, key'd by 'transformer.wte.weight', below.
    # so let's manually remove 'lm_head.weight' from decay set. This will include
    # this tensor into optimization via transformer.wte.weight only, and not decayed.
    decay.remove('lm_head.weight')

    # validate that we considered every parameter
    param_dict = {pn: p for pn, p in self.named_parameters()}
    inter_params = decay & no_decay
    union_params = decay | no_decay
    assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),)
    assert len(
      param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                              % (str(param_dict.keys() - union_params),)

    # create the pytorch optimizer object
    optim_groups = [
      {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
      {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
    ]
    optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
    return optimizer

  def generate(self,inp_idx,multiStep, normalizer=None, tar_burn_in=None ,top_k=None): #(B,T)  #need to normalize distribution every step because predicted-distribution is not normalized
    """
    Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
    the sequence max_new_tokens times, feeding the predictions back into the model each time.
    Most likely you'll want to make sure to be in model.eval() mode of operation for this.
    """
    #print("generate: inp.shape=",inp_idx.shape)
    #print("it was supposed to be (B,T) initally // PLEASE double check")
    for _ in range(multiStep):
      idx_cond = inp_idx if inp_idx.size(1) <= self.config.block_size else inp_idx[:, -self.config.block_size:,:]
      logits , _ = self(idx_cond) #(inp_idx =B,T,C)  #self.position_embedding layer only accept embeding up to size of "block_size=8"
      logits = logits[:,-1:,:]
      # optionally crop the logits to only the top k options
      # if top_k is not None:
      #   v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
      #   logits[logits < v[:, [-1]]] = -float('Inf')
      #probs = F.softmax(logits,dim=-1) # (B,C)
      #next_sample = torch.multinomial(probs,num_samples=1) #[B,1]
      next_sample = logits  #[B , 1 , #F]
      #next_sample[:,:,:-1]  = normalize(next_sample[:,:,:-1],normalizer['observations'][0][:-1],normalizer['observations'][1][:-1])  # we need this line if the  model's input should be normalized
      #print("new_sample.shape:" , next_sample.shape)

      inp_idx = torch.cat([inp_idx,next_sample],dim=1) #[B,T+1, #f]          # [B,T+1] for NLP
    return inp_idx


if __name__ =="__main__":

  torch.manual_seed(0)

  @torch.no_grad()
  def estimate_loss():   #model/data/
    out={}
    model.eval() #dropout ,...
    for split in['train','val']:
      losses=torch.zeros(eval_iters)
      for k in range(eval_iters):
        x,y = get_batch(split)
        logits,loss = model(x,y)  #coninue from here
        losses[k] = loss.item()
      out[split] = losses.mean()
    model.train()
    return out

  #dataloader
  def get_batch(txt):
    block_size=1024
    data = train_data if txt=='train' else val_data
    num_paths, len_path = data.shape[:2]
    #ix = torch.randint(len(data)-block_size,(batch_size,))
    ix = torch.randint(len_path - block_size, (batch_size,))
    # x= torch.stack([data[m:m+block_size] for m in ix] )
    # y= torch.stack([data[m+1:m+block_size+1] for  m in ix])
    x = torch.randn(batch_size,block_size,25)
    y = torch.randn(batch_size,block_size,25)
    x,y = x.to(device) , y.to(device)
    return x,y
  config = GPTConfig(block_size=1024)
  m = TransformerModel(config)
  model = m.to(device)
  #print( sum(p.numel()for p in m.parameters())/1e6 ,'M parameters' )
  #temp=torch.zeros((1,1),dtype=torch.long)
  #gen = model.generate(temp,100)
  #print(decode(gen[0].tolist()))


  #create a pytorch optimizer
  my_optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

  #train loop
  for iter in range(max_iters):
    if(iter % eval_interval ==1):
      print('iter=',iter)
      losses = estimate_loss()
      #print(losses)
      #print(iter,'/',eval_interval)
      print('iter:{}/{}  tain-loss:{:.4f}  val-loss:{:.4f}'.format(iter,max_iters,losses['train'],losses['val']))
      #print(loss.item())

    xb,yb= get_batch('train')
    #print(xb.shape)
    logits,loss = model(xb,yb)
    #my_optimizer.zero_grad(set_to_none=True) #unexpected arg?
    my_optimizer.zero_grad()
    loss.backward()
    my_optimizer.step()
  print(loss.item())

  #generate from the model
  temp=torch.zeros((1,10,25),dtype=torch.long ,device=device)
  model.eval()
  gen = model.generate(temp,100)
  #y,l1=instance(xb,yb)
  print(gen)
  #print(decode(gen[0].tolist()))

