import os
import argparse
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pandas as pd
import sys
import functools
import seaborn as sns
from torch.utils.data import DataLoader

class Sepi(nn.Module):  # myModel
    def __init__(self, input_size1, input_size2, hidden_size, m, device):
        super(Sepi, self).__init__()
        
        self.branch = nn.Sequential(nn.Linear(input_size1, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, m, bias=True))
        
        self.trunk = nn.Sequential(nn.Linear(input_size2, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, hidden_size, bias=True), nn.ReLU(),
                               nn.Linear(hidden_size, m, bias=False))
        
        self.param = torch.ones(1, requires_grad=True).to(device)
    
    def forward(self, x1, x2):
        y_branch = self.branch(x1)
        y_trunk = self.trunk(x2)
        guy = (torch.einsum("bi,bi->b", y_branch, y_trunk) + self.param).unsqueeze(1)
        return(guy)

class Model():
    def __init__(self, args):
        self.args = args
        
        self.X1,self.p1,self.X2,self.p2,self.x,self.t = self.read_data()
        xtr1,xtr2,y_tr = self.Tdata(self.X1,self.p1)
        xte1,xte2,y_te = self.Tdata(self.X2,self.p2)
        
        self.X_train = [xtr1,xtr2]
        self.y_train = y_tr
        self.X_test = [xte1,xte2]
        self.y_test = y_te 
        print('X_train',self.X_train[0].shape,self.X_train[1].shape)
        print('y_train',self.y_train.shape)
        
        input_size1 = xtr1.shape[1]
        self.sepi = Sepi(input_size1, 2, args.hidden_size, args.m, args.device).to(args.device)
    
    def read_data(self):
        args = self.args
        lt = 10
        filename = args.filename
        p1 = pd.read_csv("./dataset/data_"+filename+"/u_train.csv").values[:,1:]
        p2 = pd.read_csv("./dataset/data_"+filename+"/u_test.csv").values[:,1:] 
        X1 = pd.read_csv("./dataset/data_"+filename+"/s_train.csv").values[:,1:]
        X2 = pd.read_csv("./dataset/data_"+filename+"/s_test.csv").values[:,1:]
        x = pd.read_csv("./dataset/data_"+filename+"/x.csv").values[:,1]
        t = pd.read_csv("./dataset/data_"+filename+"/t.csv").values[:,1]
        
        noise = args.sigma_sd*np.abs(X1).mean() * np.random.randn(X1.shape[0], X1.shape[1]) + 0
        X1 = X1 + noise
        X1=X1.reshape(-1,len(x),len(t)); X2=X2.reshape(-1,len(x),len(t))
        num = args.num
        X1=X1[:num,:,::lt]; p1=p1[:num,::lt]; X2=X2[:num,:,::lt]; p2=p2[:num,::lt]
        t = t[::lt]
        return(X1,p1,X2,p2,x,t)
    
    def Tdata(self,X1,p1):
        t = self.t; x = self.x
        xtr1 = torch.tensor(p1.repeat(X1.shape[1]*X1.shape[2],0)).float()
        y_tr = torch.tensor(np.ravel(X1)).float().unsqueeze(1)
        xx = np.ravel(x[np.newaxis,:,np.newaxis].repeat(X1.shape[2],2).repeat(X1.shape[0],0))[:,np.newaxis]
        tt = np.ravel(t[np.newaxis,:].repeat(X1.shape[0]*X1.shape[1],0))[:,np.newaxis]
        xtr2 = torch.tensor(np.concatenate((xx,tt),axis=1)).float()
        return(xtr1,xtr2,y_tr)
    
    def train(self):
        args = self.args
        epochs = args.epochs
        batch_size = args.batch_size
        device = args.device
        X1, X2 = self.X1, self.X2
        X_train = [self.X_train[0].to(device),self.X_train[1].to(device)]
        y_train = self.y_train.to(device)
        X_test = [self.X_test[0].to(device),self.X_test[1].to(device)]
        y_test = self.y_test.to(device)
        sepi = self.sepi
        params = sepi.parameters()
        optimizer = optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay)
        
        for epoch in range(epochs):
            sepi.train()
            idx = torch.randperm(X_train[0].shape[0])
            Tr_loss = torch.tensor(0.0).to(device)
            for batch in range(y_train.shape[0]//batch_size):
                optimizer.zero_grad()
                x1 = X_train[0][idx][batch*batch_size:(batch+1)*batch_size,:]
                x2 = X_train[1][idx][batch*batch_size:(batch+1)*batch_size,:]
                y = y_train[idx][batch*batch_size:(batch+1)*batch_size,:]
                guy = sepi(x1,x2)
                loss = torch.sum((guy-y)**2)
                loss.backward() 
                optimizer.step()
                Tr_loss = Tr_loss + loss
            rem = y_train.shape[0]%batch_size
            if rem>0:
                optimizer.zero_grad()
                x1 = X_train[0][idx][-rem:,:]
                x2 = X_train[1][idx][-rem:,:]
                y = y_train[idx][-rem:,:]
                guy = sepi(x1,x2)
                loss = torch.sum((guy-y)**2)
                loss.backward() 
                optimizer.step()
                Tr_loss = Tr_loss + loss                
            if epoch%args.outime==0:
                with torch.no_grad():
                    guy = sepi(X_test[0],X_test[1])
                Te_loss = torch.sum((guy-y_test)**2)
                print("####################")
                print("epoch:",epoch,", Train_loss:", 
                      np.round(Tr_loss.detach().item()/X1.shape[0],4),
                      ", Test_loss:", np.round(Te_loss.item()/X2.shape[0],4))  
                torch.save(sepi, './model/'+args.filename+'_DON_'+str(args.num)+'.pkl')
                print("####################")
    
    def test(self):
        args = self.args
        x = self.x; t = self.t; 
        X1 = self.X1; p1 = self.p1; X2 = self.X2; p2 = self.p2
        device = args.device
        X_train = self.X_train; X_test = self.X_test 
        with torch.no_grad():
            sepi = torch.load('./model/'+args.filename+'_DON_'+str(args.num)+'.pkl')
            #guy_train = sepi(X_train[0].to(device),X_train[1].to(device)).cpu()
            guy_test = sepi(X_test[0].to(device),X_test[1].to(device)).cpu()
        #pred_train = guy_train.reshape(X1.shape[0],X1.shape[1],-1).numpy()
        pred_test = guy_test.reshape(X2.shape[0],X2.shape[1],-1).numpy()
        #print("train_loss:",np.sum((pred_train-X1)**2)/X1.shape[0])
        print("test_loss:",np.sum((pred_test-X2)**2)/X2.shape[0])




