#from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import opt_einsum as oe
from math import prod
import torch.nn.functional as F
import math
from torch import optim
import random
import numpy as np
import copy
import time

class Quantum_Decomp(nn.Module):
    
    def __init__(self, Q, R, qtn_layers, out_dim=None, std=1, tensors=None):
 
        # initialization
        super(Quantum_Decomp, self).__init__()

        self.d = Q
        self.in_dim = R ** Q
        self.qtn_layers = copy.deepcopy(qtn_layers)
        self.std = std

        # if tensors is None:
        #     self.tensor_num = 1
        # else:
        #     self.tensor_num = len(tensors)
        self.tensor_num = len(self.qtn_layers)

        if out_dim is None:
            self.out_dim = self.in_dim

        self.per_dim_in = [math.ceil(self.in_dim ** (1 / self.d))] * self.d 
        self.per_dim_out = [math.ceil(self.out_dim ** (1 / self.d))] * self.d

        # build the TN with 'full' rank
        quanta_weights_all = {}
        self.qtn_layer2 = {}
        self.einsum_eq_eval = {}
        self.einsum_expr_eval = {}
        for idx in range(self.tensor_num):
            quanta_weights = {}
            qtn_layer2 = []
            prefix = [i - self.d for i in range(self.d)]
            qtn_layer_prefix = [ (i, j) for i, j in enumerate(prefix)]
            self.qtn_layers[idx] = [ (i + self.d, ) + j for i, j in enumerate(self.qtn_layers[idx])]
            # concatenate the lists qtn_layer_prefix and self.qtn_layers[idx]
            self.qtn_layers[idx] = qtn_layer_prefix + self.qtn_layers[idx]
            for per_layer in self.qtn_layers[idx]:
                if len(per_layer) == 2:
                    ind, dim1 = per_layer
                    qtn_layer2.append(f'{ind} {dim1}')
                    quanta_weights[f'{ind} {dim1}'] = nn.Parameter(
                                torch.eye(self.per_dim_out[dim1]), requires_grad=False)  # reverse the order because dim1 is closer to the end
                elif len(per_layer) == 3:
                    ind, dim1, dim2 = per_layer
                    qtn_layer2.append(f'{ind} {dim1} {dim2}')
                    quanta_weights[f'{ind} {dim1} {dim2}'] = nn.Parameter(
                                torch.zeros(self.per_dim_out[dim2], 
                                            self.per_dim_out[dim1],
                                            self.per_dim_in[dim2], 
                                            self.per_dim_in[dim1]))  # reverse the order because dim1 is closer to the end
                elif len(per_layer) == 4:
                    ind, dim1, dim2, dim3 = per_layer
                    qtn_layer2.append(f'{ind} {dim1} {dim2} {dim3}')
                    quanta_weights[f'{ind} {dim1} {dim2} {dim3}'] = nn.Parameter(
                                torch.zeros(self.per_dim_out[dim3], 
                                            self.per_dim_out[dim2],
                                            self.per_dim_out[dim1],
                                            self.per_dim_in[dim3], 
                                            self.per_dim_in[dim2], 
                                            self.per_dim_in[dim1]))
                elif len(per_layer) == 5:
                    ind, dim1, dim2, dim3, dim4 = per_layer
                    qtn_layer2.append(f'{ind} {dim1} {dim2} {dim3} {dim4}')
                    quanta_weights[f'{ind} {dim1} {dim2} {dim3} {dim4}'] = nn.Parameter(
                                torch.zeros(self.per_dim_out[dim4],
                                            self.per_dim_out[dim3],
                                            self.per_dim_out[dim2],
                                            self.per_dim_out[dim1],
                                            self.per_dim_in[dim4],
                                            self.per_dim_in[dim3],
                                            self.per_dim_in[dim2],
                                            self.per_dim_in[dim1]))
            quanta_weights_all[f'{idx}'] = nn.ParameterDict(quanta_weights)
            self.qtn_layer2[idx] = qtn_layer2
            self.einsum_eq_eval[idx], self.einsum_expr_eval[idx] = self.gen_einsum_expr_eval(self.qtn_layers[idx])

        self.quanta_weights = nn.ParameterDict(quanta_weights_all)
        self.reset_parameters()

    def reset_parameters(self):
        for idx in range(self.tensor_num):
            for k, v in self.quanta_weights[f'{idx}'].items():
                if len(v.shape) > 2:
                    self.quanta_weights[f'{idx}'][k] = torch.randn(v.shape) * self.std
                #nn.init.kaiming_uniform_(v.view(v.shape[0] * v.shape[1], v.shape[2] * v.shape[3]), a=math.sqrt(5),
                #                             nonlinearity='linear')

    def gen_einsum_expr_eval(self, qtn_layer):

        d = self.d
        current_symbols_inds = list(range(d))
        init_symbols_inds = [i for i in current_symbols_inds]  # copy

        eq = ''

        for per_layer in qtn_layer:
            if len(per_layer) == 2:
                ind, dim1 = per_layer
                symbol_ind1 = current_symbols_inds[dim1]
                symbol_ind2 = symbol_ind1 + d
                eq += ',' + oe.get_symbol(symbol_ind2) + oe.get_symbol(symbol_ind1)  # reverse order because dim1 is toward the end than dim2
                current_symbols_inds[dim1] = symbol_ind2
            elif len(per_layer) == 3:
                ind, dim1, dim2 = per_layer
                symbol_ind1 = current_symbols_inds[dim1]
                symbol_ind2 = current_symbols_inds[dim2]
                symbol_ind3 = symbol_ind1 + d
                symbol_ind4 = symbol_ind2 + d
                eq += ',' + oe.get_symbol(symbol_ind4) + oe.get_symbol(symbol_ind3) + oe.get_symbol(
                    symbol_ind2) + oe.get_symbol(symbol_ind1)  # reverse order because dim1 is toward the end than dim2
                current_symbols_inds[dim1] = symbol_ind3
                current_symbols_inds[dim2] = symbol_ind4
            elif len(per_layer) == 4:
                ind, dim1, dim2, dim3 = per_layer
                symbol_ind1 = current_symbols_inds[dim1]
                symbol_ind2 = current_symbols_inds[dim2]
                symbol_ind3 = current_symbols_inds[dim3]
                symbol_ind4 = symbol_ind1 + d
                symbol_ind5 = symbol_ind2 + d
                symbol_ind6 = symbol_ind3 + d
                eq += ',' + oe.get_symbol(symbol_ind6) + oe.get_symbol(symbol_ind5) + oe.get_symbol(
                    symbol_ind4) + oe.get_symbol(symbol_ind3) + oe.get_symbol(symbol_ind2) + oe.get_symbol(
                    symbol_ind1)  # reverse order because dim1 is toward the end than dim2
                current_symbols_inds[dim1] = symbol_ind4
                current_symbols_inds[dim2] = symbol_ind5
                current_symbols_inds[dim3] = symbol_ind6
            elif len(per_layer) == 5:
                ind, dim1, dim2, dim3, dim4 = per_layer
                symbol_ind1 = current_symbols_inds[dim1]
                symbol_ind2 = current_symbols_inds[dim2]
                symbol_ind3 = current_symbols_inds[dim3]
                symbol_ind4 = current_symbols_inds[dim4]
                symbol_ind5 = symbol_ind1 + d
                symbol_ind6 = symbol_ind2 + d
                symbol_ind7 = symbol_ind3 + d
                symbol_ind8 = symbol_ind4 + d
                eq += ',' + oe.get_symbol(symbol_ind8) + oe.get_symbol(symbol_ind7) + oe.get_symbol(
                    symbol_ind6) + oe.get_symbol(symbol_ind5) + oe.get_symbol(symbol_ind4) + oe.get_symbol(
                    symbol_ind3) + oe.get_symbol(symbol_ind2) + oe.get_symbol(symbol_ind1)  # reverse order because dim1 is toward the end than dim2
                current_symbols_inds[dim1] = symbol_ind5
                current_symbols_inds[dim2] = symbol_ind6
                current_symbols_inds[dim3] = symbol_ind7
                current_symbols_inds[dim4] = symbol_ind8

        eq += '->'
        for i in current_symbols_inds:
            eq += oe.get_symbol(i)
        for i in init_symbols_inds:
            eq += oe.get_symbol(
                i)  # note that this is also the reverse order, so it is the usual matrix multiplication order which is (fan_out, fan_in)
        eq = eq[1:]

        shapes = []
        for per_layer in qtn_layer:
            if len(per_layer) == 2:
                ind, dim1 = per_layer
                shapes.append((self.per_dim_out[dim1], self.per_dim_in[dim1]))
            elif len(per_layer) == 3:
                ind, dim1, dim2 = per_layer
                shapes.append((self.per_dim_out[dim2], self.per_dim_out[dim1], self.per_dim_in[dim2],
                           self.per_dim_in[dim1]))
            elif len(per_layer) == 4:
                ind, dim1, dim2, dim3 = per_layer
                shapes.append((self.per_dim_out[dim3], self.per_dim_out[dim2], self.per_dim_out[dim1],
                           self.per_dim_in[dim3], self.per_dim_in[dim2], self.per_dim_in[dim1]))
            elif len(per_layer) == 5:
                ind, dim1, dim2, dim3, dim4 = per_layer
                shapes.append((self.per_dim_out[dim4], self.per_dim_out[dim3], self.per_dim_out[dim2],
                           self.per_dim_out[dim1], self.per_dim_in[dim4], self.per_dim_in[dim3],
                           self.per_dim_in[dim2], self.per_dim_in[dim1]))

        optimize = 'optimal' if d <= 4 else 'branch-all' if d <= 5 else 'branch-2' if d <= 7 else 'auto'
        expr = oe.contract_expression(eq, *shapes)

        return eq, expr

    def forward(self, idx):

        result = self.einsum_expr_eval[idx](*[self.quanta_weights[f'{idx}'][f'{per_layer}'] for per_layer in self.qtn_layer2[idx]])

        return result
    
    def tensor_generation(self, idx):
        # generate the quantum weights
        self.reset_parameters()
        result = self.einsum_expr_eval[idx](*[self.quanta_weights[f'{idx}'][f'{per_layer}'] for per_layer in self.qtn_layer2[idx]])

        return result
    

