# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np


class GeLU(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1. + torch.tanh(x * 0.7978845608 * (1. + 0.044715 * x * x)))


class Conv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super().__init__()

        conv = nn.Conv1d(in_channels, out_channels, kernel_size)
        self.model =  nn.Sequential(nn.utils.weight_norm(conv), nn.ZeroPad2d(padding=(0, kernel_size-1, 0, 0)), GeLU())

    def forward(self, x):
        return self.model(x)


def init_tensor(shape, stds=1e-9):
    
    tensor = torch.randn(shape)
    nn.init.normal_(tensor)
    tensor = stds * tensor

    return tensor

class Model(nn.Module):
    def __init__(self, args, data):
        super(Model, self).__init__()
        self.use_cuda = args.cuda
        self.P = args.window
        self.m = data.m      
        self.bond = args.bond_dim  
        self.hidC = args.hidCNN   
        self.Ck = args.CNN_kernel   
        self.hw1 = args.input_channel
        self.hw2 = args.skip_channel
        self.wl = args.wlocal
        self.wg = args.wglobal
        self.t1 = self.hidC ** 0.5
        self.t2 = self.P ** 0.5
        self.conv1 = Conv1d(self.m, self.hidC, self.Ck)
        self.order = args.order

        self.w_q = nn.Linear(self.hidC, self.hidC, bias= False)
        self.w_k = nn.Linear(self.hidC, self.hidC, bias= False)
        self.w_v = nn.Linear(self.hidC, self.hidC, bias= False)

        self.v_q = nn.Linear(self.P, self.P, bias= False)
        self.v_k = nn.Linear(self.P, self.P, bias= False)
        self.v_v = nn.Linear(self.P, self.P, bias= False)

        self.w_1 = nn.Linear(2 * self.hidC, self.hidC)
        self.WaTen = nn.Parameter(torch.full([1],self.wl))
        self.WbTen = nn.Parameter(torch.full([1],self.wg))

        init_args1 = {'shape': [self.bond, self.bond, self.hidC], 'stds': args.std}
        tensor = init_tensor(**init_args1).cuda()
        self.register_parameter(name='tensors',param=nn.Parameter(tensor.contiguous()))  

        self.dropout = nn.Dropout(p=args.dropout)

        self.linear1 = nn.Linear(self.bond, self.m)
        self.linear2 = nn.Linear(self.hidC, self.m)

        self.highway1 = nn.Linear(self.hw1, 1)
        self.highway2 = nn.Linear(self.hw2, 1)


        self.output = None
        if args.output_fun == 'sigmoid':
            self.output = torch.sigmoid
        if args.output_fun == 'tanh':
            self.output = torch.tanh


    def calculate_y_simple(self, y, w_vec):
        # h = 1 or 4
        if self.order == "RK4":
            y1 = torch.einsum('ab,abc->ac', y, w_vec)
            y2 = torch.einsum('ab,abc->ac', y + 0.5 * y1, w_vec)
            y3 = torch.einsum('ab,abc->ac', y + 0.5 * y2, w_vec)
            y4 = torch.einsum('ab,abc->ac', y + y3, w_vec)
            y = (1.0 / 6.0) * (y1 + 2.0 * y2 + 2.0 * y3 + y4)
        elif self.order == "RK2":
            y1 = torch.einsum('ab,abc->ac', y, w_vec)
            y2 = torch.einsum('ab,abc->ac', y + y1, w_vec)
            y = (1.0 / 2.0) * (y1 + y2)
        elif self.order == "RK1":
            y = torch.einsum('ab,abc->ac', y, w_vec)
        else:
            print("Please choose a correct order!")
            exit()
        
        return y

    def forward(self, x):
        
        # Encoder
        c = x.permute(0,2,1)
        c = self.dropout(self.conv1(c))

        local = c.permute(0,2,1)
        Qs = self.w_q(local) / self.t1
        Ks = self.w_k(local).permute(0,2,1)
        Vs = self.w_v(local)
        attns = self.dropout(F.softmax(torch.tril(torch.bmm(Qs, Ks), 0) - 1e-9, dim=2))
        outputs = torch.bmm(attns, Vs)
        
        local = local.permute(0,2,1)
        Qv = self.v_q(local) / self.t2
        Kv = self.v_k(local).permute(0,2,1)
        Vv = self.v_v(local)
        attnv = self.dropout(F.softmax(torch.bmm(Qv, Kv), dim=2))
        outputv = torch.bmm(attnv, Vv).permute(0,2,1)
        
        output = torch.cat((outputs, outputv), 2)
        output = self.dropout(torch.relu(self.w_1(output)))
        
        local = local.permute(0,2,1)
        outEn = torch.sigmoid(self.WaTen) * local + torch.sigmoid(self.WbTen) * output

        # uMPS 
        norm = F.normalize(outEn, p=2, dim=2)
        l1,l2,l3 = norm.shape
        h = torch.ones(l1, self.bond).cuda()
        vecs = norm.permute(1, 0, 2)
        for nt in range(l2):
            w_vec = torch.einsum('bcd,ad->abc', self.tensors, vecs[nt])
            h_add = self.calculate_y_simple(h, w_vec)
            h = h + h_add
        h = self.dropout(h)
        res = self.linear1(h)

        # skip connnection
        z1 = x[:, -self.hw1:, :]
        z1 = z1.permute(0, 2, 1).contiguous().view(-1, self.hw1)
        z1 = self.highway1(z1)
        z1 = z1.view(-1, self.m)

        z2 = outEn[:, -self.hw2:, :]
        z2 = z2.permute(0, 2, 1).contiguous().view(-1, self.hw2)
        z2 = self.highway2(z2)
        z2 = self.linear2(z2.view(-1, self.hidC))
        
        res = res + z1 + z2

        # output
        if self.output:
            res = self.output(res)
        return res
