import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import argparse
from .SDMTE import SparseDeepMTE
from timeit import default_timer as timer
import utils
from .utils import XYNet
from scipy.io import loadmat
import scanpy as sc
import time

def train_step(net, funcs_opt, gates_opt, x, y, L):
    funcs_opt.zero_grad()
    gates_opt.zero_grad()
    loss = net(x, L).mean()
    loss.backward()
    funcs_opt.step()
    gates_opt.step()
    return loss


def plot_gates(net, name):
    g_x, g_y = net.module.get_gates()
    plt.plot(g_x.cpu().detach().numpy())
    plt.title(f'x gates,  {name}')
    plt.savefig(f'x_gates/x_gates_{name}.png')
    plt.close()
    plt.plot(g_y.cpu().detach().numpy())
    plt.title(f'y gates,  {name}')
    plt.savefig(f'y_gates/y_gates_{name}.png')
    plt.close()


def load_data(adata_path,graph_path):
    x = np.load(adata_path)
    y = x.copy()
    L = np.load(graph_path)
    return x, y, L


def main(adata_path,graph_path,out_dim=5,lam_x=0.05,lam_y=0.05,lam_cov=0.0,n_epoch=3000,return_comps=False):
    #device = torch.device("mps")
    tm = time.time()
    device = torch.device("cpu")
    torch.manual_seed(42)

    #out_dim = args.out_dim
    x, y, L = load_data(adata_path,graph_path)

    W = np.diag(np.diag(L)) - L
    K = np.linalg.pinv(x) @ W @ x
    S, V = np.linalg.eigh(K)
    b = V[:,0:out_dim]
    a = K @ b
    A = np.max(np.abs(a),axis=1)
    B = np.max(np.abs(b),axis=1)
    A = A / np.median(A)
    B = B / np.median(B)
    gate_init2 = ((A/(A+B+1e-10)) > np.quantile(A/(A+B+1e-10),0.5)) * 1
    #gate_init2 = ((A/(A+B+1e-10)) > 0.5) * 1
    #gate_B = ((B/(A+B+1e-5)) > 0.5) * 0.9 + 0.05
    gate_init1 = ((A+B) > np.quantile(A+B,0.6)) * 1
#    gate_init2 = gate_init1 * gate_init2

    x = torch.Tensor(x).to(device)
    y = torch.Tensor(y).to(device)
    L = torch.Tensor(L).to(device)
    xynet = XYNet(x.shape[1],out_dim)




    net = SparseDeepMTE(x.shape[1],y.shape[1],xynet,lam_x,lam_y,lam_cov,out_dim,None,gate_init2)
    utils.print_parameters(net)
    net = nn.DataParallel(net)
    #if torch.cuda.is_available():
    #    net = nn.DataParallel(net, device_ids=args.cuda)
    #else:
    #    net = nn.DataParallel(net)  # because nn.DataParallel wraps nn.Module

    net = net.to(device)
    net.train()

    funcs_params = net.module.get_function_parameters()
    gates_params = net.module.get_gates_parameters()
    funcs_opt = optim.Adam(funcs_params, lr=1e-4)
    gates_opt = optim.Adam(gates_params, lr=1e-3)

    #plot_gates(net, f'{lam_x}_{lam_y}_0_vad')
    loss = []
    start = timer()
    for epoch in range(n_epoch):
        loss.append(train_step(net, funcs_opt, gates_opt, x, y, L).item())
        if (epoch + 1) % 100 == 0:
            end = timer()
            print(f'epoch: {epoch + 1}    '
                  f'loss: {loss[-1]:.4f}    '
                  f'lam: {lam_x}, {lam_y}, {lam_cov}   '
                  f'time: {end-start:.2f}')
            start = end
        #if (epoch + 1) % 500 == 0:
            #plot_gates(net, f'{lam_x}_{lam_y}_{lam_cov}_{epoch+1}_vad')
    elapsed = time.time() - tm
    print('total time =',elapsed,'s')
    g_x, g_y = net.module.get_gates()
    x_com, y_com, Wy_com = net.module.get_components(x,L)

    if return_comps:
        return g_x.cpu().detach().numpy(),g_y.cpu().detach().numpy(), x_com.cpu().detach().numpy(), y_com.cpu().detach().numpy(), Wy_com.cpu().detach().numpy()
    else:
        return g_x.cpu().detach().numpy(),g_y.cpu().detach().numpy()
    