class OrthogonalRegularization(nn.Module):
    def __init__(self, model, lambda_orth=0.1):
        super(OrthogonalRegularization, self).__init__()
        self.model = model
        self.lambda_orth = lambda_orth

    def forward(self, outputs, targets):
        loss = nn.MSELoss()(outputs, targets)

        orth_loss = 0.0
        if self.lambda_orth > 0:
            for name, param in self.model.named_parameters():
                if 'weight' in name:  # 
                    # param  (out_features, in_features)
                    w = param.reshape(param.shape[0]*param.shape[1], param.shape[2]*param.shape[3])
                    if w.ndimension() == 2:  # 
                        wt_w = torch.matmul(w.T, w)  # W^T * W
                        identity = torch.eye(w.size(1), device=w.device)  # I
                        orth_loss += ((wt_w - identity) ** 2).sum()  # Frobenius

        # 
        total_loss = loss + self.lambda_orth * orth_loss
        return total_loss
    

def Qtn_Decomp(tensors, qtn_layers, Q, R, epochs=10000, std=1, lr=1e-2, repeat_time=1):

    if len(tensors) > 1 and type(qtn_layers) is not dict:
        qtn_layers_all = {}
        for i in range(len(tensors)):
            qtn_layers_all[i] = copy.deepcopy(qtn_layers)
    else:
        qtn_layers_all = copy.deepcopy(qtn_layers)

    norms = 0
    for idx, tensor in tensors.items():
        norms += torch.linalg.vector_norm(tensor.reshape(-1))
    data_num = tensor.numel()

    rse_all = []

    print(f"Repeat time: {repeat_time}")

    for t in range(repeat_time):

        start_time = time.time()

        model = Quantum_Decomp(Q=Q, R=R, qtn_layers=qtn_layers_all, std=std, tensors=tensors)
        model.cuda()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        loss_fn = nn.MSELoss()
        loss_fn.cuda()
        #orthogonal_loss_fn = OrthogonalRegularization(model, lambda_orth=0)
        #orthogonal_loss_fn.to(device)

        for epoch in range(epochs):
            total_loss = 0
            optimizer.zero_grad()

            # 
            for idx, tensor in tensors.items():
                tensor = tensor.cuda()
                output = model(int(idx))
                loss = loss_fn(output, tensor)
                total_loss += loss

            # 
            total_loss.backward()
            optimizer.step()

            #if (epoch + 1) % epochs == 0:
        end_time = time.time()
        rse = math.sqrt(total_loss.item()*data_num)/norms
        rse_all.append(rse)

        print(f'Epoch [{epoch + 1}/{epochs}], Total rse: {rse:.4f}, time: {end_time - start_time:.4f}')

    min_rse = min(rse_all)

    print(f'Minimum rse: {min_rse:.4f}')
    
    return min_rse

