import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import platform

from argparse import ArgumentParser

parser = ArgumentParser(description='ISTA-Net-plus')

parser.add_argument('--start_epoch', type=int, default=0, help='epoch number of start training')
parser.add_argument('--end_epoch', type=int, default=200, help='epoch number of end training')
parser.add_argument('--layer_num', type=int, default=9, help='phase number of ISTA-Net-plus')
parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate')
parser.add_argument('--group_num', type=int, default=1, help='group number for training')
parser.add_argument('--cs_ratio', type=int, default=4, help='from {1, 4, 10, 25, 40, 50}')
parser.add_argument('--gpu_list', type=str, default='0', help='gpu index')
parser.add_argument('--sam', default=False, action='store_true')

parser.add_argument('--matrix_dir', type=str, default='sampling_matrix', help='sampling matrix directory')
parser.add_argument('--model_dir', type=str, default='model', help='trained or pre-trained model directory')
parser.add_argument('--data_dir', type=str, default='data', help='training data directory')
parser.add_argument('--log_dir', type=str, default='log', help='log directory')

# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = False
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = False


class GradientModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.lambda_step = nn.Parameter(torch.Tensor([0.5]))

    def forward(self, x, PhiTPhi, PhiTb):
        grad = self.lambda_step * (torch.mm(x, PhiTPhi) - PhiTb)
        return grad


# class ConvBlock(nn.Module):
#     def __init__(self, forward=True):
#         super().__init__()
#         kernel_size = (3, 3)
#         if forward:
#             channel = [(32, 1), (32, 32)]
#         else:
#             channel = [(32, 32), (1, 32)]
#         # self.conv_1 = nn.Conv2d(
#         #     in_channels=channel[0][1], out_channels=channel[0][0],
#         #     kernel_size=kernel_size, bias=False, padding='same')
#         # self.relu = nn.modules.ReLU()
#         # self.conv_2 = nn.Conv2d(
#         #     in_channels=channel[1][1], out_channels=channel[1][0],
#         #     kernel_size=kernel_size, bias=False, padding='same')

#         self.conv_1_weight = nn.Parameter(self.get_kernel(channel[0] + kernel_size))
#         self.conv_2_weight = nn.Parameter(self.get_kernel(channel[1] + kernel_size))

#     def get_kernel(self, size):
#         return init.xavier_normal_(torch.Tensor(*size))

#     def forward(self, x):
#         x = F.conv2d(x, self.conv_1_weight, padding=1)
#         x = F.relu(x)
#         x = F.conv2d(x, self.conv_2_weight, padding=1)
#         return x


# class Denoiser(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.theta = nn.Parameter(torch.tensor([0.01]))
#         self.patch2image = lambda x: x.view(-1, 1, 33, 33)
#         self.image2patch = lambda x: x.view(-1, 1089)
#         self.fc = ConvBlock(forward=True)
#         self.bc = ConvBlock(forward=False)

#     def st(self, x):
#         return torch.sign(x) * torch.relu(x.abs() - self.theta)

#     def forward(self, x):
#         x_pt = self.patch2image(x)
#         x_fc = self.fc(x_pt)
#         x_st = self.st(x_fc)
#         bs = x.shape[0]
#         x_ct = torch.concat([x_st, x_fc], dim=0)
#         x_bc = self.bc(x_ct)
#         x = self.image2patch(x_bc[:bs, ...])
#         symloss = x_ct[bs:] - x_pt
#         return x, symloss

class Denoiser(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.soft_thr = nn.Parameter(torch.Tensor([0.01]))

        self.conv1_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 1, 3, 3)))
        self.conv2_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv1_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv2_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(1, 32, 3, 3)))

    def forward(self, x):
        x_input = x.view(-1, 1, 33, 33)

        x = F.conv2d(x_input, self.conv1_forward, padding=1)
        x = F.relu(x)
        x_forward = F.conv2d(x, self.conv2_forward, padding=1)

        x = torch.mul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.soft_thr))

        x = F.conv2d(x, self.conv1_backward, padding=1)
        x = F.relu(x)
        x_backward = F.conv2d(x, self.conv2_backward, padding=1)

        x_pred = x_backward.view(-1, 1089)

        x = F.conv2d(x_forward, self.conv1_backward, padding=1)
        x = F.relu(x)
        x_est = F.conv2d(x, self.conv2_backward, padding=1)
        symloss = x_est - x_input

        return [x_pred, symloss]


