
import argparse
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset
import torch.nn.functional as F
import random

mse_loss = nn.MSELoss()

def find_pattern(train_data):
    # simple reorder plan
    train_data_cov = np.cov(train_data, rowvar=False)
    cov_mat = np.abs(train_data_cov)
    city_count = cov_mat.shape[0]
    for diag_i in range(city_count):
        cov_mat[diag_i, diag_i] = -10000

    selected = []
    reordered_cov = []

    first_idx = np.argmax(cov_mat)

    row_first = first_idx // cov_mat.shape[1]
    col_first = first_idx % cov_mat.shape[1]

    selected.append(row_first)
    selected.append(col_first)
    reordered_cov.append(cov_mat[row_first, col_first])
    cov_mat[row_first, col_first] = -10000
    cov_mat[col_first, row_first] = -10000

    while len(selected) != cov_mat.shape[1]:
        current_row_idx = selected[-1]
        current_col_idx = np.argmax(cov_mat[current_row_idx])
        if current_col_idx not in selected:
            selected.append(current_col_idx)
            reordered_cov.append(cov_mat[current_row_idx, current_col_idx])

        cov_mat[current_row_idx, current_col_idx] = -10000
        cov_mat[current_col_idx, current_row_idx] = -10000

    reorder_idx = np.array(selected)
    if len(reorder_idx) == len(set(reorder_idx)):
        print("successfully reordered.")
        # plt.plot(reordered_cov)
        # plt.title('reordered cov')
        # plt.show()
        # plt.clf()

    else:
        print("ERROR.")

    # train_data_mean = train_data.mean(0)
    # reorder_idx = np.argsort(train_data_mean)

    # analyze multi-mode patter in each dimension
    # !!! manually determine which dimensions are multi-mode
    # for dim_i in range(train_data.shape[-1]):
    #     dim_i_pattern = train_data[:, dim_i]
    #     plt.subplot(20, 10, dim_i+1)
    #     plt.hist(dim_i_pattern, bins=100, label=f'dim {dim_i}')
    #     plt.xticks([])
    #     plt.yticks([])
    # plt.show()
    # plt.clf()

    # reorder_idx = np.copy(np.flip(reorder_idx))
    return reorder_idx




# reorder dataset
class muz_dataset(Dataset):
    def __init__(self, muz_data, reorder_idx, whether_reorder):
        super().__init__()
        self.muz_set = torch.tensor(muz_data).unsqueeze(-1)
        self.reorder_idx = reorder_idx
        self.whether_reorder = whether_reorder

    def __getitem__(self, index):
        muz = self.muz_set[index]
        if self.whether_reorder:
            reordered_muz = muz[self.reorder_idx]
            return reordered_muz
        else:
            return muz

    def __len__(self):
        return len(self.muz_set)