if __name__ == '__main__':

    seed=31415
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

    qtn_layer = [(-1, -4), (-2, -3), (-2, -4), (-3, -4), (-1, -3)]
    mask = [0, 1, 2]
    Q_qubit = 4
    R_input = 2
    in_dim = R_input ** Q_qubit
    epochs = 3000

    # generate the tensor
    # model1 = Quantum_Decomp(Q_qubit, qtn_layer, in_dim, std=1)

    qtn_layers = {}
    # genarate N tensors and store them in dict tensors
    for i in range(len(mask)):
        qtn_layers[i] = copy.deepcopy(qtn_layer)
        qtn_layers[i].pop(mask[i])
        print(qtn_layers[i])

    model1 = Quantum_Decomp(Q_qubit, R_input, qtn_layers, std=1)
    tensors = {}
    for i in range(len(mask)):
        tensors[f'{i}'] = model1.tensor_generation(i).detach().clone()

    # tensor decomposition
    qtn_layers2 = {}
    for i in range(len(mask)):
        qtn_layers2[i] = copy.deepcopy(qtn_layer)
        print(qtn_layers2[i])

    for i in range(1):
        
        model = Qtn_Decomp(tensors, qtn_layers, Q_qubit, R_input, epochs, std=0.2, lr=1e-3)

        #model2 = Qtn_Decomp(tensors, qtn_layers2, Q_qubit, R_input, epochs, std=0.2)