class SharpModule(nn.Module):
    def __init__(self, rho=1e-2):
        super().__init__()
        self.rho = nn.Parameter(torch.tensor([rho]))
        self.gamma = nn.Parameter(torch.tensor([0.5]))

    def forward(self, sub_grad):
        if sub_grad is None:
            return 0.
        norm = torch.sqrt(torch.sum(torch.square(sub_grad), dim=1, keepdim=True)).detach()
        alpha = 1. - self.gamma
        inv_norm = self.gamma / norm
        beta = torch.where(torch.isfinite(inv_norm), self.gamma / norm, torch.zeros_like(norm))
        epsilon = self.rho * (alpha + beta) * sub_grad
        # print(alpha + beta)
        return epsilon


class SubGradientModule(nn.Module):
    def __init__(self, gm: GradientModule):
        super().__init__()
        self.gradient_module = gm

    def forward(self, u, v, PhiTPhi, PhiTb):
        sub_gradient = u - v
        return sub_gradient


# Define ISTA-Net
class ISTANet(torch.nn.Module):
    def __init__(self, LayerNo):
        super(ISTANet, self).__init__()
        self.LayerNo = LayerNo
        self.gradient_modules = nn.ModuleList([
            GradientModule() for _ in range(LayerNo)
        ])
        self.denoisers = nn.ModuleList([
            Denoiser() for _ in range(LayerNo)
        ])

    def layer(self, x, PhiTPhi, PhiTb, t):
        x = x - self.gradient_modules[t](x, PhiTPhi, PhiTb)
        x, layer_sym = self.denoisers[t](x)
        return x, layer_sym

    def forward(self, Phix, Phi, Qinit):

        PhiTPhi = torch.mm(torch.transpose(Phi, 0, 1), Phi)
        PhiTb = torch.mm(Phix, Phi)

        x = torch.mm(Phix, torch.transpose(Qinit, 0, 1))

        layers_sym = []   # for computing symmetric loss

        for i in range(self.LayerNo):
            [x, layer_sym] = self.layer(x, PhiTPhi, PhiTb, i)
            layers_sym.append(layer_sym)
        x_final = x

        return [x_final, layers_sym]


# Define ISTA-Net
class ISTASamNet(torch.nn.Module):
    def __init__(self, LayerNo):
        super().__init__()
        self.LayerNo = LayerNo
        self.gradient_modules = nn.ModuleList([
            GradientModule() for _ in range(LayerNo)
        ])
        self.denoisers = nn.ModuleList([
            Denoiser() for _ in range(LayerNo)
        ])
        self.sub_gradient_modules = nn.ModuleList([
            SubGradientModule(self.gradient_modules[t]) for t in range(LayerNo)
        ])
        self.sam_modules = nn.ModuleList([
            SharpModule() for _ in range(LayerNo)
        ])

    def layer(self, sg, x, PhiTPhi, PhiTb, t):
        epsilon = self.sam_modules[t](sg)
        u = x + epsilon - self.gradient_modules[t](x + epsilon, PhiTPhi, PhiTb)
        v, layer_sym = self.denoisers[t](u)
        sg = self.sub_gradient_modules[t](u, v, PhiTPhi, PhiTb)
        x = v - epsilon
        return sg, x, layer_sym

    def forward(self, Phix, Phi, Qinit):

        PhiTPhi = torch.mm(torch.transpose(Phi, 0, 1), Phi)
        PhiTb = torch.mm(Phix, Phi)

        x = torch.mm(Phix, torch.transpose(Qinit, 0, 1))

        layers_sym = []   # for computing symmetric loss
        sg = None
        for i in range(self.LayerNo):
            [sg, x, layer_sym] = self.layer(sg, x, PhiTPhi, PhiTb, i)
            layers_sym.append(layer_sym)
        x_final = x

        return [x_final, layers_sym]


class RandomDataset(Dataset):
    """
        Impl of Pytorch DataSet API
    """

    def __init__(self, data, length):
        self.data = data
        self.len = length

    def __getitem__(self, index):
        return torch.Tensor(self.data[index, :]).float()

    def __len__(self):
        return self.len


MODEL_DIR_PATTERN = "./%s/CS_%s_layer_%d_group_%d_ratio_%d_lr_%.4f"
LOG_DIR_PATTERN = "./%s/Log_CS_%s_layer_%d_group_%d_ratio_%d_lr_%.4f.txt"
OUTPUT_PATTERN = "[%02d/%02d] Total Loss: %.4f, Discrepancy Loss: %.4f,  Constraint Loss: %.4f"


