import sys, os
import torch
import torch.nn as nn
import argparse

torch.manual_seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MULTITASK_COMPRESSION_DIR = os.environ['MULTITASK_COMPRESSION_DIR'] 
sys.path.append(MULTITASK_COMPRESSION_DIR)
sys.path.append(MULTITASK_COMPRESSION_DIR + '/utils/')
sys.path.append(MULTITASK_COMPRESSION_DIR + '/multitask/')

SCRATCH_DIR = MULTITASK_COMPRESSION_DIR + '/scratch/'

from utils import *

from textfile_utils import *
from plotting_utils import *

from collections import OrderedDict

from numpy.linalg import eig


def train(weight, encoder, decoder_1, decoder_2, data_dict, train_options):
    data_dim = data_dict['data_dim']

    x = torch.tensor(data_dict['train_dataset'], dtype=torch.float32).to(device)
    num_samples = x.shape[1]

    x = x.transpose(0, 1)
    x_3d = x.reshape(x.shape[0], x.shape[1], 1)

    encoder_optimizer = torch.optim.Adam(
        encoder.parameters(), lr=train_options["learning_rate"], amsgrad=True)
    decoder_1_optimizer = torch.optim.Adam(
        decoder_1.parameters(), lr=train_options["learning_rate"], amsgrad=True)
    decoder_2_optimizer = torch.optim.Adam(
        decoder_2.parameters(), lr=train_options["learning_rate"], amsgrad=True)

    loss_fn = torch.nn.MSELoss().to(device)

    K1 = torch.tensor(data_dict['K1'], dtype=torch.float32).to(device)
    K2 = torch.tensor(data_dict['K2'], dtype=torch.float32).to(device)

    w1 = weight
    w2 = 1-weight

    K1_x = torch.matmul(K1, x_3d)
    K2_x = torch.matmul(K2, x_3d)

    train_losses = []
    train_losses_1 = []
    train_losses_2 = []

    for i in range(train_options["num_epochs"]):

        # train_encoder
        for j in range(100):
            phi = encoder(x)

            x_hat_1 = decoder_1(phi)
            x_hat_1_3d = x_hat_1.reshape(x_hat_1.shape[0], x_hat_1.shape[1], 1)
            K1_x_hat_1 = torch.matmul(K1, x_hat_1_3d)
            train_loss_1 = loss_fn(K1_x_hat_1, K1_x)

            x_hat_2 = decoder_2(phi)
            x_hat_2_3d = x_hat_2.reshape(x_hat_2.shape[0], x_hat_2.shape[1], 1)
            K2_x_hat_2 = torch.matmul(K2, x_hat_2_3d)
            train_loss_2 = loss_fn(K2_x_hat_2, K2_x)

            train_loss = w1 * train_loss_1 - w2 * train_loss_2

            encoder_optimizer.zero_grad()
            train_loss.backward()
            encoder_optimizer.step()

        train_losses.append(train_loss.item())
        
        # train_decoder
        for j in range(100):
            x_hat_1 = decoder_1(encoder(x))
            x_hat_1_3d = x_hat_1.reshape(x_hat_1.shape[0], x_hat_1.shape[1], 1)
            K1_x_hat_1 = torch.matmul(K1, x_hat_1_3d)
            train_loss_1 = loss_fn(K1_x_hat_1, K1_x)

            decoder_1_optimizer.zero_grad()
            train_loss_1.backward()
            decoder_1_optimizer.step()

            x_hat_2 = decoder_2(encoder(x))
            x_hat_2_3d = x_hat_2.reshape(x_hat_2.shape[0], x_hat_2.shape[1], 1)
            K2_x_hat_2 = torch.matmul(K2, x_hat_2_3d)
            train_loss_2 = loss_fn(K2_x_hat_2, K2_x)

            decoder_2_optimizer.zero_grad()
            train_loss_2.backward()
            decoder_2_optimizer.step()

        train_losses_1.append(train_loss_1.item())
        train_losses_2.append(train_loss_2.item())
        
        if (i + 1) % train_options["output_freq"] == 0:
            # for parameter in encoder.parameters():
            #     print(parameter)
            print("Epoch: {} train_loss: {}, train_loss_1: {}, train_loss_2: {}\n"
                  .format(i+1, train_losses[-1], train_losses_1[-1], train_losses_2[-1]))

    return train_losses, train_losses_1, train_losses_2



def compute_and_save_result(
    weight, test_data_dir, encoder_name, decoder_name, data_dict, train_options):

    SUB_TEST_DATA_DIR = test_data_dir + 'weight_' + str(weight)
    remove_and_create_dir(SUB_TEST_DATA_DIR)

    data_dim = data_dict['data_dim']
    task_dim = data_dict['task_dim'] 

    for z_dim in range(1, task_dim + 1):
        print('################')
        print('latent_dim: ', z_dim)

        encoder = init_encoder(encoder_name, {"input_dim": data_dim, "z_dim": z_dim})
        decoder_1 = init_decoder(decoder_name, {"z_dim": z_dim, "output_dim": data_dim})
        decoder_2 = init_decoder(decoder_name, {"z_dim": z_dim, "output_dim": data_dim})

        train_losses, train_losses_1, train_losses_2 = train(
            weight, encoder, decoder_1, decoder_2, data_dict, train_options)

        # save for plotting later
        result_dict = OrderedDict()
        result_dict['z_dim'] = z_dim
        result_dict['train_losses'] = train_losses
        result_dict['train_losses_1'] = train_losses_1
        result_dict['train_losses_2'] = train_losses_2
        
        write_pkl(fname = SUB_TEST_DATA_DIR + '/z_{}.pkl'.format(z_dim), 
                  input_dict = result_dict)


if __name__=="__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str)
    parser.add_argument('--encoder_name', type=str)
    parser.add_argument('--decoder_name', type=str)
    parser.add_argument('--weights', type=str)
    parser.add_argument('--num_epochs', type=int, default=500)
    args = parser.parse_args()

    model_name = args.model_name
    encoder_name = args.encoder_name
    decoder_name = args.decoder_name
    weights = args.weights.split(',')
    weights = [float(weight) for weight in weights]
    num_epochs = args.num_epochs

    BASE_DIR = SCRATCH_DIR + model_name
    DATA_DIR = BASE_DIR + '/dataset/'
    data_dict = load_pkl(DATA_DIR + '/data.pkl')
 
    assert(data_dict['task_dim'] <= data_dict['data_dim'])

    TEST_DATA_DIR = BASE_DIR + '/joint_compression/'
    remove_and_create_dir(TEST_DATA_DIR)

    train_options = { 
                      "num_epochs": num_epochs,
                      "learning_rate": 1e-3,
                      "output_freq": 10,
                      "save_model": True
                    }

    for weight in weights:
        compute_and_save_result(weight, TEST_DATA_DIR,
                                encoder_name, decoder_name, data_dict, train_options)       
                