class MixtureDensityRNN(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_mixtures=1, time_range=199, add_time=True, decode_hn=False, backbone='LSTM', device='cuda:0'):
        super(MixtureDensityRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_mixtures = num_mixtures
        self.backbone = backbone
        if self.backbone == 'GRU':
            self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        elif self.backbone == 'LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, batch_first=True)
        self.fc1 = nn.Linear(hidden_size * 2, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, num_mixtures * 3)  # 3 parameters per mixture: mean, variance, weight

        if decode_hn:
            self.h_decoder = nn.Linear(hidden_size, input_size)
        self.time = torch.tensor(np.arange(time_range)).unsqueeze(-1).unsqueeze(0).repeat(1, 1,hidden_size).to(device)

        self.add_time = add_time
        self.decode_hn = decode_hn

    def forward(self, x, h0=None):
        # when GRU
        if self.backbone == 'GRU':
            rnn_output, hn = self.rnn(x, h0)

        # when LSTM
        elif self.backbone == 'LSTM':
            rnn_output, (hn, cn) = self.rnn(x, h0)

        if self.add_time:
            # x_rep = x.repeat(1, 1, self.hidden_size)
            t_rep = self.time.repeat(x.shape[0], 1, 1)
            # fc_input = torch.cat([rnn_output, x_rep, t_rep], dim=-1)
            fc_input = torch.cat([rnn_output, t_rep], dim=-1)
        else:
            fc_input = rnn_output

        fc_x = F.relu(self.fc1(fc_input))
        fc_x = F.relu(self.fc2(fc_x))
        output = self.fc3(fc_x)
        output = output.view(x.shape[0], x.shape[1],  self.num_mixtures, 3)  # reshape output to separate mixture parameters


        return output, hn

    def nll_loss(self, x, y):
        output, _ = self.forward(x)
        mixture_means = output[:, :, :, 0]  # mean of each mixture
        mixture_variances = F.softplus(
            output[:, :, :, 1])  # variance of each mixture (softplus ensures non-negative values)
        mixture_weights = F.softmax(output[:, :, :, 2],
                                    dim=-1)  # weight of each mixture (softmax ensures non-negative and sum to 1)
        # loss = -torch.log(
        #     torch.sum(mixture_weights * torch.exp(-0.5 * (y.unsqueeze(1) - mixture_means) ** 2 / mixture_variances),
        #               dim=1)).mean()  # negative log-likelihood loss

        # neg_py_given_x = - torch.sum(mixture_weights * torch.exp(-0.5 * (y - mixture_means) ** 2 / mixture_variances),
        #               dim=1).mean()
        simple_loss = torch.sum(mixture_weights * (y-mixture_means)**2, dim=-1).mean()
        # return loss
        return simple_loss

    def mse_loss(self, x, y):
        if self.num_mixtures == 1:
            output, hn = self.forward(x)
            mixture_means = output[:, :, :, 0]  # mean of each mixture\
            pred_loss = mse_loss(mixture_means, y)
        else:
            pred_loss = self.nll_loss(x, y)

        if self.decode_hn:
            hn_rep = hn.permute(1, 0, 2).repeat(1, x.shape[1], 1)
            recon_x = self.h_decoder(hn_rep)
            recon_loss = mse_loss(recon_x, x)

        else:
            recon_loss = 0.
        return pred_loss, recon_loss

if __name__ == '__main__':
    # ================================================================================= #
    # ----------------------------- hyper parameters ---------------------------------- #
    whether_reorder = True
    add_time = True
    decode_hn = False
    backbone = 'LSTM'  # 'LSTM' 'GRU'
    batch_size = 100
    n_epoch = 100
    num_mixtures = 1
    device = 'cuda:0'
    lr = 1e-4

    print(f'============== Hyper parameters ==============\n '
          f'decode_hn: {decode_hn}\n'
          f'whether_reorder: {whether_reorder}, add_time: {add_time},\n'
          f'batch size: {batch_size}, n_epoch: {n_epoch},\n'
          f'num_mixtures: {num_mixtures}, lr: {lr},\n'
          f'backbone: {backbone}')

    # ================================================================================= #
    # ------------------------- Arguments for loading the npy ------------------------- #
    ## Arguments
    parser = argparse.ArgumentParser(description='VAE MNIST Example')
    parser.add_argument('--analyze_mode', type=bool, default=False, help='whether analyze the model.')
    parser.add_argument('--add_noise', type=bool, default=False, help='')
    parser.add_argument('--entropy', type=bool, default=False, help='verify the entropy')
    parser.add_argument('--kl_q_q', type=bool, default=True, help='verify the q_id(z)')
    parser.add_argument('--neigh', type=bool, default=False, help='verify the q_id(z)')
    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--id', type=str, default='FMNIST',
                        help='FMNIST, MNIST, CIFAR')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--z_dim', type=int, default=200, metavar='N',
                        help='dimension of the latent variable')
    parser.add_argument('--beta', type=float, default=1.0,
                        help='beta * KL')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='enables CUDA training')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--model', type=str, default='discrete_logistic', metavar='N',
                        help='which model to use: bce_vae, mse_vae,  gaussian_vae, or sigma_vae or optimal_sigma_vae, discrete_logistic')
    parser.add_argument('--log_dir', type=str, default='./results', metavar='N')
    args = parser.parse_args()

    code = '316201'
    if args.id == 'CIFAR_HVAE':
        train_data = np.load(f'../ConvVAE/vae_logs/hvae_dc/' + f'{code}_Zdim_{args.z_dim}_beta_1E-1_HVAE_DC_16-32-728_train_muz' + '.npy')
        id_test_data = np.load(f'../ConvVAE/vae_logs/hvae_dc/' + f'{code}_Zdim_{args.z_dim}_beta_1E-1_HVAE_DC_16-32-728_id_muz' + '.npy')
        ood_test_data = np.load(f'../ConvVAE/vae_logs/hvae_dc/' + f'{code}_Zdim_{args.z_dim}_beta_1E-1_HVAE_DC_16-32-728_ood_muz' + '.npy')

    else:
        train_data = np.load(
            f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_{args.add_noise}_klqq_{args.kl_q_q}_train_muz' + '.npy')
        id_test_data = np.load(
            f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_{args.add_noise}_klqq_{args.kl_q_q}_id_muz' + '.npy')
        ood_test_data = np.load(
            f'vae_logs/results_id_{args.id}/' + f'Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_{args.add_noise}_klqq_{args.kl_q_q}_ood_muz' + '.npy')

    # ================================================================================= #

    reorder_idx = find_pattern(train_data)

    train_set = muz_dataset(train_data, reorder_idx, whether_reorder)
    id_test_set = muz_dataset(id_test_data, reorder_idx, whether_reorder)
    ood_test_set = muz_dataset(ood_test_data, reorder_idx, whether_reorder)

    train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=32)
    id_test_loader = data.DataLoader(id_test_set, batch_size=batch_size, shuffle=True, num_workers=32)
    ood_test_loader = data.DataLoader(ood_test_set, batch_size=batch_size, shuffle=True, num_workers=32)

    qz_fitter = MixtureDensityRNN(1, 64, num_mixtures, time_range=args.z_dim-1, add_time=add_time, decode_hn=decode_hn, backbone=backbone).to(device)

    optimizer = optim.Adam(qz_fitter.parameters(), lr=lr)
    for epoch in range(n_epoch):
        qz_fitter.train()
        train_loss = 0.
        data_pattern = []
        for batch_idx, data in enumerate(train_loader):
            data = data.to(device)
            # if batch_idx < len(train_loader)-2:
            #     data_pattern.append(data.squeeze().cpu().numpy())
            optimizer.zero_grad()

            input = data[:, :-1, :]
            output_gt = data[:, 1:, :]
            # loss = qz_fitter.loss(input, output_gt)
            pred_loss, recon_loss = qz_fitter.mse_loss(input, output_gt)
            loss = pred_loss + recon_loss
            loss.backward()
            optimizer.step()

            train_loss += pred_loss.item()

        # data_pattern = np.stack(data_pattern).reshape(-1, data.shape[1])
        # data_var = np.var(data_pattern, axis=0, keepdims=True)
        # data_mean = data_pattern.mean(0)
        #
        # x = np.linspace(0, data_mean.size - 1, data_mean.size)
        # plt.plot(x, data_mean, color='blue')
        # plt.fill_between(x, data_mean-data_var[0, :], data_mean+data_var[0, :], color='blue', alpha=0.2)
        # plt.title('data_pattern of train mu z')
        # plt.show()
        # plt, clf()
        #
        # sort_idx = np.argsort(data_pattern)
        # data_pattern.sort()
        # plt.plot(data_pattern)
        # plt.title('sorted pattern of z')
        # plt.show()
        # plt.clf()

        train_loss /= len(train_loader)
        print(f'====> ID {args.id} Epoch: {epoch} Average loss: {train_loss:.4f}')

        if epoch % 20 == 0:
            with torch.no_grad():
                if epoch % 100 == 0:
                    plt_input = input
                    plt_gt = output_gt[0].squeeze(-1).cpu().numpy()
                    plt_pred, _ = qz_fitter.forward(plt_input)
                    plt_pred = plt_pred[0, :, :, 0].squeeze().cpu().numpy()
                    plt.plot(plt_gt, label='gt')
                    plt.plot(plt_pred, label='pred')
                    plt.legend()
                    plt.title(f'epoch{epoch} lr {lr} reorder_{whether_reorder} add_time_{add_time} dec_hn_{decode_hn} backbone {backbone} results')
                    plt.show()
                    plt.clf()


            # test
            # test on id test
            # id_data_pattern = []
                id_test_loss = 0.
                for batch_idx, data in enumerate(id_test_loader):
                    data = data.to(device)
                    # if batch_idx < len(id_test_loader)-2:
                    #     id_data_pattern.append(data.squeeze().cpu().numpy())

                    input = data[:, :-1, :]
                    output_gt = data[:, 1:, :]
                    # loss = qz_fitter.loss(input, output_gt)
                    pred_loss, recon_loss = qz_fitter.mse_loss(input, output_gt)
                    id_test_loss += pred_loss.item()
                if epoch % 100 == 0:
                    plt_input = input
                    plt_gt = output_gt[0].squeeze(-1).cpu().numpy()
                    plt_pred, _ = qz_fitter.forward(plt_input)
                    plt_pred = plt_pred[0, :, :, 0].squeeze().cpu().numpy()
                    plt.plot(plt_gt, label='gt')
                    plt.plot(plt_pred, label='pred')
                    plt.legend()
                    plt.title(
                        f'id test epoch{epoch} lr {lr} reorder_{whether_reorder} add_time_{add_time} results')
                    plt.show()
                    plt.clf()

                # id_data_pattern = np.stack(id_data_pattern).reshape(-1, data.shape[1])
                # id_data_var = np.var(id_data_pattern, axis=0, keepdims=True)
                # id_data_mean = id_data_pattern.mean(0)
                #
                # x = np.linspace(0, id_data_mean.size - 1, id_data_mean.size)
                # plt.plot(x, id_data_mean, color='blue')
                # plt.fill_between(x, id_data_mean-id_data_var[0, :], id_data_mean+id_data_var[0, :], color='blue', alpha=0.2)
                # plt.title('id_data_pattern of id mu z')
                # plt.show()
                # plt.clf()
                #
                # id_data_pattern.sort()
                # plt.plot(id_data_pattern)
                # plt.title('id_test_sorted pattern of z')
                # plt.show()
                # plt.clf()

                id_test_loss /= len(id_test_loader)
                print(f'====> ID test Epoch: {epoch} Average loss: {id_test_loss:.4f}')

                # test on ood
                # ood_data_pattern = []
                ood_test_loss = 0.
                for batch_idx, data in enumerate(ood_test_loader):
                    data = data.to(device)
                    # if batch_idx < len(ood_test_loader)-2:
                    #     ood_data_pattern.append(data.squeeze().cpu().numpy())

                    input = data[:, :-1, :]
                    output_gt = data[:, 1:, :]
                    # loss = qz_fitter.loss(input, output_gt)
                    pred_loss, recon_loss = qz_fitter.mse_loss(input, output_gt)
                    ood_test_loss += pred_loss.item()

                if epoch % 100 == 0:
                    plt_input = input
                    plt_gt = output_gt[0].squeeze(-1).cpu().numpy()
                    plt_pred, _ = qz_fitter.forward(plt_input)
                    plt_pred = plt_pred[0, :, :, 0].squeeze().cpu().numpy()
                    plt.plot(plt_gt, label='gt')
                    plt.plot(plt_pred, label='pred')
                    plt.legend()
                    plt.title(
                        f'ood test epoch{epoch} lr {lr} reorder_{whether_reorder} add_time_{add_time} results')
                    plt.show()
                    plt.clf()

            # ood_data_pattern = np.stack(ood_data_pattern).reshape(-1, data.shape[1])
            # ood_data_var = np.var(ood_data_pattern, axis=0, keepdims=True)
            # ood_data_mean = ood_data_pattern.mean(0)
            #
            # x = np.linspace(0, ood_data_mean.size - 1, ood_data_mean.size)
            # plt.plot(x, ood_data_mean, color='blue')
            # plt.fill_between(x, ood_data_mean-ood_data_var[0, :], ood_data_mean+ood_data_var[0, :], color='blue', alpha=0.2)
            # plt.title('ood_data_pattern of ood mu z')
            # plt.show()
            # plt.clf()
            #
            # # reordered by the sort_idx in training set
            # ood_data_pattern = ood_data_pattern[sort_idx]
            # # ood_data_pattern.sort()
            # plt.plot(ood_data_pattern)
            # plt.title('reordered ood_test_sorted pattern of z')
            # plt.show()
            # plt.clf()

                ood_test_loss /= len(ood_test_loader)
                print(f'====> OOD Epoch: {epoch} Average loss: {ood_test_loss:.4f}\n\n')

    torch.save(qz_fitter.state_dict(), f'./qz_fitting_results/{args.id}/Zdim_{args.z_dim}_beta_{args.beta}_{args.model}_{args.add_noise}_klqq_True_train_muz_lstm.pth')