def main(
    start_epoch=0, end_epoch=200,
    layer_num=9, learning_rate=1e-4, group_num=1,
    cs_ratio=25, gpu_id=0, sam=False,
    matrix_dir='sampling_matrix',
    model_dir='model_dir',
    data_dir='data',
    log_dir='log'
):
    # ratio_dict = {1: 10, 4: 43, 10: 109, 25: 272, 30: 327, 40: 436, 50: 545}
    # n_input = ratio_dict[cs_ratio]
    # n_output = 1089

    device = f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu'
    torch.cuda.set_device(device)

    nrtrain = 88912   # number of training blocks
    batch_size = 64

    # Load CS Sampling Matrix: phi
    Phi_data_Name = './%s/phi_0_%d_1089.mat' % (matrix_dir, cs_ratio)
    Phi_data = sio.loadmat(Phi_data_Name)
    Phi_input = Phi_data['phi']

    Training_data_Name = 'Training_Data.mat'
    Training_data = sio.loadmat('./%s/%s' % (data_dir, Training_data_Name))
    Training_labels = Training_data['labels']

    Qinit_Name = './%s/Initialization_Matrix_%d.mat' % (matrix_dir, cs_ratio)

    # Computing Initialization Matrix:
    if os.path.exists(Qinit_Name):
        Qinit_data = sio.loadmat(Qinit_Name)
        Qinit = Qinit_data['Qinit']
    else:
        X_data = Training_labels.transpose()
        Y_data = np.dot(Phi_input, X_data)
        Y_YT = np.dot(Y_data, Y_data.transpose())
        X_YT = np.dot(X_data, Y_data.transpose())
        Qinit = np.dot(X_YT, np.linalg.inv(Y_YT))
        del X_data, Y_data, X_YT, Y_YT
        sio.savemat(Qinit_Name, {'Qinit': Qinit})

    if sam:
        model = ISTASamNet(layer_num)
        model_name = "ISTA_SAM_NET"
    else:
        model = ISTANet(layer_num)
        model_name = "ISTA_NET"
    model = model.to(device)

    print_flag = False   # print parameter number

    if print_flag:
        num_count = 0
        for para in model.parameters():
            num_count += 1
            print('Layer %d' % num_count)
            print(para.size())

    if (platform.system() == "Windows"):
        rand_loader = DataLoader(
            dataset=RandomDataset(Training_labels, nrtrain),
            batch_size=batch_size, num_workers=0, shuffle=True
        )
    else:
        rand_loader = DataLoader(
            dataset=RandomDataset(Training_labels, nrtrain),
            batch_size=batch_size, num_workers=4, shuffle=True
        )

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model_dir = MODEL_DIR_PATTERN % (model_dir, model_name, layer_num,
                                     group_num, cs_ratio, learning_rate)
    log_file_name = LOG_DIR_PATTERN % (log_dir, model_name, layer_num,
                                       group_num, cs_ratio, learning_rate)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    if start_epoch > 0:
        pre_model_dir = model_dir
        model.load_state_dict(torch.load('./%s/net_params_%d.pkl' % (
            pre_model_dir, start_epoch)))

    Phi = torch.from_numpy(Phi_input).type(torch.FloatTensor)
    Phi = Phi.to(device)

    Qinit = torch.from_numpy(Qinit).type(torch.FloatTensor)
    Qinit = Qinit.to(device)

    # Training loop
    for epoch_i in range(start_epoch+1, end_epoch+1):
        for data in rand_loader:
            batch_x = data
            batch_x = batch_x.to(device)

            Phix = torch.mm(batch_x, torch.transpose(Phi, 0, 1))

            [x_output, loss_layers_sym] = model(Phix, Phi, Qinit)
            # Compute and print loss
            loss_discrepancy = torch.mean(torch.pow(x_output - batch_x, 2))
            loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2))
            for k in range(layer_num-1):
                loss_constraint += torch.mean(torch.pow(loss_layers_sym[k+1], 2))
            gamma = torch.Tensor([0.01]).to(device)
            # loss_all = loss_discrepancy
            loss_all = loss_discrepancy + torch.mul(gamma, loss_constraint)
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss_all.backward()
            # for k, v in model.named_parameters():
            #     if v.grad is None:
            #         continue
            #     print(k, v.data.mean().item(), v.grad.mean().item())
            # uuuuuu = input('wait')
            optimizer.step()



            output_data = OUTPUT_PATTERN % (epoch_i, end_epoch,
                                            loss_all.item(),
                                            loss_discrepancy.item(),
                                            loss_constraint)
            print(output_data)

        output_file = open(log_file_name, 'a')
        output_file.write(output_data + '\n')
        output_file.close()

        if epoch_i % 5 == 0:
            torch.save(model.state_dict(), "./%s/net_params_%d.pkl" % (
                model_dir, epoch_i
            ))


if __name__ == '__main__':
    args = parser.parse_args()
    main(
        args.start_epoch,
        args.end_epoch,
        args.layer_num,
        args.learning_rate,
        args.group_num,
        args.cs_ratio,
        args.gpu_list,
        sam=args.sam,
    )
